def cmd_info(args: argparse.Namespace) -> None:
"""Handle the 'info' subcommand."""
ir = _load_ir_file(args.ir_file)
# Compute op distribution
op_counts: dict[str, int] = defaultdict(int)
for node in ir.nodes:
op_counts[node.op_type] += 1
# Compute total weight parameters
total_params = 0
for w in ir.weights:
p = 1
for d in w.shape:
p *= d
total_params += p
info = {
"model_name": ir.model_name,
"num_nodes": len(ir.nodes),
"num_inputs": len(ir.graph_inputs),
"num_outputs": len(ir.graph_outputs),
"num_weights": len(ir.weights),
"total_parameters": total_params,
"input_shapes": {inp.name: list(inp.shape) for inp in ir.graph_inputs},
"output_shapes": {out.name: list(out.shape) for out in ir.graph_outputs},
"op_distribution": dict(sorted(op_counts.items(), key=lambda x: -x[1])),
}
if args.json:
output = json.dumps(info, indent=2)
else:
lines = [
f"Model: {info['model_name']}",
f"Nodes: {info['num_nodes']}",
f"Inputs: {info['num_inputs']}",
f"Outputs: {info['num_outputs']}",
f"Weights: {info['num_weights']}",
f"Total parameters: {info['total_parameters']:,}",
"",
"Input shapes:",
]
for name, shape in info["input_shapes"].items():
lines.append(f" {name}: {shape}")
lines.append("")
lines.append("Output shapes:")
for name, shape in info["output_shapes"].items():
lines.append(f" {name}: {shape}")
lines.append("")
lines.append("Op distribution:")
for op, count in info["op_distribution"].items():
lines.append(f" {op}: {count}")
output = "\n".join(lines)
_write_output(output, args.output)