Getting Started

Basic Usage

The main interface is the show() function that takes a function and its arguments:

[1]:
import jax.numpy as jnp

from visu_hlo import show


def simple_function(x):
    return 10 * x + 2 + 3


show(simple_function, jnp.array([1.0, 2.0, 3.0]));
../_images/user-guide_getting-started_2_0.svg

This will:

  1. JIT-compile the function

  2. Generate a DOT graph representation from the optimized HLO

  3. Convert it to SVG format

  4. Display it

Optimized vs Non-Optimized HLO

By default, show() displays the optimized HLO after XLA compilation (jit=True). You can also view the non-optimized HLO by setting jit=False:

Optimized HLO (default)

The optimized version shows the computation graph after XLA optimizations:

[2]:
show(simple_function, jnp.ones(10))  # jit=True is the default
../_images/user-guide_getting-started_6_0.svg

Non-Optimized HLO

The non-optimized version shows the Jaxpr computation graph converted to HLO, before XLA optimizations:

[3]:
show(simple_function, jnp.ones(10), jit=False)
../_images/user-guide_getting-started_8_0.svg

Key Differences

When comparing the two visualizations:

  1. Optimization: The optimized version shows fused operations

  2. Constant Folding: Constants like 2 + 3 = 5 are pre-computed (folded)

  3. Memory Layout: Different memory access patterns may be visible

  4. Operation Count: Fewer nodes in the optimized version

Function Arguments

You can pass multiple arguments and keyword arguments:

[4]:
def multi_arg_func(x, y, scale=1.0):
    return (x + y) * scale


show(multi_arg_func, jnp.ones(5), jnp.zeros(5), scale=2.0)
../_images/user-guide_getting-started_11_0.svg

Saving to Files

You can save the HLO graph to SVG or DOT files using write_svg() and write_dot():

from visu_hlo import write_svg, write_dot

# Save as SVG
write_svg('graph.svg', simple_function, jnp.ones(10))

# Save as DOT (Graphviz format)
write_dot('graph.dot', simple_function, jnp.ones(10))

# Save non-optimized HLO
write_svg('graph_unoptimized.svg', simple_function, jnp.ones(10), jit=False)

Understanding the Output

The generated SVG shows:

  • Nodes: Operations (add, multiply, etc.)

  • Edges: Data flow between operations

  • Colors: Different operation types

  • Labels: Operation names and shapes

Each node contains:

  • Operation name (e.g., add.1, mul.2)

  • Input/output shapes

  • Source location in your code (when available)