콘텐츠로 이동

고급 예제

이 가이드는 pytorch-ir로 초대형 autoregressive MoE 모델의 static IR를 추출하는 방법을 설명합니다. 구체적인 예제로는 text-only Kimi-K2.5 추출 흐름을 사용합니다.

왜 초대형 모델 예제가 필요한가

대형 언어 모델에서는 보통 문제의 핵심이 "weight를 올릴 수 있는가?"가 아니라 "그래프를 export 가능한 형태로 만들 수 있는가?"인 경우가 많습니다.

pytorch-ir에서는 모델과 예제 입력을 meta device에 두기 때문에 IR 추출 과정에서 실제 파라미터 값이 필요하지 않습니다. 그래서 원본 체크포인트가 로컬 메모리나 저장공간에 들어가지 않는 경우에도 추출 단계 자체는 진행할 수 있습니다.

하지만 그렇다고 추출이 가벼운 것은 아닙니다. 초대형 모델에서 여전히 비싼 부분은 다음입니다.

  • torch.export 추적과 정규화
  • 그래프 분석과 IR 변환
  • 큰 고정 shape cache 인터페이스 처리

실제로 초대형 모델 IR 추출의 핵심은 weight를 읽지 않는 것보다, forward 경로를 export-friendly 하게 만들고 extract_ir()에 안정적인 compiler/runtime용 callable ABI를 넘기는 데 있습니다.

왜 Kimi-K2.5를 선택했는가

Kimi-K2.5는 inference 인프라 관점에서 좋은 초대형 사례입니다. export를 어렵게 만드는 요소가 한 모델 안에 같이 들어 있기 때문입니다.

  • 총 1T 파라미터, 32B 활성 파라미터
  • 256K 컨텍스트 길이
  • 61개 레이어
  • 384개 routed experts, token당 8개 expert 선택
  • text generation만 필요해도 top-level은 multimodal 구조

그래서 다음 주제를 한 번에 보여주기에 적합합니다.

  • 큰 static cache 인터페이스
  • prefill/decode 그래프 분리
  • MoE export 제약
  • remote config와 local patched model의 조합

이 문서에서 text-only IR만 다루는 이유는 그냥 편의상입니다. 목적은 multimodal 전체를 설명하는 것이 아니라, huge-model 추출 패턴을 설명하는 데 있습니다.

사례: Kimi-K2.5 Text-Only IR 추출

이 예제는 세 가지 아이디어를 함께 사용합니다.

  • AutoConfig(..., trust_remote_code=True)로 원격 Hugging Face config 로드
  • export를 위해 로컬 패치된 text-only 모델 사용
  • 추출을 prefilldecode 그래프로 분리

왜 upstream 저장소의 모델 코드를 그대로 export하지 않고 로컬 패치를 쓰는가?

  • Kimi-K2.5는 MoE 모델입니다.
  • 원본 저장소 코드는 meta + torch.export 추출을 전제로 작성되어 있지 않습니다.
  • text backbone 쪽에 export-friendly 한 MoE 및 cache 처리가 필요합니다.
  • multimodal stack은 text-only IR 추출에는 필요하지 않습니다.

그래서 실제 추출 대상은 다음 조합이 됩니다.

  • Hugging Face의 remote config
  • export-friendly 하게 패치된 로컬 text-only modeling code

추출 흐름

flowchart TD
    A["remote Kimi text config 로드<br/>AutoConfig(..., trust_remote_code=True)"]
    B["config 정규화<br/>eager attention 강제"]
    C["로컬 patched text-only 모델을<br/>meta device에서 생성"]
    D["고정 shape prefill 입력 생성"]
    E["평탄화된 KV tensor를 포함한<br/>고정 shape decode 입력 생성"]
    F["prefill ABI wrapper<br/>logits + flattened KV outputs"]
    G["decode ABI wrapper<br/>flattened KV inputs/outputs"]
    H["extract_ir(prefill_wrapper, ...)"]
    I["extract_ir(decode_wrapper, ...)"]
    J["prefill IR 저장"]
    K["decode IR 저장"]
    L["KV mapping JSON 저장"]

    A --> B --> C
    C --> D --> F --> H --> J
    C --> E --> G --> I --> K
    H --> L
    I --> L

추출 스크립트 구조

이 스크립트는 downstream runtime이 실제로 필요로 하는 ABI를 기준으로 구성됩니다.

정식 소스는 examples/extract_kimi_k25_text_ir.py에 있고, examples/kimi_k25_text_local에 의존합니다. 저장소 루트에서 바로 실행할 수 있습니다.

uv run --with transformers python examples/extract_kimi_k25_text_ir.py --output-dir ./out/kimi

아래 블록은 그와 동일한 실행 가능한 스크립트입니다.

from __future__ import annotations

import json
from pathlib import Path

