Matrix Examples¶
This notebook demonstrates how visu-hlo visualizes matrix operations and linear algebra computations in JAX.
Setup¶
First, let’s import the necessary libraries:
[1]:
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
import jax
import jax.numpy as jnp
from visu_hlo import show
Matrix-Vector Multiplication¶
Let’s start with a common operation:
[2]:
@jax.jit
def matrix_vector_multiply(W, x):
"""Matrix-vector multiplication: y = Wx"""
return jnp.dot(W, x)
# Weight matrix and input vector
W = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
x = jnp.array([0.5, 1.0, 1.5])
print('Weight matrix shape:', W.shape)
print('Input vector shape:', x.shape)
print('\nVisualization:')
show(matrix_vector_multiply, W, x)
Weight matrix shape: (2, 3)
Input vector shape: (3,)
Visualization:
Basic Matrix Multiplication¶
For a simple matrix multiplication operation:
[3]:
@jax.jit
def matrix_multiply(A, B):
"""Simple matrix multiplication using jnp.dot."""
return jnp.dot(A, B)
# Create sample matrices
A = jnp.ones((3, 4))
B = jnp.ones((4, 2))
print('Matrix shapes: A', A.shape, '× B', B.shape, '= result', (3, 2))
print('\nVisualization of matrix multiplication:')
show(matrix_multiply, A, B)
Matrix shapes: A (3, 4) × B (4, 2) = result (3, 2)
Visualization of matrix multiplication:
Batch Matrix Operations¶
Working with batches of matrices:
[4]:
@jax.jit
def batch_matrix_multiply(batch_A, batch_B):
"""Multiply batches of matrices."""
return jnp.einsum('...mn,...np->...mp', batch_A, batch_B)
# Batch of 2x2 matrices
batch_A = jnp.array(
[
[[1.0, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [7.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0]],
]
)
batch_B = jnp.array(
[
[[0.1, 0.2], [0.3, 0.4]],
[[0.5, 0.6], [0.7, 0.8]],
[[0.9, 1.0], [1.1, 1.2]],
]
)
print('Batch shape:', batch_A.shape)
print('\nVisualization of batch matrix multiplication:')
show(batch_matrix_multiply, batch_A, batch_B)
Batch shape: (3, 2, 2)
Visualization of batch matrix multiplication:
Matrix Operations with Broadcasting¶
JAX’s broadcasting capabilities in matrix operations:
[5]:
@jax.jit
def broadcasted_operations(matrix, vector):
"""Matrix operations with broadcasting."""
# Add vector to each row of matrix
added = matrix + vector
# Element-wise multiplication
multiplied = matrix * vector
return added, multiplied
matrix = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
vector = jnp.array([0.1, 0.2, 0.3])
print('Matrix shape:', matrix.shape)
print('Vector shape:', vector.shape)
print('\nVisualization of broadcasted operations:')
show(broadcasted_operations, matrix, vector)
Matrix shape: (2, 3)
Vector shape: (3,)
Visualization of broadcasted operations:
Matrix Norms¶
[6]:
@jax.jit
def matrix_norms(matrix):
"""Compute various matrix norms."""
frobenius = jnp.linalg.norm(matrix, 'fro')
spectral = jnp.linalg.norm(matrix, 2)
return frobenius, spectral
test_matrix = jnp.array([[1.0, 2.0], [3.0, 4.0]])
print('Test matrix:')
print(test_matrix)
print('\nVisualization of matrix norm computation:')
show(matrix_norms, test_matrix)
Test matrix:
[[1. 2.]
[3. 4.]]
Visualization of matrix norm computation:
Matrix Inverse¶
[7]:
@jax.jit
def matrix_inverse(matrix):
"""Compute matrix inverse."""
return jnp.linalg.inv(matrix)
# Well-conditioned matrix
well_conditioned = jnp.eye(3) + 0.1 * jnp.ones((3, 3))
print('Well-conditioned matrix:')
print(well_conditioned)
print('\nVisualization of matrix inversion:')
show(matrix_inverse, well_conditioned)
Well-conditioned matrix:
[[1.1 0.1 0.1]
[0.1 1.1 0.1]
[0.1 0.1 1.1]]
Visualization of matrix inversion:
Linear System Solving¶
Solving systems of linear equations Ax = b:
[8]:
@jax.jit
def solve_linear_system(A, b):
"""Solve Ax = b using JAX's linear algebra solver."""
return jnp.linalg.solve(A, b)
# Create an invertible matrix and target vector
A = jnp.array([[3.0, 1.0, 2.0], [1.0, 4.0, 1.0], [2.0, 1.0, 3.0]])
b = jnp.array([1.0, 2.0, 3.0])
print('System matrix A:')
print(A)
print('\nTarget vector b:', b)
print('\nVisualization of linear system solver:')
show(solve_linear_system, A, b)
System matrix A:
[[3. 1. 2.]
[1. 4. 1.]
[2. 1. 3.]]
Target vector b: [1. 2. 3.]
Visualization of linear system solver:
Eigenvalue Decomposition¶
Computing eigenvalues and eigenvectors of symmetric matrices:
[9]:
@jax.jit
def compute_eigenvalues(matrix):
"""Compute eigenvalues of a symmetric matrix."""
return jnp.linalg.eigvals(matrix)
# Symmetric matrix
sym_matrix = jnp.array([[4.0, 1.0, 2.0], [1.0, 3.0, 1.0], [2.0, 1.0, 5.0]])
print('Symmetric matrix:')
print(sym_matrix)
print('\nVisualization of eigenvalue computation:')
show(compute_eigenvalues, sym_matrix)
Symmetric matrix:
[[4. 1. 2.]
[1. 3. 1.]
[2. 1. 5.]]
Visualization of eigenvalue computation:
[10]:
@jax.jit
def compute_eigenvectors(matrix):
"""Compute both eigenvalues and eigenvectors."""
eigenvals, eigenvecs = jnp.linalg.eigh(matrix)
return eigenvals, eigenvecs
print('Visualization of full eigendecomposition:')
show(compute_eigenvectors, sym_matrix)
Visualization of full eigendecomposition:
Matrix Decompositions¶
QR Decomposition¶
[11]:
@jax.jit
def qr_decomposition(matrix):
"""QR decomposition of a matrix."""
Q, R = jnp.linalg.qr(matrix)
return Q, R
# Rectangular matrix for QR decomposition
rect_matrix = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
print('Input matrix for QR decomposition:')
print(rect_matrix)
print('\nVisualization:')
show(qr_decomposition, rect_matrix)
Input matrix for QR decomposition:
[[1. 2.]
[3. 4.]
[5. 6.]]
Visualization:
Singular Value Decomposition (SVD)¶
[12]:
@jax.jit
def svd_decomposition(matrix):
"""Singular Value Decomposition."""
U, s, Vt = jnp.linalg.svd(matrix, full_matrices=False)
return U, s, Vt
print('Visualization of SVD:')
show(svd_decomposition, rect_matrix)
Visualization of SVD:
Summary¶
We have demonstrated various matrix operations in JAX and how their computational graphs are visualized with visu-hlo:
Basic operations: Matrix multiplication, matrix-vector products, Batch and Broadcasting
Linear algebra: Norm, Matrix inverse, System solving, eigendecomposition
Matrix decompositions: QR, SVD
Each visualization shows how JAX decomposes these high-level linear algebra operations into primitive HLO operations, providing insight into the computational structure and potential optimization opportunities.