Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:19:51

0001 /*
0002 I trained a crappy CNN to classify EMCal clusters as photon or not photon. S.Li 8.16.2024
0003 */
0004 
0005 #include "RawClusterCNNClassifier.h"
0006 
0007 #include <calobase/RawCluster.h>
0008 #include <calobase/RawClusterContainer.h>
0009 #include <calobase/RawClusterUtility.h>
0010 #include <calobase/RawTower.h>
0011 #include <calobase/RawTowerContainer.h>
0012 #include <calobase/RawTowerDefs.h>
0013 #include <calobase/RawTowerGeom.h>
0014 #include <calobase/RawTowerGeomContainer.h>
0015 
0016 // Tower stuff
0017 #include <calobase/TowerInfo.h>
0018 #include <calobase/TowerInfoContainer.h>
0019 #include <calobase/TowerInfoDefs.h>
0020 
0021 #include <ffaobjects/EventHeader.h>
0022 
0023 #include <fun4all/Fun4AllReturnCodes.h>
0024 #include <fun4all/Fun4AllServer.h>
0025 
0026 #include <phool/PHCompositeNode.h>
0027 #include <phool/getClass.h>
0028 #include <phool/onnxlib.h>
0029 #include <phool/phool.h>
0030 
0031 #include <iostream>
0032 #include <vector>
0033 
0034 RawClusterCNNClassifier::RawClusterCNNClassifier(const std::string &name)
0035   : SubsysReco(name)
0036 {
0037 }
0038 
0039 RawClusterCNNClassifier::~RawClusterCNNClassifier()
0040 {
0041   delete onnxmodule;
0042 }
0043 
0044 int RawClusterCNNClassifier::Init(PHCompositeNode *topNode)
0045 {
0046   // init the onnx model
0047   onnxmodule = onnxSession(m_modelPath,Verbosity());
0048 
0049   if (m_inputNodeName == m_outputNodeName)
0050   {
0051     std::cout << "RawClusterCNNClassifier::Init: inputNodeName and outputNodeName are the same, setting inplace to true" << std::endl;
0052     inplace = true;
0053   }
0054   CreateNodes(topNode);
0055 
0056   return Fun4AllReturnCodes::EVENT_OK;
0057 }
0058 
0059 int RawClusterCNNClassifier::process_event(PHCompositeNode *topNode)
0060 {
0061   // get the cluster container
0062   std::string clusterNodeName = m_inputNodeName;
0063   RawClusterContainer *clusterContainer = findNode::getClass<RawClusterContainer>(topNode, clusterNodeName);
0064   if (!clusterContainer)
0065   {
0066     std::cout << "RawClusterCNNClassifier::process_event::Could not locate input cluster node " << clusterNodeName << std::endl;
0067     return Fun4AllReturnCodes::ABORTEVENT;
0068   }
0069 
0070   if (inplace)
0071   {
0072     _clusters = clusterContainer;
0073   }
0074   else
0075   {
0076     _clusters->Reset();
0077     RawClusterContainer::Map clusterMap = clusterContainer->getClustersMap();
0078     for (auto &clusterPair : clusterMap)
0079     {
0080 
0081       RawCluster *recoCluster = (clusterPair.second);
0082       float clusterE = recoCluster->get_energy();
0083       if (clusterE < m_min_cluster_e)
0084       {
0085         continue;
0086       }
0087       RawCluster *newCluster = (RawCluster *) recoCluster->CloneMe();
0088       _clusters->AddCluster(newCluster);
0089     }
0090   }
0091 
0092   // I trained the model with the info from towerinfo container, raw tower should also work
0093   std::string towerNodeName = m_towerNodeName;
0094   TowerInfoContainer *emcTowerContainer = findNode::getClass<TowerInfoContainer>(topNode, towerNodeName);
0095   if (!emcTowerContainer)
0096   {
0097     std::cout << "RawClusterCNNClassifier::process_event Could not locate tower node " << towerNodeName << std::endl;
0098     return Fun4AllReturnCodes::ABORTEVENT;
0099   }
0100 
0101   RawClusterContainer::Map clusterMap = _clusters->getClustersMap();
0102   for (auto &clusterPair : clusterMap)
0103   {
0104     RawCluster *recoCluster = clusterPair.second;
0105     // reset the prob inplace
0106     recoCluster->set_prob(-1);
0107     CLHEP::Hep3Vector vertex(0, 0, 0);
0108     CLHEP::Hep3Vector E_vec_cluster_Full = RawClusterUtility::GetEVec(*recoCluster, vertex);
0109     float ET = E_vec_cluster_Full.perp();
0110     if (ET < minET)
0111     {
0112       continue;
0113     }
0114     const RawCluster::TowerMap tower_map =
0115         recoCluster->get_towermap();
0116 
0117     int maxtowerE = 0;
0118     int maxtowerieta = -1;
0119     int maxtoweriphi = -1;
0120 
0121     for (auto tower_iter : tower_map)
0122     {
0123       RawTowerDefs::keytype tower_key = tower_iter.first;
0124 
0125       // get ieta iphi
0126       int ix = RawTowerDefs::decode_index2(tower_key);  // iphi?
0127       int iy = RawTowerDefs::decode_index1(tower_key);  // ieta I  guess?(S.L.)
0128       RawTowerDefs::CalorimeterId caloid = RawTowerDefs::decode_caloid(tower_key);
0129       // check if cemc, but I guess they shoul all be anyways lol
0130       if (caloid != RawTowerDefs::CalorimeterId::CEMC)
0131       {
0132         continue;
0133       }
0134       // get the towerinfo key
0135       unsigned int towerinfokey = TowerInfoDefs::encode_emcal(iy, ix);  // this is the key for the towerinfo container get the towerinfo
0136       TowerInfo *towerinfo = emcTowerContainer->get_tower_at_key(towerinfokey);
0137       if (!towerinfo)
0138       {
0139         // should not happen
0140         std::cout << "No towerinfo for tower key " << towerinfokey
0141                   << std::endl;
0142         continue;
0143       }
0144       float towerE = towerinfo->get_energy();
0145       if (towerE > maxtowerE)
0146       {
0147         maxtowerE = towerE;
0148         maxtowerieta = iy;
0149         maxtoweriphi = ix;
0150       }
0151     }
0152     // find the N by N tower around the max tower
0153     std::vector<float> input;
0154     // resize to inputDimx * inputDimy
0155     int vectorSize = inputDimx * inputDimy;
0156     input.resize(vectorSize, 0);
0157 
0158     if (maxtowerE > 0)
0159     {
0160       int xlength = ((inputDimx - 1) / 2);
0161       int ylength = ((inputDimy - 1) / 2);
0162       if (maxtowerieta - ylength < 0 || maxtowerieta + ylength >= 96)
0163       {
0164         continue;
0165       }
0166       for (int ieta = maxtowerieta - ylength; ieta <= maxtowerieta + ylength; ieta++)
0167       {
0168         for (int iphi = maxtoweriphi - xlength; iphi <= maxtoweriphi + xlength; iphi++)
0169         {
0170           int mappediphi = iphi;
0171 
0172           if (mappediphi < 0)
0173           {
0174             mappediphi += 256;
0175           }
0176           if (mappediphi > 255)
0177           {
0178             mappediphi -= 256;
0179           }
0180           unsigned int towerinfokey = TowerInfoDefs::encode_emcal(ieta, mappediphi);
0181           TowerInfo *towerinfo = emcTowerContainer->get_tower_at_key(towerinfokey);
0182           if (!towerinfo)
0183           {
0184             // should not happen
0185             std::cout << "No towerinfo for tower key " << towerinfokey << std::endl;
0186             std::cout << "ieta: " << ieta << " iphi: " << mappediphi << std::endl;
0187             continue;
0188           }
0189           int index = ((ieta - maxtowerieta + ylength) * inputDimx) + iphi - maxtoweriphi + xlength;
0190           input.at(index) = towerinfo->get_energy();
0191         }
0192       }
0193     }
0194     std::vector<float> prob = onnxInference(onnxmodule, input, 1, inputDimx, inputDimy, inputDimz, outputDim);
0195     // std::cout << "new prob: " << prob[0] << "ET: " << ET << " original prob: " << recoCluster->get_prob() << std::endl;
0196     // inplace change for the prob for now
0197     recoCluster->set_prob(prob[0]);
0198   }
0199 
0200   return Fun4AllReturnCodes::EVENT_OK;
0201 }
0202 
0203 void RawClusterCNNClassifier::CreateNodes(PHCompositeNode *topNode)
0204 {
0205   PHNodeIterator iter(topNode);
0206 
0207   // Grab the CEMC node
0208   PHCompositeNode *dstNode = dynamic_cast<PHCompositeNode *>(iter.findFirst("PHCompositeNode", "DST"));
0209   if (!dstNode)
0210   {
0211     std::cout << PHWHERE << "DST Node missing, doing nothing." << std::endl;
0212     throw std::runtime_error("Failed to find DST node in EmcRawTowerBuilder::CreateNodes");
0213   }
0214 
0215   // Get the _det_name subnode
0216   PHCompositeNode *cemcNode = dynamic_cast<PHCompositeNode *>(iter.findFirst("PHCompositeNode", "CEMC"));
0217 
0218   // Check that it is there
0219   if (!cemcNode)
0220   {
0221     cemcNode = new PHCompositeNode("CEMC");
0222     dstNode->addNode(cemcNode);
0223   }
0224   std::string clusterNodeName = m_outputNodeName;
0225   _clusters = findNode::getClass<RawClusterContainer>(dstNode, clusterNodeName);
0226   if (!_clusters)
0227   {
0228     _clusters = new RawClusterContainer();
0229     PHIODataNode<PHObject> *clusterNode = new PHIODataNode<PHObject>(_clusters, clusterNodeName, "PHObject");
0230     cemcNode->addNode(clusterNode);
0231   }
0232 }