JAX Transformation Examples

This notebook demonstrates how visu-hlo visualizes the computational graphs created by JAX’s powerful function transformations. JAX provides composable function transformations for automatic differentiation, vectorization, and parallelization.

Setup

First, let’s import the necessary libraries:

[1]:
import os

os.environ['JAX_PLATFORMS'] = 'cpu'

import jax
import jax.numpy as jnp

from visu_hlo import show

Automatic Differentiation with grad

JAX’s grad transformation computes gradients of scalar-valued functions:

[2]:
def simple_function(x):
    """Simple quadratic function."""
    return x**2 + 3 * x + 1


# Gradient function
grad_fn = jax.grad(simple_function)

print('Original function f(x) = x² + 3x + 1:')
show(simple_function, jnp.array(2.0))

print("\nGradient f'(x) = 2x + 3:")
show(grad_fn, jnp.array(2.0))
Original function f(x) = x² + 3x + 1:
../_images/user-guide_transformations_3_1.svg

Gradient f'(x) = 2x + 3:
../_images/user-guide_transformations_3_3.svg

Multivariate Gradients

[3]:
def multivariate_function(params):
    """Function of multiple variables."""
    x, y, z = params
    return x**2 + y * z + jnp.sin(x * y)


grad_multivariate = jax.grad(multivariate_function)

test_params = jnp.array([1.0, 2.0, 3.0])
print('Multivariate function gradient:')
show(grad_multivariate, test_params)
Multivariate function gradient:
../_images/user-guide_transformations_5_1.svg

Partial Derivatives with argnums

[4]:
def loss_function(params, x, y):
    """Loss function with respect to parameters."""
    w, b = params
    prediction = w * x + b
    return (prediction - y) ** 2


# Gradient with respect to parameters (argnums=0)
grad_wrt_params = jax.grad(loss_function, argnums=0)

# Gradient with respect to input x (argnums=1)
grad_wrt_x = jax.grad(loss_function, argnums=1)

params = jnp.array([2.0, 1.0])  # w=2, b=1
x = jnp.array(3.0)
y = jnp.array(5.0)

print('Gradient with respect to parameters:')
show(grad_wrt_params, params, x, y)

print('\nGradient with respect to input x:')
show(grad_wrt_x, params, x, y)
Gradient with respect to parameters:
../_images/user-guide_transformations_7_1.svg

Gradient with respect to input x:
../_images/user-guide_transformations_7_3.svg

Higher-Order Derivatives

[5]:
def polynomial(x):
    """Polynomial function for higher-order derivatives."""
    return x**4 + 2 * x**3 - 3 * x**2 + x + 1


# First derivative
first_deriv = jax.grad(polynomial)
# Second derivative
second_deriv = jax.grad(jax.grad(polynomial))
# Third derivative
third_deriv = jax.grad(jax.grad(jax.grad(polynomial)))

x = jnp.array(2.0)

print('Second derivative:')
show(second_deriv, x)

print('\nThird derivative:')
show(third_deriv, x)
Second derivative:
../_images/user-guide_transformations_9_1.svg

Third derivative:
../_images/user-guide_transformations_9_3.svg

value_and_grad for Efficiency

[6]:
def expensive_function(x):
    """Function where we want both value and gradient."""
    return jnp.sum(x**3) + jnp.sum(jnp.sin(x))


# Get both value and gradient in one pass
value_and_grad_fn = jax.value_and_grad(expensive_function)

x = jnp.array([1.0, 2.0, 3.0])
print('value_and_grad (more efficient than separate calls):')
show(value_and_grad_fn, x)
value_and_grad (more efficient than separate calls):
../_images/user-guide_transformations_11_1.svg

Vectorization with vmap

Automatically vectorize functions to work on batches:

[7]:
def single_example_function(x):
    """Function that works on a single example."""
    return jnp.sum(x**2) + jnp.mean(x)


# Vectorize to work on batches
batched_function = jax.vmap(single_example_function)

# Single example
single_input = jnp.array([1.0, 2.0, 3.0])
print('Single example function:')
show(single_example_function, single_input)

# Batch of examples
batch_input = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
print('\nVectorized function (vmap):')
show(batched_function, batch_input)
Single example function:
../_images/user-guide_transformations_13_1.svg

