File indexing completed on 2025-12-16 09:20:19
0001 #include "PHFieldInterpolated.h"
0002
0003 #include <phool/phool.h> // PHWHERE
0004
0005 #include <TFile.h>
0006 #include <TTree.h>
0007
0008 #include <Geant4/G4SystemOfUnits.hh>
0009
0010 #include <iostream>
0011
0012 #include <set>
0013 #include <sstream>
0014 #include <stdexcept>
0015 #include <vector>
0016
0017 #include <boost/format.hpp>
0018
0019 void
0020 PHFieldInterpolated::load_fieldmap (
0021 std::string const& filename,
0022 float const& magfield_rescale
0023 ) {
0024
0025
0026 if (filename.empty()) {
0027 std::stringstream what;
0028 what
0029 << PHWHERE
0030 << " Empty filename";
0031 throw std::runtime_error(what.str());
0032 }
0033
0034 TFile* file = TFile::Open(filename.c_str(), "READ");
0035 if (!file) {
0036 std::stringstream what;
0037 what
0038 << PHWHERE
0039 << " Could not open file " << filename;
0040 throw std::runtime_error(what.str());
0041 }
0042
0043 TTree* tree{};
0044 file->GetObject("fieldmap", tree);
0045 if (!tree) {
0046 std::stringstream what;
0047 what
0048 << PHWHERE
0049 << " Could not get tree " << "fieldmap"
0050 << " from file " << filename;
0051 throw std::runtime_error(what.str());
0052 }
0053
0054 Point_t point;
0055 Field_t field;
0056 std::map<std::string, float&> branches = {
0057 {"x", point(0)}, {"y", point(1)}, {"z", point(2)},
0058 {"bx", field(0)}, {"by", field(1)}, {"bz", field(2)},
0059 };
0060
0061 for (auto& [name, reference] : branches) {
0062 if (tree->SetBranchAddress(name.c_str(), &reference) != TTree::kMatch) {
0063 std::stringstream what;
0064 what
0065 << PHWHERE
0066 << " Could not get branch " << name
0067 << " from tree " << "fieldmap"
0068 << " from file " << filename;
0069 throw std::runtime_error(what.str());
0070 }
0071 }
0072
0073 if (Verbosity()) {
0074 std::cout
0075 << PHWHERE << "\n"
0076 << " Loading fieldmap from file " << filename
0077 << std::endl;
0078 }
0079
0080
0081
0082 std::array<std::set<float>, 3> points;
0083
0084 for (std::size_t n = 0, N = tree->GetEntriesFast(); n < N; ++n) {
0085 tree->GetEntry(n);
0086
0087
0088 point *= cm;
0089 for (int i = 0; i < 3; ++i) {
0090 points[i].insert(point(i));
0091 }
0092 }
0093
0094 for (int i = 0; i < 3; ++i) {
0095 if (points[i].size() < 2) {
0096 std::stringstream what;
0097 what
0098 << PHWHERE
0099 << " not enough points";
0100 throw std::runtime_error(what.str());
0101 }
0102 m_N(i) = points[i].size();
0103 m_min(i) = *points[i].begin();
0104 m_max(i) = *points[i].rbegin();
0105 m_D(i) = (m_max(i) - m_min(i)) / (m_N(i) - 1);
0106 }
0107
0108 if (static_cast<Long64_t>(m_N(0)) * static_cast<Long64_t>(m_N(1)) * static_cast<Long64_t>(m_N(2)) != tree->GetEntriesFast()) {
0109 std::stringstream what;
0110 what
0111 << PHWHERE
0112 << " tree is not a grid";
0113 throw std::runtime_error(what.str());
0114 }
0115
0116 std::lock_guard lock(m_mutex);
0117 m_field.resize(tree->GetEntriesFast());
0118 for (std::size_t n = 0, N = tree->GetEntriesFast(); n < N; ++n) {
0119 tree->GetEntry(n);
0120
0121
0122 point *= cm;
0123 field *= tesla * magfield_rescale;
0124
0125
0126
0127 Indices_t indices = get_indices(point);
0128 if (!point.isApprox(get_point(indices))) {
0129 std::stringstream what;
0130 what
0131 << PHWHERE
0132 << " point does not round trip"
0133 << " (with " << point.transpose() << ")";
0134 throw std::runtime_error(what.str());
0135 }
0136
0137 m_field[get_index(indices)] = field;
0138 }
0139
0140 file->Close();
0141 }
0142
0143 void
0144 PHFieldInterpolated::GetFieldValue (
0145 double const* point_as_arr,
0146 double* field_as_arr
0147 ) const {
0148 Point_t point {
0149 (float)point_as_arr[0],
0150 (float)point_as_arr[1],
0151 (float)point_as_arr[2],
0152 };
0153
0154 for (int i = 0; i < 3; ++i) {
0155 field_as_arr[i] = 0;
0156 }
0157
0158
0159 try {
0160 validate_point(point);
0161 } catch (std::exception const&) {
0162 if (1 < Verbosity()) {
0163 std::cout
0164 << PHWHERE
0165 << " Returning 0 for point out of fieldmap bounds"
0166 << " (at " << point.transpose() << ")"
0167 << std::endl;
0168 }
0169 return;
0170 }
0171
0172
0173
0174 Field_t field = get_interpolated(point);
0175 for (int i = 0; i < 3; ++i) {
0176 field_as_arr[i] = field(i);
0177 }
0178 }
0179
0180 Eigen::VectorXf
0181 PHFieldInterpolated::get_design_vector (
0182 Point_t const& point,
0183 InterpolationCache const& cache
0184 ) {
0185 float x = point(0) - cache.m_center(0);
0186 float y = point(1) - cache.m_center(1);
0187 float z = point(2) - cache.m_center(2);
0188
0189 Eigen::VectorXf design_vector(20);
0190 design_vector <<
0191
0192 1,
0193
0194
0195 x, y, z,
0196
0197
0198 x*x, x*y, x*z,
0199 y*y, y*z,
0200 z*z,
0201
0202
0203 x*x*x, x*x*y, x*x*z, x*y*y, x*y*z, x*z*z,
0204 y*y*y, y*y*z, y*z*z,
0205 z*z*z;
0206
0207 return design_vector;
0208 }
0209
0210 void
0211 PHFieldInterpolated::cache_interpolation (
0212 Point_t const& point,
0213 InterpolationCache& cache
0214 ) const {
0215 Indices_t indices = get_indices(point);
0216
0217
0218
0219
0220 for (int i = 0; i < 3; ++i ) {
0221 while (indices(i) - 1 < 0) { ++indices(i); }
0222 while (m_N(i) < indices(i) + 3) { --indices(i); }
0223 }
0224
0225
0226 if ((indices - cache.m_buffered_indices).norm() == 0) { return; }
0227
0228 cache.m_buffered_indices = indices;
0229
0230
0231
0232 cache.m_center = get_point(indices);
0233 for (int i = 0; i < 3; ++i) {
0234 cache.m_center(i) += m_D(i) / 2;
0235 }
0236
0237
0238
0239 Eigen::MatrixXf M(64, 20);
0240 std::array<Eigen::VectorXf, 3> solution_vectors {
0241 Eigen::VectorXf(64),
0242 Eigen::VectorXf(64),
0243 Eigen::VectorXf(64),
0244 };
0245
0246
0247
0248 for (int row = 0; row < 64; ++row) {
0249 indices = {
0250 cache.m_buffered_indices(0) + (row / 16) - 1,
0251 cache.m_buffered_indices(1) + ((row / 4) % 4) - 1,
0252 cache.m_buffered_indices(2) + (row % 4) - 1,
0253 };
0254
0255 M.row(row) = get_design_vector(get_point(indices), cache);
0256 for (int i = 0; i < 3; ++i) {
0257 solution_vectors[i](row) = get_field(indices)(i);
0258 }
0259 }
0260
0261
0262 for (int i = 0; i < 3; ++i) {
0263 cache.m_coefficients[i] = M.bdcSvd (
0264 Eigen::ComputeThinU | Eigen::ComputeThinV
0265 ).solve (solution_vectors[i]);
0266 }
0267 }
0268
0269 PHFieldInterpolated::InterpolationCache&
0270 PHFieldInterpolated::get_cache (
0271 ) const {
0272
0273 InterpolationCache& this_cache = m_caches[std::this_thread::get_id()];
0274 for (auto& [thread_id, cache] : m_caches) {
0275 if (cache.m_queue_index < this_cache.m_queue_index) { ++cache.m_queue_index; }
0276 }
0277 this_cache.m_queue_index = 0;
0278
0279 return this_cache;
0280 }
0281
0282 PHFieldInterpolated::Field_t
0283 PHFieldInterpolated::get_interpolated (
0284 Point_t const& point
0285 ) const {
0286
0287 InterpolationCache this_cache;
0288
0289
0290 {
0291 std::lock_guard lock(m_mutex);
0292 this_cache = get_cache();
0293 }
0294
0295
0296 cache_interpolation(point, this_cache);
0297
0298
0299 {
0300 std::lock_guard lock(m_mutex);
0301 get_cache() = this_cache;
0302
0303
0304 std::erase_if (m_caches, [](auto const& key_val_pair) {
0305 return MAX_THREADS <= key_val_pair.second.m_queue_index;
0306 });
0307 }
0308
0309 return {
0310 get_design_vector(point, this_cache).dot(this_cache.m_coefficients[0]),
0311 get_design_vector(point, this_cache).dot(this_cache.m_coefficients[1]),
0312 get_design_vector(point, this_cache).dot(this_cache.m_coefficients[2]),
0313 };
0314 }
0315
0316 PHFieldInterpolated::Indices_t
0317 PHFieldInterpolated::get_indices (
0318 std::size_t const& index
0319 ) const {
0320 return {
0321 ((int)index / (m_N(1) * m_N(2))) % m_N(0),
0322 ((int)index / m_N(2)) % m_N(1),
0323 (int)index % m_N(2),
0324 };
0325 }
0326
0327 std::size_t
0328 PHFieldInterpolated::get_index (
0329 Indices_t const& indices
0330 ) const {
0331 validate_indices(indices);
0332 return
0333 (indices(0) * m_N(1) * m_N(2)) +
0334 (indices(1) * m_N(2)) +
0335 indices(2);
0336 }
0337
0338 PHFieldInterpolated::Indices_t
0339 PHFieldInterpolated::get_indices (
0340 Point_t const& point
0341 ) const {
0342 validate_point(point);
0343 return {
0344 (int)std::floor(point(0) / m_D(0)) + (m_N(0) / 2),
0345 (int)std::floor(point(1) / m_D(1)) + (m_N(1) / 2),
0346 (int)std::floor(point(2) / m_D(2)) + (m_N(2) / 2),
0347 };
0348 }
0349
0350 PHFieldInterpolated::Point_t
0351 PHFieldInterpolated::get_point (
0352 Indices_t const& indices
0353 ) const {
0354 validate_indices(indices);
0355 return {
0356 static_cast<float>(indices(0) - ((m_N(0) - 1.0) / 2.0)) * m_D(0),
0357 static_cast<float>(indices(1) - ((m_N(1) - 1.0) / 2.0)) * m_D(1),
0358 static_cast<float>(indices(2) - ((m_N(2) - 1.0) / 2.0)) * m_D(2),
0359 };
0360 }
0361
0362 void
0363 PHFieldInterpolated::validate_indices (
0364 Indices_t const& indices
0365 ) const {
0366 for (int i = 0; i < 3; ++i) {
0367 if (indices(i) < 0 || indices(i) >= m_N(i)) {
0368 std::stringstream what;
0369 what
0370 << PHWHERE
0371 << " Component out of range"
0372 << " (at " << indices.transpose() << ")";
0373 throw std::runtime_error(what.str());
0374 }
0375 }
0376 }
0377
0378 void
0379 PHFieldInterpolated::validate_point (
0380 Point_t const& point
0381 ) const {
0382 for (int i = 0; i < 3; ++i) {
0383 if (m_max(i) < point(i) || point(i) < m_min(i)) {
0384 std::stringstream what;
0385 what
0386 << PHWHERE
0387 << " Component out of range"
0388 << " (at " << point.transpose() << ")";
0389 throw std::runtime_error(what.str());
0390 }
0391 }
0392 }
0393
0394 void
0395 PHFieldInterpolated::print (
0396 std::ostream& stream
0397 ) const {
0398 std::lock_guard lock(m_mutex);
0399
0400 stream
0401 << PHWHERE << "\n"
0402 << " size: " << m_field.size()
0403 << " num caches: " << m_caches.size()
0404 << std::endl;
0405
0406 if (Verbosity() < 1) { return; }
0407
0408 for (int i = 0; i < 3; ++i) {
0409 stream
0410 << " component: " << i
0411 << " min: " << m_min(i)
0412 << " max: " << m_max(i)
0413 << " step: " << m_D(i)
0414 << " count: " << m_N(i)
0415 << std::endl;
0416 }
0417
0418 for (auto const& [thread_id, cache] : m_caches) {
0419 stream
0420 << " thread id: " << thread_id
0421 << " queue index: " << cache.m_queue_index
0422 << std::endl;
0423 }
0424
0425 if (Verbosity() < 2) { return; }
0426
0427 for (std::size_t index = 0; index < m_field.size(); ++index) {
0428 Indices_t indices = get_indices(index);
0429 stream
0430 << " indices: " << indices.transpose()
0431 << " point: " << get_point(indices).transpose()
0432 << " field: " << get_field(indices).transpose()
0433 << std::endl;
0434
0435 if (Verbosity() < 3 && index == 10) {
0436 stream << " ..." << std::endl;
0437 return;
0438 }
0439 }
0440 }
0441
0442 void
0443 PHFieldInterpolated::print_coefficients (
0444 std::ostream& stream
0445 ) const {
0446 std::lock_guard lock(m_mutex);
0447 InterpolationCache const& cache = m_caches[std::this_thread::get_id()];
0448 for (int i = 0; i < 3; ++i) {
0449 stream
0450 << cache.m_coefficients[i].transpose()
0451 << std::endl;
0452 }
0453 }
0454