{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Control Flow Examples\n", "\n", "This notebook demonstrates how visu-hlo visualizes JAX's structured control flow primitives. JAX provides functional control flow operations that can be compiled and differentiated.\n", "\n", "## Setup\n", "\n", "First, let's import the necessary libraries:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.environ['JAX_PLATFORMS'] = 'cpu'\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "from visu_hlo import show" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conditional Execution with `lax.cond`\n", "\n", "JAX's `lax.cond` provides structured conditional execution:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def simple_conditional(x):\n", " \"\"\"Simple conditional: square if positive, negate if negative.\"\"\"\n", " return jax.lax.cond(\n", " x > 0,\n", " lambda x: x**2, # true branch\n", " lambda x: -x, # false branch\n", " x,\n", " )\n", "\n", "\n", "print('Simple conditional (x > 0 ? x² : -x):')\n", "show(simple_conditional, jnp.array(5.0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Complex Conditional Logic" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def complex_conditional(x, y):\n", " \"\"\"More complex conditional with multiple operations.\"\"\"\n", "\n", " def true_branch(args):\n", " x, y = args\n", " return x * y + jnp.sin(x)\n", "\n", " def false_branch(args):\n", " x, y = args\n", " return x - y + jnp.cos(y)\n", "\n", " return jax.lax.cond(x > y, true_branch, false_branch, (x, y))\n", "\n", "\n", "print('Complex conditional with trigonometric functions:')\n", "show(complex_conditional, jnp.array(2.0), jnp.array(1.0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multi-way Conditionals with `lax.switch`\n", "\n", "For multiple branches based on an integer index:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def multi_way_switch(index, x):\n", " \"\"\"Multi-way conditional using lax.switch.\"\"\"\n", " branches = [\n", " lambda x: x + 1, # case 0\n", " lambda x: x * 2, # case 1\n", " lambda x: x**2, # case 2\n", " lambda x: jnp.sqrt(x), # case 3\n", " ]\n", "\n", " return jax.lax.switch(index, branches, x)\n", "\n", "\n", "print('Multi-way switch (case 2: x²):')\n", "show(multi_way_switch, 2, jnp.array(4.0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loops with `lax.fori_loop`\n", "\n", "Fixed-iteration loops:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def simple_loop(n, init_val):\n", " \"\"\"Simple accumulation loop.\"\"\"\n", "\n", " def body_fun(i, val):\n", " return val + i * 2\n", "\n", " return jax.lax.fori_loop(0, n, body_fun, init_val)\n", "\n", "\n", "print('Simple fori_loop (accumulate i * 2):')\n", "show(simple_loop, 5, jnp.array(0.0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def matrix_power_loop(matrix, n):\n", " \"\"\"Compute matrix power using a loop.\"\"\"\n", "\n", " def body_fun(i, result):\n", " return jnp.dot(result, matrix)\n", "\n", " return jax.lax.fori_loop(0, n, body_fun, matrix)\n", "\n", "\n", "test_matrix = jnp.array([[1.1, 0.1], [0.1, 1.1]])\n", "print('Matrix power using fori_loop:')\n", "show(matrix_power_loop, test_matrix, 3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## While Loops with `lax.while_loop`\n", "\n", "Condition-based loops:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def convergence_loop(x):\n", " \"\"\"Loop until convergence using while_loop.\"\"\"\n", "\n", " def cond_fun(val):\n", " return jnp.abs(val) > 0.01\n", "\n", " def body_fun(val):\n", " return val * 0.8\n", "\n", " return jax.lax.while_loop(cond_fun, body_fun, x)\n", "\n", "\n", "print('While loop until convergence:')\n", "show(convergence_loop, jnp.array(10.0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def newton_iteration(x):\n", " \"\"\"Newton's method for finding square root.\"\"\"\n", " target = 2.0 # Finding sqrt(2)\n", "\n", " def cond_fun(state):\n", " x, error = state\n", " return error > 1e-6\n", "\n", " def body_fun(state):\n", " x, _ = state\n", " new_x = 0.5 * (x + target / x)\n", " error = jnp.abs(new_x - x)\n", " return new_x, error\n", "\n", " init_state = (x, jnp.array(1.0))\n", " final_x, _ = jax.lax.while_loop(cond_fun, body_fun, init_state)\n", " return final_x\n", "\n", "\n", "print(\"Newton's method for square root:\")\n", "show(newton_iteration, jnp.array(1.5))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Scan Operations with `lax.scan`\n", "\n", "Efficient loops that accumulate intermediate results:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def cumulative_sum_scan(xs):\n", " \"\"\"Cumulative sum using lax.scan.\"\"\"\n", "\n", " def scan_fun(carry, x):\n", " new_carry = carry + x\n", " return new_carry, new_carry\n", "\n", " _, cumsum = jax.lax.scan(scan_fun, 0.0, xs)\n", " return cumsum\n", "\n", "\n", "test_array = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])\n", "print('Cumulative sum using scan:')\n", "show(cumulative_sum_scan, test_array)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def running_average_scan(xs):\n", " \"\"\"Running average using lax.scan.\"\"\"\n", "\n", " def scan_fun(carry, x):\n", " count, total = carry\n", " new_count = count + 1\n", " new_total = total + x\n", " avg = new_total / new_count\n", " return (new_count, new_total), avg\n", "\n", " _, averages = jax.lax.scan(scan_fun, (0.0, 0.0), xs)\n", " return averages\n", "\n", "\n", "print('Running average using scan:')\n", "show(running_average_scan, test_array)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Recurrent Neural Network with Scan\n", "\n", "A simple RNN implementation using scan:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def simple_rnn(params, inputs):\n", " \"\"\"Simple RNN using lax.scan.\"\"\"\n", " W_h, W_x, b = params\n", "\n", " def rnn_step(h, x):\n", " new_h = jnp.tanh(jnp.dot(W_h, h) + jnp.dot(W_x, x) + b)\n", " return new_h, new_h\n", "\n", " h0 = jnp.zeros(W_h.shape[0])\n", " _, hidden_states = jax.lax.scan(rnn_step, h0, inputs)\n", " return hidden_states\n", "\n", "\n", "# RNN parameters\n", "hidden_size = 3\n", "input_size = 2\n", "W_h = jnp.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])\n", "W_x = jnp.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])\n", "b = jnp.array([0.1, 0.1, 0.1])\n", "params = (W_h, W_x, b)\n", "\n", "# Input sequence\n", "inputs = jnp.array([[1.0, 0.5], [0.8, 1.2], [0.3, 0.9]])\n", "\n", "print('Simple RNN with scan:')\n", "show(simple_rnn, params, inputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Nested Control Flow\n", "\n", "Combining different control flow primitives:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def nested_control_flow(x, condition):\n", " \"\"\"Nested conditional and loop.\"\"\"\n", "\n", " def true_branch(x):\n", " # If condition is true, apply a loop\n", " def body_fun(i, val):\n", " return val * 1.1\n", "\n", " return jax.lax.fori_loop(0, 5, body_fun, x)\n", "\n", " def false_branch(x):\n", " # If condition is false, apply a different transformation\n", " return jnp.sqrt(jnp.abs(x))\n", "\n", " return jax.lax.cond(condition, true_branch, false_branch, x)\n", "\n", "\n", "print('Nested control flow (condition=True):')\n", "show(nested_control_flow, jnp.array(2.0), True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dynamic Programming Example\n", "\n", "Computing Fibonacci numbers using scan:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit(static_argnums=0)\n", "def fibonacci(n):\n", " \"\"\"Compute Fibonacci sequence using scan.\"\"\"\n", "\n", " def step(carry, _):\n", " a, b = carry\n", " return (b, a + b), a\n", "\n", " # Initialize with F(0)=0, F(1)=1\n", " init_carry = (0, 1)\n", " _, fib_sequence = jax.lax.scan(step, init_carry, length=n)\n", " return fib_sequence\n", "\n", "\n", "print('Fibonacci sequence using scan:')\n", "show(fibonacci, 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Optimization Loop\n", "\n", "Simple gradient descent optimization:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def gradient_descent_loop(params, learning_rate, n_steps):\n", " \"\"\"Simple gradient descent using a loop.\"\"\"\n", "\n", " def objective(x):\n", " return (x - 2.0) ** 2 + 1.0\n", "\n", " grad_fn = jax.grad(objective)\n", "\n", " def update_step(i, params):\n", " grad = grad_fn(params)\n", " return params - learning_rate * grad\n", "\n", " return jax.lax.fori_loop(0, n_steps, update_step, params)\n", "\n", "\n", "print('Gradient descent optimization loop:')\n", "show(gradient_descent_loop, jnp.array(0.0), 0.1, 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "This notebook demonstrated JAX's structured control flow primitives and their visualization:\n", "\n", "- **Conditionals**: `lax.cond` for if-then-else logic, `lax.switch` for multi-way branching\n", "- **Loops**: `lax.fori_loop` for fixed iterations, `lax.while_loop` for condition-based loops\n", "- **Scan**: `lax.scan` for efficient loops with intermediate results\n", "- **Applications**: RNNs, dynamic programming, optimization algorithms\n", "- **Nested structures**: Combining different control flow primitives\n", "\n", "All these control flow operations are:\n", "- **Functional**: No side effects, pure functions\n", "- **Compilable**: Can be JIT compiled for performance\n", "- **Differentiable**: Work with JAX's automatic differentiation\n", "- **Parallelizable**: Can be executed on GPUs and TPUs\n", "\n", "The HLO visualizations show how these high-level control structures are compiled into efficient low-level operations." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 4 }