Skip to content

Ops

Operator registry and utility module.

ops

Operator registry and implementations.

This module provides: - registry: Operator registration for IR conversion and execution - aten_ops: ATen operator mappings for IR conversion - aten_impl: ATen operator implementations for IR execution

Custom operators can be registered using decorators:

from torch_ir.ops import register_op, register_executor

@register_op("custom.my_op")
def convert_my_op(node_info):
    return OpNode(...)

@register_executor("custom.my_op")
def execute_my_op(inputs, attrs):
    return [result_tensor]

get_conversion_fn

get_conversion_fn(op_type: str) -> Optional[Callable]

Get the conversion function for an operator type.

Parameters:

Name Type Description Default
op_type str

The operator type string

required

Returns:

Type Description
Optional[Callable]

The conversion function if registered, None otherwise.

Source code in torch_ir/ops/registry.py
def get_conversion_fn(op_type: str) -> Optional[Callable]:
    """Get the conversion function for an operator type.

    Args:
        op_type: The operator type string

    Returns:
        The conversion function if registered, None otherwise.
    """
    return _lookup_registry(_CONVERSION_REGISTRY, op_type)

get_execution_fn

get_execution_fn(op_type: str) -> Optional[Callable]

Get the execution function for an operator type.

Parameters:

Name Type Description Default
op_type str

The operator type string

required

Returns:

Type Description
Optional[Callable]

The execution function if registered, None otherwise.

Source code in torch_ir/ops/registry.py
def get_execution_fn(op_type: str) -> Optional[Callable]:
    """Get the execution function for an operator type.

    Args:
        op_type: The operator type string

    Returns:
        The execution function if registered, None otherwise.
    """
    return _lookup_registry(_EXECUTION_REGISTRY, op_type)

register_executor

register_executor(op_pattern: str)

Decorator to register an execution function for an operator.

Parameters:

Name Type Description Default
op_pattern str

The operator pattern to match (e.g., "aten.conv2d.default")

required
Example

@register_executor("aten.conv2d.default") def execute_conv2d(inputs: List[torch.Tensor], attrs: Dict) -> List[torch.Tensor]: ...

Source code in torch_ir/ops/registry.py
def register_executor(op_pattern: str):
    """Decorator to register an execution function for an operator.

    Args:
        op_pattern: The operator pattern to match (e.g., "aten.conv2d.default")

    Example:
        @register_executor("aten.conv2d.default")
        def execute_conv2d(inputs: List[torch.Tensor], attrs: Dict) -> List[torch.Tensor]:
            ...
    """

    def decorator(func: Callable) -> Callable:
        _EXECUTION_REGISTRY[op_pattern] = func
        return func

    return decorator

register_op

register_op(op_pattern: str)

Decorator to register an IR conversion function for an operator.

Parameters:

Name Type Description Default
op_pattern str

The operator pattern to match (e.g., "aten.conv2d.default")

required
Example

@register_op("aten.conv2d.default") def convert_conv2d(node_info: NodeInfo) -> OpNode: ...

Source code in torch_ir/ops/registry.py
def register_op(op_pattern: str):
    """Decorator to register an IR conversion function for an operator.

    Args:
        op_pattern: The operator pattern to match (e.g., "aten.conv2d.default")

    Example:
        @register_op("aten.conv2d.default")
        def convert_conv2d(node_info: NodeInfo) -> OpNode:
            ...
    """

    def decorator(func: Callable) -> Callable:
        _CONVERSION_REGISTRY[op_pattern] = func
        return func

    return decorator