APIs

show()

visu_hlo.show(f, *args, jit=True, **keywords)[source]

Display the HLO representation of functions as SVG.

Parameters:
  • f (Callable[[...], Any] | str) – Function to be displayed. It can be a callable (jitted or not), or it can be a HLO or StableHLO representation as string.

  • *args (Any) – Arguments to be passed to f.

  • jit (bool) – If True, display the optimized HLO (after XLA compilation). If False, display the non-optimized HLO.

  • **keywords (Any) – Keyword arguments to be passed to f.

Return type:

None

Example

To display the XLA-optimized HLO representation of a function:

>>> import jax.numpy as jnp
>>> from visu_hlo import show
>>> def func(x):
...     return 3 * x * 2
>>> show(func, jnp.ones(10))  # Display optimized HLO (default)
>>> show(func, jnp.ones(10), jit=False)  # Display non-optimized HLO

write_svg()

visu_hlo.write_svg(path, f, *args, jit=True, **keywords)[source]

Write the HLO representation of functions as an SVG file.

Parameters:
  • path (str | Path) – Path to the SVG file.

  • f (Callable[[...], Any] | str) – Function to be written. It can be a callable (jitted or not), or it can be a HLO or StableHLO representation as string.

  • *args (Any) – Arguments to be passed to f.

  • jit (bool) – If True, write the optimized HLO (after XLA compilation). If False, write the non-optimized HLO.

  • **keywords (Any) – Keyword arguments to be passed to f.

Return type:

None

Example

To write the XLA-optimized HLO representation of a function as an SVG file:

>>> import jax.numpy as jnp
>>> from visu_hlo import write_svg
>>> def func(x):
...     return 3 * x * 2
>>> write_svg('graph.svg', func, jnp.ones(10))

write_dot()

visu_hlo.write_dot(path, f, *args, jit=True, **keywords)[source]

Write the HLO representation of functions as a DOT file.

Parameters:
  • path (str | Path) – Path to the DOT file.

  • f (Callable[[...], Any] | str) – Function to be written. It can be a callable (jitted or not), or it can be a HLO or StableHLO representation as string.

  • *args (Any) – Arguments to be passed to f.

  • jit (bool) – If True, write the optimized HLO (after XLA compilation). If False, write the non-optimized HLO.

  • **keywords (Any) – Keyword arguments to be passed to f.

Return type:

None

Example

To write the XLA-optimized HLO representation of a function as an DOT file:

>>> import jax.numpy as jnp
>>> from visu_hlo import write_dot
>>> def func(x):
...     return 3 * x * 2
>>> write_dot('graph.dot', func, jnp.ones(10))