Skip to content

Executor

Module that executes IR graphs with actual weights.

executor

IR Executor for running IR graphs with actual weights.

ExecutionError

Bases: Exception

Raised when IR execution fails.

TensorRegistry

TensorRegistry()

Name-to-tensor mapping used during IR execution.

Stores weights, graph inputs, and intermediate results so that each node can retrieve its input tensors by name.

Source code in torch_ir/executor.py
def __init__(self):
    self._tensors: Dict[str, torch.Tensor] = {}

register

register(name: str, tensor: Tensor) -> None

Register a tensor under the given name.

Source code in torch_ir/executor.py
def register(self, name: str, tensor: torch.Tensor) -> None:
    """Register a tensor under the given name."""
    self._tensors[name] = tensor

get

get(name: str) -> Optional[torch.Tensor]

Return the tensor for name, or None if not found.

Source code in torch_ir/executor.py
def get(self, name: str) -> Optional[torch.Tensor]:
    """Return the tensor for *name*, or ``None`` if not found."""
    return self._tensors.get(name)

has

has(name: str) -> bool

Return True if name is registered.

Source code in torch_ir/executor.py
def has(self, name: str) -> bool:
    """Return ``True`` if *name* is registered."""
    return name in self._tensors

clear

clear() -> None

Remove all registered tensors.

Source code in torch_ir/executor.py
def clear(self) -> None:
    """Remove all registered tensors."""
    self._tensors.clear()

IRExecutor

IRExecutor(ir: IR, weights: Optional[Dict[str, Tensor]] = None, constants: Optional[Dict[str, Tensor]] = None)

Executes an IR graph with actual weight tensors.

Typical usage::

executor = IRExecutor(ir)
executor.load_weights_from_state_dict(state_dict)
outputs = executor.execute((input_tensor,))

Initialize the executor.

Parameters:

Name Type Description Default
ir IR

The IR graph to execute.

required
weights Optional[Dict[str, Tensor]]

Optional pre-loaded weights. If None, must call load_weights().

None
constants Optional[Dict[str, Tensor]]

Optional lifted tensor constants. When provided, these override ir.constants and are used for constant placeholders that are not part of the model's state_dict (e.g., index tensors assigned as plain attributes). This is needed when the IR was extracted on meta device, where constant values are unavailable.

None
Source code in torch_ir/executor.py
def __init__(
    self,
    ir: IR,
    weights: Optional[Dict[str, torch.Tensor]] = None,
    constants: Optional[Dict[str, torch.Tensor]] = None,
):
    """Initialize the executor.

    Args:
        ir: The IR graph to execute.
        weights: Optional pre-loaded weights. If None, must call load_weights().
        constants: Optional lifted tensor constants. When provided, these
            override ``ir.constants`` and are used for constant placeholders
            that are not part of the model's ``state_dict`` (e.g., index
            tensors assigned as plain attributes). This is needed when
            the IR was extracted on meta device, where constant values
            are unavailable.
    """
    self.ir = ir
    self.weights = weights
    self.constants = constants
    self.registry = TensorRegistry()

load_weights

load_weights(path: Union[str, Path]) -> None

Load weights from a file.

Parameters:

Name Type Description Default
path Union[str, Path]

Path to the weight file (.pt or .safetensors).

required
Source code in torch_ir/executor.py
def load_weights(self, path: Union[str, Path]) -> None:
    """Load weights from a file.

    Args:
        path: Path to the weight file (.pt or .safetensors).
    """
    self.weights = load_weights(path)

load_weights_from_state_dict

load_weights_from_state_dict(state_dict: Dict[str, Tensor]) -> None

Use an existing state dict as weights.

Parameters:

Name Type Description Default
state_dict Dict[str, Tensor]

The state dict to use.

required
Source code in torch_ir/executor.py
def load_weights_from_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
    """Use an existing state dict as weights.

    Args:
        state_dict: The state dict to use.
    """
    self.weights = state_dict

execute

