JAX Transformation Examples¶
This notebook demonstrates how visu-hlo visualizes the computational graphs created by JAX’s powerful function transformations. JAX provides composable function transformations for automatic differentiation, vectorization, and parallelization.
Setup¶
First, let’s import the necessary libraries:
[1]:
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
import jax
import jax.numpy as jnp
from visu_hlo import show
Automatic Differentiation with grad¶
JAX’s grad transformation computes gradients of scalar-valued functions:
[2]:
def simple_function(x):
"""Simple quadratic function."""
return x**2 + 3 * x + 1
# Gradient function
grad_fn = jax.grad(simple_function)
print('Original function f(x) = x² + 3x + 1:')
show(simple_function, jnp.array(2.0))
print("\nGradient f'(x) = 2x + 3:")
show(grad_fn, jnp.array(2.0))
Original function f(x) = x² + 3x + 1:
Gradient f'(x) = 2x + 3:
Multivariate Gradients¶
[3]:
def multivariate_function(params):
"""Function of multiple variables."""
x, y, z = params
return x**2 + y * z + jnp.sin(x * y)
grad_multivariate = jax.grad(multivariate_function)
test_params = jnp.array([1.0, 2.0, 3.0])
print('Multivariate function gradient:')
show(grad_multivariate, test_params)
Multivariate function gradient:
Partial Derivatives with argnums¶
[4]:
def loss_function(params, x, y):
"""Loss function with respect to parameters."""
w, b = params
prediction = w * x + b
return (prediction - y) ** 2
# Gradient with respect to parameters (argnums=0)
grad_wrt_params = jax.grad(loss_function, argnums=0)
# Gradient with respect to input x (argnums=1)
grad_wrt_x = jax.grad(loss_function, argnums=1)
params = jnp.array([2.0, 1.0]) # w=2, b=1
x = jnp.array(3.0)
y = jnp.array(5.0)
print('Gradient with respect to parameters:')
show(grad_wrt_params, params, x, y)
print('\nGradient with respect to input x:')
show(grad_wrt_x, params, x, y)
Gradient with respect to parameters:
Gradient with respect to input x:
Higher-Order Derivatives¶
[5]:
def polynomial(x):
"""Polynomial function for higher-order derivatives."""
return x**4 + 2 * x**3 - 3 * x**2 + x + 1
# First derivative
first_deriv = jax.grad(polynomial)
# Second derivative
second_deriv = jax.grad(jax.grad(polynomial))
# Third derivative
third_deriv = jax.grad(jax.grad(jax.grad(polynomial)))
x = jnp.array(2.0)
print('Second derivative:')
show(second_deriv, x)
print('\nThird derivative:')
show(third_deriv, x)
Second derivative:
Third derivative:
value_and_grad for Efficiency¶
[6]:
def expensive_function(x):
"""Function where we want both value and gradient."""
return jnp.sum(x**3) + jnp.sum(jnp.sin(x))
# Get both value and gradient in one pass
value_and_grad_fn = jax.value_and_grad(expensive_function)
x = jnp.array([1.0, 2.0, 3.0])
print('value_and_grad (more efficient than separate calls):')
show(value_and_grad_fn, x)
value_and_grad (more efficient than separate calls):
Vectorization with vmap¶
Automatically vectorize functions to work on batches:
[7]:
def single_example_function(x):
"""Function that works on a single example."""
return jnp.sum(x**2) + jnp.mean(x)
# Vectorize to work on batches
batched_function = jax.vmap(single_example_function)
# Single example
single_input = jnp.array([1.0, 2.0, 3.0])
print('Single example function:')
show(single_example_function, single_input)
# Batch of examples
batch_input = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
print('\nVectorized function (vmap):')
show(batched_function, batch_input)
Single example function:
Vectorized function (vmap):
Advanced vmap with in_axes¶
[8]:
def matrix_vector_mult(matrix, vector):
"""Matrix-vector multiplication."""
return jnp.dot(matrix, vector)
# Vectorize over the vector argument (axis 0) but not the matrix
batch_matvec = jax.vmap(matrix_vector_mult, in_axes=(None, 0))
matrix = jnp.array([[1.0, 2.0], [3.0, 4.0]])
vectors = jnp.array([[1.0, 1.0], [2.0, 3.0], [0.5, 1.5]])
print('Vectorized matrix-vector multiplication:')
print(f'Matrix shape: {matrix.shape}, Vectors shape: {vectors.shape}')
show(batch_matvec, matrix, vectors)
Vectorized matrix-vector multiplication:
Matrix shape: (2, 2), Vectors shape: (3, 2)
Nested vmap for Multiple Batch Dimensions¶
[9]:
def pairwise_distance(x, y):
"""Euclidean distance between two points."""
return jnp.sqrt(jnp.sum((x - y) ** 2))
# First vmap over y, then over x
vectorized_distances = jax.vmap(jax.vmap(pairwise_distance, in_axes=(None, 0)), in_axes=(0, None))
points_x = jnp.array([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]])
points_y = jnp.array([[1.0, 0.0], [0.0, 1.0]])
print('Nested vmap for pairwise distances:')
print(f'Points X shape: {points_x.shape}, Points Y shape: {points_y.shape}')
show(vectorized_distances, points_x, points_y)
Nested vmap for pairwise distances:
Points X shape: (3, 2), Points Y shape: (2, 2)
Combining grad and vmap¶
Vectorized gradients for batch processing:
[10]:
def loss_per_example(params, x, y):
"""Loss for a single example."""
w, b = params
prediction = jnp.dot(w, x) + b
return (prediction - y) ** 2
# Gradient for a single example
grad_single = jax.grad(loss_per_example)
# Vectorized gradient for batch of examples
grad_batch = jax.vmap(grad_single, in_axes=(None, 0, 0))
params = jnp.array([1.0, 2.0]), jnp.array(0.5) # w, b
batch_x = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
batch_y = jnp.array([3.0, 7.0, 11.0])
print('Vectorized gradients for batch:')
show(grad_batch, params, batch_x, batch_y)
Vectorized gradients for batch:
Combining Multiple Transformations¶
JAX transformations are composable:
[11]:
def neural_network_layer(params, x):
"""Simple neural network layer."""
W, b = params
return jnp.tanh(jnp.dot(W, x) + b)
def loss_fn(params, batch_x, batch_y):
"""Loss function for the neural network."""
# Vectorize over the batch
batch_predictions = jax.vmap(neural_network_layer, in_axes=(None, 0))(params, batch_x)
# Mean squared error
return jnp.mean((batch_predictions - batch_y) ** 2)
# Combine JIT and grad
jit_grad_loss = jax.jit(jax.grad(loss_fn))
# Parameters
W = jnp.array([[0.1, 0.2], [0.3, 0.4]])
b = jnp.array([0.1, 0.1])
params = (W, b)
# Batch data
batch_x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
batch_y = jnp.array([[0.5, 0.8], [0.2, 0.9]])
print('JIT + grad + vmap combination:')
show(jit_grad_loss, params, batch_x, batch_y)
JIT + grad + vmap combination:
Advanced: Jacobian Computation¶
[12]:
def vector_function(x):
"""Vector-valued function for Jacobian computation."""
return jnp.array([x[0] ** 2 + x[1], x[0] * x[1], x[1] ** 2])
# Jacobian using vmap and grad
jacobian_fn = jax.jacfwd(vector_function)
x = jnp.array([2.0, 3.0])
print('Jacobian computation:')
show(jax.jit(jacobian_fn), x)
Jacobian computation:
Hessian Computation¶
[13]:
def scalar_function_for_hessian(x):
"""Scalar function for Hessian computation."""
return x[0] ** 3 + x[1] ** 2 + x[0] * x[1]
# Hessian using nested grad
hessian_fn = jax.hessian(scalar_function_for_hessian)
x = jnp.array([1.0, 2.0])
print('Hessian computation:')
show(jax.jit(hessian_fn), x)
Hessian computation:
Custom Transformations with custom_vjp¶
Defining custom vector-Jacobian products:
[14]:
@jax.custom_vjp
def smooth_abs(x):
"""Smooth approximation to absolute value."""
return jnp.sqrt(x**2 + 1e-8)
def smooth_abs_fwd(x):
"""Forward pass."""
y = smooth_abs(x)
return y, x
def smooth_abs_bwd(x, g):
"""Backward pass with custom gradient."""
return (g * x / jnp.sqrt(x**2 + 1e-8),)
smooth_abs.defvjp(smooth_abs_fwd, smooth_abs_bwd)
# Use in a function with grad
def function_with_custom_grad(x):
return jnp.sum(smooth_abs(x))
grad_custom = jax.grad(function_with_custom_grad)
x = jnp.array([-1.0, 0.5, 2.0])
print('Function with custom VJP:')
show(jax.jit(grad_custom), x)
Function with custom VJP:
Optimization with Transformations¶
A complete optimization example combining multiple transformations:
[15]:
def optimization_step(params, batch_x, batch_y, learning_rate):
"""Single optimization step combining multiple transformations."""
def model(params, x):
W1, b1, W2, b2 = params
h = jnp.tanh(jnp.dot(W1, x) + b1)
return jnp.dot(W2, h) + b2
def batch_loss(params, batch_x, batch_y):
# Vectorize model over batch
predictions = jax.vmap(model, in_axes=(None, 0))(params, batch_x)
return jnp.mean((predictions - batch_y) ** 2)
# Get both loss and gradients efficiently
loss, grads = jax.value_and_grad(batch_loss)(params, batch_x, batch_y)
# Update parameters
new_params = jax.tree.map(lambda p, g: p - learning_rate * g, params, grads)
return new_params, loss
# Initialize parameters
W1 = jnp.array([[0.1, 0.2], [0.3, 0.4]])
b1 = jnp.array([0.0, 0.0])
W2 = jnp.array([[0.5, 0.6]])
b2 = jnp.array([0.0])
params = (W1, b1, W2, b2)
# Batch data
batch_x = jnp.array([[1.0, 0.5], [0.8, 1.2]])
batch_y = jnp.array([[0.7], [0.9]])
learning_rate = 0.01
print('Complete optimization step (JIT + value_and_grad + vmap):')
show(jax.jit(optimization_step), params, batch_x, batch_y, learning_rate)
Complete optimization step (JIT + value_and_grad + vmap):
Summary¶
This notebook demonstrated JAX’s powerful function transformations and their visualizations:
Core Transformations:¶
``grad``: Automatic differentiation for computing gradients
``vmap``: Automatic vectorization for batch processing
``jit``: Just-in-time compilation for performance
``pmap``: Parallel mapping across devices (conceptual)
Advanced Features:¶
Higher-order derivatives: Nested
gradcalls``value_and_grad``: Efficient computation of both value and gradient
Jacobians and Hessians:
jacfwd,jacrev,hessianCustom transformations:
custom_vjpfor specialized gradients
Composition:¶
Transformations are composable:
jit(grad(vmap(...)))Order matters:
vmap(grad(...))vsgrad(vmap(...))Can be combined for complex workflows
Key Benefits:¶
Functional: No side effects, pure transformations
Performant: JIT compilation and vectorization
Flexible: Works with arbitrary Python functions
Scalable: Efficient batch processing and parallelization
The HLO visualizations reveal how these high-level transformations are compiled into efficient computational graphs, showing the automatic optimizations performed by JAX’s compiler.