Usage Guide¶
This document explains the main workflows for extracting, inspecting, saving, and validating IR.
1. Basic Workflow¶
1.1 Basic IR Extraction Flow¶
import torch
import torch.nn as nn
from torch_ir import extract_ir
# 1. Define model
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 64, 3, padding=1)
self.bn = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.fc = nn.Linear(64 * 32 * 32, 10)
def forward(self, x):
x = self.relu(self.bn(self.conv(x)))
x = x.flatten(1)
return self.fc(x)
# 2. Create model on meta device
with torch.device('meta'):
model = MyModel()
# 3. Set model to eval mode (important!)
model.eval()
# 4. Prepare example inputs
example_inputs = (torch.randn(1, 3, 32, 32, device='meta'),)
# 5. Extract IR
ir = extract_ir(model, example_inputs, model_name="MyModel")
# 6. Check results
print(ir)
1.2 Important Notes¶
- eval() mode: Required for BatchNorm, Dropout, etc. to work correctly
- meta device: Both model and inputs must be on meta device
- static shape: Dynamic shapes are not supported
2. IR Analysis¶
2.1 Exploring IR Structure¶
# Basic IR information
print(f"Model: {ir.model_name}")
print(f"PyTorch version: {ir.pytorch_version}")
print(f"Total nodes: {len(ir.nodes)}")
print(f"Total weights: {len(ir.weights)}")
# Graph inputs and outputs
print("\nGraph Inputs:")
for inp in ir.graph_inputs:
print(f" {inp.name}: shape={inp.shape}, dtype={inp.dtype}")
print("\nGraph Outputs:")
for out in ir.graph_outputs:
print(f" {out.name}: shape={out.shape}, dtype={out.dtype}")
2.2 Node Analysis¶
# Iterate through all nodes
for node in ir.nodes:
print(f"\nNode: {node.name}")
print(f" Op type: {node.op_type}")
print(f" Inputs:")
for inp in node.inputs:
print(f" - {inp.name}: {inp.shape}")
print(f" Outputs:")
for out in node.outputs:
print(f" - {out.name}: {out.shape}")
if node.attrs:
print(f" Attrs: {node.attrs}")
2.3 Operator Statistics¶
from collections import Counter
# Count by operator type
op_counts = Counter(node.op_type for node in ir.nodes)
print("Operation counts:")
for op_type, count in op_counts.most_common():
print(f" {op_type}: {count}")
2.4 Weight Information¶
# Weight metadata
print("Weights:")
for weight in ir.weights:
print(f" {weight.name}: shape={weight.shape}, dtype={weight.dtype}")
# Weight name mapping (placeholder → state_dict key)
print("\nWeight name mapping:")
for placeholder, sd_key in ir.weight_name_mapping.items():
print(f" {placeholder} → {sd_key}")
3. Saving and Loading IR¶
3.1 Save to JSON File¶
# Save
ir.save("model_ir.json")
# Or use serializer
from torch_ir import save_ir, serialize_ir
save_ir(ir, "model_ir.json")
# Serialize to JSON string
json_str = serialize_ir(ir)
3.2 Load from JSON File¶
from torch_ir import load_ir, deserialize_ir
# Load from file
loaded_ir = load_ir("model_ir.json")
# Or use IR.load()
from torch_ir import IR
loaded_ir = IR.load("model_ir.json")
# Deserialize from JSON string
ir = deserialize_ir(json_str)
3.3 IR Validation¶
from torch_ir import validate_ir
# Validate IR structure
try:
validate_ir(ir)
print("IR is valid")
except Exception as e:
print(f"IR validation failed: {e}")
4. IR Execution and Verification¶
4.1 IR Execution¶
You can execute the IR with actual weights to obtain results.
from torch_ir import IRExecutor, execute_ir
# Get weights from original model
original_model = MyModel()
original_model.load_state_dict(torch.load('weights.pt'))
state_dict = original_model.state_dict()
# Method 1: Use IRExecutor
executor = IRExecutor(ir)
executor.load_weights_from_state_dict(state_dict)
test_input = torch.randn(1, 3, 32, 32)
outputs = executor.execute((test_input,))
# Method 2: Use execute_ir function
outputs = execute_ir(ir, (test_input,), weights=state_dict)
print(f"Output shape: {outputs[0].shape}")
4.2 Verification Against Original Model¶
from torch_ir import verify_ir_with_state_dict, verify_ir
# Prepare original model
original_model = MyModel()
original_model.load_state_dict(torch.load('weights.pt'))
original_model.eval()
# Test inputs
test_inputs = (torch.randn(1, 3, 32, 32),)
# Verify (using state_dict)
is_valid, report = verify_ir_with_state_dict(
ir=ir,
state_dict=original_model.state_dict(),
original_model=original_model,
test_inputs=test_inputs,
rtol=1e-5, # Relative error tolerance
atol=1e-5, # Absolute error tolerance
)
print(f"Verification: {'PASSED' if is_valid else 'FAILED'}")
print(report)
# Verify (using file path)
torch.save(original_model.state_dict(), 'weights.pt')
is_valid, report = verify_ir(
ir=ir,
weights_path='weights.pt',
original_model=original_model,
test_inputs=test_inputs,
)
4.3 Analyzing Verification Results¶
if not is_valid:
print(f"Max difference: {report.max_diff}")
print(f"Mean difference: {report.mean_diff}")
print(f"Error message: {report.error_message}")
# Analyze individual outputs
for detail in report.output_details:
print(f" Output {detail['index']}: "
f"shape={detail['shape']}, "
f"is_close={detail['is_close']}, "
f"max_diff={detail['max_diff']:.2e}")
5. Various Model Types¶
5.1 CNN Models¶
import torchvision.models as models
# Create ResNet on meta device
with torch.device('meta'):
model = models.resnet18()
model.eval()
inputs = (torch.randn(1, 3, 224, 224, device='meta'),)
ir = extract_ir(model, inputs, model_name="ResNet18")
print(f"ResNet18 IR: {len(ir.nodes)} nodes, {len(ir.weights)} weights")
5.2 Sequential Models¶
with torch.device('meta'):
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10),
)
model.eval()
inputs = (torch.randn(1, 784, device='meta'),)
ir = extract_ir(model, inputs, model_name="MLP")
5.3 Transformer Models¶
with torch.device('meta'):
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
model = nn.TransformerEncoder(encoder_layer, num_layers=6)
model.eval()
inputs = (torch.randn(10, 1, 512, device='meta'),) # (seq_len, batch, d_model)
ir = extract_ir(model, inputs, model_name="TransformerEncoder")
5.4 Multi-Input Models¶
class MultiInputModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(5, 20)
self.fc3 = nn.Linear(40, 10)
def forward(self, x1, x2):
h1 = self.fc1(x1)
h2 = self.fc2(x2)
return self.fc3(torch.cat([h1, h2], dim=1))
with torch.device('meta'):
model = MultiInputModel()
model.eval()
inputs = (
torch.randn(1, 10, device='meta'),
torch.randn(1, 5, device='meta'),
)
ir = extract_ir(model, inputs)
6. Weight Management¶
6.1 Loading Weights¶
from torch_ir import load_weights, load_weights_pt, load_weights_safetensors
# Auto-detect format
weights = load_weights('model.pt')
# Specify format
weights = load_weights_pt('model.pt')
weights = load_weights_safetensors('model.safetensors')
6.2 Weight Validation¶
from torch_ir.weight_loader import validate_weights_against_ir
# Validate that weights match IR
errors = validate_weights_against_ir(weights, ir)
if errors:
print("Weight validation errors:")
for error in errors:
print(f" - {error}")
else:
print("Weights are valid")
7. Advanced Usage¶
7.1 Strict Mode¶
# Strict mode: Raises exception on conversion errors
try:
ir = extract_ir(model, inputs, strict=True)
except Exception as e:
print(f"Conversion error: {e}")
# Non-strict (default): Falls back to default converter on conversion errors
ir = extract_ir(model, inputs, strict=False)
All ATen operators are automatically supported, so most models work correctly even with
strict=False.
7.2 Custom Model Name¶
ir = extract_ir(model, inputs, model_name="MyCustomModel_v2")
print(ir.model_name) # "MyCustomModel_v2"
7.3 Using IRConverter Directly¶
from torch_ir import export_model
from torch_ir.converter import IRConverter, convert_exported_program
# Call torch.export directly
exported = export_model(model, inputs, strict=False)
# Use Converter
converter = IRConverter(strict=False)
ir = converter.convert(exported, model_name="MyModel")
# Or use function directly
ir = convert_exported_program(exported, model_name="MyModel")
7.4 Using GraphAnalyzer Directly¶
from torch_ir import export_model
from torch_ir.analyzer import GraphAnalyzer
exported = export_model(model, inputs, strict=False)
analyzer = GraphAnalyzer(exported)
# Extract individual information
graph_inputs = analyzer.get_graph_inputs()
graph_outputs = analyzer.get_graph_outputs()
weights = analyzer.get_weights()
weight_mapping = analyzer.get_weight_name_mapping()
nodes = analyzer.get_call_function_nodes()
8. Common Patterns¶
8.1 Full Pipeline Example¶
import torch
from torch_ir import extract_ir, verify_ir_with_state_dict
def full_pipeline(model_class, input_shape, weights_path):
"""Full pipeline from IR extraction to verification"""
# 1. Original model (for verification)
original = model_class()
original.load_state_dict(torch.load(weights_path))
original.eval()
# 2. Meta model (for IR extraction)
with torch.device('meta'):
meta_model = model_class()
meta_model.eval()
# 3. Extract IR
inputs = (torch.randn(*input_shape, device='meta'),)
ir = extract_ir(meta_model, inputs)
# 4. Save IR
ir.save(f"{model_class.__name__}_ir.json")
# 5. Verify
test_inputs = (torch.randn(*input_shape),)
is_valid, report = verify_ir_with_state_dict(
ir=ir,
state_dict=original.state_dict(),
original_model=original,
test_inputs=test_inputs,
)
return ir, is_valid, report
# Usage example
ir, valid, report = full_pipeline(MyModel, (1, 3, 32, 32), 'weights.pt')
8.2 Batch IR Extraction¶
def extract_multiple_models(model_configs):
"""Extract IR for multiple models at once"""
results = {}
for name, (model_class, input_shape) in model_configs.items():
with torch.device('meta'):
model = model_class()
model.eval()
inputs = (torch.randn(*input_shape, device='meta'),)
ir = extract_ir(model, inputs, model_name=name)
ir.save(f"{name}_ir.json")
results[name] = ir
print(f"{name}: {len(ir.nodes)} nodes")
return results
# Usage example
configs = {
"ResNet18": (models.resnet18, (1, 3, 224, 224)),
"ResNet50": (models.resnet50, (1, 3, 224, 224)),
"VGG16": (models.vgg16, (1, 3, 224, 224)),
}
irs = extract_multiple_models(configs)
9. CLI Tools¶
The pytorch-ir CLI allows you to inspect and visualize IR files directly from the terminal without writing Python code.
9.1 IR Summary¶
# Display IR summary
pytorch-ir info model_ir.json
# JSON format output
pytorch-ir info model_ir.json --json
# Save to file
pytorch-ir info model_ir.json --json -o summary.json
For example, a DeepResNet model with 3 residual blocks:
Model: DeepResNet
Nodes: 27
Inputs: 1
Outputs: 1
Weights: 51
Total parameters: 57,617
Input shapes:
x: [1, 3, 32, 32]
Output shapes:
linear: [1, 10]
Op distribution:
aten.conv2d.default: 7
aten.batch_norm.default: 7
aten.relu.default: 7
aten.add.Tensor: 3
aten.adaptive_avg_pool2d.default: 1
aten.flatten.using_ints: 1
aten.linear.default: 1
9.2 Graph Visualization¶
# Print Mermaid diagram to stdout
pytorch-ir visualize model_ir.json
# Save as Mermaid text file
pytorch-ir visualize model_ir.json -o graph.mmd
# Render as PNG/SVG image (requires: pip install pytorch-ir[rendering])
pytorch-ir visualize model_ir.json -o graph.png
pytorch-ir visualize model_ir.json -o graph.svg
# Limit displayed nodes for large graphs
pytorch-ir visualize model_ir.json --max-nodes 50
Below is the actual IR graph of a TransformerBlock (self-attention + FFN + residual connections).
The Q/K/V projections branch out in parallel, and add.Tensor nodes represent residual connections:
flowchart TD
input_x[/"Input: x<br/>1x16x64"/]
op_linear["linear<br/>1x16x64"]
input_x -->|"1x16x64"| op_linear
w_p_q_proj_weight[/"p_q_proj_weight<br/>64x64"/]
w_p_q_proj_weight -.->|"64x64"| op_linear
w_p_q_proj_bias[/"p_q_proj_bias<br/>64"/]
w_p_q_proj_bias -.->|"64"| op_linear
op_view["view<br/>1x16x4x16"]
op_linear -->|"1x16x64"| op_view
op_transpose["transpose.int<br/>1x4x16x16"]
op_view -->|"1x16x4x16"| op_transpose
op_linear_1["linear<br/>1x16x64"]
input_x -->|"1x16x64"| op_linear_1
w_p_k_proj_weight[/"p_k_proj_weight<br/>64x64"/]
w_p_k_proj_weight -.->|"64x64"| op_linear_1
w_p_k_proj_bias[/"p_k_proj_bias<br/>64"/]
w_p_k_proj_bias -.->|"64"| op_linear_1
op_view_1["view<br/>1x16x4x16"]
op_linear_1 -->|"1x16x64"| op_view_1
op_transpose_1["transpose.int<br/>1x4x16x16"]
op_view_1 -->|"1x16x4x16"| op_transpose_1
op_linear_2["linear<br/>1x16x64"]
input_x -->|"1x16x64"| op_linear_2
w_p_v_proj_weight[/"p_v_proj_weight<br/>64x64"/]
w_p_v_proj_weight -.->|"64x64"| op_linear_2
w_p_v_proj_bias[/"p_v_proj_bias<br/>64"/]
w_p_v_proj_bias -.->|"64"| op_linear_2
op_view_2["view<br/>1x16x4x16"]
op_linear_2 -->|"1x16x64"| op_view_2
op_transpose_2["transpose.int<br/>1x4x16x16"]
op_view_2 -->|"1x16x4x16"| op_transpose_2
op_transpose_3["transpose.int<br/>1x4x16x16"]
op_transpose_1 -->|"1x4x16x16"| op_transpose_3
op_matmul["matmul<br/>1x4x16x16"]
op_transpose -->|"1x4x16x16"| op_matmul
op_transpose_3 -->|"1x4x16x16"| op_matmul
op_div["div.Tensor<br/>1x4x16x16"]
op_matmul -->|"1x4x16x16"| op_div
op_softmax["softmax.int<br/>1x4x16x16"]
op_div -->|"1x4x16x16"| op_softmax
op_matmul_1["matmul<br/>1x4x16x16"]
op_softmax -->|"1x4x16x16"| op_matmul_1
op_transpose_2 -->|"1x4x16x16"| op_matmul_1
op_transpose_4["transpose.int<br/>1x16x4x16"]
op_matmul_1 -->|"1x4x16x16"| op_transpose_4
op_contiguous["contiguous<br/>1x16x4x16"]
op_transpose_4 -->|"1x16x4x16"| op_contiguous
op_view_3["view<br/>1x16x64"]
op_contiguous -->|"1x16x4x16"| op_view_3
op_linear_3["linear<br/>1x16x64"]
op_view_3 -->|"1x16x64"| op_linear_3
w_p_out_proj_weight[/"p_out_proj_weight<br/>64x64"/]
w_p_out_proj_weight -.->|"64x64"| op_linear_3
w_p_out_proj_bias[/"p_out_proj_bias<br/>64"/]
w_p_out_proj_bias -.->|"64"| op_linear_3
op_add["add.Tensor<br/>1x16x64"]
input_x -->|"1x16x64"| op_add
op_linear_3 -->|"1x16x64"| op_add
op_layer_norm["layer_norm<br/>1x16x64"]
op_add -->|"1x16x64"| op_layer_norm
w_p_norm1_weight[/"p_norm1_weight<br/>64"/]
w_p_norm1_weight -.->|"64"| op_layer_norm
w_p_norm1_bias[/"p_norm1_bias<br/>64"/]
w_p_norm1_bias -.->|"64"| op_layer_norm
op_linear_4["linear<br/>1x16x256"]
op_layer_norm -->|"1x16x64"| op_linear_4
w_p_ffn_0_weight[/"p_ffn_0_weight<br/>256x64"/]
w_p_ffn_0_weight -.->|"256x64"| op_linear_4
w_p_ffn_0_bias[/"p_ffn_0_bias<br/>256"/]
w_p_ffn_0_bias -.->|"256"| op_linear_4
op_gelu["gelu<br/>1x16x256"]
op_linear_4 -->|"1x16x256"| op_gelu
op_linear_5["linear<br/>1x16x64"]
op_gelu -->|"1x16x256"| op_linear_5
w_p_ffn_2_weight[/"p_ffn_2_weight<br/>64x256"/]
w_p_ffn_2_weight -.->|"64x256"| op_linear_5
w_p_ffn_2_bias[/"p_ffn_2_bias<br/>64"/]
w_p_ffn_2_bias -.->|"64"| op_linear_5
op_add_1["add.Tensor<br/>1x16x64"]
op_layer_norm -->|"1x16x64"| op_add_1
op_linear_5 -->|"1x16x64"| op_add_1
op_layer_norm_1["layer_norm<br/>1x16x64"]
op_add_1 -->|"1x16x64"| op_layer_norm_1
w_p_norm2_weight[/"p_norm2_weight<br/>64"/]
w_p_norm2_weight -.->|"64"| op_layer_norm_1
w_p_norm2_bias[/"p_norm2_bias<br/>64"/]
w_p_norm2_bias -.->|"64"| op_layer_norm_1
output_0[\"Output<br/>1x16x64"/]
op_layer_norm_1 --> output_0
For full CLI documentation, see the CLI Reference.
10. Next Steps¶
- API Reference - Detailed API documentation
- Operator Support - List of supported operators
- Extension Guide - Adding custom operators