Control Flow Examples¶
This notebook demonstrates how visu-hlo visualizes JAX’s structured control flow primitives. JAX provides functional control flow operations that can be compiled and differentiated.
Setup¶
First, let’s import the necessary libraries:
[1]:
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
import jax
import jax.numpy as jnp
from visu_hlo import show
Conditional Execution with lax.cond¶
JAX’s lax.cond provides structured conditional execution:
[2]:
@jax.jit
def simple_conditional(x):
"""Simple conditional: square if positive, negate if negative."""
return jax.lax.cond(
x > 0,
lambda x: x**2, # true branch
lambda x: -x, # false branch
x,
)
print('Simple conditional (x > 0 ? x² : -x):')
show(simple_conditional, jnp.array(5.0))
Simple conditional (x > 0 ? x² : -x):
Complex Conditional Logic¶
[3]:
@jax.jit
def complex_conditional(x, y):
"""More complex conditional with multiple operations."""
def true_branch(args):
x, y = args
return x * y + jnp.sin(x)
def false_branch(args):
x, y = args
return x - y + jnp.cos(y)
return jax.lax.cond(x > y, true_branch, false_branch, (x, y))
print('Complex conditional with trigonometric functions:')
show(complex_conditional, jnp.array(2.0), jnp.array(1.0))
Complex conditional with trigonometric functions:
Multi-way Conditionals with lax.switch¶
For multiple branches based on an integer index:
[4]:
@jax.jit
def multi_way_switch(index, x):
"""Multi-way conditional using lax.switch."""
branches = [
lambda x: x + 1, # case 0
lambda x: x * 2, # case 1
lambda x: x**2, # case 2
lambda x: jnp.sqrt(x), # case 3
]
return jax.lax.switch(index, branches, x)
print('Multi-way switch (case 2: x²):')
show(multi_way_switch, 2, jnp.array(4.0))
Multi-way switch (case 2: x²):
Loops with lax.fori_loop¶
Fixed-iteration loops:
[5]:
@jax.jit
def simple_loop(n, init_val):
"""Simple accumulation loop."""
def body_fun(i, val):
return val + i * 2
return jax.lax.fori_loop(0, n, body_fun, init_val)
print('Simple fori_loop (accumulate i * 2):')
show(simple_loop, 5, jnp.array(0.0))
Simple fori_loop (accumulate i * 2):
[6]:
@jax.jit
def matrix_power_loop(matrix, n):
"""Compute matrix power using a loop."""
def body_fun(i, result):
return jnp.dot(result, matrix)
return jax.lax.fori_loop(0, n, body_fun, matrix)
test_matrix = jnp.array([[1.1, 0.1], [0.1, 1.1]])
print('Matrix power using fori_loop:')
show(matrix_power_loop, test_matrix, 3)
Matrix power using fori_loop:
While Loops with lax.while_loop¶
Condition-based loops:
[7]:
@jax.jit
def convergence_loop(x):
"""Loop until convergence using while_loop."""
def cond_fun(val):
return jnp.abs(val) > 0.01
def body_fun(val):
return val * 0.8
return jax.lax.while_loop(cond_fun, body_fun, x)
print('While loop until convergence:')
show(convergence_loop, jnp.array(10.0))
While loop until convergence:
[8]:
@jax.jit
def newton_iteration(x):
"""Newton's method for finding square root."""
target = 2.0 # Finding sqrt(2)
def cond_fun(state):
x, error = state
return error > 1e-6
def body_fun(state):
x, _ = state
new_x = 0.5 * (x + target / x)
error = jnp.abs(new_x - x)
return new_x, error
init_state = (x, jnp.array(1.0))
final_x, _ = jax.lax.while_loop(cond_fun, body_fun, init_state)
return final_x
print("Newton's method for square root:")
show(newton_iteration, jnp.array(1.5))
Newton's method for square root:
Scan Operations with lax.scan¶
Efficient loops that accumulate intermediate results:
[9]:
@jax.jit
def cumulative_sum_scan(xs):
"""Cumulative sum using lax.scan."""
def scan_fun(carry, x):
new_carry = carry + x
return new_carry, new_carry
_, cumsum = jax.lax.scan(scan_fun, 0.0, xs)
return cumsum
test_array = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
print('Cumulative sum using scan:')
show(cumulative_sum_scan, test_array)
Cumulative sum using scan:
[10]:
@jax.jit
def running_average_scan(xs):
"""Running average using lax.scan."""
def scan_fun(carry, x):
count, total = carry
new_count = count + 1
new_total = total + x
avg = new_total / new_count
return (new_count, new_total), avg
_, averages = jax.lax.scan(scan_fun, (0.0, 0.0), xs)
return averages
print('Running average using scan:')
show(running_average_scan, test_array)
Running average using scan:
Recurrent Neural Network with Scan¶
A simple RNN implementation using scan:
[11]:
@jax.jit
def simple_rnn(params, inputs):
"""Simple RNN using lax.scan."""
W_h, W_x, b = params
def rnn_step(h, x):
new_h = jnp.tanh(jnp.dot(W_h, h) + jnp.dot(W_x, x) + b)
return new_h, new_h
h0 = jnp.zeros(W_h.shape[0])
_, hidden_states = jax.lax.scan(rnn_step, h0, inputs)
return hidden_states
# RNN parameters
hidden_size = 3
input_size = 2
W_h = jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
W_x = jnp.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
b = jnp.array([0.1, 0.1, 0.1])
params = (W_h, W_x, b)
# Input sequence
inputs = jnp.array([[1.0, 0.5], [0.8, 1.2], [0.3, 0.9]])
print('Simple RNN with scan:')
show(simple_rnn, params, inputs)
Simple RNN with scan:
Nested Control Flow¶
Combining different control flow primitives:
[12]:
@jax.jit
def nested_control_flow(x, condition):
"""Nested conditional and loop."""
def true_branch(x):
# If condition is true, apply a loop
def body_fun(i, val):
return val * 1.1
return jax.lax.fori_loop(0, 5, body_fun, x)
def false_branch(x):
# If condition is false, apply a different transformation
return jnp.sqrt(jnp.abs(x))
return jax.lax.cond(condition, true_branch, false_branch, x)
print('Nested control flow (condition=True):')
show(nested_control_flow, jnp.array(2.0), True)
Nested control flow (condition=True):
Dynamic Programming Example¶
Computing Fibonacci numbers using scan:
[13]:
@jax.jit(static_argnums=0)
def fibonacci(n):
"""Compute Fibonacci sequence using scan."""
def step(carry, _):
a, b = carry
return (b, a + b), a
# Initialize with F(0)=0, F(1)=1
init_carry = (0, 1)
_, fib_sequence = jax.lax.scan(step, init_carry, length=n)
return fib_sequence
print('Fibonacci sequence using scan:')
show(fibonacci, 10)
Fibonacci sequence using scan:
Optimization Loop¶
Simple gradient descent optimization:
[14]:
@jax.jit
def gradient_descent_loop(params, learning_rate, n_steps):
"""Simple gradient descent using a loop."""
def objective(x):
return (x - 2.0) ** 2 + 1.0
grad_fn = jax.grad(objective)
def update_step(i, params):
grad = grad_fn(params)
return params - learning_rate * grad
return jax.lax.fori_loop(0, n_steps, update_step, params)
print('Gradient descent optimization loop:')
show(gradient_descent_loop, jnp.array(0.0), 0.1, 10)
Gradient descent optimization loop:
Summary¶
This notebook demonstrated JAX’s structured control flow primitives and their visualization:
Conditionals:
lax.condfor if-then-else logic,lax.switchfor multi-way branchingLoops:
lax.fori_loopfor fixed iterations,lax.while_loopfor condition-based loopsScan:
lax.scanfor efficient loops with intermediate resultsApplications: RNNs, dynamic programming, optimization algorithms
Nested structures: Combining different control flow primitives
All these control flow operations are:
Functional: No side effects, pure functions
Compilable: Can be JIT compiled for performance
Differentiable: Work with JAX’s automatic differentiation
Parallelizable: Can be executed on GPUs and TPUs
The HLO visualizations show how these high-level control structures are compiled into efficient low-level operations.