Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:18:06

0001 #include "PtCalculator.h"
0002 #include <cmath>
0003 #include <stdexcept>
0004 #include <limits>
0005 #include <array>
0006 #include <memory>
0007 
0008 #include <onnxruntime_cxx_api.h>   // ONNX Runtime C++ API
0009 #include <fstream>
0010 #include <nlohmann/json.hpp>       // Choose any json library (or the one you already have)
0011 
0012 namespace SiCaloPt {
0013 
0014 PtCalculator::PtCalculator(const PtCalculatorConfig& cfg)
0015 : m_cfg(cfg) {}
0016 
0017 void PtCalculator::setConfig(const PtCalculatorConfig& cfg) {
0018     m_cfg = cfg;
0019 }
0020 
0021 bool PtCalculator::init(std::string* err)
0022 {
0023     // EMD ML Model loading
0024     if (m_cfg.mlEMD_model_path) 
0025     {
0026         std::string err_string;
0027         auto fn = MakeOnnxInfer(*m_cfg.mlEMD_model_path, &err_string);
0028         if (!fn) 
0029         { 
0030             if (err) *err = "ONNX load mlEMD failed: " + err_string; 
0031             return false; 
0032         }
0033         setMLEMDInfer(std::move(fn));
0034         // scaler (optional)
0035         if (m_cfg.mlEMD_scaler_json) 
0036         {
0037             std::vector<float> mean, scale;
0038             if (!LoadScalerJson(*m_cfg.mlEMD_scaler_json, mean, scale, &err_string)) 
0039             {
0040                 if (err) *err = "Load mlEMD scaler failed: " + err_string; 
0041                 return false;
0042             }
0043             setMLEMDStandardizer(std::move(mean), std::move(scale));
0044         }
0045     }
0046 
0047     // Eproj ML model loading
0048     if (m_cfg.mlEproj_model_path) 
0049     {
0050         std::string err_string;
0051         auto fn = MakeOnnxInfer(*m_cfg.mlEproj_model_path, &err_string);
0052         if (!fn)
0053         { 
0054             if (err) *err = "ONNX load mlEproj failed: " + err_string; 
0055             return false; 
0056         }
0057         setMLEprojInfer(std::move(fn));
0058         if (m_cfg.mlEproj_scaler_json) 
0059         {
0060             std::vector<float> mean, scale;
0061             if (!LoadScalerJson(*m_cfg.mlEproj_scaler_json, mean, scale, &err_string)) 
0062             {
0063                 if (err) *err = "Load mlEproj scaler failed: " + err_string; 
0064                 return false;
0065             }
0066             setMLEprojStandardizer(std::move(mean), std::move(scale));
0067         }
0068     }
0069 
0070     // Combined Gate ML model loading
0071     if (m_cfg.mlCombined_model_path) 
0072     {
0073         std::string err_string;
0074         auto fn = MakeOnnxInfer(*m_cfg.mlCombined_model_path, &err_string);
0075         if (!fn) 
0076         { 
0077             if (err) *err = "ONNX load mlCombined failed: " + err_string; 
0078             return false; 
0079         }
0080         setMLCombinedInfer(std::move(fn));
0081         if (m_cfg.mlCombined_scaler_json) 
0082         {
0083             std::vector<float> mean, scale;
0084             if (!LoadScalerJson(*m_cfg.mlCombined_scaler_json, mean, scale, &err_string)) 
0085             {
0086                 if (err) *err = "Load mlCombined scaler failed: " + err_string; 
0087                 return false;
0088             }
0089             setMLCombinedStandardizer(std::move(mean), std::move(scale));
0090         }
0091     }
0092 
0093     return true;
0094 }
0095 
0096 
0097 PtResult PtCalculator::ComputePt(Method method, const AnyInput& input) const 
0098 {
0099     try 
0100     {
0101         switch (method) 
0102         {
0103             case Method::MethodEMD:
0104                 return ComputeEMD(std::get<InputEMD>(input));
0105             case Method::MethodEproj:
0106                 return ComputeEproj(std::get<InputEproj>(input));
0107             case Method::MethodMLEMD:
0108                 return ComputeMLEMD(std::get<InputMLEMD>(input));
0109             case Method::MethodMLEproj:
0110                 return ComputeMLEproj(std::get<InputMLEproj>(input));
0111             case Method::MethodMLCombined:
0112                 return ComputeMLCombined(std::get<InputMLCombined>(input));
0113             default:
0114                 return PtResult{.pt_reco = NAN, .ok = false, .err = "Unknown method"};
0115         }
0116     } 
0117     catch (const std::bad_variant_access&) 
0118     {
0119         return PtResult{.pt_reco = NAN, .ok = false, .err = "Input type does not match method"};
0120     }
0121 }
0122 
0123 PtResult PtCalculator::ComputeEMD(const InputEMD& in) const 
0124 {
0125     if (in.EMD_Angle == 0.f) 
0126     {
0127         return PtResult{.pt_reco = NAN, .ok = false, .err = "EMD_Angle is zero"};
0128     }
0129 
0130     // whethear considering the eta dependence on EMD compute
0131     if (consider_eta_dependence_on_EMDcompute)
0132     {
0133         float x_eta = in.EMD_Eta;
0134         m_par_Ceta = 0.198211 + (0.013064*x_eta*x_eta) + (-0.009812*x_eta*x_eta*x_eta*x_eta); // Function to calculate the correct factor of eta
0135         m_par_Power = -1.0;
0136     }  
0137     const float pt = m_par_Ceta * std::pow(std::fabs(in.EMD_Angle), m_par_Power);
0138     return PtResult{.pt_reco = pt, .ok = true, .err = ""};
0139 }
0140 
0141 PtResult PtCalculator::ComputeEproj(const InputEproj& in) const 
0142 {
0143     const float Distance_Z = std::fabs(in.Z_Calo - in.Z_vertex);
0144     const float Distance_R = std::fabs(in.Radius_Calo - in.Radius_vertex);
0145     const float Distance   = std::sqrt((Distance_R*Distance_R) + (Distance_Z*Distance_Z));
0146 
0147     if (Distance == 0.f) 
0148     {
0149         return PtResult{.pt_reco = NAN, .ok = false, .err = "Distance is zero"};
0150     }
0151     const float pt = in.Energy_Calo*(Distance_R/Distance);
0152     return PtResult{.pt_reco = pt, .ok = true, .err = ""};
0153 }
0154 
0155 PtResult PtCalculator::ComputeMLEMD(const InputMLEMD& in) const 
0156 {
0157     if (!m_mlEMD_infer) 
0158     {
0159         return PtResult{.pt_reco = NAN, .ok = false, .err = "MLEMD infer function not set"};
0160     }
0161     std::vector<float> x = in.features;
0162     if (!m_mlEMD_mean.empty() && !m_mlEMD_scale.empty()) 
0163     {
0164         if (m_mlEMD_mean.size() != x.size() || m_mlEMD_scale.size() != x.size()) 
0165         {
0166             return PtResult{.pt_reco = NAN, .ok = false, .err = "MLEMD standardizer dim mismatch"};
0167         }
0168         applyStandardize(x, m_mlEMD_mean, m_mlEMD_scale);
0169     }
0170     const float pt = m_mlEMD_infer(x);
0171     return PtResult{.pt_reco = pt, .ok = true, .err = ""};
0172 }
0173 
0174 PtResult PtCalculator::ComputeMLEproj(const InputMLEproj& in) const 
0175 {
0176     if (!m_mlEproj_infer) 
0177     {
0178         return PtResult{.pt_reco = NAN, .ok = false, .err = "MLEproj infer function not set"};
0179     }
0180     std::vector<float> x = in.features;
0181     if (!m_mlEproj_mean.empty() && !m_mlEproj_scale.empty()) 
0182     {
0183         if (m_mlEproj_mean.size() != x.size() || m_mlEproj_scale.size() != x.size()) 
0184         {
0185             return PtResult{.pt_reco = NAN, .ok = false, .err = "MLEproj standardizer dim mismatch"};
0186         }
0187         applyStandardize(x, m_mlEproj_mean, m_mlEproj_scale);
0188     }
0189     const float pt = m_mlEproj_infer(x);
0190     return PtResult{.pt_reco = pt, .ok = true, .err = ""};
0191 }
0192 
0193 PtResult PtCalculator::ComputeMLCombined(const InputMLCombined& in) const 
0194 {
0195     if (!m_mlCombined_infer) 
0196     {
0197         return PtResult{.pt_reco = NAN, .ok = false, .err = "MLCombined infer function not set"};
0198     }
0199     std::vector<float> x = in.features;
0200     if (!m_mlCombined_mean.empty() && !m_mlCombined_scale.empty()) 
0201     {
0202         if (m_mlCombined_mean.size() != x.size() || m_mlCombined_scale.size() != x.size()) 
0203         {
0204             return PtResult{.pt_reco = NAN, .ok = false, .err = "MLCombined standardizer dim mismatch"};
0205         }
0206         applyStandardize(x, m_mlCombined_mean, m_mlCombined_scale);
0207     }
0208     const float pt = m_mlCombined_infer(x);
0209     return PtResult{.pt_reco = pt, .ok = true, .err = ""};
0210 }
0211 
0212 void PtCalculator::setMLEMDInfer(InferFn fn) { m_mlEMD_infer = std::move(fn); }
0213 void PtCalculator::setMLEprojInfer(InferFn fn) { m_mlEproj_infer = std::move(fn); }
0214 void PtCalculator::setMLCombinedInfer(InferFn fn) { m_mlCombined_infer = std::move(fn); }
0215 
0216 void PtCalculator::setMLEMDStandardizer(std::vector<float> mean, std::vector<float> scale) 
0217 {
0218     m_mlEMD_mean  = std::move(mean);
0219     m_mlEMD_scale = std::move(scale);
0220 }
0221 
0222 void PtCalculator::setMLEprojStandardizer(std::vector<float> mean, std::vector<float> scale) 
0223 {
0224     m_mlEproj_mean  = std::move(mean);
0225     m_mlEproj_scale = std::move(scale);
0226 }
0227 
0228 void PtCalculator::setMLCombinedStandardizer(std::vector<float> mean, std::vector<float> scale) 
0229 {
0230     m_mlCombined_mean  = std::move(mean);
0231     m_mlCombined_scale = std::move(scale);
0232 }
0233 
0234 void PtCalculator::applyStandardize(std::vector<float>& x,
0235                                     const std::vector<float>& mean,
0236                                     const std::vector<float>& scale) 
0237 {
0238     const size_t n = x.size();
0239     for (size_t i=0; i<n; ++i) 
0240     {
0241         if (scale[i] == 0.f) 
0242         {
0243             throw std::runtime_error("Scale value is zero during standardization");
0244         }
0245         x[i] = (x[i] - mean[i]) / scale[i];
0246     }
0247 }
0248 
0249 // --------------------------------------------------------------------
0250 bool PtCalculator::LoadScalerJson(const std::string& path,
0251                                   std::vector<float>& mean,
0252                                   std::vector<float>& scale,
0253                                   std::string* err)
0254 {
0255     try {
0256         std::ifstream fin(path);
0257         nlohmann::json js; fin >> js;
0258         if (!js.contains("mean") || !js.contains("scale")) {
0259             if (err) *err = "scaler json missing keys";
0260             return false;
0261         }
0262         mean  = js["mean"].get<std::vector<float>>();
0263         scale = js["scale"].get<std::vector<float>>();
0264         return true;
0265     } catch (const std::exception& e) {
0266         if (err) *err = e.what();
0267         return false;
0268     }
0269 }
0270 
0271 PtCalculator::InferFn PtCalculator::MakeOnnxInfer(const std::string& onnx_path,
0272                                                   std::string* err)
0273 {
0274     try {
0275         static Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "ptcalc");
0276         Ort::SessionOptions so;
0277         so.SetIntraOpNumThreads(1);
0278         so.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
0279 
0280         auto sess = std::make_shared<Ort::Session>(env, onnx_path.c_str(), so);
0281 
0282         Ort::AllocatorWithDefaultOptions alloc;
0283         auto in_name  = std::string(sess->GetInputNameAllocated(0, alloc).get());
0284         auto out_name = std::string(sess->GetOutputNameAllocated(0, alloc).get());
0285 
0286         return [sess, in_name, out_name](const std::vector<float>& feats) -> float {
0287             const int64_t N = 1;
0288             const int64_t D = static_cast<int64_t>(feats.size());
0289             std::array<int64_t,2> shape{N, D};
0290 
0291             Ort::MemoryInfo mem = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
0292             Ort::Value input = Ort::Value::CreateTensor<float>(
0293                 mem,
0294                 const_cast<float*>(feats.data()),
0295                 feats.size(),
0296                 shape.data(), shape.size()
0297             );
0298 
0299             const char* in_names[]  = { in_name.c_str()  };
0300             const char* out_names[] = { out_name.c_str() };
0301 
0302             auto outputs = sess->Run(Ort::RunOptions{nullptr}, in_names, &input, 1, out_names, 1);
0303             float* out_ptr = outputs.front().GetTensorMutableData<float>();
0304             return out_ptr[0];
0305         };
0306     } catch (const std::exception& e) {
0307         if (err) *err = e.what();
0308         return {};
0309     }
0310 }
0311 
0312 } // namespace SiCaloPt