File indexing completed on 2025-12-16 09:18:06
0001
0002 """
0003 export_to_onnx.py
0004 ---------------------
0005 Export a trained PyTorch TrackCaloRegressor to ONNX and dump StandardScaler to JSON,
0006 using ABSOLUTE PATHS so it works no matter where you run it from.
0007
0008 python export_to_onnx.py \
0009 --model-dir /abs/path/to/your/code/dir \
0010 --weights /abs/path/to/model_weight/best_model_pt_0to10.pt \
0011 --scaler /abs/path/to/model_weight/scaler_0to10.joblib \
0012 --out /abs/path/to/out/model.onnx \
0013 --json /abs/path/to/out/scaler.json \
0014 --in-dim 7
0015
0016 Examples:
0017 python export_to_onnx.py \
0018 --model-dir /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version2 \
0019 --weights /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version2/model_weight/best_model_pt_0.0_10.0_INTT_CaloIwoE.pt \
0020 --scaler /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version2/model_weight/scaler_pt_0.0_10.0_INTT_CaloIwoE.pkl \
0021 --out /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/model_MLEMD.onnx \
0022 --json /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/scaler_MLEMD.json \
0023 --in-dim 7
0024
0025 python export_to_onnx.py \
0026 --model-dir /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version4 \
0027 --weights /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version4/model_weight/best_model_pt_0.0_10.0_INTT_CaloIwoE.pt \
0028 --scaler /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version4/ML_Weight_Scaler/scaler_identity.pkl \
0029 --out /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/model_MLEproj.onnx \
0030 --json /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/scaler_MLEproj.json \
0031 --in-dim 2
0032
0033 python export_to_onnx.py \
0034 --model-dir /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/combine2_gate \
0035 --weights /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/combine2_gate/model_weight/best_model_combined.pt \
0036 --scaler /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/combine2_gate/ML_Weight_Scaler/scaler_identity.pkl \
0037 --out /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/model_MLCombined.onnx \
0038 --json /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/scaler_MLCombined.json \
0039 --in-dim 2
0040
0041 Notes:
0042 - --model-dir should be the directory that contains model.py (with TrackCaloRegressor).
0043 - All other paths should be absolute. The script will create parent folders for outputs.
0044 """
0045 import argparse
0046 import json
0047 import sys
0048 from pathlib import Path
0049
0050 def abspath(p: str, must_exist: bool=False) -> Path:
0051 path = Path(p).expanduser().absolute()
0052 if must_exist and not path.exists():
0053 raise FileNotFoundError(f"Path does not exist: {path}")
0054 return path
0055
0056 def main():
0057 ap = argparse.ArgumentParser(description="Export TrackCaloRegressor to ONNX (absolute paths)")
0058 ap.add_argument("--model-dir", required=True, help="Absolute path to directory containing model.py")
0059 ap.add_argument("--weights", required=True, help="Absolute path to .pt/.pth weights (state_dict)")
0060 ap.add_argument("--scaler", required=True, help="Absolute path to StandardScaler .joblib")
0061 ap.add_argument("--out", required=True, help="Absolute path to output model.onnx")
0062 ap.add_argument("--json", required=True, help="Absolute path to output scaler.json")
0063 ap.add_argument("--in-dim", type=int, default=7, help="Input feature dimension (default 7)")
0064 ap.add_argument("--hidden-dim",type=int, default=256, help="Hidden size, if customized")
0065
0066 ap.add_argument("--class-name",default="FusionRegressor", help="Model class name in model.py")
0067 args = ap.parse_args()
0068
0069
0070 model_dir = abspath(args.model_dir, must_exist=True)
0071 weights = abspath(args.weights, must_exist=True)
0072 scalerjb = abspath(args.scaler, must_exist=False)
0073 out_onnx = abspath(args.out, must_exist=False)
0074 out_json = abspath(args.json, must_exist=False)
0075
0076
0077 out_onnx.parent.mkdir(parents=True, exist_ok=True)
0078 out_json.parent.mkdir(parents=True, exist_ok=True)
0079
0080
0081 if str(model_dir) not in sys.path:
0082 sys.path.insert(0, str(model_dir))
0083
0084
0085 try:
0086 import importlib
0087
0088 model_module = importlib.import_module("model_combined")
0089 ModelClass = getattr(model_module, args.class_name)
0090 except Exception as e:
0091 print(f"[ERROR] Failed to import '{args.class_name}' from {model_dir}/model.py: {e}", file=sys.stderr)
0092 sys.exit(2)
0093
0094
0095 import torch, joblib, numpy as np
0096
0097
0098 model = ModelClass(input_dim=args.in_dim, hidden_dim=args.hidden_dim) if "hidden_dim" in ModelClass.__init__.__code__.co_varnames else ModelClass(input_dim=args.in_dim)
0099 try:
0100 ckpt = torch.load(str(weights), map_location="cpu")
0101 if isinstance(ckpt, dict) and "state_dict" in ckpt:
0102 state_dict = ckpt["state_dict"]
0103 elif isinstance(ckpt, dict):
0104
0105 state_dict = ckpt
0106 else:
0107 raise RuntimeError(f"Unsupported checkpoint object: {type(ckpt)}")
0108 model.load_state_dict(state_dict, strict=False)
0109 except Exception as e:
0110 print(f"[ERROR] Failed to load weights: {e}", file=sys.stderr)
0111 sys.exit(3)
0112 model.eval()
0113
0114
0115 try:
0116 dummy = torch.randn(1, args.in_dim, dtype=torch.float32)
0117 torch.onnx.export(
0118 model, dummy, str(out_onnx),
0119 input_names=["x"], output_names=["pt_hat"],
0120 dynamic_axes={"x": {0: "N"}, "pt_hat": {0: "N"}},
0121 opset_version=13,
0122 )
0123 print(f"[OK] Exported ONNX -> {out_onnx}")
0124 except Exception as e:
0125 print(f"[ERROR] Failed to export ONNX: {e}", file=sys.stderr)
0126 sys.exit(4)
0127
0128
0129 from pathlib import Path
0130 if not Path(scalerjb).exists():
0131 print(f"[INFO] Scaler file not found: {scalerjb}. Skip writing scaler JSON.")
0132
0133 return
0134
0135 try:
0136 scaler = joblib.load(str(scalerjb))
0137 required = ["mean_", "scale_", "n_features_in_"]
0138 for a in required:
0139 if not hasattr(scaler, a):
0140 raise AttributeError(f"Scaler missing '{a}'. Is it sklearn StandardScaler?")
0141 payload = {
0142 "mean": np.asarray(scaler.mean_, dtype=float).tolist(),
0143 "scale": np.asarray(scaler.scale_, dtype=float).tolist(),
0144 "n_features": int(scaler.n_features_in_),
0145 }
0146 out_json.write_text(json.dumps(payload, indent=2))
0147 print(f"[OK] Wrote scaler JSON -> {out_json} (n_features={payload['n_features']})")
0148 if payload["n_features"] != args.in_dim:
0149 print(f"[WARN] scaler n_features ({payload['n_features']}) != --in-dim ({args.in_dim})", file=sys.stderr)
0150 except Exception as e:
0151 print(f"[ERROR] Failed to write scaler JSON: {e}", file=sys.stderr)
0152 sys.exit(5)
0153
0154 if __name__ == "__main__":
0155 main()