Skip to content

API Reference

Public API documentation for the IR extraction framework.

Main Entry Point

extract_ir

extract_ir(model: Module, example_inputs: Tuple[Any, ...], *, model_name: Optional[str] = None, strict: bool = True) -> IR

Extract IR from a PyTorch model.

This is the main entry point for IR extraction. The model should be on meta device for weight-free extraction.

Parameters:

Name Type Description Default
model Module

The PyTorch model to extract IR from. Should be on meta device.

required
example_inputs Tuple[Any, ...]

Example inputs for tracing. Should be on meta device.

required
model_name Optional[str]

Optional name for the model. If None, uses class name.

None
strict bool

If True, validate meta device and raise errors for unsupported ops.

True

Returns:

Type Description
IR

The extracted IR.

Raises:

Type Description
ExportError

If model export fails.

ConversionError

If IR conversion fails (in strict mode).

Example
import torch
from torch_ir import extract_ir

with torch.device('meta'):
    model = torch.nn.Linear(10, 5)
inputs = (torch.randn(1, 10, device='meta'),)
ir = extract_ir(model, inputs)
print(f"Extracted {len(ir.nodes)} nodes")
Source code in torch_ir/__init__.py
def extract_ir(
    model: nn.Module,
    example_inputs: Tuple[Any, ...],
    *,
    model_name: Optional[str] = None,
    strict: bool = True,
) -> IR:
    """Extract IR from a PyTorch model.

    This is the main entry point for IR extraction. The model should be on
    meta device for weight-free extraction.

    Args:
        model: The PyTorch model to extract IR from. Should be on meta device.
        example_inputs: Example inputs for tracing. Should be on meta device.
        model_name: Optional name for the model. If None, uses class name.
        strict: If True, validate meta device and raise errors for unsupported ops.

    Returns:
        The extracted IR.

    Raises:
        ExportError: If model export fails.
        ConversionError: If IR conversion fails (in strict mode).

    Example:
        ```python
        import torch
        from torch_ir import extract_ir

        with torch.device('meta'):
            model = torch.nn.Linear(10, 5)
        inputs = (torch.randn(1, 10, device='meta'),)
        ir = extract_ir(model, inputs)
        print(f"Extracted {len(ir.nodes)} nodes")
        ```
    """
    # Export model using torch.export
    exported = export_model(model, example_inputs, strict=strict)

    # Get model name
    if model_name is None:
        model_name = get_model_name(model)

    # Convert to IR
    ir = convert_exported_program(exported, model_name=model_name, strict=False)

    # Capture lifted tensor constants from export
    if hasattr(exported, "constants") and exported.constants:
        meta_constants = [
            name
            for name, tensor in exported.constants.items()
            if isinstance(tensor, torch.Tensor) and tensor.device.type == "meta"
        ]
        if meta_constants:
            warnings.warn(
                f"Skipping {len(meta_constants)} meta-device constant(s) "
                f"(shape/dtype already in weights): {meta_constants}"
            )
        ir.constants = {
            k: v for k, v in exported.constants.items() if isinstance(v, torch.Tensor) and v.device.type != "meta"
        }

    return ir

Module List

Module Description
IR Data Structures TensorMeta, OpNode, IR
Exporter export_model, ExportError
Converter IRConverter, ConversionError, convert_exported_program
Executor IRExecutor, execute_ir, ExecutionError
Serializer serialize_ir, save_ir, load_ir, IRSerializer
Verifier verify_ir, verify_ir_with_state_dict, IRVerifier
Weight Loader load_weights, WeightLoader, WeightLoadError
Ops register_op, register_executor
Visualize ir_to_mermaid, generate_op_distribution_pie