Vectorized function (vmap):
../_images/user-guide_transformations_13_3.svg

Advanced vmap with in_axes

[8]:
def matrix_vector_mult(matrix, vector):
    """Matrix-vector multiplication."""
    return jnp.dot(matrix, vector)


# Vectorize over the vector argument (axis 0) but not the matrix
batch_matvec = jax.vmap(matrix_vector_mult, in_axes=(None, 0))

matrix = jnp.array([[1.0, 2.0], [3.0, 4.0]])
vectors = jnp.array([[1.0, 1.0], [2.0, 3.0], [0.5, 1.5]])

print('Vectorized matrix-vector multiplication:')
print(f'Matrix shape: {matrix.shape}, Vectors shape: {vectors.shape}')
show(batch_matvec, matrix, vectors)
Vectorized matrix-vector multiplication:
Matrix shape: (2, 2), Vectors shape: (3, 2)
../_images/user-guide_transformations_15_1.svg

Nested vmap for Multiple Batch Dimensions

[9]:
def pairwise_distance(x, y):
    """Euclidean distance between two points."""
    return jnp.sqrt(jnp.sum((x - y) ** 2))


# First vmap over y, then over x
vectorized_distances = jax.vmap(jax.vmap(pairwise_distance, in_axes=(None, 0)), in_axes=(0, None))

points_x = jnp.array([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]])
points_y = jnp.array([[1.0, 0.0], [0.0, 1.0]])

print('Nested vmap for pairwise distances:')
print(f'Points X shape: {points_x.shape}, Points Y shape: {points_y.shape}')
show(vectorized_distances, points_x, points_y)
Nested vmap for pairwise distances:
Points X shape: (3, 2), Points Y shape: (2, 2)
../_images/user-guide_transformations_17_1.svg

Combining grad and vmap

Vectorized gradients for batch processing:

[10]:
def loss_per_example(params, x, y):
    """Loss for a single example."""
    w, b = params
    prediction = jnp.dot(w, x) + b
    return (prediction - y) ** 2


# Gradient for a single example
grad_single = jax.grad(loss_per_example)

# Vectorized gradient for batch of examples
grad_batch = jax.vmap(grad_single, in_axes=(None, 0, 0))

params = jnp.array([1.0, 2.0]), jnp.array(0.5)  # w, b
batch_x = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
batch_y = jnp.array([3.0, 7.0, 11.0])

print('Vectorized gradients for batch:')
show(grad_batch, params, batch_x, batch_y)
Vectorized gradients for batch:
../_images/user-guide_transformations_19_1.svg

Combining Multiple Transformations

JAX transformations are composable:

[11]:
def neural_network_layer(params, x):
    """Simple neural network layer."""
    W, b = params
    return jnp.tanh(jnp.dot(W, x) + b)


def loss_fn(params, batch_x, batch_y):
    """Loss function for the neural network."""
    # Vectorize over the batch
    batch_predictions = jax.vmap(neural_network_layer, in_axes=(None, 0))(params, batch_x)
    # Mean squared error
    return jnp.mean((batch_predictions - batch_y) ** 2)


# Combine JIT and grad
jit_grad_loss = jax.jit(jax.grad(loss_fn))

# Parameters
W = jnp.array([[0.1, 0.2], [0.3, 0.4]])
b = jnp.array([0.1, 0.1])
params = (W, b)

# Batch data
batch_x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
batch_y = jnp.array([[0.5, 0.8], [0.2, 0.9]])

print('JIT + grad + vmap combination:')
show(jit_grad_loss, params, batch_x, batch_y)
JIT + grad + vmap combination:
../_images/user-guide_transformations_21_1.svg

Advanced: Jacobian Computation

[12]:
def vector_function(x):
    """Vector-valued function for Jacobian computation."""
    return jnp.array([x[0] ** 2 + x[1], x[0] * x[1], x[1] ** 2])


# Jacobian using vmap and grad
jacobian_fn = jax.jacfwd(vector_function)

x = jnp.array([2.0, 3.0])
print('Jacobian computation:')
show(jax.jit(jacobian_fn), x)
Jacobian computation:
../_images/user-guide_transformations_23_1.svg

