Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-08-06 08:17:58

0001 #include "PHFieldInterpolated.h"
0002 
0003 #include <phool/phool.h> // PHWHERE
0004 
0005 #include <TFile.h>
0006 #include <TTree.h>
0007 
0008 #include <Geant4/G4SystemOfUnits.hh>
0009 
0010 #include <iostream>
0011 #include <sstream>
0012 #include <stdexcept>
0013 #include <vector>
0014 
0015 #include <boost/format.hpp>
0016 
0017 void
0018 PHFieldInterpolated::load_fieldmap (
0019     std::string const& filename,
0020     float const& magfield_rescale
0021 ) {
0022     // Necessary as calling c_str() on an empty string returns NULL,
0023     // but TFile::Open doesn't have a guard clause for a NULL argument
0024     if (filename.empty()) {
0025         std::stringstream what;
0026         what
0027             << PHWHERE
0028             << " Empty filename";
0029         throw std::runtime_error(what.str());
0030     }
0031 
0032     TFile* file = TFile::Open(filename.c_str(), "READ");
0033     if (!file) {
0034         std::stringstream what;
0035         what
0036             << PHWHERE
0037             << " Could not open file " << filename;
0038         throw std::runtime_error(what.str());
0039     }
0040 
0041     TTree* tree{};
0042     file->GetObject("fieldmap", tree);
0043     if (!tree) {
0044         std::stringstream what;
0045         what
0046             << PHWHERE
0047             << " Could not get tree " << "fieldmap"
0048             << " from file " << filename;
0049         throw std::runtime_error(what.str());
0050     }
0051 
0052     Point_t point;
0053     Field_t field;
0054     std::map<std::string, float&> branches = {
0055         {"x", point(0)}, {"y", point(1)}, {"z", point(2)},
0056         {"bx", field(0)}, {"by", field(1)}, {"bz", field(2)},
0057     };
0058 
0059     for (auto& [name, reference] : branches) {
0060         if (tree->SetBranchAddress(name.c_str(), &reference) != TTree::kMatch) {
0061             std::stringstream what;
0062             what
0063                 << PHWHERE
0064                 << " Could not get branch " << name
0065                 << " from tree " << "fieldmap"
0066                 << " from file " << filename;
0067             throw std::runtime_error(what.str());
0068         }
0069     }
0070 
0071     if (Verbosity()) {
0072         std::cout
0073             << PHWHERE << "\n"
0074             << " Loading fieldmap from file " << filename
0075             << std::endl;
0076     }
0077 
0078     for (std::size_t n = 0, N = tree->GetEntriesFast(); n < N; ++n) {
0079         tree->GetEntry(n);
0080 
0081         // Unit conversions
0082         point *= cm;
0083         field *= tesla * magfield_rescale;
0084 
0085         // Check if what we read in is consistent with our deterministic computations of the domain
0086         Point_t expected = get_point(get_indices(n));
0087         if (1.0E-4 < (point - expected).norm()) {
0088             std::stringstream what;
0089             what
0090                 << PHWHERE
0091                 << " Read point from file which is dissimilar to calculated value"
0092                 << " entry: " << n
0093                 << " expected: " << expected.transpose()
0094                 << " read: " << point.transpose();
0095             throw std::runtime_error(what.str());
0096         };
0097 
0098         m_field.push_back(field);
0099 
0100         // Print information when < 5 and an ellipsis when == 5
0101         int print_index{6};
0102         switch (Verbosity()) {
0103             case 0:
0104             case 1:
0105                 // Left at 6 and will print nothing
0106                 break;
0107             case 2:
0108                 // Prints the next 5 terms after every x coordinate roll over
0109                 print_index = (int)n % (GRID_COUNT * GRID_COUNT * GRID_COUNT);
0110                 break;
0111             case 3:
0112                 // Prints the next 5 terms after every y coordinate roll over
0113                 print_index = (int)n % (GRID_COUNT * GRID_COUNT);
0114                 break;
0115             case 4:
0116                 // Prints the next 5 terms after every z coordinate roll over
0117                 print_index = (int)n % (GRID_COUNT);
0118                 break;
0119             default:
0120                 // Prints the whole map
0121                 print_index = 0;
0122                 break;
0123         }
0124 
0125         if (print_index < 5) {
0126             std::cout
0127                 << " index: " << n
0128                 << " point: " << point.transpose()
0129                 << " field: " << field.transpose()
0130                 << std::endl;
0131         }
0132 
0133         if (print_index == 5) {
0134             std::cout
0135                 << " ... "
0136                 << std::endl;
0137         }
0138     }
0139 
0140     file->Close();
0141 }
0142 
0143 void
0144 PHFieldInterpolated::GetFieldValue (
0145     double const* point_as_arr, // pointer to (immutable) double[4]
0146     double* field_as_arr // pointer to (mutable) double[3]
0147 ) const {
0148     try {
0149         Field_t field = get_interpolated ({
0150             (float)point_as_arr[0],
0151             (float)point_as_arr[1],
0152             (float)point_as_arr[2],
0153         });
0154         for (int i = 0; i < 3; ++i) {
0155             field_as_arr[i] = field(i);
0156         }
0157     } catch (std::exception const& e) {
0158         // If you're from a printout,
0159         // you'll need to look around this file
0160         // for the literal "e.what()"
0161         std::cout
0162             << PHWHERE << "\n"
0163             << "\t" << e.what() << "\n"
0164             << std::endl;
0165     }
0166 }
0167 
0168 Eigen::VectorXf
0169 PHFieldInterpolated::get_design_vector (
0170     Point_t const& point
0171 ) const {
0172     float x = point(0) - m_center(0);
0173     float y = point(1) - m_center(1);
0174     float z = point(2) - m_center(2);
0175 
0176     Eigen::VectorXf design_vector(20);
0177     design_vector <<
0178         // O(0)
0179         1,
0180 
0181         // O(1)
0182         x, y, z,
0183 
0184         // O(2)
0185         x*x, x*y, x*z,
0186         y*y, y*z,
0187         z*z,
0188 
0189         // O(3)
0190         x*x*x, x*x*y, x*x*z, x*y*y, x*y*z, x*z*z,
0191         y*y*y, y*y*z, y*z*z,
0192         z*z*z;
0193 
0194     return design_vector;
0195 }
0196 
0197 void
0198 PHFieldInterpolated::cache_interpolation (
0199     Point_t const& point
0200 ) const {
0201     Indices_t indices = get_indices(point);
0202     if ((indices - m_buffered_indices).norm() == 0) { return; }
0203     m_buffered_indices = indices;
0204 
0205     // The point at the the center of the cell
0206     // get_point gets the left-down-back corner
0207     m_center = get_point(indices);
0208     for (int i = 0; i < 3; ++i) {
0209         m_center(i) += GRID_STEP / 2;
0210     }
0211 
0212     // Effectively we are solving three scalar interpolation problems
0213     // However, all systems of equations will share the same design matrix
0214     Eigen::MatrixXf M(64, 20);
0215     std::array<Eigen::VectorXf, 3> solution_vectors {
0216         Eigen::VectorXf(64),
0217         Eigen::VectorXf(64),
0218         Eigen::VectorXf(64),
0219     };
0220 
0221     // Get the 64 neighboring points about the cell containing point and its neighboring cells
0222     // Note that the indices are of the left-down-back corner of this cell
0223     for (int row = 0; row < 64; ++row) {
0224         indices = {
0225             m_buffered_indices(0) + (row / 16) - 1,
0226             m_buffered_indices(1) + ((row / 4) % 4) - 1,
0227             m_buffered_indices(2) + (row % 4) - 1,
0228         };
0229 
0230         // Possible to throw while searching neighbors
0231         // Re-throw an error, but with a different message
0232         try {
0233             validate_indices(indices);
0234         } catch (std::exception const&) {
0235             m_buffered_indices = {-1, -1, -1};
0236             std::stringstream what;
0237             what
0238                 << PHWHERE
0239                 << " Point too close to edge of fieldmap "
0240                 << " (at " << point << ")";
0241             throw std::runtime_error(what.str());
0242         }
0243 
0244         M.row(row) = get_design_vector(get_point(indices));
0245         for (int i = 0; i < 3; ++i) {
0246             solution_vectors[i](row) = get_field(indices)(i);
0247         }
0248     }
0249 
0250     for (int i = 0; i < 3; ++i) {
0251         m_coefficients[i] = M.bdcSvd (
0252             Eigen::ComputeThinU | Eigen::ComputeThinV
0253         ).solve (solution_vectors[i]);
0254     }
0255 }
0256 
0257 PHFieldInterpolated::Field_t
0258 PHFieldInterpolated::get_interpolated (
0259     Point_t const& point
0260 ) const {
0261     cache_interpolation(point);
0262     return {
0263         get_design_vector(point).dot(m_coefficients[0]),
0264         get_design_vector(point).dot(m_coefficients[1]),
0265         get_design_vector(point).dot(m_coefficients[2]),
0266     };
0267 }
0268 
0269 PHFieldInterpolated::Indices_t
0270 PHFieldInterpolated::get_indices (
0271     std::size_t const& index
0272 ) {
0273     return {
0274         ((int)index / (GRID_COUNT * GRID_COUNT)),
0275         ((int)index / GRID_COUNT) % GRID_COUNT,
0276         ((int)index) % GRID_COUNT,
0277     };
0278 }
0279 
0280 std::size_t
0281 PHFieldInterpolated::get_index (
0282     Indices_t const& indices
0283 ) {
0284     validate_indices(indices);
0285     return
0286         (indices(0) * GRID_COUNT * GRID_COUNT) +
0287         (indices(1) * GRID_COUNT) +
0288         indices(2);
0289 }
0290 
0291 PHFieldInterpolated::Indices_t
0292 PHFieldInterpolated::get_indices (
0293     Point_t const& point
0294 ) {
0295     validate_point(point);
0296     return {
0297         (int)std::floor(point(0) / GRID_STEP) + (GRID_COUNT / 2),
0298         (int)std::floor(point(1) / GRID_STEP) + (GRID_COUNT / 2),
0299         (int)std::floor(point(2) / GRID_STEP) + (GRID_COUNT / 2),
0300     };
0301 }
0302 
0303 PHFieldInterpolated::Point_t
0304 PHFieldInterpolated::get_point (
0305     Indices_t const& indices
0306 ) {
0307     validate_indices(indices);
0308     return {
0309         // NOLINTNEXTLINE(bugprone-integer-division)
0310         (indices(0) - (GRID_COUNT / 2)) * GRID_STEP,
0311         // NOLINTNEXTLINE(bugprone-integer-division)
0312         (indices(1) - (GRID_COUNT / 2)) * GRID_STEP,
0313         // NOLINTNEXTLINE(bugprone-integer-division)
0314         (indices(2) - (GRID_COUNT / 2)) * GRID_STEP,
0315     };
0316 }
0317 
0318 void
0319 PHFieldInterpolated::validate_indices (
0320     Indices_t const& indices
0321 ) {
0322     for (int i = 0; i < 3; ++i) {
0323         if (indices(i) < 0 || indices(i) >= GRID_COUNT) {
0324             std::stringstream what;
0325             what
0326                 << PHWHERE
0327                 << " Component out of range"
0328                 << " (at " << indices << ")";
0329             throw std::runtime_error(what.str());
0330         }
0331     }
0332 }
0333 
0334 void
0335 PHFieldInterpolated::validate_point (
0336     Point_t const& point 
0337 ) {
0338     for (int i = 0; i < 3; ++i) {
0339         if (GRID_MAX < point(i) || point(i) < GRID_MIN) {
0340             std::stringstream what;
0341             what
0342                 << PHWHERE
0343                 << " Component out of range"
0344                 << " (at " << point << ")";
0345             throw std::runtime_error(what.str());
0346         }
0347     }
0348 }
0349 
0350 void
0351 PHFieldInterpolated::print_map (
0352 ) const {
0353     for (std::size_t index = 0; index < m_field.size(); ++index) {
0354         Indices_t indices = get_indices(index);
0355         std::cout
0356             << " indices: " << indices.transpose()
0357             << " point: " << get_point(indices).transpose()
0358             << " field: " << get_field(indices).transpose()
0359             << std::endl;
0360     }
0361 }
0362 
0363 void
0364 PHFieldInterpolated::print_coefficients (
0365 ) const {
0366     for (int i = 0; i < 3; ++i) {
0367         std::cout
0368             << m_coefficients[i].transpose()
0369             << std::endl;
0370     }
0371 }
0372