Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:06

0001 #!/usr/bin/env python3
0002 """
0003 export_to_onnx.py
0004 ---------------------
0005 Export a trained PyTorch TrackCaloRegressor to ONNX and dump StandardScaler to JSON,
0006 using ABSOLUTE PATHS so it works no matter where you run it from.
0007 
0008 python export_to_onnx.py \
0009     --model-dir   /abs/path/to/your/code/dir \
0010     --weights     /abs/path/to/model_weight/best_model_pt_0to10.pt \
0011     --scaler      /abs/path/to/model_weight/scaler_0to10.joblib \
0012     --out         /abs/path/to/out/model.onnx \
0013     --json        /abs/path/to/out/scaler.json \
0014     --in-dim 7
0015 
0016 Examples:
0017     python export_to_onnx.py \
0018         --model-dir   /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version2 \
0019         --weights     /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version2/model_weight/best_model_pt_0.0_10.0_INTT_CaloIwoE.pt \
0020         --scaler      /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version2/model_weight/scaler_pt_0.0_10.0_INTT_CaloIwoE.pkl \
0021         --out         /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/model_MLEMD.onnx \
0022         --json        /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/scaler_MLEMD.json \
0023         --in-dim 7
0024 
0025     python export_to_onnx.py \
0026         --model-dir   /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version4 \
0027         --weights     /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version4/model_weight/best_model_pt_0.0_10.0_INTT_CaloIwoE.pt \
0028         --scaler      /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/version4/ML_Weight_Scaler/scaler_identity.pkl \
0029         --out         /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/model_MLEproj.onnx \
0030         --json        /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/scaler_MLEproj.json \
0031         --in-dim 2
0032 
0033     python export_to_onnx.py \
0034        --model-dir   /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/combine2_gate \
0035        --weights     /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/combine2_gate/model_weight/best_model_combined.pt \
0036        --scaler      /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/combine2_gate/ML_Weight_Scaler/scaler_identity.pkl \
0037        --out         /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/model_MLCombined.onnx \
0038        --json        /mnt/e/sphenix/INTT-EMCAL/InttSeedingTrackDev/ML4Reco/Implement/ML_Weight_Scaler/scaler_MLCombined.json \
0039        --in-dim 2
0040         
0041 Notes:
0042 - --model-dir should be the directory that contains model.py (with TrackCaloRegressor).
0043 - All other paths should be absolute. The script will create parent folders for outputs.
0044 """
0045 import argparse
0046 import json
0047 import sys
0048 from pathlib import Path
0049 
0050 def abspath(p: str, must_exist: bool=False) -> Path:
0051     path = Path(p).expanduser().absolute()
0052     if must_exist and not path.exists():
0053         raise FileNotFoundError(f"Path does not exist: {path}")
0054     return path
0055 
0056 def main():
0057     ap = argparse.ArgumentParser(description="Export TrackCaloRegressor to ONNX (absolute paths)")
0058     ap.add_argument("--model-dir", required=True, help="Absolute path to directory containing model.py")
0059     ap.add_argument("--weights",   required=True, help="Absolute path to .pt/.pth weights (state_dict)")
0060     ap.add_argument("--scaler",    required=True, help="Absolute path to StandardScaler .joblib")
0061     ap.add_argument("--out",       required=True, help="Absolute path to output model.onnx")
0062     ap.add_argument("--json",      required=True, help="Absolute path to output scaler.json")
0063     ap.add_argument("--in-dim",    type=int, default=7, help="Input feature dimension (default 7)")
0064     ap.add_argument("--hidden-dim",type=int, default=256, help="Hidden size, if customized")
0065     # ap.add_argument("--class-name",default="TrackCaloRegressor", help="Model class name in model.py")
0066     ap.add_argument("--class-name",default="FusionRegressor", help="Model class name in model.py")
0067     args = ap.parse_args()
0068 
0069     # Resolve and validate paths
0070     model_dir = abspath(args.model_dir, must_exist=True)
0071     weights   = abspath(args.weights, must_exist=True)
0072     scalerjb  = abspath(args.scaler,  must_exist=False)
0073     out_onnx  = abspath(args.out,     must_exist=False)
0074     out_json  = abspath(args.json,    must_exist=False)
0075 
0076     # Ensure output directories exist
0077     out_onnx.parent.mkdir(parents=True, exist_ok=True)
0078     out_json.parent.mkdir(parents=True, exist_ok=True)
0079 
0080     # Add model_dir to sys.path BEFORE importing model.py
0081     if str(model_dir) not in sys.path:
0082         sys.path.insert(0, str(model_dir))
0083 
0084     # Now import the model class dynamically
0085     try:
0086         import importlib
0087         # model_module = importlib.import_module("model")
0088         model_module = importlib.import_module("model_combined")
0089         ModelClass = getattr(model_module, args.class_name)
0090     except Exception as e:
0091         print(f"[ERROR] Failed to import '{args.class_name}' from {model_dir}/model.py: {e}", file=sys.stderr)
0092         sys.exit(2)
0093 
0094     # Torch & IO
0095     import torch, joblib, numpy as np
0096 
0097     # Build model and load weights
0098     model = ModelClass(input_dim=args.in_dim, hidden_dim=args.hidden_dim) if "hidden_dim" in ModelClass.__init__.__code__.co_varnames else ModelClass(input_dim=args.in_dim)
0099     try:
0100         ckpt = torch.load(str(weights), map_location="cpu")
0101         if isinstance(ckpt, dict) and "state_dict" in ckpt:
0102             state_dict = ckpt["state_dict"]
0103         elif isinstance(ckpt, dict):
0104             # Likely a raw state_dict
0105             state_dict = ckpt
0106         else:
0107             raise RuntimeError(f"Unsupported checkpoint object: {type(ckpt)}")
0108         model.load_state_dict(state_dict, strict=False)
0109     except Exception as e:
0110         print(f"[ERROR] Failed to load weights: {e}", file=sys.stderr)
0111         sys.exit(3)
0112     model.eval()
0113 
0114     # Export ONNX (dynamic batch)
0115     try:
0116         dummy = torch.randn(1, args.in_dim, dtype=torch.float32)
0117         torch.onnx.export(
0118             model, dummy, str(out_onnx),
0119             input_names=["x"], output_names=["pt_hat"],
0120             dynamic_axes={"x": {0: "N"}, "pt_hat": {0: "N"}},
0121             opset_version=13,
0122         )
0123         print(f"[OK] Exported ONNX -> {out_onnx}")
0124     except Exception as e:
0125         print(f"[ERROR] Failed to export ONNX: {e}", file=sys.stderr)
0126         sys.exit(4)
0127 
0128     # Dump scaler JSON
0129     from pathlib import Path
0130     if not Path(scalerjb).exists():
0131         print(f"[INFO] Scaler file not found: {scalerjb}. Skip writing scaler JSON.")
0132         # 直接跳过整个 scaler 处理
0133         return
0134     
0135     try:
0136         scaler = joblib.load(str(scalerjb))
0137         required = ["mean_", "scale_", "n_features_in_"]
0138         for a in required:
0139             if not hasattr(scaler, a):
0140                 raise AttributeError(f"Scaler missing '{a}'. Is it sklearn StandardScaler?")
0141         payload = {
0142             "mean": np.asarray(scaler.mean_, dtype=float).tolist(),
0143             "scale": np.asarray(scaler.scale_, dtype=float).tolist(),
0144             "n_features": int(scaler.n_features_in_),
0145         }
0146         out_json.write_text(json.dumps(payload, indent=2))
0147         print(f"[OK] Wrote scaler JSON -> {out_json} (n_features={payload['n_features']})")
0148         if payload["n_features"] != args.in_dim:
0149             print(f"[WARN] scaler n_features ({payload['n_features']}) != --in-dim ({args.in_dim})", file=sys.stderr)
0150     except Exception as e:
0151         print(f"[ERROR] Failed to write scaler JSON: {e}", file=sys.stderr)
0152         sys.exit(5)
0153 
0154 if __name__ == "__main__":
0155     main()