Skip to content

Weight Loader

Module for loading weights from .pt and .safetensors files.

weight_loader

Weight loader for .pt and .safetensors files.

WeightLoadError

Bases: Exception

Raised when weight loading fails.

load_weights_pt

load_weights_pt(path: Union[str, Path]) -> Dict[str, torch.Tensor]

Load weights from a .pt file.

Parameters:

Name Type Description Default
path Union[str, Path]

Path to the .pt file.

required

Returns:

Type Description
Dict[str, Tensor]

Dictionary mapping parameter names to tensors.

Raises:

Type Description
WeightLoadError

If loading fails.

Source code in torch_ir/weight_loader.py
def load_weights_pt(path: Union[str, Path]) -> Dict[str, torch.Tensor]:
    """Load weights from a .pt file.

    Args:
        path: Path to the .pt file.

    Returns:
        Dictionary mapping parameter names to tensors.

    Raises:
        WeightLoadError: If loading fails.
    """
    path = Path(path)

    if not path.exists():
        raise WeightLoadError(f"Weight file not found: {path}")

    try:
        state_dict = torch.load(path, map_location="cpu", weights_only=True)
        return state_dict
    except Exception as e:
        raise WeightLoadError(f"Failed to load weights from {path}: {e}") from e

load_weights_safetensors

load_weights_safetensors(path: Union[str, Path]) -> Dict[str, torch.Tensor]

Load weights from a .safetensors file.

Parameters:

Name Type Description Default
path Union[str, Path]

Path to the .safetensors file.

required

Returns:

Type Description
Dict[str, Tensor]

Dictionary mapping parameter names to tensors.

Raises:

Type Description
WeightLoadError

If loading fails.

ImportError

If safetensors is not installed.

Source code in torch_ir/weight_loader.py
def load_weights_safetensors(path: Union[str, Path]) -> Dict[str, torch.Tensor]:
    """Load weights from a .safetensors file.

    Args:
        path: Path to the .safetensors file.

    Returns:
        Dictionary mapping parameter names to tensors.

    Raises:
        WeightLoadError: If loading fails.
        ImportError: If safetensors is not installed.
    """
    try:
        from safetensors.torch import load_file  # ty: ignore[unresolved-import]
    except ImportError:
        raise ImportError(
            "safetensors package is required for loading .safetensors files.\nInstall it with: pip install safetensors"
        )

    path = Path(path)

    if not path.exists():
        raise WeightLoadError(f"Weight file not found: {path}")

    try:
        return load_file(path)
    except Exception as e:
        raise WeightLoadError(f"Failed to load weights from {path}: {e}") from e

load_weights

load_weights(path: Union[str, Path]) -> Dict[str, torch.Tensor]

Load weights from a file (auto-detect format).

Parameters:

Name Type Description Default
path Union[str, Path]

Path to the weight file (.pt or .safetensors).

required

Returns:

Type Description
Dict[str, Tensor]

Dictionary mapping parameter names to tensors.

Raises:

Type Description
WeightLoadError

If loading fails or format is unknown.

Source code in torch_ir/weight_loader.py
def load_weights(path: Union[str, Path]) -> Dict[str, torch.Tensor]:
    """Load weights from a file (auto-detect format).

    Args:
        path: Path to the weight file (.pt or .safetensors).

    Returns:
        Dictionary mapping parameter names to tensors.

    Raises:
        WeightLoadError: If loading fails or format is unknown.
    """
    path = Path(path)

    if path.suffix == ".safetensors":
        return load_weights_safetensors(path)
    elif path.suffix in (".pt", ".pth", ".bin"):
        return load_weights_pt(path)
    else:
        # Try .pt format as default
        try:
            return load_weights_pt(path)
        except Exception:
            raise WeightLoadError(
                f"Unknown weight file format: {path.suffix}\nSupported formats: .pt, .pth, .bin, .safetensors"
            )

validate_weights_against_ir

validate_weights_against_ir(weights: Dict[str, Tensor], ir: IR) -> List[str]

Validate that loaded weights match the IR weight metadata.

Checks for missing weights and shape/dtype mismatches between the loaded tensors and the metadata recorded during IR extraction.

Parameters:

Name Type Description Default
weights Dict[str, Tensor]

Loaded weight dictionary (state_dict).

required
ir IR

The IR containing weight metadata.

required

Returns:

Type Description
List[str]

List of human-readable validation error strings. Empty when all weights match.

Source code in torch_ir/weight_loader.py
def validate_weights_against_ir(
    weights: Dict[str, torch.Tensor],
    ir: IR,
) -> List[str]:
    """Validate that loaded weights match the IR weight metadata.

    Checks for missing weights and shape/dtype mismatches between the loaded
    tensors and the metadata recorded during IR extraction.

    Args:
        weights: Loaded weight dictionary (``state_dict``).
        ir: The IR containing weight metadata.

    Returns:
        List of human-readable validation error strings. Empty when all weights match.
    """
    errors = []

    ir_weight_names = {w.name for w in ir.weights}
    loaded_weight_names = set(weights.keys())

    # Check for missing weights
    missing = ir_weight_names - loaded_weight_names
    if missing:
        errors.append(f"Missing weights: {missing}")

    # Check for shape/dtype mismatches
    for weight_meta in ir.weights:
        if weight_meta.name not in weights:
            continue

        loaded_tensor = weights[weight_meta.name]
        expected_shape = tuple(weight_meta.shape)
        actual_shape = tuple(loaded_tensor.shape)

        if expected_shape != actual_shape:
            errors.append(f"Shape mismatch for '{weight_meta.name}': expected {expected_shape}, got {actual_shape}")

        expected_dtype = _dtype_str_to_torch(weight_meta.dtype)
        if loaded_tensor.dtype != expected_dtype:
            errors.append(
                f"Dtype mismatch for '{weight_meta.name}': expected {expected_dtype}, got {loaded_tensor.dtype}"
            )

    return errors