Advanced Example¶
This guide shows how to extract static IR from a huge autoregressive MoE model with pytorch-ir, using a text-only Kimi-K2.5 workflow as a concrete example.
Why a Huge-Model Example Matters¶
For large language models, the export problem is usually not "can I load the weights?" but "can I make the graph exportable at all?"
With pytorch-ir, the model and example inputs live on the meta device, so IR extraction does not need real parameter values. This keeps the extraction stage feasible even when the original checkpoint is far too large to fit in local RAM or storage.
That does not mean extraction is cheap. For huge models, the expensive parts are still:
torch.exporttracing and normalization- graph analysis and IR conversion
- handling large fixed-shape cache interfaces
In practice, huge-model extraction is mostly about making the forward path export-friendly and giving extract_ir() a stable, compiler-oriented callable interface.
Why Kimi-K2.5¶
Kimi-K2.5 is a good huge-model case study for inference infrastructure work because it combines several characteristics that make export non-trivial:
- around 1T total parameters with 32B activated parameters
- 256K context length
- 61 layers
- 384 routed experts with 8 selected experts per token
- multimodal top-level architecture, even though many inference stacks only need text generation
This makes it a realistic example for:
- large static-cache interfaces
- prefill/decode graph splitting
- MoE export constraints
- remote-config plus local-model patching
We use the text-only IR path here simply for convenience. The goal of this guide is to show the huge-model extraction pattern, not to cover the multimodal stack.
Case Study: Kimi-K2.5 Text-Only IR Extraction¶
The example flow uses three ideas together:
- load the remote Hugging Face config with
AutoConfig(..., trust_remote_code=True) - use a local patched text-only model for export
- split the export into separate prefill and decode graphs
Why patch the model locally instead of exporting the upstream repo model directly?
- Kimi-K2.5 is an MoE model
- the original repository code is not written for
meta + torch.exportextraction - the text backbone needs export-friendly MoE and cache behavior
- the multimodal stack is unnecessary for text-only IR extraction
So the practical extraction target becomes:
- remote config from Hugging Face
- local text-only modeling code with export-friendly patches
Extraction Flow¶
flowchart TD
A["Load remote Kimi text config<br/>AutoConfig(..., trust_remote_code=True)"]
B["Normalize config<br/>force eager attention"]
C["Instantiate local patched text-only model<br/>on meta device"]
D["Build fixed-shape prefill inputs"]
E["Build fixed-shape decode inputs<br/>with flattened KV tensors"]
F["Wrap model for prefill ABI<br/>logits + flattened KV outputs"]
G["Wrap model for decode ABI<br/>flattened KV inputs/outputs"]
H["extract_ir(prefill_wrapper, ...)"]
I["extract_ir(decode_wrapper, ...)"]
J["Save prefill IR"]
K["Save decode IR"]
L["Save KV mapping JSON"]
A --> B --> C
C --> D --> F --> H --> J
C --> E --> G --> I --> K
H --> L
I --> L
Extraction Script Architecture¶
The extraction script is organized around the ABI that downstream runtimes actually need.
The canonical source lives in examples/extract_kimi_k25_text_ir.py and depends on examples/kimi_k25_text_local. You can run it directly from the repository root:
The block below is the same runnable script.
from __future__ import annotations
import json
from pathlib import Path
import torch
import torch.nn as nn
from torch_ir import extract_ir
from transformers import AutoConfig
from kimi_k25_text_local import KimiK25TextForCausalLM
MODEL_ID = "moonshotai/Kimi-K2.5"
PREFILL_SEQ_LEN = 128
MAX_CACHE_LEN = 2048
OUTPUT_DIR = Path(".")
def prepare_text_config(config):
config._attn_implementation = "eager"
if getattr(config, "rope_scaling", None) is None and hasattr(config, "to_dict"):
rope_scaling = config.to_dict().get("rope_scaling")
if rope_scaling is not None:
config.rope_scaling = rope_scaling
return config
def load_default_config():
remote_config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
return prepare_text_config(remote_config.text_config)
def build_model(config, *, device: str) -> KimiK25TextForCausalLM:
with torch.device(device):
model = KimiK25TextForCausalLM(config)
return model.to(dtype=config.dtype)
def make_additive_causal_mask(
*,
query_positions: torch.Tensor,
key_length: int,
device: torch.device,
) -> torch.Tensor:
key_positions = torch.arange(key_length, device=device)
allowed = key_positions.unsqueeze(0) <= query_positions.reshape(-1, 1)
min_value = torch.tensor(torch.finfo(torch.float32).min, device=device)
mask = torch.where(allowed, torch.zeros((), device=device), min_value)
return mask.unsqueeze(0).unsqueeze(0)
class KimiPrefillWrapper(nn.Module):
def __init__(self, model: KimiK25TextForCausalLM):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask, position_ids):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=True,
)
result = [outputs.logits]
for layer_kv in outputs.past_key_values:
result.append(layer_kv[0])
result.append(layer_kv[1])
return tuple(result)
class _IndexCopyCache:
def __init__(self, kv_flat, num_layers: int, seen_tokens: int, max_cache_len: int):
self.key_cache = [kv_flat[2 * i] for i in range(num_layers)]
self.value_cache = [kv_flat[2 * i + 1] for i in range(num_layers)]
self._seen_tokens = seen_tokens
self._max_cache_len = max_cache_len
def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
cache_position = cache_kwargs["cache_position"]
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_copy(2, cache_position, key_states)
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_copy(2, cache_position, value_states)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def get_seq_length(self, layer_idx=0):
del layer_idx
return self._seen_tokens
def get_max_cache_shape(self, layer_idx=0):
del layer_idx
return self._max_cache_len
def __getitem__(self, idx):
return self.key_cache[idx], self.value_cache[idx]
def __iter__(self):
for idx in range(len(self.key_cache)):
yield self.key_cache[idx], self.value_cache[idx]
def __len__(self):
return len(self.key_cache)
class KimiDecodeWrapper(nn.Module):
def __init__(self, model: KimiK25TextForCausalLM, num_layers: int, seen_tokens: int, max_cache_len: int):
super().__init__()
self.model = model
self.num_layers = num_layers
self.seen_tokens = seen_tokens
self.max_cache_len = max_cache_len
def forward(self, input_ids, attention_mask, position_ids, cache_position, *past_kv_flat):
cache = _IndexCopyCache(past_kv_flat, self.num_layers, self.seen_tokens, self.max_cache_len)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=cache,
cache_position=cache_position,
use_cache=True,
)
result = [outputs.logits]
for idx in range(self.num_layers):
result.append(cache.key_cache[idx])
result.append(cache.value_cache[idx])
return tuple(result)
def build_prefill_meta_inputs(config, prefill_seq_len: int):
positions = torch.arange(prefill_seq_len, device="meta").unsqueeze(0)
attention_mask = make_additive_causal_mask(
query_positions=positions[0],
key_length=prefill_seq_len,
device=torch.device("meta"),
)
return (
torch.randint(0, config.vocab_size, (1, prefill_seq_len), device="meta"),
attention_mask,
positions,
)
def build_decode_meta_inputs(config, prefill_seq_len: int, max_cache_len: int):
positions = torch.tensor([[prefill_seq_len]], device="meta")
attention_mask = make_additive_causal_mask(
query_positions=positions[0],
key_length=max_cache_len,
device=torch.device("meta"),
)
past_kv_args = []
head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
for _ in range(config.num_hidden_layers):
past_kv_args.append(
torch.randn(1, config.num_attention_heads, max_cache_len, head_dim, device="meta", dtype=config.dtype)
)
past_kv_args.append(
torch.randn(
1,
config.num_attention_heads,
max_cache_len,
config.v_head_dim,
device="meta",
dtype=config.dtype,
)
)
return (
torch.randint(0, config.vocab_size, (1, 1), device="meta"),
attention_mask,
positions,
torch.tensor([prefill_seq_len], device="meta"),
*past_kv_args,
)
def save_kv_mapping(prefill_ir, decode_ir, num_layers: int, output_path: Path):
prefill_kv = prefill_ir.graph_outputs[1:]
fixed_inputs = {"input_ids", "attention_mask", "position_ids", "cache_position"}
decode_kv_in = [meta for meta in decode_ir.graph_inputs if meta.name not in fixed_inputs]
decode_kv_out = decode_ir.graph_outputs[1:]
layers = []
for idx in range(num_layers):
layers.append(
{
"layer": idx,
"prefill_key_output": prefill_kv[2 * idx].name,
"prefill_value_output": prefill_kv[2 * idx + 1].name,
"decode_key_input": decode_kv_in[2 * idx].name,
"decode_value_input": decode_kv_in[2 * idx + 1].name,
"decode_key_output": decode_kv_out[2 * idx].name,
"decode_value_output": decode_kv_out[2 * idx + 1].name,
}
)
with output_path.open("w") as f:
json.dump({"num_layers": num_layers, "layers": layers}, f, indent=2)
def main():
config = load_default_config()
model = build_model(config, device="meta")
model.eval()
prefill_wrapper = KimiPrefillWrapper(model)
prefill_inputs = build_prefill_meta_inputs(config, PREFILL_SEQ_LEN)
prefill_ir = extract_ir(prefill_wrapper, prefill_inputs, model_name="KimiK25_Text_Prefill")
decode_wrapper = KimiDecodeWrapper(model, config.num_hidden_layers, PREFILL_SEQ_LEN, MAX_CACHE_LEN)
decode_inputs = build_decode_meta_inputs(config, PREFILL_SEQ_LEN, MAX_CACHE_LEN)
decode_ir = extract_ir(decode_wrapper, decode_inputs, model_name="KimiK25_Text_Decode")
prefill_ir.save(OUTPUT_DIR / "kimi_k25_text_prefill_ir.json")
decode_ir.save(OUTPUT_DIR / "kimi_k25_text_decode_ir.json")
save_kv_mapping(
prefill_ir,
decode_ir,
config.num_hidden_layers,
OUTPUT_DIR / "kimi_k25_text_kv_mapping.json",
)
if __name__ == "__main__":
main()
Important details:
- the example uses the remote config object directly
- eager attention is forced up front to keep the exported graph predictable
- config normalization happens before model construction
- prefill uses a fixed
(batch=1, seq_len=PREFILL_SEQ_LEN)interface - decode uses a fixed
(batch=1, seq_len=1, kv_len=MAX_CACHE_LEN)interface
Why Wrappers Are Necessary¶
extract_ir() traces the exact callable you provide. For huge autoregressive models, the raw model callable is usually not the ABI you want to hand to a compiler or runtime.
The wrapper layer solves that mismatch.
Prefill wrapper¶
class KimiPrefillWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask, position_ids):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=True,
)
result = [outputs.logits]
for layer_kv in outputs.past_key_values:
result.append(layer_kv[0])
result.append(layer_kv[1])
return tuple(result)
Why this wrapper exists:
- the raw model returns structured
past_key_values - compiler backends usually want explicit graph outputs
- flattening KV outputs gives a stable output ordering for all layers
Decode wrapper¶
class KimiDecodeWrapper(nn.Module):
def __init__(self, model, num_layers: int, seen_tokens: int, max_cache_len: int):
super().__init__()
self.model = model
self.num_layers = num_layers
self.seen_tokens = seen_tokens
self.max_cache_len = max_cache_len
def forward(self, input_ids, attention_mask, position_ids, cache_position, *past_kv_flat):
cache = _IndexCopyCache(past_kv_flat, self.num_layers, self.seen_tokens, self.max_cache_len)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=cache,
cache_position=cache_position,
use_cache=True,
)
result = [outputs.logits]
for idx in range(self.num_layers):
result.append(cache.key_cache[idx])
result.append(cache.value_cache[idx])
return tuple(result)
Why this wrapper exists:
- decode must accept cache tensors as explicit graph inputs
- decode must return updated cache tensors as explicit graph outputs
- flattening the cache keeps the exported signature deterministic and backend-friendly
Without wrappers, the exported graph interface is tied too closely to model-internal Python objects.
Why Decode Needs a Static Cache Wrapper¶
Decode export needs more than a plain wrapper. It needs a cache object whose updates become visible as tensor ops inside the exported graph.
class _IndexCopyCache:
def __init__(self, kv_flat, num_layers: int, seen_tokens: int, max_cache_len: int):
self.key_cache = [kv_flat[2 * i] for i in range(num_layers)]
self.value_cache = [kv_flat[2 * i + 1] for i in range(num_layers)]
self._seen_tokens = seen_tokens
self._max_cache_len = max_cache_len
def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
cache_position = cache_kwargs["cache_position"]
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_copy(2, cache_position, key_states)
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_copy(2, cache_position, value_states)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
This object exists because decode export needs:
- fixed-size cache buffers
- explicit tensor updates
- no hidden Python-side mutation outside the exported graph interface
Using index_copy makes cache updates appear as normal tensor operations, which is exactly what the IR and downstream executors need.
Final Extraction Calls¶
The full extraction flow exports two graphs and one mapping file.
with torch.device("meta"):
model = build_model(config, device="meta")
model.eval()
prefill_wrapper = KimiPrefillWrapper(model)
prefill_inputs = build_prefill_meta_inputs(config, prefill_seq_len)
prefill_ir = extract_ir(prefill_wrapper, prefill_inputs, model_name="KimiK25_Text_Prefill")
decode_wrapper = KimiDecodeWrapper(model, config.num_hidden_layers, prefill_seq_len, max_cache_len)
decode_inputs = build_decode_meta_inputs(config, prefill_seq_len, max_cache_len)
decode_ir = extract_ir(decode_wrapper, decode_inputs, model_name="KimiK25_Text_Decode")
After extraction, the example saves:
- prefill IR
- decode IR
- layer-by-layer KV mapping metadata
Outputs¶
The example produces three files:
kimi_k25_text_prefill_ir.jsonkimi_k25_text_decode_ir.jsonkimi_k25_text_kv_mapping.json
Their roles are different:
prefillIR describes the full prompt pass and emits logits plus initial KV tensorsdecodeIR describes one-token decode with fixed cache inputs and outputskv_mappingtells a runtime exactly how prefill KV outputs line up with decode KV inputs and outputs
For large models, this split is often much easier to integrate than trying to force one exported graph to serve both phases.
What This Example Does Not Do¶
- It does not download or load actual model weights.
- It does not export the original multimodal Kimi stack.
- It does not document tiny verification mode here.
This page is intentionally focused on the real huge-model extraction path.
Summary¶
For huge autoregressive models, successful IR extraction usually depends on three things:
- a remote config that preserves the original architecture
- local patched modeling code that is export-friendly
- wrapper-defined prefill and decode interfaces with explicit static cache tensors
That combination lets pytorch-ir extract stable, backend-oriented IR even when the original model is too large to run in a normal weight-loaded setup.