import torch
import torch.nn as nn
from torch_ir import extract_ir
from transformers import AutoConfig

from kimi_k25_text_local import KimiK25TextForCausalLM

MODEL_ID = "moonshotai/Kimi-K2.5"
PREFILL_SEQ_LEN = 128
MAX_CACHE_LEN = 2048
OUTPUT_DIR = Path(".")


def prepare_text_config(config):
    config._attn_implementation = "eager"
    if getattr(config, "rope_scaling", None) is None and hasattr(config, "to_dict"):
        rope_scaling = config.to_dict().get("rope_scaling")
        if rope_scaling is not None:
            config.rope_scaling = rope_scaling
    return config


def load_default_config():
    remote_config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
    return prepare_text_config(remote_config.text_config)


def build_model(config, *, device: str) -> KimiK25TextForCausalLM:
    with torch.device(device):
        model = KimiK25TextForCausalLM(config)
    return model.to(dtype=config.dtype)


def make_additive_causal_mask(
    *,
    query_positions: torch.Tensor,
    key_length: int,
    device: torch.device,
) -> torch.Tensor:
    key_positions = torch.arange(key_length, device=device)
    allowed = key_positions.unsqueeze(0) <= query_positions.reshape(-1, 1)
    min_value = torch.tensor(torch.finfo(torch.float32).min, device=device)
    mask = torch.where(allowed, torch.zeros((), device=device), min_value)
    return mask.unsqueeze(0).unsqueeze(0)


class KimiPrefillWrapper(nn.Module):
    def __init__(self, model: KimiK25TextForCausalLM):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask, position_ids):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            use_cache=True,
        )
        result = [outputs.logits]
        for layer_kv in outputs.past_key_values:
            result.append(layer_kv[0])
            result.append(layer_kv[1])
        return tuple(result)


class _IndexCopyCache:
    def __init__(self, kv_flat, num_layers: int, seen_tokens: int, max_cache_len: int):
        self.key_cache = [kv_flat[2 * i] for i in range(num_layers)]
        self.value_cache = [kv_flat[2 * i + 1] for i in range(num_layers)]
        self._seen_tokens = seen_tokens
        self._max_cache_len = max_cache_len

    def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
        cache_position = cache_kwargs["cache_position"]
        self.key_cache[layer_idx] = self.key_cache[layer_idx].index_copy(2, cache_position, key_states)
        self.value_cache[layer_idx] = self.value_cache[layer_idx].index_copy(2, cache_position, value_states)
        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    def get_seq_length(self, layer_idx=0):
        del layer_idx
        return self._seen_tokens

    def get_max_cache_shape(self, layer_idx=0):
        del layer_idx
        return self._max_cache_len

    def __getitem__(self, idx):
        return self.key_cache[idx], self.value_cache[idx]

    def __iter__(self):
        for idx in range(len(self.key_cache)):
            yield self.key_cache[idx], self.value_cache[idx]

    def __len__(self):
        return len(self.key_cache)


class KimiDecodeWrapper(nn.Module):
    def __init__(self, model: KimiK25TextForCausalLM, num_layers: int, seen_tokens: int, max_cache_len: int):
        super().__init__()
        self.model = model
        self.num_layers = num_layers
        self.seen_tokens = seen_tokens
        self.max_cache_len = max_cache_len

    def forward(self, input_ids, attention_mask, position_ids, cache_position, *past_kv_flat):
        cache = _IndexCopyCache(past_kv_flat, self.num_layers, self.seen_tokens, self.max_cache_len)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=cache,
            cache_position=cache_position,
            use_cache=True,
        )
        result = [outputs.logits]
        for idx in range(self.num_layers):
            result.append(cache.key_cache[idx])
            result.append(cache.value_cache[idx])
        return tuple(result)


def build_prefill_meta_inputs(config, prefill_seq_len: int):
    positions = torch.arange(prefill_seq_len, device="meta").unsqueeze(0)
    attention_mask = make_additive_causal_mask(
        query_positions=positions[0],
        key_length=prefill_seq_len,
        device=torch.device("meta"),
    )
    return (
        torch.randint(0, config.vocab_size, (1, prefill_seq_len), device="meta"),
        attention_mask,
        positions,
    )


def build_decode_meta_inputs(config, prefill_seq_len: int, max_cache_len: int):
    positions = torch.tensor([[prefill_seq_len]], device="meta")
    attention_mask = make_additive_causal_mask(
        query_positions=positions[0],
        key_length=max_cache_len,
        device=torch.device("meta"),
    )
    past_kv_args = []
    head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
    for _ in range(config.num_hidden_layers):
        past_kv_args.append(
            torch.randn(1, config.num_attention_heads, max_cache_len, head_dim, device="meta", dtype=config.dtype)
        )
        past_kv_args.append(
            torch.randn(
                1,
                config.num_attention_heads,
                max_cache_len,
                config.v_head_dim,
                device="meta",
                dtype=config.dtype,
            )
        )
    return (
        torch.randint(0, config.vocab_size, (1, 1), device="meta"),
        attention_mask,
        positions,
        torch.tensor([prefill_seq_len], device="meta"),
        *past_kv_args,
    )


