{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# JAX Transformation Examples\n", "\n", "This notebook demonstrates how visu-hlo visualizes the computational graphs created by JAX's powerful function transformations. JAX provides composable function transformations for automatic differentiation, vectorization, and parallelization.\n", "\n", "## Setup\n", "\n", "First, let's import the necessary libraries:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.environ['JAX_PLATFORMS'] = 'cpu'\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "from visu_hlo import show" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Automatic Differentiation with `grad`\n", "\n", "JAX's `grad` transformation computes gradients of scalar-valued functions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def simple_function(x):\n", " \"\"\"Simple quadratic function.\"\"\"\n", " return x**2 + 3 * x + 1\n", "\n", "\n", "# Gradient function\n", "grad_fn = jax.grad(simple_function)\n", "\n", "print('Original function f(x) = x² + 3x + 1:')\n", "show(simple_function, jnp.array(2.0))\n", "\n", "print(\"\\nGradient f'(x) = 2x + 3:\")\n", "show(grad_fn, jnp.array(2.0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multivariate Gradients" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def multivariate_function(params):\n", " \"\"\"Function of multiple variables.\"\"\"\n", " x, y, z = params\n", " return x**2 + y * z + jnp.sin(x * y)\n", "\n", "\n", "grad_multivariate = jax.grad(multivariate_function)\n", "\n", "test_params = jnp.array([1.0, 2.0, 3.0])\n", "print('Multivariate function gradient:')\n", "show(grad_multivariate, test_params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Partial Derivatives with `argnums`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def loss_function(params, x, y):\n", " \"\"\"Loss function with respect to parameters.\"\"\"\n", " w, b = params\n", " prediction = w * x + b\n", " return (prediction - y) ** 2\n", "\n", "\n", "# Gradient with respect to parameters (argnums=0)\n", "grad_wrt_params = jax.grad(loss_function, argnums=0)\n", "\n", "# Gradient with respect to input x (argnums=1)\n", "grad_wrt_x = jax.grad(loss_function, argnums=1)\n", "\n", "params = jnp.array([2.0, 1.0]) # w=2, b=1\n", "x = jnp.array(3.0)\n", "y = jnp.array(5.0)\n", "\n", "print('Gradient with respect to parameters:')\n", "show(grad_wrt_params, params, x, y)\n", "\n", "print('\\nGradient with respect to input x:')\n", "show(grad_wrt_x, params, x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Higher-Order Derivatives" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def polynomial(x):\n", " \"\"\"Polynomial function for higher-order derivatives.\"\"\"\n", " return x**4 + 2 * x**3 - 3 * x**2 + x + 1\n", "\n", "\n", "# First derivative\n", "first_deriv = jax.grad(polynomial)\n", "# Second derivative\n", "second_deriv = jax.grad(jax.grad(polynomial))\n", "# Third derivative\n", "third_deriv = jax.grad(jax.grad(jax.grad(polynomial)))\n", "\n", "x = jnp.array(2.0)\n", "\n", "print('Second derivative:')\n", "show(second_deriv, x)\n", "\n", "print('\\nThird derivative:')\n", "show(third_deriv, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `value_and_grad` for Efficiency" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def expensive_function(x):\n", " \"\"\"Function where we want both value and gradient.\"\"\"\n", " return jnp.sum(x**3) + jnp.sum(jnp.sin(x))\n", "\n", "\n", "# Get both value and gradient in one pass\n", "value_and_grad_fn = jax.value_and_grad(expensive_function)\n", "\n", "x = jnp.array([1.0, 2.0, 3.0])\n", "print('value_and_grad (more efficient than separate calls):')\n", "show(value_and_grad_fn, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Vectorization with `vmap`\n", "\n", "Automatically vectorize functions to work on batches:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def single_example_function(x):\n", " \"\"\"Function that works on a single example.\"\"\"\n", " return jnp.sum(x**2) + jnp.mean(x)\n", "\n", "\n", "# Vectorize to work on batches\n", "batched_function = jax.vmap(single_example_function)\n", "\n", "# Single example\n", "single_input = jnp.array([1.0, 2.0, 3.0])\n", "print('Single example function:')\n", "show(single_example_function, single_input)\n", "\n", "# Batch of examples\n", "batch_input = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])\n", "print('\\nVectorized function (vmap):')\n", "show(batched_function, batch_input)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Advanced `vmap` with `in_axes`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def matrix_vector_mult(matrix, vector):\n", " \"\"\"Matrix-vector multiplication.\"\"\"\n", " return jnp.dot(matrix, vector)\n", "\n", "\n", "# Vectorize over the vector argument (axis 0) but not the matrix\n", "batch_matvec = jax.vmap(matrix_vector_mult, in_axes=(None, 0))\n", "\n", "matrix = jnp.array([[1.0, 2.0], [3.0, 4.0]])\n", "vectors = jnp.array([[1.0, 1.0], [2.0, 3.0], [0.5, 1.5]])\n", "\n", "print('Vectorized matrix-vector multiplication:')\n", "print(f'Matrix shape: {matrix.shape}, Vectors shape: {vectors.shape}')\n", "show(batch_matvec, matrix, vectors)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Nested `vmap` for Multiple Batch Dimensions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def pairwise_distance(x, y):\n", " \"\"\"Euclidean distance between two points.\"\"\"\n", " return jnp.sqrt(jnp.sum((x - y) ** 2))\n", "\n", "\n", "# First vmap over y, then over x\n", "vectorized_distances = jax.vmap(jax.vmap(pairwise_distance, in_axes=(None, 0)), in_axes=(0, None))\n", "\n", "points_x = jnp.array([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]])\n", "points_y = jnp.array([[1.0, 0.0], [0.0, 1.0]])\n", "\n", "print('Nested vmap for pairwise distances:')\n", "print(f'Points X shape: {points_x.shape}, Points Y shape: {points_y.shape}')\n", "show(vectorized_distances, points_x, points_y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Combining `grad` and `vmap`\n", "\n", "Vectorized gradients for batch processing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def loss_per_example(params, x, y):\n", " \"\"\"Loss for a single example.\"\"\"\n", " w, b = params\n", " prediction = jnp.dot(w, x) + b\n", " return (prediction - y) ** 2\n", "\n", "\n", "# Gradient for a single example\n", "grad_single = jax.grad(loss_per_example)\n", "\n", "# Vectorized gradient for batch of examples\n", "grad_batch = jax.vmap(grad_single, in_axes=(None, 0, 0))\n", "\n", "params = jnp.array([1.0, 2.0]), jnp.array(0.5) # w, b\n", "batch_x = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])\n", "batch_y = jnp.array([3.0, 7.0, 11.0])\n", "\n", "print('Vectorized gradients for batch:')\n", "show(grad_batch, params, batch_x, batch_y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Combining Multiple Transformations\n", "\n", "JAX transformations are composable:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def neural_network_layer(params, x):\n", " \"\"\"Simple neural network layer.\"\"\"\n", " W, b = params\n", " return jnp.tanh(jnp.dot(W, x) + b)\n", "\n", "\n", "def loss_fn(params, batch_x, batch_y):\n", " \"\"\"Loss function for the neural network.\"\"\"\n", " # Vectorize over the batch\n", " batch_predictions = jax.vmap(neural_network_layer, in_axes=(None, 0))(params, batch_x)\n", " # Mean squared error\n", " return jnp.mean((batch_predictions - batch_y) ** 2)\n", "\n", "\n", "# Combine JIT and grad\n", "jit_grad_loss = jax.jit(jax.grad(loss_fn))\n", "\n", "# Parameters\n", "W = jnp.array([[0.1, 0.2], [0.3, 0.4]])\n", "b = jnp.array([0.1, 0.1])\n", "params = (W, b)\n", "\n", "# Batch data\n", "batch_x = jnp.array([[1.0, 2.0], [3.0, 4.0]])\n", "batch_y = jnp.array([[0.5, 0.8], [0.2, 0.9]])\n", "\n", "print('JIT + grad + vmap combination:')\n", "show(jit_grad_loss, params, batch_x, batch_y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Advanced: Jacobian Computation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def vector_function(x):\n", " \"\"\"Vector-valued function for Jacobian computation.\"\"\"\n", " return jnp.array([x[0] ** 2 + x[1], x[0] * x[1], x[1] ** 2])\n", "\n", "\n", "# Jacobian using vmap and grad\n", "jacobian_fn = jax.jacfwd(vector_function)\n", "\n", "x = jnp.array([2.0, 3.0])\n", "print('Jacobian computation:')\n", "show(jax.jit(jacobian_fn), x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hessian Computation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def scalar_function_for_hessian(x):\n", " \"\"\"Scalar function for Hessian computation.\"\"\"\n", " return x[0] ** 3 + x[1] ** 2 + x[0] * x[1]\n", "\n", "\n", "# Hessian using nested grad\n", "hessian_fn = jax.hessian(scalar_function_for_hessian)\n", "\n", "x = jnp.array([1.0, 2.0])\n", "print('Hessian computation:')\n", "show(jax.jit(hessian_fn), x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Custom Transformations with `custom_vjp`\n", "\n", "Defining custom vector-Jacobian products:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.custom_vjp\n", "def smooth_abs(x):\n", " \"\"\"Smooth approximation to absolute value.\"\"\"\n", " return jnp.sqrt(x**2 + 1e-8)\n", "\n", "\n", "def smooth_abs_fwd(x):\n", " \"\"\"Forward pass.\"\"\"\n", " y = smooth_abs(x)\n", " return y, x\n", "\n", "\n", "def smooth_abs_bwd(x, g):\n", " \"\"\"Backward pass with custom gradient.\"\"\"\n", " return (g * x / jnp.sqrt(x**2 + 1e-8),)\n", "\n", "\n", "smooth_abs.defvjp(smooth_abs_fwd, smooth_abs_bwd)\n", "\n", "\n", "# Use in a function with grad\n", "def function_with_custom_grad(x):\n", " return jnp.sum(smooth_abs(x))\n", "\n", "\n", "grad_custom = jax.grad(function_with_custom_grad)\n", "\n", "x = jnp.array([-1.0, 0.5, 2.0])\n", "print('Function with custom VJP:')\n", "show(jax.jit(grad_custom), x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Optimization with Transformations\n", "\n", "A complete optimization example combining multiple transformations:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def optimization_step(params, batch_x, batch_y, learning_rate):\n", " \"\"\"Single optimization step combining multiple transformations.\"\"\"\n", "\n", " def model(params, x):\n", " W1, b1, W2, b2 = params\n", " h = jnp.tanh(jnp.dot(W1, x) + b1)\n", " return jnp.dot(W2, h) + b2\n", "\n", " def batch_loss(params, batch_x, batch_y):\n", " # Vectorize model over batch\n", " predictions = jax.vmap(model, in_axes=(None, 0))(params, batch_x)\n", " return jnp.mean((predictions - batch_y) ** 2)\n", "\n", " # Get both loss and gradients efficiently\n", " loss, grads = jax.value_and_grad(batch_loss)(params, batch_x, batch_y)\n", "\n", " # Update parameters\n", " new_params = jax.tree.map(lambda p, g: p - learning_rate * g, params, grads)\n", "\n", " return new_params, loss\n", "\n", "\n", "# Initialize parameters\n", "W1 = jnp.array([[0.1, 0.2], [0.3, 0.4]])\n", "b1 = jnp.array([0.0, 0.0])\n", "W2 = jnp.array([[0.5, 0.6]])\n", "b2 = jnp.array([0.0])\n", "params = (W1, b1, W2, b2)\n", "\n", "# Batch data\n", "batch_x = jnp.array([[1.0, 0.5], [0.8, 1.2]])\n", "batch_y = jnp.array([[0.7], [0.9]])\n", "learning_rate = 0.01\n", "\n", "print('Complete optimization step (JIT + value_and_grad + vmap):')\n", "show(jax.jit(optimization_step), params, batch_x, batch_y, learning_rate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "This notebook demonstrated JAX's powerful function transformations and their visualizations:\n", "\n", "### Core Transformations:\n", "- **`grad`**: Automatic differentiation for computing gradients\n", "- **`vmap`**: Automatic vectorization for batch processing\n", "- **`jit`**: Just-in-time compilation for performance\n", "- **`pmap`**: Parallel mapping across devices (conceptual)\n", "\n", "### Advanced Features:\n", "- **Higher-order derivatives**: Nested `grad` calls\n", "- **`value_and_grad`**: Efficient computation of both value and gradient\n", "- **Jacobians and Hessians**: `jacfwd`, `jacrev`, `hessian`\n", "- **Custom transformations**: `custom_vjp` for specialized gradients\n", "\n", "### Composition:\n", "- Transformations are **composable**: `jit(grad(vmap(...)))`\n", "- Order matters: `vmap(grad(...))` vs `grad(vmap(...))`\n", "- Can be combined for complex workflows\n", "\n", "### Key Benefits:\n", "- **Functional**: No side effects, pure transformations\n", "- **Performant**: JIT compilation and vectorization\n", "- **Flexible**: Works with arbitrary Python functions\n", "- **Scalable**: Efficient batch processing and parallelization\n", "\n", "The HLO visualizations reveal how these high-level transformations are compiled into efficient computational graphs, showing the automatic optimizations performed by JAX's compiler." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 4 }