Extension Guide¶
This document explains how to add custom operators to the IR extraction framework.
1. Overview¶
The framework automatically handles all ATen operators. Custom registration is only needed in the following cases:
- Non-ATen op: Operators that cannot be resolved to
torch.ops.aten.* - Special conversion logic: When OpNode structure different from the default conversion (
_default_conversion) is needed - Special execution logic: When execution method that ATen fallback cannot handle is required
In most cases, you don't need to register anything.
2. Understanding the Operator Registry¶
2.1 Registry Structure¶
# torch_ir/ops/registry.py
# Store IR conversion functions (for custom conversion)
_CONVERSION_REGISTRY: Dict[str, Callable] = {}
# Store execution functions (for custom execution)
_EXECUTION_REGISTRY: Dict[str, Callable] = {}
2.2 Processing Priority¶
IR Conversion (converter.py):
1. Check for custom conversion function registered in _CONVERSION_REGISTRY
2. If not found, use _default_conversion() (automatically handles all ATen ops)
Execution (executor.py):
1. Check for custom execution function registered in _EXECUTION_REGISTRY
2. If not found, use _aten_fallback() (schema-based automatic ATen op execution)
2.3 Operator Name Patterns¶
ATen operator names follow this pattern:
Examples:
- aten.conv2d.default
- aten.linear.default
- aten.add.Tensor
- aten.softmax.int
3. Registering Custom Execution Functions¶
Only needed for non-ATen ops that ATen fallback cannot handle.
3.1 Basic Structure¶
from torch_ir.ops import register_executor
import torch
from typing import List, Dict, Any
@register_executor("my_custom_op")
def execute_my_custom_op(
inputs: List[torch.Tensor],
attrs: Dict[str, Any]
) -> List[torch.Tensor]:
"""Execute my_custom_op."""
x = inputs[0]
param = attrs.get("param", 1.0)
result = some_operation(x, param)
return [result] # Always return as list
3.2 Input/Output Rules¶
- Input:
List[torch.Tensor]- in the order of IR node inputs - Output:
List[torch.Tensor]- in the order of IR node outputs - Must return as list even for single output
3.3 Multi-output Example¶
@register_executor("my_op_with_two_outputs")
def execute_my_op(inputs: List[torch.Tensor], attrs: Dict[str, Any]) -> List[torch.Tensor]:
x = inputs[0]
values, indices = x.topk(attrs.get("k", 1), dim=attrs.get("dim", -1))
return [values, indices] # Two outputs
4. Registering Custom IR Conversion Functions (Optional)¶
The default converter (_default_conversion) is sufficient in most cases, but if you want to customize the OpNode structure:
from torch_ir.ops import register_op
from torch_ir import OpNode
from torch_ir.analyzer import NodeInfo
@register_op("my_custom_op")
def convert_my_custom_op(node_info: NodeInfo) -> OpNode:
"""Custom conversion with extra processing."""
return OpNode(
name=node_info.name,
op_type="my_custom_op",
inputs=node_info.input_metas,
outputs=node_info.output_metas,
attrs={**node_info.attrs, "extra_info": "custom_value"},
)
4.1 NodeInfo Structure¶
Conversion functions receive a NodeInfo object:
@dataclass
class NodeInfo:
name: str # Node name (e.g., "conv2d_1")
op: str # Operation type ("call_function")
target: Any # Operation target (e.g., torch.ops.aten.conv2d.default)
args: Tuple[Any, ...] # FX node arguments
kwargs: Dict[str, Any] # FX node keyword arguments
input_metas: List[TensorMeta] # Input tensor metadata
output_metas: List[TensorMeta] # Output tensor metadata
attrs: Dict[str, Any] # Extracted attributes (auto-extracted based on schema)
4.2 Automatic Attribute Extraction¶
node_info.attrs contains all non-Tensor arguments automatically extracted based on the OpOverload schema. You can use it as-is without additional extraction.
5. Complete Example: Non-ATen Custom Op¶
# my_custom_ops.py
import torch
from typing import List, Dict, Any
from torch_ir.ops import register_op, register_executor
from torch_ir import OpNode
from torch_ir.analyzer import NodeInfo
@register_op("custom.fused_gate")
def convert_fused_gate(node_info: NodeInfo) -> OpNode:
return OpNode(
name=node_info.name,
op_type="custom.fused_gate",
inputs=node_info.input_metas,
outputs=node_info.output_metas,
attrs=node_info.attrs,
)
@register_executor("custom.fused_gate")
def execute_fused_gate(
inputs: List[torch.Tensor],
attrs: Dict[str, Any]
) -> List[torch.Tensor]:
"""Fused gating: sigmoid(gate) * value."""
gate = inputs[0]
value = inputs[1]
return [torch.sigmoid(gate) * value]
Usage¶
6. Organizing as Module¶
my_project/
├── my_ops/
│ ├── __init__.py # Import all submodules
│ ├── custom_gate.py # Contains @register_executor
│ └── custom_pooling.py # Contains @register_executor
└── main.py
# my_ops/__init__.py
from . import custom_gate
from . import custom_pooling
# Automatically registered on import
7. Debugging Tips¶
7.1 Checking FX Graph¶
from torch_ir import export_model
exported = export_model(model, inputs, strict=False)
# Print FX graph
print(exported.graph_module.graph)
# Check individual nodes
for node in exported.graph_module.graph.nodes:
if node.op == "call_function":
print(f"Node: {node.name}")
print(f" Target: {node.target}")
print(f" Args: {node.args}")
7.2 Checking Registration¶
from torch_ir.ops.registry import get_conversion_fn, get_execution_fn
op_type = "my_custom_op"
print(f"Conversion: {get_conversion_fn(op_type)}")
print(f"Execution: {get_execution_fn(op_type)}")
7.3 Checking ATen Op Schema¶
import torch
fn = torch.ops.aten.conv2d.default
for arg in fn._schema.arguments:
print(f" {arg.name}: {arg.type} (kwarg_only={arg.kwarg_only})")
8. Precautions¶
8.1 ATen Ops Don't Need Registration¶
Registering @register_executor for ATen ops (aten.*) will call the custom function instead of ATen fallback. Unless you have a specific reason, don't register them — fallback handles them correctly based on schema.
8.2 Input Order¶
The input order in the FX graph must match the input order in the execution function.
8.3 Attribute Defaults¶
Provide default values in case attributes are missing:
9. Contribution Guide¶
To contribute to the framework:
- Add execution functions for non-ATen ops to
torch_ir/ops/aten_impl.py - Add tests in
tests/ - Update
docs/operators.mddocumentation
ATen ops are automatically supported, so no separate implementation is needed.