def save_kv_mapping(prefill_ir, decode_ir, num_layers: int, output_path: Path):
    prefill_kv = prefill_ir.graph_outputs[1:]
    fixed_inputs = {"input_ids", "attention_mask", "position_ids", "cache_position"}
    decode_kv_in = [meta for meta in decode_ir.graph_inputs if meta.name not in fixed_inputs]
    decode_kv_out = decode_ir.graph_outputs[1:]

    layers = []
    for idx in range(num_layers):
        layers.append(
            {
                "layer": idx,
                "prefill_key_output": prefill_kv[2 * idx].name,
                "prefill_value_output": prefill_kv[2 * idx + 1].name,
                "decode_key_input": decode_kv_in[2 * idx].name,
                "decode_value_input": decode_kv_in[2 * idx + 1].name,
                "decode_key_output": decode_kv_out[2 * idx].name,
                "decode_value_output": decode_kv_out[2 * idx + 1].name,
            }
        )

    with output_path.open("w") as f:
        json.dump({"num_layers": num_layers, "layers": layers}, f, indent=2)


def main():
    config = load_default_config()
    model = build_model(config, device="meta")
    model.eval()

    prefill_wrapper = KimiPrefillWrapper(model)
    prefill_inputs = build_prefill_meta_inputs(config, PREFILL_SEQ_LEN)
    prefill_ir = extract_ir(prefill_wrapper, prefill_inputs, model_name="KimiK25_Text_Prefill")

    decode_wrapper = KimiDecodeWrapper(model, config.num_hidden_layers, PREFILL_SEQ_LEN, MAX_CACHE_LEN)
    decode_inputs = build_decode_meta_inputs(config, PREFILL_SEQ_LEN, MAX_CACHE_LEN)
    decode_ir = extract_ir(decode_wrapper, decode_inputs, model_name="KimiK25_Text_Decode")

    prefill_ir.save(OUTPUT_DIR / "kimi_k25_text_prefill_ir.json")
    decode_ir.save(OUTPUT_DIR / "kimi_k25_text_decode_ir.json")
    save_kv_mapping(
        prefill_ir,
        decode_ir,
        config.num_hidden_layers,
        OUTPUT_DIR / "kimi_k25_text_kv_mapping.json",
    )


if __name__ == "__main__":
    main()

중요한 점은 다음과 같습니다.

  • 예제는 remote config 객체를 그대로 사용합니다.
  • export된 그래프를 예측 가능하게 유지하기 위해 eager attention을 강제합니다.
  • 모델 생성 전에 config 정규화를 먼저 수행합니다.
  • prefill은 고정된 (batch=1, seq_len=PREFILL_SEQ_LEN) 인터페이스를 사용합니다.
  • decode는 고정된 (batch=1, seq_len=1, kv_len=MAX_CACHE_LEN) 인터페이스를 사용합니다.

왜 wrapper가 필요한가

extract_ir()는 전달받은 callable의 시그니처를 그대로 trace합니다. 초대형 autoregressive 모델에서는 원본 모델 callable이 곧바로 compiler나 runtime에 넘기기 좋은 ABI인 경우가 드뭅니다.

Wrapper 계층은 이 불일치를 해결합니다.

Prefill wrapper

class KimiPrefillWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask, position_ids):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            use_cache=True,
        )
        result = [outputs.logits]
        for layer_kv in outputs.past_key_values:
            result.append(layer_kv[0])
            result.append(layer_kv[1])
        return tuple(result)

이 wrapper가 필요한 이유는 다음과 같습니다.

  • 원본 모델은 구조화된 past_key_values를 반환합니다.
  • compiler backend는 보통 명시적인 graph output을 원합니다.
  • KV 출력을 평탄화하면 모든 레이어에 대해 안정적인 출력 순서를 만들 수 있습니다.

Decode wrapper

class KimiDecodeWrapper(nn.Module):
    def __init__(self, model, num_layers: int, seen_tokens: int, max_cache_len: int):
        super().__init__()
        self.model = model
        self.num_layers = num_layers
        self.seen_tokens = seen_tokens
        self.max_cache_len = max_cache_len

    def forward(self, input_ids, attention_mask, position_ids, cache_position, *past_kv_flat):
        cache = _IndexCopyCache(past_kv_flat, self.num_layers, self.seen_tokens, self.max_cache_len)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=cache,
            cache_position=cache_position,
            use_cache=True,
        )
        result = [outputs.logits]
        for idx in range(self.num_layers):
            result.append(cache.key_cache[idx])
            result.append(cache.value_cache[idx])
        return tuple(result)

