Skip to content

Verifier

Module that verifies IR execution results against the original model output.

verifier

Verifier for comparing original model output vs IR execution output.

VerificationReport dataclass

VerificationReport(is_valid: bool, max_diff: float = 0.0, mean_diff: float = 0.0, num_outputs: int = 0, output_details: List[Dict[str, Any]] = list(), error_message: Optional[str] = None)

Report from verification comparison.

Attributes:

Name Type Description
is_valid bool

Whether all outputs are within tolerance.

max_diff float

Maximum absolute difference across all output tensors.

mean_diff float

Maximum of per-output mean absolute differences.

num_outputs int

Number of output tensors compared.

output_details List[Dict[str, Any]]

Per-output comparison details (index, shape, is_close, max_diff, mean_diff).

error_message Optional[str]

Human-readable error description when verification fails. None on success.

verify_ir

verify_ir(ir: IR, weights_path: Union[str, Path], original_model: Module, test_inputs: Tuple[Tensor, ...], rtol: float = 1e-05, atol: float = 1e-05, constants: Optional[Dict[str, Tensor]] = None) -> Tuple[bool, VerificationReport]

Verify that IR execution matches original model output.

Parameters:

Name Type Description Default
ir IR

The IR graph to verify.

required
weights_path Union[str, Path]

Path to the weight file.

required
original_model Module

The original PyTorch model (with weights loaded).

required
test_inputs Tuple[Tensor, ...]

Test input tensors.

required
rtol float

Relative tolerance for torch.allclose.

1e-05
atol float

Absolute tolerance for torch.allclose.

1e-05
constants Optional[Dict[str, Tensor]]

Optional lifted tensor constants for meta-device-extracted IRs.

None

Returns:

Type Description
Tuple[bool, VerificationReport]

Tuple of (is_valid, report).

Source code in torch_ir/verifier.py
def verify_ir(
    ir: IR,
    weights_path: Union[str, Path],
    original_model: nn.Module,
    test_inputs: Tuple[torch.Tensor, ...],
    rtol: float = 1e-5,
    atol: float = 1e-5,
    constants: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[bool, VerificationReport]:
    """Verify that IR execution matches original model output.

    Args:
        ir: The IR graph to verify.
        weights_path: Path to the weight file.
        original_model: The original PyTorch model (with weights loaded).
        test_inputs: Test input tensors.
        rtol: Relative tolerance for torch.allclose.
        atol: Absolute tolerance for torch.allclose.
        constants: Optional lifted tensor constants for meta-device-extracted IRs.

    Returns:
        Tuple of (is_valid, report).
    """
    try:
        original_outputs = _run_original_model(original_model, test_inputs)
        weights = load_weights(weights_path)
        ir_outputs = execute_ir(ir, test_inputs, weights=weights, constants=constants)
        return _verify_outputs(original_outputs, ir_outputs, rtol, atol)
    except Exception as e:
        return False, VerificationReport(
            is_valid=False,
            error_message=f"Verification error: {str(e)}",
        )

verify_ir_with_state_dict

verify_ir_with_state_dict(ir: IR, state_dict: Dict[str, Tensor], original_model: Module, test_inputs: Tuple[Tensor, ...], rtol: float = 1e-05, atol: float = 1e-05, constants: Optional[Dict[str, Tensor]] = None) -> Tuple[bool, VerificationReport]

Verify IR execution using a state dict instead of file.

Parameters:

Name Type Description Default
ir IR

The IR graph to verify.

required
state_dict Dict[str, Tensor]

The weight state dict.

required
original_model Module

The original PyTorch model (with weights loaded).

required
test_inputs Tuple[Tensor, ...]

Test input tensors.

required
rtol float

Relative tolerance.

1e-05
atol float

Absolute tolerance.

1e-05
constants Optional[Dict[str, Tensor]]

Optional lifted tensor constants for meta-device-extracted IRs.

None

Returns:

Type Description
Tuple[bool, VerificationReport]

Tuple of (is_valid, report).

Source code in torch_ir/verifier.py
def verify_ir_with_state_dict(
    ir: IR,
    state_dict: Dict[str, torch.Tensor],
    original_model: nn.Module,
    test_inputs: Tuple[torch.Tensor, ...],
    rtol: float = 1e-5,
    atol: float = 1e-5,
    constants: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[bool, VerificationReport]:
    """Verify IR execution using a state dict instead of file.

    Args:
        ir: The IR graph to verify.
        state_dict: The weight state dict.
        original_model: The original PyTorch model (with weights loaded).
        test_inputs: Test input tensors.
        rtol: Relative tolerance.
        atol: Absolute tolerance.
        constants: Optional lifted tensor constants for meta-device-extracted IRs.

    Returns:
        Tuple of (is_valid, report).
    """
    try:
        original_outputs = _run_original_model(original_model, test_inputs)
        ir_outputs = execute_ir(ir, test_inputs, weights=state_dict, constants=constants)
        return _verify_outputs(original_outputs, ir_outputs, rtol, atol)
    except Exception as e:
        return False, VerificationReport(
            is_valid=False,
            error_message=f"Verification error: {str(e)}",
        )