Example¶
Compare the IR extraction result of a simple MLP model through PyTorch source code, JSON IR, and graph visualization.
SimpleMLP¶
A simple MLP model with Linear(4, 8) → ReLU → Linear(8, 2) architecture.
import torch
import torch.nn as nn
from torch_ir import extract_ir
class SimpleMLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(4, 8)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(8, 2)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Create model on meta device and extract IR
with torch.device('meta'):
model = SimpleMLP()
model.eval()
example_inputs = (torch.randn(1, 4, device='meta'),)
ir = extract_ir(model, example_inputs, model_name="SimpleMLP")
print(f"Number of nodes: {len(ir.nodes)}")
print(f"Number of weights: {len(ir.weights)}")
for node in ir.nodes:
print(f" {node.op_type}: {[t.shape for t in node.outputs]}")
Structure of the ir.to_dict() result:
{
"model_name": "SimpleMLP",
"pytorch_version": "2.x.x",
"nodes": [
{
"name": "linear",
"op_type": "aten.linear.default",
"inputs": [
{"name": "x", "shape": [1, 4], "dtype": "float32"},
{"name": "p_fc1_weight", "shape": [8, 4], "dtype": "float32"},
{"name": "p_fc1_bias", "shape": [8], "dtype": "float32"}
],
"outputs": [
{"name": "linear", "shape": [1, 8], "dtype": "float32"}
],
"attrs": {}
},
{
"name": "relu",
"op_type": "aten.relu.default",
"inputs": [
{"name": "linear", "shape": [1, 8], "dtype": "float32"}
],
"outputs": [
{"name": "relu", "shape": [1, 8], "dtype": "float32"}
],
"attrs": {}
},
{
"name": "linear_1",
"op_type": "aten.linear.default",
"inputs": [
{"name": "relu", "shape": [1, 8], "dtype": "float32"},
{"name": "p_fc2_weight", "shape": [2, 8], "dtype": "float32"},
{"name": "p_fc2_bias", "shape": [2], "dtype": "float32"}
],
"outputs": [
{"name": "linear_1", "shape": [1, 2], "dtype": "float32"}
],
"attrs": {}
}
],
"graph_inputs": [
{"name": "x", "shape": [1, 4], "dtype": "float32"}
],
"graph_outputs": [
{"name": "linear_1", "shape": [1, 2], "dtype": "float32"}
],
"weights": [
{"name": "p_fc1_weight", "shape": [8, 4], "dtype": "float32"},
{"name": "p_fc1_bias", "shape": [8], "dtype": "float32"},
{"name": "p_fc2_weight", "shape": [2, 8], "dtype": "float32"},
{"name": "p_fc2_bias", "shape": [2], "dtype": "float32"}
],
"weight_name_mapping": {
"p_fc1_weight": "fc1.weight",
"p_fc1_bias": "fc1.bias",
"p_fc2_weight": "fc2.weight",
"p_fc2_bias": "fc2.bias"
}
}
Mermaid diagram generated by the ir_to_mermaid() function. Weight inputs are shown as dashed edges.
flowchart TD
input_x[/"Input: x<br/>1x4"/]
op_linear["linear<br/>1x8"]
input_x -->|"1x4"| op_linear
w_p_fc1_weight[/"p_fc1_weight<br/>8x4"/]
w_p_fc1_weight -.->|"8x4"| op_linear
w_p_fc1_bias[/"p_fc1_bias<br/>8"/]
w_p_fc1_bias -.->|"8"| op_linear
op_relu["relu<br/>1x8"]
op_linear -->|"1x8"| op_relu
op_linear_1["linear<br/>1x2"]
op_relu -->|"1x8"| op_linear_1
w_p_fc2_weight[/"p_fc2_weight<br/>2x8"/]
w_p_fc2_weight -.->|"2x8"| op_linear_1
w_p_fc2_bias[/"p_fc2_bias<br/>2"/]
w_p_fc2_bias -.->|"2"| op_linear_1
output_0[\"Output<br/>1x2"/]
op_linear_1 --> output_0
Diagram legend:
- Parallelogram nodes: graph inputs/outputs and weights
- Rectangular nodes: operations (showing op_type and output shape)
- Solid edges: activation data flow
- Dashed edges: weight inputs
Programmatic Visualization¶
Use the ir_to_mermaid() function to convert any IR to a Mermaid diagram:
from torch_ir import extract_ir, ir_to_mermaid
# Extract IR
with torch.device('meta'):
model = SimpleMLP()
model.eval()
ir = extract_ir(model, (torch.randn(1, 4, device='meta'),))
# Generate Mermaid diagram
mermaid_str = ir_to_mermaid(ir)
print(mermaid_str)
# Limit number of nodes for large models
mermaid_str = ir_to_mermaid(ir, max_nodes=20)
ir_to_mermaid() output
flowchart TD
input_x[/"Input: x<br/>1x4"/]
op_linear["linear<br/>1x8"]
input_x -->|"1x4"| op_linear
w_p_fc1_weight[/"p_fc1_weight<br/>8x4"/]
w_p_fc1_weight -.->|"8x4"| op_linear
w_p_fc1_bias[/"p_fc1_bias<br/>8"/]
w_p_fc1_bias -.->|"8"| op_linear
op_relu["relu<br/>1x8"]
op_linear -->|"1x8"| op_relu
op_linear_1["linear<br/>1x2"]
op_relu -->|"1x8"| op_linear_1
w_p_fc2_weight[/"p_fc2_weight<br/>2x8"/]
w_p_fc2_weight -.->|"2x8"| op_linear_1
w_p_fc2_bias[/"p_fc2_bias<br/>2"/]
w_p_fc2_bias -.->|"2"| op_linear_1
output_0[\"Output<br/>1x2"/]
op_linear_1 --> output_0
You can also generate operator distribution pie charts:
from torch_ir import generate_op_distribution_pie
pie_chart = generate_op_distribution_pie(ir)
print(pie_chart)
generate_op_distribution_pie() output
pie showData
title "Operator Distribution"
"linear" : 2
"relu" : 1
See the Visualize API Reference for detailed API documentation.