콘텐츠로 이동

사용 가이드

이 문서는 IR를 추출하고, 확인하고, 저장하고, 검증하는 주요 사용 흐름을 설명합니다.

1. 기본 워크플로우

1.1 IR 추출 기본 흐름

import torch
import torch.nn as nn
from torch_ir import extract_ir

# 1. 모델 정의
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 64, 3, padding=1)
        self.bn = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(64 * 32 * 32, 10)

    def forward(self, x):
        x = self.relu(self.bn(self.conv(x)))
        x = x.flatten(1)
        return self.fc(x)

# 2. Meta device에서 모델 생성
with torch.device('meta'):
    model = MyModel()

# 3. 모델을 eval 모드로 설정 (중요!)
model.eval()

# 4. Example inputs 준비
example_inputs = (torch.randn(1, 3, 32, 32, device='meta'),)

# 5. IR 추출
ir = extract_ir(model, example_inputs, model_name="MyModel")

# 6. 결과 확인
print(ir)

1.2 중요 사항

  • eval() 모드: BatchNorm, Dropout 등이 올바르게 동작하려면 필수
  • meta device: 모델과 inputs 모두 meta device에 있어야 함
  • 정적 shape: 동적 shape는 지원하지 않음

2. IR 분석

2.1 IR 구조 탐색

# IR 기본 정보
print(f"Model: {ir.model_name}")
print(f"PyTorch version: {ir.pytorch_version}")
print(f"Total nodes: {len(ir.nodes)}")
print(f"Total weights: {len(ir.weights)}")

# 그래프 입출력
print("\nGraph Inputs:")
for inp in ir.graph_inputs:
    print(f"  {inp.name}: shape={inp.shape}, dtype={inp.dtype}")

print("\nGraph Outputs:")
for out in ir.graph_outputs:
    print(f"  {out.name}: shape={out.shape}, dtype={out.dtype}")

2.2 노드 분석

# 모든 노드 순회
for node in ir.nodes:
    print(f"\nNode: {node.name}")
    print(f"  Op type: {node.op_type}")
    print(f"  Inputs:")
    for inp in node.inputs:
        print(f"    - {inp.name}: {inp.shape}")
    print(f"  Outputs:")
    for out in node.outputs:
        print(f"    - {out.name}: {out.shape}")
    if node.attrs:
        print(f"  Attrs: {node.attrs}")

2.3 연산자 통계

from collections import Counter

# 연산자 종류별 개수
op_counts = Counter(node.op_type for node in ir.nodes)

print("Operation counts:")
for op_type, count in op_counts.most_common():
    print(f"  {op_type}: {count}")

2.4 Weight 정보

# Weight 메타데이터
print("Weights:")
for weight in ir.weights:
    print(f"  {weight.name}: shape={weight.shape}, dtype={weight.dtype}")

# Weight 이름 매핑 (placeholder → state_dict key)
print("\nWeight name mapping:")
for placeholder, sd_key in ir.weight_name_mapping.items():
    print(f"  {placeholder}{sd_key}")

3. IR 저장 및 로드

3.1 JSON 파일로 저장

# 저장
ir.save("model_ir.json")

# 또는 serializer 사용
from torch_ir import save_ir, serialize_ir

save_ir(ir, "model_ir.json")

# JSON 문자열로 직렬화
json_str = serialize_ir(ir)

3.2 JSON 파일에서 로드

from torch_ir import load_ir, deserialize_ir

# 파일에서 로드
loaded_ir = load_ir("model_ir.json")

# 또는 IR.load() 사용
from torch_ir import IR
loaded_ir = IR.load("model_ir.json")

# JSON 문자열에서 역직렬화
ir = deserialize_ir(json_str)

3.3 IR 검증

from torch_ir import validate_ir

# IR 구조 검증
try:
    validate_ir(ir)
    print("IR is valid")
except Exception as e:
    print(f"IR validation failed: {e}")

4. IR 실행 및 검증

4.1 IR 실행

