Skip to content

Exporter

Module that wraps torch.export to export models.

exporter

Model exporter using torch.export.

ExportError

Bases: Exception

Raised when model export fails.

is_meta_tensor

is_meta_tensor(tensor: Tensor) -> bool

Check if a tensor is on the meta device.

Parameters:

Name Type Description Default
tensor Tensor

The tensor to check.

required

Returns:

Type Description
bool

True if the tensor's device type is "meta".

Source code in torch_ir/exporter.py
def is_meta_tensor(tensor: torch.Tensor) -> bool:
    """Check if a tensor is on the meta device.

    Args:
        tensor: The tensor to check.

    Returns:
        ``True`` if the tensor's device type is ``"meta"``.
    """
    return tensor.device.type == "meta"

is_meta_module

is_meta_module(module: Module) -> bool

Check if a module has all parameters and buffers on meta device.

Parameters:

Name Type Description Default
module Module

The PyTorch module to check.

required

Returns:

Type Description
bool

True if every parameter and buffer is on the meta device.

Source code in torch_ir/exporter.py
def is_meta_module(module: nn.Module) -> bool:
    """Check if a module has all parameters and buffers on meta device.

    Args:
        module: The PyTorch module to check.

    Returns:
        ``True`` if every parameter and buffer is on the meta device.
    """
    for param in module.parameters():
        if not is_meta_tensor(param):
            return False
    for buffer in module.buffers():
        if not is_meta_tensor(buffer):
            return False
    return True

validate_meta_device

validate_meta_device(model: Module) -> None

Validate that the model is on meta device.

Raises:

Type Description
ExportError

If the model is not on meta device.

Source code in torch_ir/exporter.py
def validate_meta_device(model: nn.Module) -> None:
    """Validate that the model is on meta device.

    Raises:
        ExportError: If the model is not on meta device.
    """
    if not is_meta_module(model):
        raise ExportError(
            "Model must be on meta device for weight-free IR extraction.\n"
            "Convert your model using one of:\n"
            "  1. model = model.to('meta')\n"
            "  2. with torch.device('meta'):\n"
            "         model = MyModel()\n"
        )

validate_inputs_meta

validate_inputs_meta(inputs: Tuple[Any, ...]) -> None

Validate that all tensor inputs are on meta device.

Raises:

Type Description
ExportError

If any tensor input is not on meta device.

Source code in torch_ir/exporter.py
def validate_inputs_meta(inputs: Tuple[Any, ...]) -> None:
    """Validate that all tensor inputs are on meta device.

    Raises:
        ExportError: If any tensor input is not on meta device.
    """
    for i, inp in enumerate(inputs):
        if isinstance(inp, torch.Tensor) and not is_meta_tensor(inp):
            raise ExportError(
                f"Input tensor at index {i} must be on meta device.\n"
                f"Current device: {inp.device}\n"
                "Use: torch.randn(..., device='meta')"
            )

export_model

export_model(model: Module, example_inputs: Tuple[Any, ...], *, strict: bool = True) -> ExportedProgram

Export a model using torch.export.

Parameters:

Name Type Description Default
model Module

The PyTorch model to export (must be on meta device).

required
example_inputs Tuple[Any, ...]

Example inputs for tracing (must be on meta device).

required
strict bool

If True, validate meta device. Set False for testing with real tensors.

True

Returns:

Type Description
ExportedProgram

ExportedProgram containing the traced graph.

Raises:

Type Description
ExportError

If validation fails or export encounters an error.

Source code in torch_ir/exporter.py
def export_model(
    model: nn.Module,
    example_inputs: Tuple[Any, ...],
    *,
    strict: bool = True,
) -> ExportedProgram:
    """Export a model using torch.export.

    Args:
        model: The PyTorch model to export (must be on meta device).
        example_inputs: Example inputs for tracing (must be on meta device).
        strict: If True, validate meta device. Set False for testing with real tensors.

    Returns:
        ExportedProgram containing the traced graph.

    Raises:
        ExportError: If validation fails or export encounters an error.
    """
    if strict:
        validate_meta_device(model)
        validate_inputs_meta(example_inputs)

    try:
        # torch.export uses FakeTensor internally (meta tensor subclass)
        exported = torch.export.export(model, example_inputs)
        return exported
    except Exception as e:
        raise ExportError(f"Failed to export model: {e}") from e

get_model_name

get_model_name(model: Module) -> str

Extract model name from module class.

Source code in torch_ir/exporter.py
def get_model_name(model: nn.Module) -> str:
    """Extract model name from module class."""
    return model.__class__.__name__