이 wrapper가 필요한 이유는 다음과 같습니다.

  • decode는 cache tensor를 명시적 graph input으로 받아야 합니다.
  • decode는 갱신된 cache tensor를 명시적 graph output으로 돌려줘야 합니다.
  • cache를 평탄화하면 export된 시그니처가 결정적이고 backend 친화적으로 유지됩니다.

Wrapper가 없으면 export된 그래프 인터페이스가 모델 내부 Python 객체 표현에 너무 강하게 결합됩니다.

왜 decode에는 static cache wrapper가 필요한가

Decode export에는 일반 wrapper만으로는 부족합니다. cache 갱신이 export된 그래프 안에서 tensor op으로 보이도록 해 주는 cache 객체가 필요합니다.

class _IndexCopyCache:
    def __init__(self, kv_flat, num_layers: int, seen_tokens: int, max_cache_len: int):
        self.key_cache = [kv_flat[2 * i] for i in range(num_layers)]
        self.value_cache = [kv_flat[2 * i + 1] for i in range(num_layers)]
        self._seen_tokens = seen_tokens
        self._max_cache_len = max_cache_len

    def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
        cache_position = cache_kwargs["cache_position"]
        self.key_cache[layer_idx] = self.key_cache[layer_idx].index_copy(2, cache_position, key_states)
        self.value_cache[layer_idx] = self.value_cache[layer_idx].index_copy(2, cache_position, value_states)
        return self.key_cache[layer_idx], self.value_cache[layer_idx]

이 객체가 필요한 이유는 다음과 같습니다.

  • cache buffer shape가 고정되어 있어야 합니다.
  • cache 갱신이 명시적 tensor 연산으로 나타나야 합니다.
  • export된 그래프 인터페이스 밖의 숨겨진 Python-side mutation이 없어야 합니다.

index_copy를 사용하면 cache 갱신이 일반 tensor op처럼 그래프에 나타나므로, IR와 downstream executor가 그대로 다룰 수 있습니다.

최종 추출 호출

전체 추출 흐름은 두 개의 그래프와 하나의 mapping 파일을 생성합니다.

with torch.device("meta"):
    model = build_model(config, device="meta")
model.eval()

prefill_wrapper = KimiPrefillWrapper(model)
prefill_inputs = build_prefill_meta_inputs(config, prefill_seq_len)
prefill_ir = extract_ir(prefill_wrapper, prefill_inputs, model_name="KimiK25_Text_Prefill")

decode_wrapper = KimiDecodeWrapper(model, config.num_hidden_layers, prefill_seq_len, max_cache_len)
decode_inputs = build_decode_meta_inputs(config, prefill_seq_len, max_cache_len)
decode_ir = extract_ir(decode_wrapper, decode_inputs, model_name="KimiK25_Text_Decode")

추출 후 예제는 다음을 저장합니다.

  • prefill IR
  • decode IR
  • 레이어별 KV mapping 메타데이터

출력 산출물

예제는 세 개의 파일을 생성합니다.

  • kimi_k25_text_prefill_ir.json
  • kimi_k25_text_decode_ir.json
  • kimi_k25_text_kv_mapping.json

각 파일의 역할은 다릅니다.

  • prefill IR은 전체 prompt pass를 설명하고, logits와 초기 KV tensor를 출력합니다.
  • decode IR은 고정된 cache 입력과 출력을 갖는 one-token decode를 설명합니다.
  • kv_mapping은 prefill KV 출력이 decode KV 입력 및 출력과 어떻게 대응되는지 runtime에 알려줍니다.

대형 모델에서는 하나의 export 그래프에 두 단계를 억지로 합치는 것보다, 이런 식으로 분리하는 편이 통합이 훨씬 쉬운 경우가 많습니다.

이 예제가 하지 않는 것

  • 실제 모델 weight를 다운로드하거나 로드하지 않습니다.
  • 원본 Kimi multimodal stack을 export하지 않습니다.
  • tiny verification mode는 이 문서에서 다루지 않습니다.

이 페이지는 의도적으로 실제 huge-model 추출 경로에만 집중합니다.

요약

초대형 autoregressive 모델에서 IR 추출이 성공하려면 보통 세 가지가 필요합니다.

  • 원본 아키텍처를 보존하는 remote config
  • export-friendly 하게 패치된 로컬 modeling code
  • 명시적인 static cache tensor를 갖는 wrapper 기반 prefill/decode 인터페이스

이 조합을 사용하면 원본 모델을 일반적인 weight-loaded 형태로 실행할 수 없는 환경에서도 pytorch-ir로 안정적이고 backend 지향적인 IR를 추출할 수 있습니다.