Control Flow Examples

This notebook demonstrates how visu-hlo visualizes JAX’s structured control flow primitives. JAX provides functional control flow operations that can be compiled and differentiated.

Setup

First, let’s import the necessary libraries:

[1]:
import os

os.environ['JAX_PLATFORMS'] = 'cpu'

import jax
import jax.numpy as jnp

from visu_hlo import show

Conditional Execution with lax.cond

JAX’s lax.cond provides structured conditional execution:

[2]:
@jax.jit
def simple_conditional(x):
    """Simple conditional: square if positive, negate if negative."""
    return jax.lax.cond(
        x > 0,
        lambda x: x**2,  # true branch
        lambda x: -x,  # false branch
        x,
    )


print('Simple conditional (x > 0 ? x² : -x):')
show(simple_conditional, jnp.array(5.0))
Simple conditional (x > 0 ? x² : -x):
../_images/user-guide_control-flow_3_1.svg

Complex Conditional Logic

[3]:
@jax.jit
def complex_conditional(x, y):
    """More complex conditional with multiple operations."""

    def true_branch(args):
        x, y = args
        return x * y + jnp.sin(x)

    def false_branch(args):
        x, y = args
        return x - y + jnp.cos(y)

    return jax.lax.cond(x > y, true_branch, false_branch, (x, y))


print('Complex conditional with trigonometric functions:')
show(complex_conditional, jnp.array(2.0), jnp.array(1.0))
Complex conditional with trigonometric functions:
../_images/user-guide_control-flow_5_1.svg

Multi-way Conditionals with lax.switch

For multiple branches based on an integer index:

[4]:
@jax.jit
def multi_way_switch(index, x):
    """Multi-way conditional using lax.switch."""
    branches = [
        lambda x: x + 1,  # case 0
        lambda x: x * 2,  # case 1
        lambda x: x**2,  # case 2
        lambda x: jnp.sqrt(x),  # case 3
    ]

    return jax.lax.switch(index, branches, x)


print('Multi-way switch (case 2: x²):')
show(multi_way_switch, 2, jnp.array(4.0))
Multi-way switch (case 2: x²):
../_images/user-guide_control-flow_7_1.svg

Loops with lax.fori_loop

Fixed-iteration loops:

[5]:
@jax.jit
def simple_loop(n, init_val):
    """Simple accumulation loop."""

    def body_fun(i, val):
        return val + i * 2

    return jax.lax.fori_loop(0, n, body_fun, init_val)


print('Simple fori_loop (accumulate i * 2):')
show(simple_loop, 5, jnp.array(0.0))
Simple fori_loop (accumulate i * 2):
../_images/user-guide_control-flow_9_1.svg
[6]:
@jax.jit
def matrix_power_loop(matrix, n):
    """Compute matrix power using a loop."""

    def body_fun(i, result):
        return jnp.dot(result, matrix)

    return jax.lax.fori_loop(0, n, body_fun, matrix)


test_matrix = jnp.array([[1.1, 0.1], [0.1, 1.1]])
print('Matrix power using fori_loop:')
show(matrix_power_loop, test_matrix, 3)
Matrix power using fori_loop:
../_images/user-guide_control-flow_10_1.svg

While Loops with lax.while_loop

Condition-based loops:

[7]:
@jax.jit
def convergence_loop(x):
    """Loop until convergence using while_loop."""

    def cond_fun(val):
        return jnp.abs(val) > 0.01

    def body_fun(val):
        return val * 0.8

    return jax.lax.while_loop(cond_fun, body_fun, x)


print('While loop until convergence:')
show(convergence_loop, jnp.array(10.0))
While loop until convergence:
../_images/user-guide_control-flow_12_1.svg
[8]:
@jax.jit
def newton_iteration(x):
    """Newton's method for finding square root."""
    target = 2.0  # Finding sqrt(2)

    def cond_fun(state):
        x, error = state
        return error > 1e-6

    def body_fun(state):
        x, _ = state
        new_x = 0.5 * (x + target / x)
        error = jnp.abs(new_x - x)
        return new_x, error

    init_state = (x, jnp.array(1.0))
    final_x, _ = jax.lax.while_loop(cond_fun, body_fun, init_state)
    return final_x


print("Newton's method for square root:")
show(newton_iteration, jnp.array(1.5))
Newton's method for square root:
../_images/user-guide_control-flow_13_1.svg

Scan Operations with lax.scan

Efficient loops that accumulate intermediate results:

[9]:
@jax.jit
def cumulative_sum_scan(xs):
    """Cumulative sum using lax.scan."""

    def scan_fun(carry, x):
        new_carry = carry + x
        return new_carry, new_carry

    _, cumsum = jax.lax.scan(scan_fun, 0.0, xs)
    return cumsum


