Skip to content

Commit

Permalink
Adding Hlo module parsing util functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Jul 19, 2024
1 parent 4276eea commit f6d22c8
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 1 deletion.
2 changes: 1 addition & 1 deletion jax_scalify/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from . import core, lax, ops, tree
from . import core, lax, ops, tree, utils
from .core import ( # noqa: F401
Pow2RoundMode,
ScaledArray,
Expand Down
2 changes: 2 additions & 0 deletions jax_scalify/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
from .hlo import parse_hlo_module, print_hlo_module # noqa: F401
107 changes: 107 additions & 0 deletions jax_scalify/utils/hlo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
import json
from dataclasses import dataclass
from typing import Any, Dict, List
import textwrap

from jax.stages import Compiled, Lowered


@dataclass
class HloOperationInfo:
"""HLO module operation (raw) info.
Parse from raw `as_text` compiled HloModule.
Args:
cmd: Raw HLO operation (function + inputs/outputs).
metadata: JAX metadata (line, ...)
backend_config: Optional backend config dictionary.
"""

cmd: str
indent: int = 0
metadata: str | None = None
backend_config: Dict[Any, Any] | None = None

def as_text(self, metadata: bool = False, backend_cfg: bool = False, indent: int = 2) -> str:
"""Convert to raw text, with formatting issues."""
indent_txt = " " * (indent * self.indent)
line = indent_txt + self.cmd
if backend_cfg and self.backend_config:
# A bit hacky text formating of backend config!
backend_cfg_raw = json.dumps(self.backend_config, indent=indent)
backend_cfg_raw = "backend_cfg: " + backend_cfg_raw
backend_cfg_raw = textwrap.indent(backend_cfg_raw, indent_txt + " " * indent)
line += "\n" + backend_cfg_raw
return line


def parse_hlo_operation_raw_line(raw_line: str) -> HloOperationInfo:
"""Very crude and ugly parsing of an Hlo operation raw line!
Returns:
Parsed Hlo operation line.
"""
metadata: str | None = None
backend_cfg = None

# Parse "metadata={...}" block.
metadata_prefix = ", metadata={"
lidx = raw_line.find(metadata_prefix)
if lidx >= 0:
ridx = raw_line[lidx:].find("}") + lidx
metadata = raw_line[lidx : ridx + 1]
raw_line = raw_line.replace(metadata, "")
metadata = metadata[2:]

# Parse "backend_config={...}" block.
backend_cfg_prefix = ", backend_config="
lidx = raw_line.find(backend_cfg_prefix)
if lidx >= 0:
backend_cfg_str = raw_line[lidx + len(backend_cfg_prefix) :]
# TODO: deal with exception raised.
backend_cfg = json.loads(backend_cfg_str)
raw_line = raw_line[:lidx]

# Clean the raw line.
raw_line = raw_line.rstrip()
size = len(raw_line)
raw_line = raw_line.lstrip()
indent = (size - len(raw_line)) // 2
return HloOperationInfo(raw_line, indent, metadata, backend_cfg)


def parse_hlo_module(module: Lowered | Compiled) -> List[HloOperationInfo]:
"""Parse an Hlo module, to be human-readable!
Note: `m.hlo_modules()[0].computations()[0].render_html()`
is also generating a nice HTML output!
Args:
module: HLO module or JAX stages compiled instance.
Returns:
List of HLO operation info.
"""
assert isinstance(module, (Lowered, Compiled))
if isinstance(module, Lowered):
module = module.compile()
module_raw_txt = module.as_text()
module_lines = module_raw_txt.split("\n")
ops = [parse_hlo_operation_raw_line(line) for line in module_lines]
return ops


def print_hlo_module(
module: Lowered | Compiled, metadata: bool = False, backend_cfg: bool = False, indent: int = 2
) -> None:
"""Human-readable Hlo module printing.
Args:
module: AOT Lowered or Compiled JAX module.
metadata: Print op metadata as well.
backend_cfg: Print op backend config as well.
"""
cmds = parse_hlo_module(module)
for c in cmds:
print(c.as_text(metadata=metadata, backend_cfg=backend_cfg, indent=indent))
23 changes: 23 additions & 0 deletions tests/utils/test_hlo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
import chex
import jax
import jax.numpy as jnp

import jax_scalify as jsa


class ScalifyHloUtilsTests(chex.TestCase):
def test__hlo_util__parse_hlo_module(self):
def matmul_fn(a, b):
return jax.lax.dot(a, b.T)

a = jax.core.ShapedArray((16, 48), jnp.float16)
b = jax.core.ShapedArray((32, 48), jnp.float16)

matmul_lowered = jax.jit(matmul_fn).lower(a, b)
matmul_compiled = matmul_lowered.compile()

ops = jsa.utils.parse_hlo_module(matmul_compiled)
assert len(ops) >= 6
# TODO: other tests???
# jsa.utils.print_hlo_module(matmul_compiled, backend_cfg=True, indent=2)

0 comments on commit f6d22c8

Please sign in to comment.