Skip to content

Converter

Module that converts ExportedProgram to IR.

converter

IR Builder/Converter - converts ExportedProgram to IR.

ConversionError

Bases: Exception

Raised when IR conversion fails.

convert_node

convert_node(node_info: NodeInfo) -> OpNode

Convert a single FX node to an OpNode.

Parameters:

Name Type Description Default
node_info NodeInfo

The analyzed node information.

required

Returns:

Type Description
OpNode

The converted OpNode.

Source code in torch_ir/converter.py
def convert_node(node_info: NodeInfo) -> OpNode:
    """Convert a single FX node to an OpNode.

    Args:
        node_info: The analyzed node information.

    Returns:
        The converted OpNode.
    """
    op_type = get_op_type(node_info.target)

    # Try to find a registered conversion function
    conversion_fn = get_conversion_fn(op_type)

    if conversion_fn is not None:
        return conversion_fn(node_info)
    else:
        # Use default conversion for unregistered ops
        return _default_conversion(node_info)

convert_exported_program

convert_exported_program(exported: ExportedProgram, model_name: str = '', strict: bool = False) -> IR

Convert an ExportedProgram to IR.

Parameters:

Name Type Description Default
exported ExportedProgram

The exported program from torch.export.

required
model_name str

Optional name for the model.

''
strict bool

If True, raise error for unsupported ops. If False, use default conversion.

False

Returns:

Type Description
IR

The converted IR.

Raises:

Type Description
ConversionError

If strict mode and unsupported operation encountered.

Source code in torch_ir/converter.py
def convert_exported_program(
    exported: ExportedProgram,
    model_name: str = "",
    strict: bool = False,
) -> IR:
    """Convert an ExportedProgram to IR.

    Args:
        exported: The exported program from torch.export.
        model_name: Optional name for the model.
        strict: If True, raise error for unsupported ops. If False, use default conversion.

    Returns:
        The converted IR.

    Raises:
        ConversionError: If strict mode and unsupported operation encountered.
    """
    _validate_static_shapes(exported)

    analyzer = GraphAnalyzer(exported)

    # Extract graph metadata
    graph_inputs = analyzer.get_graph_inputs()
    graph_outputs = analyzer.get_graph_outputs()
    weights = analyzer.get_weights()
    weight_name_mapping = analyzer.get_weight_name_mapping()

    # Convert all call_function nodes
    nodes = []

    for node_info in analyzer.get_call_function_nodes():
        try:
            op_node = convert_node(node_info)
            nodes.append(op_node)
        except Exception as e:
            if strict:
                op_type = get_op_type(node_info.target)
                raise ConversionError(f"Failed to convert node '{node_info.name}' with op '{op_type}': {e}") from e
            else:
                # Use default conversion as fallback
                op_node = _default_conversion(node_info)
                nodes.append(op_node)

    return IR(
        nodes=nodes,
        graph_inputs=graph_inputs,
        graph_outputs=graph_outputs,
        weights=weights,
        weight_name_mapping=weight_name_mapping,
        model_name=model_name,
        pytorch_version=torch.__version__,
    )