test_array = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
print('Cumulative sum using scan:')
show(cumulative_sum_scan, test_array)
Cumulative sum using scan:
../_images/user-guide_control-flow_15_1.svg
[10]:
@jax.jit
def running_average_scan(xs):
    """Running average using lax.scan."""

    def scan_fun(carry, x):
        count, total = carry
        new_count = count + 1
        new_total = total + x
        avg = new_total / new_count
        return (new_count, new_total), avg

    _, averages = jax.lax.scan(scan_fun, (0.0, 0.0), xs)
    return averages


print('Running average using scan:')
show(running_average_scan, test_array)
Running average using scan:
../_images/user-guide_control-flow_16_1.svg

Recurrent Neural Network with Scan

A simple RNN implementation using scan:

[11]:
@jax.jit
def simple_rnn(params, inputs):
    """Simple RNN using lax.scan."""
    W_h, W_x, b = params

    def rnn_step(h, x):
        new_h = jnp.tanh(jnp.dot(W_h, h) + jnp.dot(W_x, x) + b)
        return new_h, new_h

    h0 = jnp.zeros(W_h.shape[0])
    _, hidden_states = jax.lax.scan(rnn_step, h0, inputs)
    return hidden_states


# RNN parameters
hidden_size = 3
input_size = 2
W_h = jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
W_x = jnp.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
b = jnp.array([0.1, 0.1, 0.1])
params = (W_h, W_x, b)

# Input sequence
inputs = jnp.array([[1.0, 0.5], [0.8, 1.2], [0.3, 0.9]])

print('Simple RNN with scan:')
show(simple_rnn, params, inputs)
Simple RNN with scan:
../_images/user-guide_control-flow_18_1.svg

Nested Control Flow

Combining different control flow primitives:

[12]:
@jax.jit
def nested_control_flow(x, condition):
    """Nested conditional and loop."""

    def true_branch(x):
        # If condition is true, apply a loop
        def body_fun(i, val):
            return val * 1.1

        return jax.lax.fori_loop(0, 5, body_fun, x)

    def false_branch(x):
        # If condition is false, apply a different transformation
        return jnp.sqrt(jnp.abs(x))

    return jax.lax.cond(condition, true_branch, false_branch, x)


print('Nested control flow (condition=True):')
show(nested_control_flow, jnp.array(2.0), True)
Nested control flow (condition=True):
../_images/user-guide_control-flow_20_1.svg

Dynamic Programming Example

Computing Fibonacci numbers using scan:

[13]:
@jax.jit(static_argnums=0)
def fibonacci(n):
    """Compute Fibonacci sequence using scan."""

    def step(carry, _):
        a, b = carry
        return (b, a + b), a

    # Initialize with F(0)=0, F(1)=1
    init_carry = (0, 1)
    _, fib_sequence = jax.lax.scan(step, init_carry, length=n)
    return fib_sequence


print('Fibonacci sequence using scan:')
show(fibonacci, 10)
Fibonacci sequence using scan:
../_images/user-guide_control-flow_22_1.svg

Optimization Loop

Simple gradient descent optimization:

[14]:
@jax.jit
def gradient_descent_loop(params, learning_rate, n_steps):
    """Simple gradient descent using a loop."""

    def objective(x):
        return (x - 2.0) ** 2 + 1.0

    grad_fn = jax.grad(objective)

    def update_step(i, params):
        grad = grad_fn(params)
        return params - learning_rate * grad

    return jax.lax.fori_loop(0, n_steps, update_step, params)


print('Gradient descent optimization loop:')
show(gradient_descent_loop, jnp.array(0.0), 0.1, 10)
Gradient descent optimization loop:
../_images/user-guide_control-flow_24_1.svg

Summary

This notebook demonstrated JAX’s structured control flow primitives and their visualization:

  • Conditionals: lax.cond for if-then-else logic, lax.switch for multi-way branching

  • Loops: lax.fori_loop for fixed iterations, lax.while_loop for condition-based loops

  • Scan: lax.scan for efficient loops with intermediate results

  • Applications: RNNs, dynamic programming, optimization algorithms

  • Nested structures: Combining different control flow primitives

All these control flow operations are:

  • Functional: No side effects, pure functions

  • Compilable: Can be JIT compiled for performance

  • Differentiable: Work with JAX’s automatic differentiation

  • Parallelizable: Can be executed on GPUs and TPUs

The HLO visualizations show how these high-level control structures are compiled into efficient low-level operations.