Back to home page

sPhenix code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 09:20:19

0001 #include "PHFieldInterpolated.h"
0002 
0003 #include <phool/phool.h> // PHWHERE
0004 
0005 #include <TFile.h>
0006 #include <TTree.h>
0007 
0008 #include <Geant4/G4SystemOfUnits.hh>
0009 
0010 #include <iostream>
0011 
0012 #include <set>
0013 #include <sstream>
0014 #include <stdexcept>
0015 #include <vector>
0016 
0017 #include <boost/format.hpp>
0018 
0019 void
0020 PHFieldInterpolated::load_fieldmap (
0021     std::string const& filename,
0022     float const& magfield_rescale
0023 ) {
0024     // Necessary as calling c_str() on an empty string returns NULL,
0025     // but TFile::Open doesn't have a guard clause for a NULL argument
0026     if (filename.empty()) {
0027         std::stringstream what;
0028         what
0029             << PHWHERE
0030             << " Empty filename";
0031         throw std::runtime_error(what.str());
0032     }
0033 
0034     TFile* file = TFile::Open(filename.c_str(), "READ");
0035     if (!file) {
0036         std::stringstream what;
0037         what
0038             << PHWHERE
0039             << " Could not open file " << filename;
0040         throw std::runtime_error(what.str());
0041     }
0042 
0043     TTree* tree{};
0044     file->GetObject("fieldmap", tree);
0045     if (!tree) {
0046         std::stringstream what;
0047         what
0048             << PHWHERE
0049             << " Could not get tree " << "fieldmap"
0050             << " from file " << filename;
0051         throw std::runtime_error(what.str());
0052     }
0053 
0054     Point_t point;
0055     Field_t field;
0056     std::map<std::string, float&> branches = {
0057         {"x", point(0)}, {"y", point(1)}, {"z", point(2)},
0058         {"bx", field(0)}, {"by", field(1)}, {"bz", field(2)},
0059     };
0060 
0061     for (auto& [name, reference] : branches) {
0062         if (tree->SetBranchAddress(name.c_str(), &reference) != TTree::kMatch) {
0063             std::stringstream what;
0064             what
0065                 << PHWHERE
0066                 << " Could not get branch " << name
0067                 << " from tree " << "fieldmap"
0068                 << " from file " << filename;
0069             throw std::runtime_error(what.str());
0070         }
0071     }
0072 
0073     if (Verbosity()) {
0074         std::cout
0075             << PHWHERE << "\n"
0076             << " Loading fieldmap from file " << filename
0077             << std::endl;
0078     }
0079 
0080     // The unique values taken on coordinate axes
0081     // used to check that the tree is a perfect grid
0082     std::array<std::set<float>, 3> points;
0083 
0084     for (std::size_t n = 0, N = tree->GetEntriesFast(); n < N; ++n) {
0085         tree->GetEntry(n);
0086 
0087         // Unit conversions
0088         point *= cm;
0089         for (int i = 0; i < 3; ++i) {
0090             points[i].insert(point(i));
0091         }
0092     }
0093 
0094     for (int i = 0; i < 3; ++i) {
0095         if (points[i].size() < 2) {
0096             std::stringstream what;
0097             what
0098                 << PHWHERE
0099                 << " not enough points";
0100             throw std::runtime_error(what.str());
0101         }
0102         m_N(i) = points[i].size();
0103         m_min(i) = *points[i].begin();
0104         m_max(i) = *points[i].rbegin();
0105         m_D(i) = (m_max(i) - m_min(i)) / (m_N(i) - 1);
0106     }
0107 
0108     if (static_cast<Long64_t>(m_N(0)) * static_cast<Long64_t>(m_N(1)) * static_cast<Long64_t>(m_N(2)) != tree->GetEntriesFast()) {
0109         std::stringstream what;
0110         what
0111             << PHWHERE
0112             << " tree is not a grid";
0113         throw std::runtime_error(what.str());
0114     }
0115 
0116     std::lock_guard lock(m_mutex);
0117     m_field.resize(tree->GetEntriesFast());
0118     for (std::size_t n = 0, N = tree->GetEntriesFast(); n < N; ++n) {
0119         tree->GetEntry(n);
0120 
0121         // Unit conversions
0122         point *= cm;
0123         field *= tesla * magfield_rescale;
0124 
0125         // Check that the deterministic computation
0126         // actually matches what we're reading in
0127         Indices_t indices = get_indices(point);
0128         if (!point.isApprox(get_point(indices))) {
0129             std::stringstream what;
0130             what
0131                 << PHWHERE
0132                 << " point does not round trip"
0133                 << " (with " << point.transpose() << ")";
0134             throw std::runtime_error(what.str());
0135         }
0136 
0137         m_field[get_index(indices)] = field;
0138     }
0139 
0140     file->Close();
0141 }
0142 
0143 void
0144 PHFieldInterpolated::GetFieldValue (
0145     double const* point_as_arr, // pointer to (immutable) double[4]
0146     double* field_as_arr // pointer to (mutable) double[3]
0147 ) const {
0148     Point_t point {
0149         (float)point_as_arr[0],
0150         (float)point_as_arr[1],
0151         (float)point_as_arr[2],
0152     };
0153 
0154     for (int i = 0; i < 3; ++i) {
0155         field_as_arr[i] = 0;
0156     }
0157 
0158     // Catches points out of bounds and early returns leaving field as 0
0159     try {
0160         validate_point(point);
0161     } catch (std::exception const&) {
0162         if (1 < Verbosity()) {
0163             std::cout
0164                 << PHWHERE
0165                 << " Returning 0 for point out of fieldmap bounds"
0166                 << " (at " << point.transpose() << ")"
0167                 << std::endl;
0168         }
0169         return;
0170     }
0171 
0172     // This should not throw if point is in bounds, which we just checked
0173     // If this throws, there is a bug in class logic
0174     Field_t field = get_interpolated(point);
0175     for (int i = 0; i < 3; ++i) {
0176         field_as_arr[i] = field(i);
0177     }
0178 }
0179 
0180 Eigen::VectorXf
0181 PHFieldInterpolated::get_design_vector (
0182     Point_t const& point,
0183     InterpolationCache const& cache
0184 ) {
0185     float x = point(0) - cache.m_center(0);
0186     float y = point(1) - cache.m_center(1);
0187     float z = point(2) - cache.m_center(2);
0188 
0189     Eigen::VectorXf design_vector(20);
0190     design_vector <<
0191         // O(0)
0192         1,
0193 
0194         // O(1)
0195         x, y, z,
0196 
0197         // O(2)
0198         x*x, x*y, x*z,
0199         y*y, y*z,
0200         z*z,
0201 
0202         // O(3)
0203         x*x*x, x*x*y, x*x*z, x*y*y, x*y*z, x*z*z,
0204         y*y*y, y*y*z, y*z*z,
0205         z*z*z;
0206 
0207     return design_vector;
0208 }
0209 
0210 void
0211 PHFieldInterpolated::cache_interpolation (
0212     Point_t const& point,
0213     InterpolationCache& cache
0214 ) const {
0215     Indices_t indices = get_indices(point);
0216 
0217     // We get neighbors offset by [-1, +2] relative to the buffered point
0218     // If we're close enough to an edge, this range isn't valid
0219     // buffer about a shifted voxel further in instead
0220     for (int i = 0; i < 3; ++i ) {
0221         while (indices(i) - 1 < 0) { ++indices(i); }
0222         while (m_N(i) < indices(i) + 3) { --indices(i); }
0223     }
0224 
0225     // Our coefficients have already been computed for the voxel we want to evaluate in
0226     if ((indices - cache.m_buffered_indices).norm() == 0) { return; }
0227 
0228     cache.m_buffered_indices = indices;
0229 
0230     // The point at the the center of the voxel
0231     // get_point gets the left-down-back corner
0232     cache.m_center = get_point(indices);
0233     for (int i = 0; i < 3; ++i) {
0234         cache.m_center(i) += m_D(i) / 2;
0235     }
0236 
0237     // Effectively we are solving three scalar interpolation problems
0238     // However, all systems of equations will share the same design matrix
0239     Eigen::MatrixXf M(64, 20);
0240     std::array<Eigen::VectorXf, 3> solution_vectors {
0241         Eigen::VectorXf(64),
0242         Eigen::VectorXf(64),
0243         Eigen::VectorXf(64),
0244     };
0245 
0246     // Get the 64 (nearest) neighboring points about the cell containing point and its neighboring cells
0247     // Note that the indices are of the left-down-back corner of this cell
0248     for (int row = 0; row < 64; ++row) {
0249         indices = {
0250             cache.m_buffered_indices(0) + (row / 16) - 1,
0251             cache.m_buffered_indices(1) + ((row / 4) % 4) - 1,
0252             cache.m_buffered_indices(2) + (row % 4) - 1,
0253         };
0254 
0255         M.row(row) = get_design_vector(get_point(indices), cache);
0256         for (int i = 0; i < 3; ++i) {
0257             solution_vectors[i](row) = get_field(indices)(i);
0258         }
0259     }
0260 
0261     // Use Eigen to solve the least squares problem we've set up
0262     for (int i = 0; i < 3; ++i) {
0263         cache.m_coefficients[i] = M.bdcSvd (
0264             Eigen::ComputeThinU | Eigen::ComputeThinV
0265         ).solve (solution_vectors[i]);
0266     }
0267 }
0268 
0269 PHFieldInterpolated::InterpolationCache&
0270 PHFieldInterpolated::get_cache (
0271 ) const {
0272     // Update the access counts
0273     InterpolationCache& this_cache = m_caches[std::this_thread::get_id()];
0274     for (auto& [thread_id, cache] : m_caches) {
0275         if (cache.m_queue_index < this_cache.m_queue_index) { ++cache.m_queue_index; }
0276     }
0277     this_cache.m_queue_index = 0;
0278 
0279     return this_cache;
0280 }
0281 
0282 PHFieldInterpolated::Field_t
0283 PHFieldInterpolated::get_interpolated (
0284     Point_t const& point
0285 ) const {
0286 
0287     InterpolationCache this_cache;
0288 
0289     // Get a copy of the interpolation information this thread is using
0290     {
0291         std::lock_guard lock(m_mutex);
0292         this_cache = get_cache();
0293     }
0294 
0295     // Update the cache to be about the point
0296     cache_interpolation(point, this_cache);
0297 
0298     // Update its place in the member map
0299     {
0300         std::lock_guard lock(m_mutex);
0301         get_cache() = this_cache;
0302 
0303         // Prune map entries which haven't been used in a while
0304         std::erase_if (m_caches, [](auto const& key_val_pair) {
0305             return MAX_THREADS <= key_val_pair.second.m_queue_index;
0306         });
0307     }
0308 
0309     return {
0310         get_design_vector(point, this_cache).dot(this_cache.m_coefficients[0]),
0311         get_design_vector(point, this_cache).dot(this_cache.m_coefficients[1]),
0312         get_design_vector(point, this_cache).dot(this_cache.m_coefficients[2]),
0313     };
0314 }
0315 
0316 PHFieldInterpolated::Indices_t
0317 PHFieldInterpolated::get_indices (
0318     std::size_t const& index
0319 ) const {
0320     return {
0321         ((int)index / (m_N(1) * m_N(2))) % m_N(0),
0322         ((int)index / m_N(2)) % m_N(1),
0323         (int)index % m_N(2),
0324     };
0325 }
0326 
0327 std::size_t
0328 PHFieldInterpolated::get_index (
0329     Indices_t const& indices
0330 ) const {
0331     validate_indices(indices);
0332     return
0333         (indices(0) * m_N(1) * m_N(2)) +
0334         (indices(1) * m_N(2)) +
0335         indices(2);
0336 }
0337 
0338 PHFieldInterpolated::Indices_t
0339 PHFieldInterpolated::get_indices (
0340     Point_t const& point
0341 ) const {
0342     validate_point(point);
0343     return {
0344         (int)std::floor(point(0) / m_D(0)) + (m_N(0) / 2),
0345         (int)std::floor(point(1) / m_D(1)) + (m_N(1) / 2),
0346         (int)std::floor(point(2) / m_D(2)) + (m_N(2) / 2),
0347     };
0348 }
0349 
0350 PHFieldInterpolated::Point_t
0351 PHFieldInterpolated::get_point (
0352     Indices_t const& indices
0353 ) const {
0354     validate_indices(indices);
0355     return {
0356         static_cast<float>(indices(0) - ((m_N(0) - 1.0) / 2.0)) * m_D(0),
0357         static_cast<float>(indices(1) - ((m_N(1) - 1.0) / 2.0)) * m_D(1),
0358         static_cast<float>(indices(2) - ((m_N(2) - 1.0) / 2.0)) * m_D(2),
0359     };
0360 }
0361 
0362 void
0363 PHFieldInterpolated::validate_indices (
0364     Indices_t const& indices
0365 ) const {
0366     for (int i = 0; i < 3; ++i) {
0367         if (indices(i) < 0 || indices(i) >= m_N(i)) {
0368             std::stringstream what;
0369             what
0370                 << PHWHERE
0371                 << " Component out of range"
0372                 << " (at " << indices.transpose() << ")";
0373             throw std::runtime_error(what.str());
0374         }
0375     }
0376 }
0377 
0378 void
0379 PHFieldInterpolated::validate_point (
0380     Point_t const& point 
0381 ) const {
0382     for (int i = 0; i < 3; ++i) {
0383         if (m_max(i) < point(i) || point(i) < m_min(i)) {
0384             std::stringstream what;
0385             what
0386                 << PHWHERE
0387                 << " Component out of range"
0388                 << " (at " << point.transpose() << ")";
0389             throw std::runtime_error(what.str());
0390         }
0391     }
0392 }
0393 
0394 void
0395 PHFieldInterpolated::print (
0396     std::ostream& stream
0397 ) const {
0398     std::lock_guard lock(m_mutex);
0399 
0400     stream
0401         << PHWHERE << "\n"
0402         << " size: " << m_field.size()
0403         << " num caches: " << m_caches.size()
0404         << std::endl;
0405 
0406     if (Verbosity() < 1) { return; }
0407 
0408     for (int i = 0; i < 3; ++i) {
0409         stream
0410             << " component: " << i
0411             << " min: " << m_min(i)
0412             << " max: " << m_max(i)
0413             << " step: " << m_D(i)
0414             << " count: " << m_N(i)
0415             << std::endl;
0416     }
0417 
0418     for (auto const& [thread_id, cache] : m_caches) {
0419         stream
0420             << " thread id: " << thread_id
0421             << " queue index: " << cache.m_queue_index
0422             << std::endl;
0423     }
0424 
0425     if (Verbosity() < 2) { return; }
0426 
0427     for (std::size_t index = 0; index < m_field.size(); ++index) {
0428         Indices_t indices = get_indices(index);
0429         stream
0430             << " indices: " << indices.transpose()
0431             << " point: " << get_point(indices).transpose()
0432             << " field: " << get_field(indices).transpose()
0433             << std::endl;
0434 
0435         if (Verbosity() < 3 && index == 10) {
0436             stream << " ..." << std::endl;
0437             return;
0438         }
0439     }
0440 }
0441 
0442 void
0443 PHFieldInterpolated::print_coefficients (
0444     std::ostream& stream
0445 ) const {
0446     std::lock_guard lock(m_mutex);
0447     InterpolationCache const& cache = m_caches[std::this_thread::get_id()];
0448     for (int i = 0; i < 3; ++i) {
0449         stream
0450             << cache.m_coefficients[i].transpose()
0451             << std::endl;
0452     }
0453 }
0454