Skip to content

Example

Compare the IR extraction result of a simple MLP model through PyTorch source code, JSON IR, and graph visualization.

SimpleMLP

A simple MLP model with Linear(4, 8) → ReLU → Linear(8, 2) architecture.

import torch
import torch.nn as nn
from torch_ir import extract_ir

class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 8)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(8, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Create model on meta device and extract IR
with torch.device('meta'):
    model = SimpleMLP()
model.eval()

example_inputs = (torch.randn(1, 4, device='meta'),)
ir = extract_ir(model, example_inputs, model_name="SimpleMLP")

print(f"Number of nodes: {len(ir.nodes)}")
print(f"Number of weights: {len(ir.weights)}")
for node in ir.nodes:
    print(f"  {node.op_type}: {[t.shape for t in node.outputs]}")

Structure of the ir.to_dict() result:

{
  "model_name": "SimpleMLP",
  "pytorch_version": "2.x.x",
  "nodes": [
    {
      "name": "linear",
      "op_type": "aten.linear.default",
      "inputs": [
        {"name": "x", "shape": [1, 4], "dtype": "float32"},
        {"name": "p_fc1_weight", "shape": [8, 4], "dtype": "float32"},
        {"name": "p_fc1_bias", "shape": [8], "dtype": "float32"}
      ],
      "outputs": [
        {"name": "linear", "shape": [1, 8], "dtype": "float32"}
      ],
      "attrs": {}
    },
    {
      "name": "relu",
      "op_type": "aten.relu.default",
      "inputs": [
        {"name": "linear", "shape": [1, 8], "dtype": "float32"}
      ],
      "outputs": [
        {"name": "relu", "shape": [1, 8], "dtype": "float32"}
      ],
      "attrs": {}
    },
    {
      "name": "linear_1",
      "op_type": "aten.linear.default",
      "inputs": [
        {"name": "relu", "shape": [1, 8], "dtype": "float32"},
        {"name": "p_fc2_weight", "shape": [2, 8], "dtype": "float32"},
        {"name": "p_fc2_bias", "shape": [2], "dtype": "float32"}
      ],
      "outputs": [
        {"name": "linear_1", "shape": [1, 2], "dtype": "float32"}
      ],
      "attrs": {}
    }
  ],
  "graph_inputs": [
    {"name": "x", "shape": [1, 4], "dtype": "float32"}
  ],
  "graph_outputs": [
    {"name": "linear_1", "shape": [1, 2], "dtype": "float32"}
  ],
  "weights": [
    {"name": "p_fc1_weight", "shape": [8, 4], "dtype": "float32"},
    {"name": "p_fc1_bias", "shape": [8], "dtype": "float32"},
    {"name": "p_fc2_weight", "shape": [2, 8], "dtype": "float32"},
    {"name": "p_fc2_bias", "shape": [2], "dtype": "float32"}
  ],
  "weight_name_mapping": {
    "p_fc1_weight": "fc1.weight",
    "p_fc1_bias": "fc1.bias",
    "p_fc2_weight": "fc2.weight",
    "p_fc2_bias": "fc2.bias"
  }
}

Mermaid diagram generated by the ir_to_mermaid() function. Weight inputs are shown as dashed edges.

flowchart TD
    input_x[/"Input: x<br/>1x4"/]
    op_linear["linear<br/>1x8"]
    input_x -->|"1x4"| op_linear
    w_p_fc1_weight[/"p_fc1_weight<br/>8x4"/]
    w_p_fc1_weight -.->|"8x4"| op_linear
    w_p_fc1_bias[/"p_fc1_bias<br/>8"/]
    w_p_fc1_bias -.->|"8"| op_linear
    op_relu["relu<br/>1x8"]
    op_linear -->|"1x8"| op_relu
    op_linear_1["linear<br/>1x2"]
    op_relu -->|"1x8"| op_linear_1
    w_p_fc2_weight[/"p_fc2_weight<br/>2x8"/]
    w_p_fc2_weight -.->|"2x8"| op_linear_1
    w_p_fc2_bias[/"p_fc2_bias<br/>2"/]
    w_p_fc2_bias -.->|"2"| op_linear_1
    output_0[\"Output<br/>1x2"/]
    op_linear_1 --> output_0

Diagram legend:

  • Parallelogram nodes: graph inputs/outputs and weights
  • Rectangular nodes: operations (showing op_type and output shape)
  • Solid edges: activation data flow
  • Dashed edges: weight inputs

Programmatic Visualization

Use the ir_to_mermaid() function to convert any IR to a Mermaid diagram:

from torch_ir import extract_ir, ir_to_mermaid

# Extract IR
with torch.device('meta'):
    model = SimpleMLP()
model.eval()
ir = extract_ir(model, (torch.randn(1, 4, device='meta'),))

# Generate Mermaid diagram
mermaid_str = ir_to_mermaid(ir)
print(mermaid_str)

# Limit number of nodes for large models
mermaid_str = ir_to_mermaid(ir, max_nodes=20)
ir_to_mermaid() output
flowchart TD
    input_x[/"Input: x<br/>1x4"/]
    op_linear["linear<br/>1x8"]
    input_x -->|"1x4"| op_linear
    w_p_fc1_weight[/"p_fc1_weight<br/>8x4"/]
    w_p_fc1_weight -.->|"8x4"| op_linear
    w_p_fc1_bias[/"p_fc1_bias<br/>8"/]
    w_p_fc1_bias -.->|"8"| op_linear
    op_relu["relu<br/>1x8"]
    op_linear -->|"1x8"| op_relu
    op_linear_1["linear<br/>1x2"]
    op_relu -->|"1x8"| op_linear_1
    w_p_fc2_weight[/"p_fc2_weight<br/>2x8"/]
    w_p_fc2_weight -.->|"2x8"| op_linear_1
    w_p_fc2_bias[/"p_fc2_bias<br/>2"/]
    w_p_fc2_bias -.->|"2"| op_linear_1
    output_0[\"Output<br/>1x2"/]
    op_linear_1 --> output_0

You can also generate operator distribution pie charts:

from torch_ir import generate_op_distribution_pie

pie_chart = generate_op_distribution_pie(ir)
print(pie_chart)
generate_op_distribution_pie() output
pie showData
    title "Operator Distribution"
    "linear" : 2
    "relu" : 1

See the Visualize API Reference for detailed API documentation.