Matrix Examples

This notebook demonstrates how visu-hlo visualizes matrix operations and linear algebra computations in JAX.

Setup

First, let’s import the necessary libraries:

[1]:
import os

os.environ['JAX_PLATFORMS'] = 'cpu'

import jax
import jax.numpy as jnp

from visu_hlo import show

Matrix-Vector Multiplication

Let’s start with a common operation:

[2]:
@jax.jit
def matrix_vector_multiply(W, x):
    """Matrix-vector multiplication: y = Wx"""
    return jnp.dot(W, x)


# Weight matrix and input vector
W = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
x = jnp.array([0.5, 1.0, 1.5])

print('Weight matrix shape:', W.shape)
print('Input vector shape:', x.shape)
print('\nVisualization:')
show(matrix_vector_multiply, W, x)
Weight matrix shape: (2, 3)
Input vector shape: (3,)

Visualization:
../_images/user-guide_matrix-operations_3_1.svg

Basic Matrix Multiplication

For a simple matrix multiplication operation:

[3]:
@jax.jit
def matrix_multiply(A, B):
    """Simple matrix multiplication using jnp.dot."""
    return jnp.dot(A, B)


# Create sample matrices
A = jnp.ones((3, 4))
B = jnp.ones((4, 2))

print('Matrix shapes: A', A.shape, '× B', B.shape, '= result', (3, 2))
print('\nVisualization of matrix multiplication:')
show(matrix_multiply, A, B)
Matrix shapes: A (3, 4) × B (4, 2) = result (3, 2)

Visualization of matrix multiplication:
../_images/user-guide_matrix-operations_5_1.svg

Batch Matrix Operations

Working with batches of matrices:

[4]:
@jax.jit
def batch_matrix_multiply(batch_A, batch_B):
    """Multiply batches of matrices."""
    return jnp.einsum('...mn,...np->...mp', batch_A, batch_B)


# Batch of 2x2 matrices
batch_A = jnp.array(
    [
        [[1.0, 2.0], [3.0, 4.0]],
        [[5.0, 6.0], [7.0, 8.0]],
        [[9.0, 10.0], [11.0, 12.0]],
    ]
)
batch_B = jnp.array(
    [
        [[0.1, 0.2], [0.3, 0.4]],
        [[0.5, 0.6], [0.7, 0.8]],
        [[0.9, 1.0], [1.1, 1.2]],
    ]
)

print('Batch shape:', batch_A.shape)
print('\nVisualization of batch matrix multiplication:')
show(batch_matrix_multiply, batch_A, batch_B)
Batch shape: (3, 2, 2)

Visualization of batch matrix multiplication:
../_images/user-guide_matrix-operations_7_1.svg

Matrix Operations with Broadcasting

JAX’s broadcasting capabilities in matrix operations:

[5]:
@jax.jit
def broadcasted_operations(matrix, vector):
    """Matrix operations with broadcasting."""
    # Add vector to each row of matrix
    added = matrix + vector
    # Element-wise multiplication
    multiplied = matrix * vector
    return added, multiplied


matrix = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
vector = jnp.array([0.1, 0.2, 0.3])

print('Matrix shape:', matrix.shape)
print('Vector shape:', vector.shape)
print('\nVisualization of broadcasted operations:')
show(broadcasted_operations, matrix, vector)
Matrix shape: (2, 3)
Vector shape: (3,)

Visualization of broadcasted operations:
../_images/user-guide_matrix-operations_9_1.svg

Matrix Norms

[6]:
@jax.jit
def matrix_norms(matrix):
    """Compute various matrix norms."""
    frobenius = jnp.linalg.norm(matrix, 'fro')
    spectral = jnp.linalg.norm(matrix, 2)
    return frobenius, spectral


test_matrix = jnp.array([[1.0, 2.0], [3.0, 4.0]])

print('Test matrix:')
print(test_matrix)
print('\nVisualization of matrix norm computation:')
show(matrix_norms, test_matrix)
Test matrix:
[[1. 2.]
 [3. 4.]]

Visualization of matrix norm computation:
../_images/user-guide_matrix-operations_11_1.svg

Matrix Inverse

