Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2026-04-05 08:11:42

0001 #include "PtCalculator.h"
0002 #include <cmath>
0003 #include <stdexcept>
0004 #include <limits>
0005 #include <array>
0006 #include <memory>
0007 #include <iostream>
0008 
0009 #include <onnxruntime_cxx_api.h>   // ONNX Runtime C++ API
0010 #include <fstream>
0011 #include <nlohmann/json.hpp>       // Choose any json library (or the one you already have)
0012 
0013 namespace {
0014 
0015 // Convert 7 raw inputs -> 3 engineered features:
0016 // raw order (as in PtCalculator.h comment):
0017 //   [0] R34, [1] Z34, [2] R56, [3] Z56, [4] Rcalo, [5] Zcalo, [6] E
0018 static std::vector<float> MLEprojRaw7_to_Feature3(const std::vector<float>& raw7)
0019 {
0020     constexpr double eps = 1e-12;
0021 
0022     const double R0 = raw7.at(0), Z0 = raw7.at(1);
0023     const double R1 = raw7.at(2), Z1 = raw7.at(3);
0024     const double R2 = raw7.at(4), Z2 = raw7.at(5);
0025     const double E  = raw7.at(6);
0026 
0027     // direction eta from endpoint (0 -> 2) in RZ
0028     const double dR = (R2 - R0);
0029     const double dZ = (Z2 - Z0);
0030     const double vT = std::fabs(dR);              // RZ 里把 transverse 当作 |dR|
0031     const double eta_dir = std::asinh(dZ / (vT + eps));  // 对应 data.py: asinh(vz/(vT+eps))
0032 
0033     // kink angle between segments (0->1) and (1->2) in RZ
0034     const double v01x = (R1 - R0), v01z = (Z1 - Z0);
0035     const double v12x = (R2 - R1), v12z = (Z2 - Z1);
0036 
0037     const double n01 = std::sqrt(v01x*v01x + v01z*v01z);
0038     const double n12 = std::sqrt(v12x*v12x + v12z*v12z);
0039     const double denom = (n01 * n12) + eps;
0040 
0041     double cos_kink = (v01x*v12x + v01z*v12z) / denom;
0042     if (cos_kink > 1.0) cos_kink = 1.0;
0043     if (cos_kink < -1.0) cos_kink = -1.0;
0044     const double kink = std::acos(cos_kink);
0045 
0046     // logE
0047     const double logE = std::log(E + 1e-6);       // 对应 data.py: log(E+1e-6)
0048 
0049     return { static_cast<float>(eta_dir),
0050              static_cast<float>(logE),
0051              static_cast<float>(kink) };
0052 }
0053 
0054 } // anonymous namespace
0055 
0056 
0057 namespace SiCaloPt {
0058 
0059 PtCalculator::PtCalculator(const PtCalculatorConfig& cfg)
0060 : m_cfg(cfg) {}
0061 
0062 void PtCalculator::setConfig(const PtCalculatorConfig& cfg) {
0063     m_cfg = cfg;
0064 }
0065 
0066 bool PtCalculator::init(std::string* err)
0067 {
0068     // EMD ML Model loading
0069     if (m_cfg.mlEMD_model_path) 
0070     {
0071         std::string err_string;
0072         auto fn = MakeOnnxInfer(*m_cfg.mlEMD_model_path, &err_string);
0073         if (!fn) 
0074         { 
0075             if (err) *err = "ONNX load mlEMD failed: " + err_string; 
0076             return false; 
0077         }
0078         setMLEMDInfer(std::move(fn));
0079         // scaler (optional)
0080         if (m_cfg.mlEMD_scaler_json) 
0081         {
0082             std::vector<float> mean, scale;
0083             if (!LoadScalerJson(*m_cfg.mlEMD_scaler_json, mean, scale, &err_string)) 
0084             {
0085                 if (err) *err = "Load mlEMD scaler failed: " + err_string; 
0086                 return false;
0087             }
0088             setMLEMDStandardizer(std::move(mean), std::move(scale));
0089         }
0090     }
0091 
0092     // Eproj ML model loading
0093     if (m_cfg.mlEproj_model_path) 
0094     {
0095         std::string err_string;
0096         auto fn = MakeOnnxInfer(*m_cfg.mlEproj_model_path, &err_string);
0097         if (!fn)
0098         { 
0099             if (err) *err = "ONNX load mlEproj failed: " + err_string; 
0100             return false; 
0101         }
0102         setMLEprojInfer(std::move(fn));
0103         if (m_cfg.mlEproj_scaler_json) 
0104         {
0105             std::vector<float> mean, scale;
0106             if (!LoadScalerJson(*m_cfg.mlEproj_scaler_json, mean, scale, &err_string)) 
0107             {
0108                 if (err) *err = "Load mlEproj scaler failed: " + err_string; 
0109                 return false;
0110             }
0111             setMLEprojStandardizer(std::move(mean), std::move(scale));
0112         }
0113     }
0114 
0115     // Combined Gate ML model loading
0116     if (m_cfg.mlCombined_model_path) 
0117     {
0118         std::string err_string;
0119         auto fn = MakeOnnxInfer(*m_cfg.mlCombined_model_path, &err_string);
0120         if (!fn) 
0121         { 
0122             if (err) *err = "ONNX load mlCombined failed: " + err_string; 
0123             return false; 
0124         }
0125         setMLCombinedInfer(std::move(fn));
0126         if (m_cfg.mlCombined_scaler_json) 
0127         {
0128             std::vector<float> mean, scale;
0129             if (!LoadScalerJson(*m_cfg.mlCombined_scaler_json, mean, scale, &err_string)) 
0130             {
0131                 if (err) *err = "Load mlCombined scaler failed: " + err_string; 
0132                 return false;
0133             }
0134             setMLCombinedStandardizer(std::move(mean), std::move(scale));
0135         }
0136     }
0137 
0138     return true;
0139 }
0140 
0141 
0142 PtResult PtCalculator::ComputePt(Method method, const AnyInput& input) const 
0143 {
0144     try 
0145     {
0146         switch (method) 
0147         {
0148             case Method::MethodEMD:
0149                 return ComputeEMD(std::get<InputEMD>(input));
0150             case Method::MethodEproj:
0151                 return ComputeEproj(std::get<InputEproj>(input));
0152             case Method::MethodMLEMD:
0153                 return ComputeMLEMD(std::get<InputMLEMD>(input));
0154             case Method::MethodMLEproj:
0155                 return ComputeMLEproj(std::get<InputMLEproj>(input));
0156             case Method::MethodMLCombined:
0157                 return ComputeMLCombined(std::get<InputMLCombined>(input));
0158             default:
0159                 return PtResult{.pt_reco = NAN, .ok = false, .err = "Unknown method"};
0160         }
0161     } 
0162     catch (const std::bad_variant_access&) 
0163     {
0164         return PtResult{.pt_reco = NAN, .ok = false, .err = "Input type does not match method"};
0165     }
0166 }
0167 
0168 PtResult PtCalculator::ComputeEMD(const InputEMD& in) const 
0169 {
0170     if (in.EMD_Angle == 0.f) 
0171     {
0172         return PtResult{.pt_reco = NAN, .ok = false, .err = "EMD_Angle is zero"};
0173     }
0174 
0175     int scenario = getScenario(in.EMD_Angle);
0176     if (scenario == -999)
0177     {
0178         return PtResult{.pt_reco = NAN, .ok = false, .err = "Invalid scenario"};
0179     }
0180 
0181     if (consider_eta_dependence_on_EMDcompute)
0182     {
0183         switch (scenario)
0184         {
0185             case -10:
0186                 // negative charge + projected cluster
0187                 m_par_Ceta = 0.199854 + (0.00971966*x_eta*x_eta) + (-0.0177071*x_eta*x_eta*x_eta*x_eta);
0188                 m_par_Power = -1.0;
0189                 break;
0190 
0191             case -11:
0192                 // negative charge + inner face
0193                 m_par_Ceta = 0.198211 + (0.013064*x_eta*x_eta) + (-0.009812*x_eta*x_eta*x_eta*x_eta);
0194                 m_par_Power = -1.0;
0195                 break;
0196 
0197             case -12:
0198                 // negative charge + detail
0199                 m_par_Ceta = 0.197232 + (0.014244*x_eta*x_eta) + (-0.0188948*x_eta*x_eta*x_eta*x_eta);
0200                 m_par_Power = -1.0;
0201                 break;
0202 
0203             case 10:
0204                 // positive charge + projected cluster
0205                 m_par_Ceta = 0.203717 + (-0.000515*x_eta*x_eta) + (-0.00131*x_eta*x_eta*x_eta*x_eta);
0206                 m_par_Power = -1.0;
0207                 break;
0208 
0209             case 11:
0210                 // positive charge + inner face
0211                 m_par_Ceta = 0.197297 + (0.01297*x_eta*x_eta) + (-0.00979*x_eta*x_eta*x_eta*x_eta);
0212                 m_par_Power = -1.0;
0213                 break;
0214 
0215             case 12:
0216                 // positive charge + detail
0217                 m_par_Ceta = 0.200697 + (-0.0040371*x_eta*x_eta) + (-0.000426*x_eta*x_eta*x_eta*x_eta);
0218                 m_par_Power = -1.0;
0219                 break;
0220 
0221             default:
0222                 return PtResult{.pt_reco = NAN, .ok = false, .err = "Unknown scenario"};
0223         }
0224     }
0225     const float pt = m_par_Ceta * std::pow(std::fabs(in.EMD_Angle), m_par_Power);
0226     return PtResult{.pt_reco = pt, .ok = true, .err = ""};
0227 }
0228 
0229 
0230 PtResult PtCalculator::ComputeEproj(const InputEproj& in) const 
0231 {
0232     const float Distance_Z = std::fabs(in.Z_Calo - in.Z_vertex);
0233     const float Distance_R = std::fabs(in.Radius_Calo - in.Radius_vertex);
0234     const float Distance   = std::sqrt((Distance_R*Distance_R) + (Distance_Z*Distance_Z));
0235 
0236     if (Distance == 0.f) 
0237     {
0238         return PtResult{.pt_reco = NAN, .ok = false, .err = "Distance is zero"};
0239     }
0240     const float pt = in.Energy_Calo*(Distance_R/Distance);
0241     return PtResult{.pt_reco = pt, .ok = true, .err = ""};
0242 }
0243 
0244 PtResult PtCalculator::ComputeMLEMD(const InputMLEMD& in) const 
0245 {
0246     if (!m_mlEMD_infer) 
0247     {
0248         return PtResult{.pt_reco = NAN, .ok = false, .err = "MLEMD infer function not set"};
0249     }
0250 
0251     std::vector<float> x = in.features;
0252 
0253     if (x.size() != 2)
0254     {
0255         return PtResult{.pt_reco = NAN, .ok = false, .err = "MLEMD expects 2 features: {dphi, eta}"};
0256     }
0257 
0258     const float dphi = x[0];
0259     if (!std::isfinite(dphi) || dphi == 0.f)
0260     {
0261         return PtResult{.pt_reco = NAN, .ok = false, .err = "MLEMD dphi is invalid or zero"};
0262     }
0263 
0264     // convert external input dphi -> internal model input 1/dphi
0265     x[0] = 1.f / dphi;
0266 
0267     if (!m_mlEMD_mean.empty() && !m_mlEMD_scale.empty()) 
0268     {
0269         if (m_mlEMD_mean.size() != x.size() || m_mlEMD_scale.size() != x.size()) 
0270         {
0271             return PtResult{.pt_reco = NAN, .ok = false, .err = "MLEMD standardizer dim mismatch"};
0272         }
0273         applyStandardize(x, m_mlEMD_mean, m_mlEMD_scale);
0274     }
0275 
0276     const float pt = m_mlEMD_infer(x);
0277     return PtResult{.pt_reco = pt, .ok = true, .err = ""};
0278 }
0279 
0280 PtResult PtCalculator::ComputeMLEproj(const InputMLEproj& in) const 
0281 {
0282     if (!m_mlEproj_infer) 
0283     {
0284         return PtResult{.pt_reco = NAN, .ok = false, .err = "MLEproj infer function not set"};
0285     }
0286 
0287     std::vector<float> x = in.features;
0288 
0289     // --- compute pt0 (baseline) ---
0290     // training uses: pt0 = E * sin(theta),  pt_pred = pt0 * exp(delta)
0291     // If we only have R,Z, approximate direction in RZ plane:
0292     //   vT = |dR|, vmag = sqrt(dR^2 + dZ^2), sin(theta) ~ vT/vmag
0293     double pt0 = NAN;
0294 
0295     if (x.size() == 7)
0296     {
0297         const double R0 = x[0], Z0 = x[1];
0298         const double R2 = x[4], Z2 = x[5];
0299         const double E  = x[6];
0300 
0301         const double dR = (R2 - R0);
0302         const double dZ = (Z2 - Z0);
0303 
0304         const double vT   = std::fabs(dR);
0305         const double vmag = std::sqrt(dR*dR + dZ*dZ) + 1e-12;
0306         const double uT   = vT / vmag;   // ~ sin(theta)
0307 
0308         pt0 = E * uT;
0309 
0310         // 7 raw -> 3 engineered
0311         x = MLEprojRaw7_to_Feature3(x);
0312     }
0313     else if (x.size() == 3)
0314     {
0315         // x = [eta_dir, logE, kink]
0316         const double eta_dir = x[0];
0317         const double logE    = x[1];
0318 
0319         // E = exp(logE) - 1e-6  (inverse of log(E+1e-6))
0320         const double E = std::exp(logE) - 1e-6;
0321 
0322         // sin(theta) = 1/cosh(eta)  (since eta = asinh(vz/vT), equivalent to usual direction eta)
0323         const double uT = 1.0 / std::cosh(eta_dir);
0324 
0325         pt0 = E * uT;
0326     }
0327     else
0328     {
0329         return PtResult{.pt_reco = NAN, .ok = false,
0330                         .err = "MLEproj expects 7 raw inputs or 3 engineered features"};
0331     }
0332 
0333     if (!std::isfinite(pt0) || pt0 <= 0.0)
0334     {
0335         return PtResult{.pt_reco = NAN, .ok = false, .err = "MLEproj pt0 invalid"};
0336     }
0337 
0338     // --- standardize x (must match dim=3 after conversion) ---
0339     if (!m_mlEproj_mean.empty() && !m_mlEproj_scale.empty()) 
0340     {
0341         if (m_mlEproj_mean.size() != x.size() || m_mlEproj_scale.size() != x.size()) 
0342         {
0343             return PtResult{.pt_reco = NAN, .ok = false, .err = "MLEproj standardizer dim mismatch"};
0344         }
0345         applyStandardize(x, m_mlEproj_mean, m_mlEproj_scale);
0346     }
0347 
0348     // --- infer delta, then pt = pt0 * exp(delta) ---
0349     const float delta_f = m_mlEproj_infer(x);
0350     const double delta  = static_cast<double>(delta_f);
0351 
0352     if (!std::isfinite(delta))
0353     {
0354         return PtResult{.pt_reco = NAN, .ok = false, .err = "MLEproj delta invalid"};
0355     }
0356 
0357     const double pt = pt0 * std::exp(delta);
0358 
0359     if (!std::isfinite(pt) || pt <= 0.0)
0360     {
0361         return PtResult{.pt_reco = NAN, .ok = false, .err = "MLEproj pt invalid"};
0362     }
0363 
0364     return PtResult{.pt_reco = static_cast<float>(pt), .ok = true, .err = ""};
0365 }
0366 
0367 PtResult PtCalculator::ComputeMLCombined(const InputMLCombined& in) const 
0368 {
0369     if (!m_mlCombined_infer) 
0370     {
0371         return PtResult{.pt_reco = NAN, .ok = false, .err = "MLCombined infer function not set"};
0372     }
0373     std::vector<float> x = in.features;
0374     if (!m_mlCombined_mean.empty() && !m_mlCombined_scale.empty()) 
0375     {
0376         if (m_mlCombined_mean.size() != x.size() || m_mlCombined_scale.size() != x.size()) 
0377         {
0378             return PtResult{.pt_reco = NAN, .ok = false, .err = "MLCombined standardizer dim mismatch"};
0379         }
0380         applyStandardize(x, m_mlCombined_mean, m_mlCombined_scale);
0381     }
0382     const float pt = m_mlCombined_infer(x);
0383     return PtResult{.pt_reco = pt, .ok = true, .err = ""};
0384 }
0385 
0386 void PtCalculator::setMLEMDInfer(InferFn fn) { m_mlEMD_infer = std::move(fn); }
0387 void PtCalculator::setMLEprojInfer(InferFn fn) { m_mlEproj_infer = std::move(fn); }
0388 void PtCalculator::setMLCombinedInfer(InferFn fn) { m_mlCombined_infer = std::move(fn); }
0389 
0390 void PtCalculator::setMLEMDStandardizer(std::vector<float> mean, std::vector<float> scale) 
0391 {
0392     m_mlEMD_mean  = std::move(mean);
0393     m_mlEMD_scale = std::move(scale);
0394 }
0395 
0396 void PtCalculator::setMLEprojStandardizer(std::vector<float> mean, std::vector<float> scale) 
0397 {
0398     m_mlEproj_mean  = std::move(mean);
0399     m_mlEproj_scale = std::move(scale);
0400 }
0401 
0402 void PtCalculator::setMLCombinedStandardizer(std::vector<float> mean, std::vector<float> scale) 
0403 {
0404     m_mlCombined_mean  = std::move(mean);
0405     m_mlCombined_scale = std::move(scale);
0406 }
0407 
0408 void PtCalculator::applyStandardize(std::vector<float>& x,
0409                                     const std::vector<float>& mean,
0410                                     const std::vector<float>& scale) 
0411 {
0412     const size_t n = x.size();
0413     for (size_t i=0; i<n; ++i) 
0414     {
0415         if (scale[i] == 0.f) 
0416         {
0417             throw std::runtime_error("Scale value is zero during standardization");
0418         }
0419         x[i] = (x[i] - mean[i]) / scale[i];
0420     }
0421 }
0422 
0423 // --------------------------------------------------------------------
0424 bool PtCalculator::LoadScalerJson(const std::string& path,
0425                                   std::vector<float>& mean,
0426                                   std::vector<float>& scale,
0427                                   std::string* err)
0428 {
0429     try {
0430         std::ifstream fin(path);
0431         nlohmann::json js; fin >> js;
0432         if (!js.contains("mean") || !js.contains("scale")) {
0433             if (err) *err = "scaler json missing keys";
0434             return false;
0435         }
0436         mean  = js["mean"].get<std::vector<float>>();
0437         scale = js["scale"].get<std::vector<float>>();
0438         return true;
0439     } catch (const std::exception& e) {
0440         if (err) *err = e.what();
0441         return false;
0442     }
0443 }
0444 
0445 PtCalculator::InferFn PtCalculator::MakeOnnxInfer(const std::string& onnx_path,
0446                                                   std::string* err)
0447 {
0448     try {
0449         static Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ptcalc");
0450         Ort::SessionOptions so;
0451         so.SetIntraOpNumThreads(1);
0452         so.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0453 
0454         auto sess = std::make_shared<Ort::Session>(env, onnx_path.c_str(), so);
0455 
0456         Ort::AllocatorWithDefaultOptions alloc;
0457         auto in_name  = std::string(sess->GetInputNameAllocated(0, alloc).get());
0458         auto out_name = std::string(sess->GetOutputNameAllocated(0, alloc).get());
0459 
0460         return [sess, in_name, out_name](const std::vector<float>& feats) -> float {
0461             const int64_t N = 1;
0462             const int64_t D = static_cast<int64_t>(feats.size());
0463             std::array<int64_t,2> shape{N, D};
0464 
0465             Ort::MemoryInfo mem = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
0466             Ort::Value input = Ort::Value::CreateTensor<float>(
0467                 mem,
0468                 const_cast<float*>(feats.data()),
0469                 feats.size(),
0470                 shape.data(), shape.size()
0471             );
0472 
0473             const char* in_names[]  = { in_name.c_str()  };
0474             const char* out_names[] = { out_name.c_str() };
0475 
0476             auto outputs = sess->Run(Ort::RunOptions{nullptr}, in_names, &input, 1, out_names, 1);
0477             float* out_ptr = outputs.front().GetTensorMutableData<float>();
0478             return out_ptr[0];
0479         };
0480     } catch (const std::exception& e) {
0481         if (err) *err = e.what();
0482         return {};
0483     }
0484 }
0485 
0486 } // namespace SiCaloPt