IR을 실제 weight와 함께 실행하여 결과를 얻을 수 있습니다.

from torch_ir import IRExecutor, execute_ir

# 원본 모델에서 weight 가져오기
original_model = MyModel()
original_model.load_state_dict(torch.load('weights.pt'))
state_dict = original_model.state_dict()

# 방법 1: IRExecutor 사용
executor = IRExecutor(ir)
executor.load_weights_from_state_dict(state_dict)

test_input = torch.randn(1, 3, 32, 32)
outputs = executor.execute((test_input,))

# 방법 2: execute_ir 함수 사용
outputs = execute_ir(ir, (test_input,), weights=state_dict)

print(f"Output shape: {outputs[0].shape}")

4.2 원본 모델과 비교 검증

from torch_ir import verify_ir_with_state_dict, verify_ir

# 원본 모델 준비
original_model = MyModel()
original_model.load_state_dict(torch.load('weights.pt'))
original_model.eval()

# 테스트 입력
test_inputs = (torch.randn(1, 3, 32, 32),)

# 검증 (state_dict 사용)
is_valid, report = verify_ir_with_state_dict(
    ir=ir,
    state_dict=original_model.state_dict(),
    original_model=original_model,
    test_inputs=test_inputs,
    rtol=1e-5,  # 상대 오차 허용치
    atol=1e-5,  # 절대 오차 허용치
)

print(f"Verification: {'PASSED' if is_valid else 'FAILED'}")
print(report)

# 검증 (파일 경로 사용)
torch.save(original_model.state_dict(), 'weights.pt')
is_valid, report = verify_ir(
    ir=ir,
    weights_path='weights.pt',
    original_model=original_model,
    test_inputs=test_inputs,
)

4.3 검증 결과 분석

if not is_valid:
    print(f"Max difference: {report.max_diff}")
    print(f"Mean difference: {report.mean_diff}")
    print(f"Error message: {report.error_message}")

    # 개별 출력 분석
    for detail in report.output_details:
        print(f"  Output {detail['index']}: "
              f"shape={detail['shape']}, "
              f"is_close={detail['is_close']}, "
              f"max_diff={detail['max_diff']:.2e}")

5. 다양한 모델 타입

5.1 CNN 모델

import torchvision.models as models

# Meta device에서 ResNet 생성
with torch.device('meta'):
    model = models.resnet18()
model.eval()

inputs = (torch.randn(1, 3, 224, 224, device='meta'),)
ir = extract_ir(model, inputs, model_name="ResNet18")

print(f"ResNet18 IR: {len(ir.nodes)} nodes, {len(ir.weights)} weights")

5.2 Sequential 모델

with torch.device('meta'):
    model = nn.Sequential(
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 10),
    )
model.eval()

inputs = (torch.randn(1, 784, device='meta'),)
ir = extract_ir(model, inputs, model_name="MLP")

5.3 Transformer 모델

with torch.device('meta'):
    encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
    model = nn.TransformerEncoder(encoder_layer, num_layers=6)
model.eval()

inputs = (torch.randn(10, 1, 512, device='meta'),)  # (seq_len, batch, d_model)
ir = extract_ir(model, inputs, model_name="TransformerEncoder")

5.4 다중 입력 모델

class MultiInputModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(5, 20)
        self.fc3 = nn.Linear(40, 10)

    def forward(self, x1, x2):
        h1 = self.fc1(x1)
        h2 = self.fc2(x2)
        return self.fc3(torch.cat([h1, h2], dim=1))

with torch.device('meta'):
    model = MultiInputModel()
model.eval()

inputs = (
    torch.randn(1, 10, device='meta'),
    torch.randn(1, 5, device='meta'),
)
ir = extract_ir(model, inputs)

6. Weight 관리

6.1 Weight 로드

from torch_ir import load_weights, load_weights_pt, load_weights_safetensors

# 자동 포맷 감지
weights = load_weights('model.pt')

# 특정 포맷 지정
weights = load_weights_pt('model.pt')
weights = load_weights_safetensors('model.safetensors')

