{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Matrix Examples\n", "\n", "This notebook demonstrates how visu-hlo visualizes matrix operations and linear algebra computations in JAX.\n", "\n", "## Setup\n", "\n", "First, let's import the necessary libraries:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.environ['JAX_PLATFORMS'] = 'cpu'\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "from visu_hlo import show" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Matrix-Vector Multiplication\n", "\n", "Let's start with a common operation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def matrix_vector_multiply(W, x):\n", " \"\"\"Matrix-vector multiplication: y = Wx\"\"\"\n", " return jnp.dot(W, x)\n", "\n", "\n", "# Weight matrix and input vector\n", "W = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])\n", "x = jnp.array([0.5, 1.0, 1.5])\n", "\n", "print('Weight matrix shape:', W.shape)\n", "print('Input vector shape:', x.shape)\n", "print('\\nVisualization:')\n", "show(matrix_vector_multiply, W, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic Matrix Multiplication\n", "\n", "For a simple matrix multiplication operation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def matrix_multiply(A, B):\n", " \"\"\"Simple matrix multiplication using jnp.dot.\"\"\"\n", " return jnp.dot(A, B)\n", "\n", "\n", "# Create sample matrices\n", "A = jnp.ones((3, 4))\n", "B = jnp.ones((4, 2))\n", "\n", "print('Matrix shapes: A', A.shape, '× B', B.shape, '= result', (3, 2))\n", "print('\\nVisualization of matrix multiplication:')\n", "show(matrix_multiply, A, B)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Batch Matrix Operations\n", "\n", "Working with batches of matrices:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def batch_matrix_multiply(batch_A, batch_B):\n", " \"\"\"Multiply batches of matrices.\"\"\"\n", " return jnp.einsum('...mn,...np->...mp', batch_A, batch_B)\n", "\n", "\n", "# Batch of 2x2 matrices\n", "batch_A = jnp.array(\n", " [\n", " [[1.0, 2.0], [3.0, 4.0]],\n", " [[5.0, 6.0], [7.0, 8.0]],\n", " [[9.0, 10.0], [11.0, 12.0]],\n", " ]\n", ")\n", "batch_B = jnp.array(\n", " [\n", " [[0.1, 0.2], [0.3, 0.4]],\n", " [[0.5, 0.6], [0.7, 0.8]],\n", " [[0.9, 1.0], [1.1, 1.2]],\n", " ]\n", ")\n", "\n", "print('Batch shape:', batch_A.shape)\n", "print('\\nVisualization of batch matrix multiplication:')\n", "show(batch_matrix_multiply, batch_A, batch_B)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Matrix Operations with Broadcasting\n", "\n", "JAX's broadcasting capabilities in matrix operations:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def broadcasted_operations(matrix, vector):\n", " \"\"\"Matrix operations with broadcasting.\"\"\"\n", " # Add vector to each row of matrix\n", " added = matrix + vector\n", " # Element-wise multiplication\n", " multiplied = matrix * vector\n", " return added, multiplied\n", "\n", "\n", "matrix = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])\n", "vector = jnp.array([0.1, 0.2, 0.3])\n", "\n", "print('Matrix shape:', matrix.shape)\n", "print('Vector shape:', vector.shape)\n", "print('\\nVisualization of broadcasted operations:')\n", "show(broadcasted_operations, matrix, vector)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Matrix Norms" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def matrix_norms(matrix):\n", " \"\"\"Compute various matrix norms.\"\"\"\n", " frobenius = jnp.linalg.norm(matrix, 'fro')\n", " spectral = jnp.linalg.norm(matrix, 2)\n", " return frobenius, spectral\n", "\n", "\n", "test_matrix = jnp.array([[1.0, 2.0], [3.0, 4.0]])\n", "\n", "print('Test matrix:')\n", "print(test_matrix)\n", "print('\\nVisualization of matrix norm computation:')\n", "show(matrix_norms, test_matrix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Matrix Inverse" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def matrix_inverse(matrix):\n", " \"\"\"Compute matrix inverse.\"\"\"\n", " return jnp.linalg.inv(matrix)\n", "\n", "\n", "# Well-conditioned matrix\n", "well_conditioned = jnp.eye(3) + 0.1 * jnp.ones((3, 3))\n", "\n", "print('Well-conditioned matrix:')\n", "print(well_conditioned)\n", "print('\\nVisualization of matrix inversion:')\n", "show(matrix_inverse, well_conditioned)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Linear System Solving\n", "\n", "Solving systems of linear equations Ax = b:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def solve_linear_system(A, b):\n", " \"\"\"Solve Ax = b using JAX's linear algebra solver.\"\"\"\n", " return jnp.linalg.solve(A, b)\n", "\n", "\n", "# Create an invertible matrix and target vector\n", "A = jnp.array([[3.0, 1.0, 2.0], [1.0, 4.0, 1.0], [2.0, 1.0, 3.0]])\n", "b = jnp.array([1.0, 2.0, 3.0])\n", "\n", "print('System matrix A:')\n", "print(A)\n", "print('\\nTarget vector b:', b)\n", "print('\\nVisualization of linear system solver:')\n", "show(solve_linear_system, A, b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Eigenvalue Decomposition\n", "\n", "Computing eigenvalues and eigenvectors of symmetric matrices:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def compute_eigenvalues(matrix):\n", " \"\"\"Compute eigenvalues of a symmetric matrix.\"\"\"\n", " return jnp.linalg.eigvals(matrix)\n", "\n", "\n", "# Symmetric matrix\n", "sym_matrix = jnp.array([[4.0, 1.0, 2.0], [1.0, 3.0, 1.0], [2.0, 1.0, 5.0]])\n", "\n", "print('Symmetric matrix:')\n", "print(sym_matrix)\n", "print('\\nVisualization of eigenvalue computation:')\n", "show(compute_eigenvalues, sym_matrix)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def compute_eigenvectors(matrix):\n", " \"\"\"Compute both eigenvalues and eigenvectors.\"\"\"\n", " eigenvals, eigenvecs = jnp.linalg.eigh(matrix)\n", " return eigenvals, eigenvecs\n", "\n", "\n", "print('Visualization of full eigendecomposition:')\n", "show(compute_eigenvectors, sym_matrix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Matrix Decompositions\n", "\n", "### QR Decomposition" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def qr_decomposition(matrix):\n", " \"\"\"QR decomposition of a matrix.\"\"\"\n", " Q, R = jnp.linalg.qr(matrix)\n", " return Q, R\n", "\n", "\n", "# Rectangular matrix for QR decomposition\n", "rect_matrix = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])\n", "\n", "print('Input matrix for QR decomposition:')\n", "print(rect_matrix)\n", "print('\\nVisualization:')\n", "show(qr_decomposition, rect_matrix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Singular Value Decomposition (SVD)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def svd_decomposition(matrix):\n", " \"\"\"Singular Value Decomposition.\"\"\"\n", " U, s, Vt = jnp.linalg.svd(matrix, full_matrices=False)\n", " return U, s, Vt\n", "\n", "\n", "print('Visualization of SVD:')\n", "show(svd_decomposition, rect_matrix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "We have demonstrated various matrix operations in JAX and how their computational graphs are visualized with visu-hlo:\n", "\n", "- **Basic operations**: Matrix multiplication, matrix-vector products, Batch and Broadcasting\n", "- **Linear algebra**: Norm, Matrix inverse, System solving, eigendecomposition\n", "- **Matrix decompositions**: QR, SVD\n", "\n", "Each visualization shows how JAX decomposes these high-level linear algebra operations into primitive HLO operations, providing insight into the computational structure and potential optimization opportunities." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 4 }