File indexing completed on 2025-08-06 08:14:56
0001 #include <TFile.h>
0002 #include <TTree.h>
0003 #include <TString.h>
0004 #include <TChain.h>
0005
0006
0007 #if not defined(__CINT__) || defined(__MAKECINT__)
0008
0009 #include "TMVA/Factory.h"
0010 #include "TMVA/Tools.h"
0011 #endif
0012
0013
0014
0015 TChain* handleFile(string name, string extension, string treename, int filecount){
0016 TChain *all = new TChain(treename.c_str());
0017 string temp;
0018 for (int i = 0; i < filecount; ++i)
0019 {
0020
0021 ostringstream s;
0022 s<<i;
0023 temp = name+string(s.str())+extension;
0024 all->Add(temp.c_str());
0025 }
0026 return all;
0027 }
0028
0029
0030 void makeFactory(TTree* signalTree,std::string outfile,std::string factoryname)
0031 {
0032 using namespace TMVA;
0033 TString jobname(factoryname.c_str());
0034 TFile *targetFile = new TFile(outfile.c_str(),"RECREATE");
0035 Factory *factory = new Factory(jobname,targetFile,"AnalysisType=Regression");
0036 factory->AddRegressionTree(signalTree,1.0);
0037 factory->AddVariable("track1_pt",'F');
0038 factory->AddVariable("track2_pt",'F');
0039 factory->AddVariable("track1_phi",'F');
0040 factory->AddVariable("track1_phi-track2_phi","d#phi","rad");
0041 factory->AddVariable("track1_eta",'F');
0042 factory->AddVariable("track1_eta-track2_eta","d#eta","rad");
0043 factory->AddVariable("vtx_radius","radius","[cm]");
0044 factory->AddTarget("tvtx_radius","radius","[cm]");
0045
0046 string track_pT_cut = "";
0047
0048
0049 string tCutInitializer = track_pT_cut;
0050 TCut preTraingCuts(tCutInitializer.c_str());
0051 factory->PrepareTrainingAndTestTree(preTraingCuts,"nTrain_Regression=0:nTest_Regression=0");
0052 factory->BookMethod(Types::kMLP,"MLP_ANN","HiddenLayers=2000");
0053 factory->BookMethod(Types::kMLP,"MLP_ANN2","HiddenLayers=500,6");
0054
0055
0056 factory->TrainAllMethods();
0057 factory->TestAllMethods();
0058 factory->EvaluateAllMethods();
0059 targetFile->Write();
0060 targetFile->Close();
0061 }
0062
0063
0064 int vtxPredictionTraining(){
0065 using namespace std;
0066 string treePath = "/sphenix/user/vassalli/gammasample/conversiononlineanalysis";
0067 string treeExtension = ".root";
0068 string outname = "vtxTrain.root";
0069 unsigned int nFiles=200;
0070
0071 TChain *signalTree = handleFile(treePath,treeExtension,"vtxingTree",nFiles);
0072 makeFactory(signalTree,outname,"vtxFactory");
0073
0074
0075 }