Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:19:49

0001 #include "CaloWaveformProcessing.h"
0002 #include "CaloWaveformFitting.h"
0003 
0004 #include <ffamodules/CDBInterface.h>
0005 
0006 #include <phool/onnxlib.h>
0007 
0008 #include <algorithm>  // for max
0009 #include <cassert>
0010 #include <cstdlib>  // for getenv
0011 #include <iostream>
0012 #include <limits>
0013 #include <memory>  // for allocator_traits<>::value_type
0014 #include <string>
0015 
0016 namespace
0017 {
0018   Ort::Session *onnxmodule;
0019 }
0020 
0021 CaloWaveformProcessing::~CaloWaveformProcessing()
0022 {
0023   delete m_Fitter;
0024 }
0025 
0026 void CaloWaveformProcessing::initialize_processing()
0027 {
0028   char *calibrationsroot = getenv("CALIBRATIONROOT");
0029   assert(calibrationsroot);
0030   if (m_processingtype == CaloWaveformProcessing::TEMPLATE || m_processingtype == CaloWaveformProcessing::TEMPLATE_NOSAT)
0031   {
0032     std::string calibrations_repo_template = std::string(calibrationsroot) + "/WaveformProcessing/templates/" + m_template_input_file;
0033     url_template = CDBInterface::instance()->getUrl(m_template_name, calibrations_repo_template);
0034     m_Fitter = new CaloWaveformFitting();
0035     m_Fitter->initialize_processing(url_template);
0036     if (m_processingtype == CaloWaveformProcessing::TEMPLATE_NOSAT)
0037     {
0038       m_Fitter->set_handleSaturation(false);
0039     }
0040     m_Fitter->set_nthreads(get_nthreads());
0041     if (m_setTimeLim)
0042     {
0043       m_Fitter->set_timeFitLim(m_timeLim_low, m_timeLim_high);
0044     }
0045 
0046     if (_bdosoftwarezerosuppression)
0047     {
0048       m_Fitter->set_softwarezerosuppression(_bdosoftwarezerosuppression, _nsoftwarezerosuppression);
0049     }
0050     if (_dobitfliprecovery)
0051     {
0052       m_Fitter->set_bitFlipRecovery(_dobitfliprecovery);
0053     }
0054   }
0055   else if (m_processingtype == CaloWaveformProcessing::ONNX)
0056   {
0057     // std::string calibrations_repo_model = m_model_name;
0058     // url_onnx = CDBInterface::instance()->getUrl("CEMC_ONNX", m_model_name);
0059     onnxmodule = onnxSession(m_model_name, Verbosity());
0060   }
0061   else if (m_processingtype == CaloWaveformProcessing::NYQUIST)
0062   {
0063     std::string calibrations_repo_template = std::string(calibrationsroot) + "/WaveformProcessing/templates/" + m_template_input_file;
0064     url_template = CDBInterface::instance()->getUrl(m_template_name, calibrations_repo_template);
0065     m_Fitter = new CaloWaveformFitting();
0066     m_Fitter->initialize_processing(url_template);
0067   }
0068 }
0069 
0070 std::vector<std::vector<float>> CaloWaveformProcessing::process_waveform(std::vector<std::vector<float>> waveformvector)
0071 {
0072   unsigned int size1 = waveformvector.size();
0073   std::vector<std::vector<float>> fitresults;
0074   if (m_processingtype == CaloWaveformProcessing::TEMPLATE || m_processingtype == CaloWaveformProcessing::TEMPLATE_NOSAT)
0075   {
0076     for (unsigned int i = 0; i < size1; i++)
0077     {
0078       waveformvector.at(i).push_back((float) i);
0079     }
0080     fitresults = m_Fitter->calo_processing_templatefit(waveformvector);
0081   }
0082   if (m_processingtype == CaloWaveformProcessing::ONNX)
0083   {
0084     fitresults = CaloWaveformProcessing::calo_processing_ONNX(waveformvector);
0085   }
0086   if (m_processingtype == CaloWaveformProcessing::FAST)
0087   {
0088     fitresults = CaloWaveformFitting::calo_processing_fast(waveformvector);
0089   }
0090   if (m_processingtype == CaloWaveformProcessing::NYQUIST)
0091   {
0092     fitresults = m_Fitter->calo_processing_nyquist(waveformvector);
0093   }
0094   return fitresults;
0095 }
0096 
0097 std::vector<std::vector<float>> CaloWaveformProcessing::calo_processing_ONNX(const std::vector<std::vector<float>> &chnlvector)
0098 {
0099   std::vector<std::vector<float>> fit_values;
0100   std::vector<float> val;  // single row to return
0101   unsigned int nchnls = chnlvector.size();
0102   for (unsigned int m = 0; m < nchnls; m++)
0103   {
0104     val.clear();
0105     const std::vector<float> &v = chnlvector.at(m);
0106     int size1 = v.size();
0107     if (size1 == _nzerosuppresssamples)
0108     {
0109       val.push_back(v.at(1) - v.at(0));
0110       val.push_back(std::numeric_limits<float>::quiet_NaN());
0111       val.push_back(v.at(0));
0112       if (v.at(0) != 0 && v.at(1) == 0)  // check if post-sample is 0, if so set high chi2
0113       {
0114         val.push_back(1000000);
0115       }
0116       else
0117       {
0118         val.push_back(std::numeric_limits<float>::quiet_NaN());
0119       }
0120       val.push_back(0);
0121       fit_values.push_back(val);
0122     }
0123     else
0124     {
0125       float maxheight = 0;
0126       int maxbin = 0;
0127       for (int i = 0; i < size1; i++)
0128       {
0129         if (v.at(i) > maxheight)
0130         {
0131           maxheight = v.at(i);
0132           maxbin = i;
0133         }
0134       }
0135       float pedestal = 1500;
0136       if (maxbin > 4)
0137       {
0138         pedestal = 0.5 * (v.at(maxbin - 4) + v.at(maxbin - 5));
0139       }
0140       else if (maxbin > 3)
0141       {
0142         pedestal = (v.at(maxbin - 4));
0143       }
0144       else
0145       {
0146         pedestal = 0.5 * (v.at(size1 - 3) + v.at(size1 - 2));
0147       }
0148 
0149       if ((_bdosoftwarezerosuppression && v.at(6) - v.at(0) < _nsoftwarezerosuppression) || (_maxsoftwarezerosuppression && maxheight - pedestal < _nsoftwarezerosuppression))
0150       {
0151         val.push_back(v.at(6) - v.at(0));
0152         val.push_back(std::numeric_limits<float>::quiet_NaN());
0153         val.push_back(v.at(0));
0154         if (v.at(0) != 0 && v.at(1) == 0)  // check if post-sample is 0, if so set high chi2
0155         {
0156           val.push_back(1000000);
0157         }
0158         else
0159         {
0160           val.push_back(std::numeric_limits<float>::quiet_NaN());
0161         }
0162         val.push_back(0);
0163         fit_values.push_back(val);
0164       }
0165       else
0166       {
0167         unsigned int nsamples = v.size();
0168         if (nsamples == 12)
0169         {
0170           // downstream onnx does not have a static input vector API,
0171           // so we need to make a copy
0172           std::vector<float> vtmp(v); //NOLINT(performance-unnecessary-copy-initialization)
0173           val = onnxInference(onnxmodule, vtmp, 1, onnxlib::n_input, onnxlib::n_output);
0174           unsigned int nvals = val.size();
0175           for (unsigned int i = 0; i < nvals; i++)
0176           {
0177             val.at(i) = val.at(i) * m_Onnx_factor.at(i) + m_Onnx_offset.at(i);
0178           }
0179           val.push_back(2000);
0180           val.push_back(0);
0181           fit_values.push_back(val);
0182         }
0183         else
0184         {
0185           float v_diff = v[1] - v[0];
0186           std::vector<float> val1{v_diff, std::numeric_limits<float>::quiet_NaN(), v[1], std::numeric_limits<float>::quiet_NaN(), 0};
0187           fit_values.push_back(val1);
0188         }
0189       }
0190     }
0191   }
0192   return fit_values;
0193 }
0194 
0195 int CaloWaveformProcessing::get_nthreads()
0196 {
0197   if (m_Fitter)
0198   {
0199     return m_Fitter->get_nthreads();
0200   }
0201   return _nthreads;
0202 }
0203 void CaloWaveformProcessing::set_nthreads(int nthreads)
0204 {
0205   _nthreads = nthreads;
0206   if (m_Fitter)
0207   {
0208     m_Fitter->set_nthreads(nthreads);
0209   }
0210   return;
0211 }