{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Getting Started" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic Usage\n", "\n", "The main interface is the `show()` function that takes a function and its arguments:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import jax.numpy as jnp\n", "\n", "from visu_hlo import show\n", "\n", "\n", "def simple_function(x):\n", " return 10 * x + 2 + 3\n", "\n", "\n", "show(simple_function, jnp.array([1.0, 2.0, 3.0]));" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This will:\n", "1. JIT-compile the function\n", "2. Generate a DOT graph representation from the optimized HLO\n", "3. Convert it to SVG format\n", "4. Display it" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Optimized vs Non-Optimized HLO\n", "\n", "By default, `show()` displays the optimized HLO after XLA compilation (`jit=True`). You can also view the non-optimized HLO by setting `jit=False`:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Optimized HLO (default)\n", "\n", "The optimized version shows the computation graph after XLA optimizations:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show(simple_function, jnp.ones(10)) # jit=True is the default" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Non-Optimized HLO\n", "\n", "The non-optimized version shows the Jaxpr computation graph converted to HLO, before XLA optimizations:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show(simple_function, jnp.ones(10), jit=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Key Differences\n", "\n", "When comparing the two visualizations:\n", "\n", "1. **Optimization**: The optimized version shows fused operations\n", "2. **Constant Folding**: Constants like `2 + 3 = 5` are pre-computed (folded)\n", "3. **Memory Layout**: Different memory access patterns may be visible\n", "4. **Operation Count**: Fewer nodes in the optimized version" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Function Arguments\n", "\n", "You can pass multiple arguments and keyword arguments:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def multi_arg_func(x, y, scale=1.0):\n", " return (x + y) * scale\n", "\n", "\n", "show(multi_arg_func, jnp.ones(5), jnp.zeros(5), scale=2.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Saving to Files\n", "\n", "You can save the HLO graph to SVG or DOT files using `write_svg()` and `write_dot()`:\n", "\n", "```python\n", "from visu_hlo import write_svg, write_dot\n", "\n", "# Save as SVG\n", "write_svg('graph.svg', simple_function, jnp.ones(10))\n", "\n", "# Save as DOT (Graphviz format)\n", "write_dot('graph.dot', simple_function, jnp.ones(10))\n", "\n", "# Save non-optimized HLO\n", "write_svg('graph_unoptimized.svg', simple_function, jnp.ones(10), jit=False)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Understanding the Output\n", "\n", "The generated SVG shows:\n", "- **Nodes**: Operations (add, multiply, etc.)\n", "- **Edges**: Data flow between operations\n", "- **Colors**: Different operation types\n", "- **Labels**: Operation names and shapes\n", "\n", "Each node contains:\n", "- Operation name (e.g., `add.1`, `mul.2`)\n", "- Input/output shapes\n", "- Source location in your code (when available)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 4 }