Weight Loader¶
Module for loading weights from .pt and .safetensors files.
weight_loader ¶
Weight loader for .pt and .safetensors files.
WeightLoadError ¶
Bases: Exception
Raised when weight loading fails.
load_weights_pt ¶
Load weights from a .pt file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
Union[str, Path]
|
Path to the .pt file. |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Tensor]
|
Dictionary mapping parameter names to tensors. |
Raises:
| Type | Description |
|---|---|
WeightLoadError
|
If loading fails. |
Source code in torch_ir/weight_loader.py
load_weights_safetensors ¶
Load weights from a .safetensors file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
Union[str, Path]
|
Path to the .safetensors file. |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Tensor]
|
Dictionary mapping parameter names to tensors. |
Raises:
| Type | Description |
|---|---|
WeightLoadError
|
If loading fails. |
ImportError
|
If safetensors is not installed. |
Source code in torch_ir/weight_loader.py
load_weights ¶
Load weights from a file (auto-detect format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
Union[str, Path]
|
Path to the weight file (.pt or .safetensors). |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, Tensor]
|
Dictionary mapping parameter names to tensors. |
Raises:
| Type | Description |
|---|---|
WeightLoadError
|
If loading fails or format is unknown. |
Source code in torch_ir/weight_loader.py
validate_weights_against_ir ¶
Validate that loaded weights match the IR weight metadata.
Checks for missing weights and shape/dtype mismatches between the loaded tensors and the metadata recorded during IR extraction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
weights
|
Dict[str, Tensor]
|
Loaded weight dictionary ( |
required |
ir
|
IR
|
The IR containing weight metadata. |
required |
Returns:
| Type | Description |
|---|---|
List[str]
|
List of human-readable validation error strings. Empty when all weights match. |