Hessian Computation

[13]:
def scalar_function_for_hessian(x):
    """Scalar function for Hessian computation."""
    return x[0] ** 3 + x[1] ** 2 + x[0] * x[1]


# Hessian using nested grad
hessian_fn = jax.hessian(scalar_function_for_hessian)

x = jnp.array([1.0, 2.0])
print('Hessian computation:')
show(jax.jit(hessian_fn), x)
Hessian computation:
../_images/user-guide_transformations_25_1.svg

Custom Transformations with custom_vjp

Defining custom vector-Jacobian products:

[14]:
@jax.custom_vjp
def smooth_abs(x):
    """Smooth approximation to absolute value."""
    return jnp.sqrt(x**2 + 1e-8)


def smooth_abs_fwd(x):
    """Forward pass."""
    y = smooth_abs(x)
    return y, x


def smooth_abs_bwd(x, g):
    """Backward pass with custom gradient."""
    return (g * x / jnp.sqrt(x**2 + 1e-8),)


smooth_abs.defvjp(smooth_abs_fwd, smooth_abs_bwd)


# Use in a function with grad
def function_with_custom_grad(x):
    return jnp.sum(smooth_abs(x))


grad_custom = jax.grad(function_with_custom_grad)

x = jnp.array([-1.0, 0.5, 2.0])
print('Function with custom VJP:')
show(jax.jit(grad_custom), x)
Function with custom VJP:
../_images/user-guide_transformations_27_1.svg

Optimization with Transformations

A complete optimization example combining multiple transformations:

[15]:
def optimization_step(params, batch_x, batch_y, learning_rate):
    """Single optimization step combining multiple transformations."""

    def model(params, x):
        W1, b1, W2, b2 = params
        h = jnp.tanh(jnp.dot(W1, x) + b1)
        return jnp.dot(W2, h) + b2

    def batch_loss(params, batch_x, batch_y):
        # Vectorize model over batch
        predictions = jax.vmap(model, in_axes=(None, 0))(params, batch_x)
        return jnp.mean((predictions - batch_y) ** 2)

    # Get both loss and gradients efficiently
    loss, grads = jax.value_and_grad(batch_loss)(params, batch_x, batch_y)

    # Update parameters
    new_params = jax.tree.map(lambda p, g: p - learning_rate * g, params, grads)

    return new_params, loss


# Initialize parameters
W1 = jnp.array([[0.1, 0.2], [0.3, 0.4]])
b1 = jnp.array([0.0, 0.0])
W2 = jnp.array([[0.5, 0.6]])
b2 = jnp.array([0.0])
params = (W1, b1, W2, b2)

# Batch data
batch_x = jnp.array([[1.0, 0.5], [0.8, 1.2]])
batch_y = jnp.array([[0.7], [0.9]])
learning_rate = 0.01

print('Complete optimization step (JIT + value_and_grad + vmap):')
show(jax.jit(optimization_step), params, batch_x, batch_y, learning_rate)
Complete optimization step (JIT + value_and_grad + vmap):
../_images/user-guide_transformations_29_1.svg

Summary

This notebook demonstrated JAX’s powerful function transformations and their visualizations:

Core Transformations:

  • ``grad``: Automatic differentiation for computing gradients

  • ``vmap``: Automatic vectorization for batch processing

  • ``jit``: Just-in-time compilation for performance

  • ``pmap``: Parallel mapping across devices (conceptual)

Advanced Features:

  • Higher-order derivatives: Nested grad calls

  • ``value_and_grad``: Efficient computation of both value and gradient

  • Jacobians and Hessians: jacfwd, jacrev, hessian

  • Custom transformations: custom_vjp for specialized gradients

Composition:

  • Transformations are composable: jit(grad(vmap(...)))

  • Order matters: vmap(grad(...)) vs grad(vmap(...))

  • Can be combined for complex workflows

Key Benefits:

  • Functional: No side effects, pure transformations

  • Performant: JIT compilation and vectorization

  • Flexible: Works with arbitrary Python functions

  • Scalable: Efficient batch processing and parallelization

The HLO visualizations reveal how these high-level transformations are compiled into efficient computational graphs, showing the automatic optimizations performed by JAX’s compiler.