6.2 Weight 검증

from torch_ir.weight_loader import validate_weights_against_ir

# Weight가 IR과 일치하는지 검증
errors = validate_weights_against_ir(weights, ir)
if errors:
    print("Weight validation errors:")
    for error in errors:
        print(f"  - {error}")
else:
    print("Weights are valid")

7. 고급 사용법

7.1 Strict 모드

# Strict 모드: 변환 중 오류 발생 시 예외 발생
try:
    ir = extract_ir(model, inputs, strict=True)
except Exception as e:
    print(f"Conversion error: {e}")

# Non-strict (기본): 변환 오류 시 기본 변환기로 대체
ir = extract_ir(model, inputs, strict=False)

모든 ATen 연산자는 자동으로 지원되므로 strict=False로도 대부분의 모델이 정상 동작합니다.

7.2 커스텀 모델 이름

ir = extract_ir(model, inputs, model_name="MyCustomModel_v2")
print(ir.model_name)  # "MyCustomModel_v2"

7.3 IRConverter 직접 사용

from torch_ir import export_model
from torch_ir.converter import IRConverter, convert_exported_program

# torch.export 직접 호출
exported = export_model(model, inputs, strict=False)

# Converter 사용
converter = IRConverter(strict=False)
ir = converter.convert(exported, model_name="MyModel")

# 또는 함수 직접 사용
ir = convert_exported_program(exported, model_name="MyModel")

7.4 GraphAnalyzer 직접 사용

from torch_ir import export_model
from torch_ir.analyzer import GraphAnalyzer

exported = export_model(model, inputs, strict=False)
analyzer = GraphAnalyzer(exported)

# 개별 정보 추출
graph_inputs = analyzer.get_graph_inputs()
graph_outputs = analyzer.get_graph_outputs()
weights = analyzer.get_weights()
weight_mapping = analyzer.get_weight_name_mapping()
nodes = analyzer.get_call_function_nodes()

8. 일반적인 패턴

8.1 전체 파이프라인 예제

import torch
from torch_ir import extract_ir, verify_ir_with_state_dict

def full_pipeline(model_class, input_shape, weights_path):
    """IR 추출부터 검증까지 전체 파이프라인"""

    # 1. 원본 모델 (검증용)
    original = model_class()
    original.load_state_dict(torch.load(weights_path))
    original.eval()

    # 2. Meta 모델 (IR 추출용)
    with torch.device('meta'):
        meta_model = model_class()
    meta_model.eval()

    # 3. IR 추출
    inputs = (torch.randn(*input_shape, device='meta'),)
    ir = extract_ir(meta_model, inputs)

    # 4. IR 저장
    ir.save(f"{model_class.__name__}_ir.json")

    # 5. 검증
    test_inputs = (torch.randn(*input_shape),)
    is_valid, report = verify_ir_with_state_dict(
        ir=ir,
        state_dict=original.state_dict(),
        original_model=original,
        test_inputs=test_inputs,
    )

    return ir, is_valid, report

# 사용 예
ir, valid, report = full_pipeline(MyModel, (1, 3, 32, 32), 'weights.pt')

8.2 배치 IR 추출

def extract_multiple_models(model_configs):
    """여러 모델의 IR을 한번에 추출"""
    results = {}

    for name, (model_class, input_shape) in model_configs.items():
        with torch.device('meta'):
            model = model_class()
        model.eval()

        inputs = (torch.randn(*input_shape, device='meta'),)
        ir = extract_ir(model, inputs, model_name=name)
        ir.save(f"{name}_ir.json")

        results[name] = ir
        print(f"{name}: {len(ir.nodes)} nodes")

    return results

# 사용 예
configs = {
    "ResNet18": (models.resnet18, (1, 3, 224, 224)),
    "ResNet50": (models.resnet50, (1, 3, 224, 224)),
    "VGG16": (models.vgg16, (1, 3, 224, 224)),
}
irs = extract_multiple_models(configs)

9. CLI 도구