[7]:
@jax.jit
def matrix_inverse(matrix):
    """Compute matrix inverse."""
    return jnp.linalg.inv(matrix)


# Well-conditioned matrix
well_conditioned = jnp.eye(3) + 0.1 * jnp.ones((3, 3))

print('Well-conditioned matrix:')
print(well_conditioned)
print('\nVisualization of matrix inversion:')
show(matrix_inverse, well_conditioned)
Well-conditioned matrix:
[[1.1 0.1 0.1]
 [0.1 1.1 0.1]
 [0.1 0.1 1.1]]

Visualization of matrix inversion:
../_images/user-guide_matrix-operations_13_1.svg

Linear System Solving

Solving systems of linear equations Ax = b:

[8]:
@jax.jit
def solve_linear_system(A, b):
    """Solve Ax = b using JAX's linear algebra solver."""
    return jnp.linalg.solve(A, b)


# Create an invertible matrix and target vector
A = jnp.array([[3.0, 1.0, 2.0], [1.0, 4.0, 1.0], [2.0, 1.0, 3.0]])
b = jnp.array([1.0, 2.0, 3.0])

print('System matrix A:')
print(A)
print('\nTarget vector b:', b)
print('\nVisualization of linear system solver:')
show(solve_linear_system, A, b)
System matrix A:
[[3. 1. 2.]
 [1. 4. 1.]
 [2. 1. 3.]]

Target vector b: [1. 2. 3.]

Visualization of linear system solver:
../_images/user-guide_matrix-operations_15_1.svg

Eigenvalue Decomposition

Computing eigenvalues and eigenvectors of symmetric matrices:

[9]:
@jax.jit
def compute_eigenvalues(matrix):
    """Compute eigenvalues of a symmetric matrix."""
    return jnp.linalg.eigvals(matrix)


# Symmetric matrix
sym_matrix = jnp.array([[4.0, 1.0, 2.0], [1.0, 3.0, 1.0], [2.0, 1.0, 5.0]])

print('Symmetric matrix:')
print(sym_matrix)
print('\nVisualization of eigenvalue computation:')
show(compute_eigenvalues, sym_matrix)
Symmetric matrix:
[[4. 1. 2.]
 [1. 3. 1.]
 [2. 1. 5.]]

Visualization of eigenvalue computation:
../_images/user-guide_matrix-operations_17_1.svg
[10]:
@jax.jit
def compute_eigenvectors(matrix):
    """Compute both eigenvalues and eigenvectors."""
    eigenvals, eigenvecs = jnp.linalg.eigh(matrix)
    return eigenvals, eigenvecs


print('Visualization of full eigendecomposition:')
show(compute_eigenvectors, sym_matrix)
Visualization of full eigendecomposition:
../_images/user-guide_matrix-operations_18_1.svg

Matrix Decompositions

QR Decomposition

[11]:
@jax.jit
def qr_decomposition(matrix):
    """QR decomposition of a matrix."""
    Q, R = jnp.linalg.qr(matrix)
    return Q, R


# Rectangular matrix for QR decomposition
rect_matrix = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])

print('Input matrix for QR decomposition:')
print(rect_matrix)
print('\nVisualization:')
show(qr_decomposition, rect_matrix)
Input matrix for QR decomposition:
[[1. 2.]
 [3. 4.]
 [5. 6.]]

Visualization:
../_images/user-guide_matrix-operations_20_1.svg

Singular Value Decomposition (SVD)

[12]:
@jax.jit
def svd_decomposition(matrix):
    """Singular Value Decomposition."""
    U, s, Vt = jnp.linalg.svd(matrix, full_matrices=False)
    return U, s, Vt


print('Visualization of SVD:')
show(svd_decomposition, rect_matrix)
Visualization of SVD:
../_images/user-guide_matrix-operations_22_1.svg

Summary

We have demonstrated various matrix operations in JAX and how their computational graphs are visualized with visu-hlo:

  • Basic operations: Matrix multiplication, matrix-vector products, Batch and Broadcasting

  • Linear algebra: Norm, Matrix inverse, System solving, eigendecomposition

  • Matrix decompositions: QR, SVD

Each visualization shows how JAX decomposes these high-level linear algebra operations into primitive HLO operations, providing insight into the computational structure and potential optimization opportunities.