Documentation¶
visu-hlo is a Python package that displays the HLO (High Level Operations) representation of JAX functions as SVG visualizations. This tool helps developers understand the computational graphs generated by XLA compilations.
Features¶
🎯 Easy Visualization: Display HLO graphs with a single function call
⚡ JIT Support: Works with both regular and jitted JAX functions
🖼️ SVG Output: High-quality vector graphics that scale perfectly
🖥️ Cross-Platform: Supports Linux, macOS, and Windows
📦 Lightweight: Minimal dependencies - just JAX and Graphviz
Quick Example¶
import jax.numpy as jnp
from visu_hlo import show
# Display optimized HLO (default)
show(lambda x: 3 * x * 2, jnp.ones(10))
To display the non-optimized HLO:
show(lambda x: 3 * x * 2, jnp.ones(10), jit=False)
To save as an SVG file:
from visu_hlo import write_svg
write_svg('graph.svg', func, jnp.ones(10))
Installation¶
pip install visu-hlo
System dependency: Install Graphviz
Content¶
User Guide
- Installation
- Getting Started
- Matrix Examples
- Control Flow Examples
- Setup
- Conditional Execution with
lax.cond - Complex Conditional Logic
- Multi-way Conditionals with
lax.switch - Loops with
lax.fori_loop - While Loops with
lax.while_loop - Scan Operations with
lax.scan - Recurrent Neural Network with Scan
- Nested Control Flow
- Dynamic Programming Example
- Optimization Loop
- Summary
- JAX Transformation Examples
- Setup
- Automatic Differentiation with
grad - Multivariate Gradients
- Partial Derivatives with
argnums - Higher-Order Derivatives
value_and_gradfor Efficiency- Vectorization with
vmap - Advanced
vmapwithin_axes - Nested
vmapfor Multiple Batch Dimensions - Combining
gradandvmap - Combining Multiple Transformations
- Advanced: Jacobian Computation
- Hessian Computation
- Custom Transformations with
custom_vjp - Optimization with Transformations
- Summary
Reference Guide
Developer Guide