execute(inputs: Tuple[Tensor, ...]) -> Tuple[torch.Tensor, ...]

Execute the IR graph.

Parameters:

Name Type Description Default
inputs Tuple[Tensor, ...]

Input tensors matching graph_inputs.

required

Returns:

Type Description
Tuple[Tensor, ...]

Output tensors matching graph_outputs.

Raises:

Type Description
ExecutionError

If execution fails.

Source code in torch_ir/executor.py
def execute(self, inputs: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
    """Execute the IR graph.

    Args:
        inputs: Input tensors matching graph_inputs.

    Returns:
        Output tensors matching graph_outputs.

    Raises:
        ExecutionError: If execution fails.
    """
    if self.weights is None:
        raise ExecutionError("Weights not loaded. Call load_weights() or load_weights_from_state_dict() first.")

    self._prepare(inputs)

    # Track node outputs by node name for producer-based lookup
    output_map: Dict[str, List[torch.Tensor]] = {}

    # Register graph inputs in output_map so producer references work
    for input_meta, tensor in zip(self.ir.graph_inputs, inputs):
        output_map[input_meta.name] = [tensor]

    # Execute nodes in order
    for node in self.ir.nodes:
        # Gather inputs using producer references
        input_tensors = _get_input_tensors(node, self.registry, output_map)

        # Execute node
        output_tensors = _execute_node(node, input_tensors)

        # Store in output_map for downstream producer references
        output_map[node.name] = output_tensors

        # Also register in registry for backward compat and graph output lookup
        for output_meta, tensor in zip(node.outputs, output_tensors):
            self.registry.register(output_meta.name, tensor)

    # Gather graph outputs
    outputs = []
    for output_meta in self.ir.graph_outputs:
        if output_meta.name not in self.registry:
            raise ExecutionError(f"Graph output '{output_meta.name}' not found in registry")
        outputs.append(self.registry[output_meta.name])

    return tuple(outputs)

execute_ir

execute_ir(ir: IR, inputs: Tuple[Tensor, ...], weights: Optional[Dict[str, Tensor]] = None, weights_path: Optional[Union[str, Path]] = None, constants: Optional[Dict[str, Tensor]] = None) -> Tuple[torch.Tensor, ...]

Execute an IR graph (convenience function).

Parameters:

Name Type Description Default
ir IR

The IR graph to execute.

required
inputs Tuple[Tensor, ...]

Input tensors.

required
weights Optional[Dict[str, Tensor]]

Pre-loaded weights dict.

None
weights_path Optional[Union[str, Path]]

Path to weight file (alternative to weights).

None
constants Optional[Dict[str, Tensor]]

Optional lifted tensor constants for meta-device-extracted IRs. See :class:IRExecutor for details.

None

Returns:

Type Description
Tuple[Tensor, ...]

Output tensors.

Raises:

Type Description
ExecutionError

If execution fails.

ValueError

If neither weights nor weights_path is provided.

Source code in torch_ir/executor.py
def execute_ir(
    ir: IR,
    inputs: Tuple[torch.Tensor, ...],
    weights: Optional[Dict[str, torch.Tensor]] = None,
    weights_path: Optional[Union[str, Path]] = None,
    constants: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, ...]:
    """Execute an IR graph (convenience function).

    Args:
        ir: The IR graph to execute.
        inputs: Input tensors.
        weights: Pre-loaded weights dict.
        weights_path: Path to weight file (alternative to weights).
        constants: Optional lifted tensor constants for meta-device-extracted IRs.
            See :class:`IRExecutor` for details.

    Returns:
        Output tensors.

    Raises:
        ExecutionError: If execution fails.
        ValueError: If neither weights nor weights_path is provided.
    """
    if weights is None and weights_path is None:
        raise ValueError("Either weights or weights_path must be provided")

    executor = IRExecutor(ir, weights, constants=constants)

    if weights_path is not None:
        executor.load_weights(weights_path)

    return executor.execute(inputs)