pytorch-ir CLI를 사용하면 Python 코드 작성 없이 터미널에서 IR 파일을 조회하고 시각화할 수 있습니다.

9.1 IR 요약 정보

# IR 요약 표시
pytorch-ir info model_ir.json

# JSON 형식 출력
pytorch-ir info model_ir.json --json

# 파일로 저장
pytorch-ir info model_ir.json --json -o summary.json

예를 들어, 3개의 잔차 블록을 가진 DeepResNet 모델의 IR 요약:

Model: DeepResNet
Nodes: 27
Inputs: 1
Outputs: 1
Weights: 51
Total parameters: 57,617

Input shapes:
  x: [1, 3, 32, 32]

Output shapes:
  linear: [1, 10]

Op distribution:
  aten.conv2d.default: 7
  aten.batch_norm.default: 7
  aten.relu.default: 7
  aten.add.Tensor: 3
  aten.adaptive_avg_pool2d.default: 1
  aten.flatten.using_ints: 1
  aten.linear.default: 1

9.2 그래프 시각화

# Mermaid 다이어그램을 stdout으로 출력
pytorch-ir visualize model_ir.json

# Mermaid 텍스트 파일로 저장
pytorch-ir visualize model_ir.json -o graph.mmd

# PNG/SVG 이미지로 렌더링 (필요: pip install pytorch-ir[rendering])
pytorch-ir visualize model_ir.json -o graph.png
pytorch-ir visualize model_ir.json -o graph.svg

# 대규모 그래프의 표시 노드 수 제한
pytorch-ir visualize model_ir.json --max-nodes 50

아래는 TransformerBlock (self-attention + FFN + residual connections) 모델의 실제 IR 그래프입니다. 왼쪽 경로에서 Q/K/V 프로젝션 → attention → output projection이 진행되고, add.Tensor 노드에서 잔차 연결이 합류합니다:

flowchart TD
    input_x[/"Input: x<br/>1x16x64"/]
    op_linear["linear<br/>1x16x64"]
    input_x -->|"1x16x64"| op_linear
    w_p_q_proj_weight[/"p_q_proj_weight<br/>64x64"/]
    w_p_q_proj_weight -.->|"64x64"| op_linear
    w_p_q_proj_bias[/"p_q_proj_bias<br/>64"/]
    w_p_q_proj_bias -.->|"64"| op_linear
    op_view["view<br/>1x16x4x16"]
    op_linear -->|"1x16x64"| op_view
    op_transpose["transpose.int<br/>1x4x16x16"]
    op_view -->|"1x16x4x16"| op_transpose
    op_linear_1["linear<br/>1x16x64"]
    input_x -->|"1x16x64"| op_linear_1
    w_p_k_proj_weight[/"p_k_proj_weight<br/>64x64"/]
    w_p_k_proj_weight -.->|"64x64"| op_linear_1
    w_p_k_proj_bias[/"p_k_proj_bias<br/>64"/]
    w_p_k_proj_bias -.->|"64"| op_linear_1
    op_view_1["view<br/>1x16x4x16"]
    op_linear_1 -->|"1x16x64"| op_view_1
    op_transpose_1["transpose.int<br/>1x4x16x16"]
    op_view_1 -->|"1x16x4x16"| op_transpose_1
    op_linear_2["linear<br/>1x16x64"]
    input_x -->|"1x16x64"| op_linear_2
    w_p_v_proj_weight[/"p_v_proj_weight<br/>64x64"/]
    w_p_v_proj_weight -.->|"64x64"| op_linear_2
    w_p_v_proj_bias[/"p_v_proj_bias<br/>64"/]
    w_p_v_proj_bias -.->|"64"| op_linear_2
    op_view_2["view<br/>1x16x4x16"]
    op_linear_2 -->|"1x16x64"| op_view_2
    op_transpose_2["transpose.int<br/>1x4x16x16"]
    op_view_2 -->|"1x16x4x16"| op_transpose_2
    op_transpose_3["transpose.int<br/>1x4x16x16"]
    op_transpose_1 -->|"1x4x16x16"| op_transpose_3
    op_matmul["matmul<br/>1x4x16x16"]
    op_transpose -->|"1x4x16x16"| op_matmul
    op_transpose_3 -->|"1x4x16x16"| op_matmul
    op_div["div.Tensor<br/>1x4x16x16"]
    op_matmul -->|"1x4x16x16"| op_div
    op_softmax["softmax.int<br/>1x4x16x16"]
    op_div -->|"1x4x16x16"| op_softmax
    op_matmul_1["matmul<br/>1x4x16x16"]
    op_softmax -->|"1x4x16x16"| op_matmul_1
    op_transpose_2 -->|"1x4x16x16"| op_matmul_1
    op_transpose_4["transpose.int<br/>1x16x4x16"]
    op_matmul_1 -->|"1x4x16x16"| op_transpose_4
    op_contiguous["contiguous<br/>1x16x4x16"]
    op_transpose_4 -->|"1x16x4x16"| op_contiguous
    op_view_3["view<br/>1x16x64"]
    op_contiguous -->|"1x16x4x16"| op_view_3
    op_linear_3["linear<br/>1x16x64"]
    op_view_3 -->|"1x16x64"| op_linear_3
    w_p_out_proj_weight[/"p_out_proj_weight<br/>64x64"/]
    w_p_out_proj_weight -.->|"64x64"| op_linear_3
    w_p_out_proj_bias[/"p_out_proj_bias<br/>64"/]
    w_p_out_proj_bias -.->|"64"| op_linear_3
    op_add["add.Tensor<br/>1x16x64"]
    input_x -->|"1x16x64"| op_add
    op_linear_3 -->|"1x16x64"| op_add
    op_layer_norm["layer_norm<br/>1x16x64"]
    op_add -->|"1x16x64"| op_layer_norm
    w_p_norm1_weight[/"p_norm1_weight<br/>64"/]
    w_p_norm1_weight -.->|"64"| op_layer_norm
    w_p_norm1_bias[/"p_norm1_bias<br/>64"/]
    w_p_norm1_bias -.->|"64"| op_layer_norm
    op_linear_4["linear<br/>1x16x256"]
    op_layer_norm -->|"1x16x64"| op_linear_4
    w_p_ffn_0_weight[/"p_ffn_0_weight<br/>256x64"/]
    w_p_ffn_0_weight -.->|"256x64"| op_linear_4
    w_p_ffn_0_bias[/"p_ffn_0_bias<br/>256"/]
    w_p_ffn_0_bias -.->|"256"| op_linear_4
    op_gelu["gelu<br/>1x16x256"]
    op_linear_4 -->|"1x16x256"| op_gelu
    op_linear_5["linear<br/>1x16x64"]
    op_gelu -->|"1x16x256"| op_linear_5
    w_p_ffn_2_weight[/"p_ffn_2_weight<br/>64x256"/]
    w_p_ffn_2_weight -.->|"64x256"| op_linear_5
    w_p_ffn_2_bias[/"p_ffn_2_bias<br/>64"/]
    w_p_ffn_2_bias -.->|"64"| op_linear_5
    op_add_1["add.Tensor<br/>1x16x64"]
    op_layer_norm -->|"1x16x64"| op_add_1
    op_linear_5 -->|"1x16x64"| op_add_1
    op_layer_norm_1["layer_norm<br/>1x16x64"]
    op_add_1 -->|"1x16x64"| op_layer_norm_1
    w_p_norm2_weight[/"p_norm2_weight<br/>64"/]
    w_p_norm2_weight -.->|"64"| op_layer_norm_1
    w_p_norm2_bias[/"p_norm2_bias<br/>64"/]
    w_p_norm2_bias -.->|"64"| op_layer_norm_1
    output_0[\"Output<br/>1x16x64"/]
    op_layer_norm_1 --> output_0

전체 CLI 문서는 CLI 레퍼런스를 참고하세요.

10. 다음 단계