File indexing completed on 2025-08-06 08:17:58
0001 #include "PHFieldInterpolated.h"
0002
0003 #include <phool/phool.h> // PHWHERE
0004
0005 #include <TFile.h>
0006 #include <TTree.h>
0007
0008 #include <Geant4/G4SystemOfUnits.hh>
0009
0010 #include <iostream>
0011 #include <sstream>
0012 #include <stdexcept>
0013 #include <vector>
0014
0015 #include <boost/format.hpp>
0016
0017 void
0018 PHFieldInterpolated::load_fieldmap (
0019 std::string const& filename,
0020 float const& magfield_rescale
0021 ) {
0022
0023
0024 if (filename.empty()) {
0025 std::stringstream what;
0026 what
0027 << PHWHERE
0028 << " Empty filename";
0029 throw std::runtime_error(what.str());
0030 }
0031
0032 TFile* file = TFile::Open(filename.c_str(), "READ");
0033 if (!file) {
0034 std::stringstream what;
0035 what
0036 << PHWHERE
0037 << " Could not open file " << filename;
0038 throw std::runtime_error(what.str());
0039 }
0040
0041 TTree* tree{};
0042 file->GetObject("fieldmap", tree);
0043 if (!tree) {
0044 std::stringstream what;
0045 what
0046 << PHWHERE
0047 << " Could not get tree " << "fieldmap"
0048 << " from file " << filename;
0049 throw std::runtime_error(what.str());
0050 }
0051
0052 Point_t point;
0053 Field_t field;
0054 std::map<std::string, float&> branches = {
0055 {"x", point(0)}, {"y", point(1)}, {"z", point(2)},
0056 {"bx", field(0)}, {"by", field(1)}, {"bz", field(2)},
0057 };
0058
0059 for (auto& [name, reference] : branches) {
0060 if (tree->SetBranchAddress(name.c_str(), &reference) != TTree::kMatch) {
0061 std::stringstream what;
0062 what
0063 << PHWHERE
0064 << " Could not get branch " << name
0065 << " from tree " << "fieldmap"
0066 << " from file " << filename;
0067 throw std::runtime_error(what.str());
0068 }
0069 }
0070
0071 if (Verbosity()) {
0072 std::cout
0073 << PHWHERE << "\n"
0074 << " Loading fieldmap from file " << filename
0075 << std::endl;
0076 }
0077
0078 for (std::size_t n = 0, N = tree->GetEntriesFast(); n < N; ++n) {
0079 tree->GetEntry(n);
0080
0081
0082 point *= cm;
0083 field *= tesla * magfield_rescale;
0084
0085
0086 Point_t expected = get_point(get_indices(n));
0087 if (1.0E-4 < (point - expected).norm()) {
0088 std::stringstream what;
0089 what
0090 << PHWHERE
0091 << " Read point from file which is dissimilar to calculated value"
0092 << " entry: " << n
0093 << " expected: " << expected.transpose()
0094 << " read: " << point.transpose();
0095 throw std::runtime_error(what.str());
0096 };
0097
0098 m_field.push_back(field);
0099
0100
0101 int print_index{6};
0102 switch (Verbosity()) {
0103 case 0:
0104 case 1:
0105
0106 break;
0107 case 2:
0108
0109 print_index = (int)n % (GRID_COUNT * GRID_COUNT * GRID_COUNT);
0110 break;
0111 case 3:
0112
0113 print_index = (int)n % (GRID_COUNT * GRID_COUNT);
0114 break;
0115 case 4:
0116
0117 print_index = (int)n % (GRID_COUNT);
0118 break;
0119 default:
0120
0121 print_index = 0;
0122 break;
0123 }
0124
0125 if (print_index < 5) {
0126 std::cout
0127 << " index: " << n
0128 << " point: " << point.transpose()
0129 << " field: " << field.transpose()
0130 << std::endl;
0131 }
0132
0133 if (print_index == 5) {
0134 std::cout
0135 << " ... "
0136 << std::endl;
0137 }
0138 }
0139
0140 file->Close();
0141 }
0142
0143 void
0144 PHFieldInterpolated::GetFieldValue (
0145 double const* point_as_arr,
0146 double* field_as_arr
0147 ) const {
0148 try {
0149 Field_t field = get_interpolated ({
0150 (float)point_as_arr[0],
0151 (float)point_as_arr[1],
0152 (float)point_as_arr[2],
0153 });
0154 for (int i = 0; i < 3; ++i) {
0155 field_as_arr[i] = field(i);
0156 }
0157 } catch (std::exception const& e) {
0158
0159
0160
0161 std::cout
0162 << PHWHERE << "\n"
0163 << "\t" << e.what() << "\n"
0164 << std::endl;
0165 }
0166 }
0167
0168 Eigen::VectorXf
0169 PHFieldInterpolated::get_design_vector (
0170 Point_t const& point
0171 ) const {
0172 float x = point(0) - m_center(0);
0173 float y = point(1) - m_center(1);
0174 float z = point(2) - m_center(2);
0175
0176 Eigen::VectorXf design_vector(20);
0177 design_vector <<
0178
0179 1,
0180
0181
0182 x, y, z,
0183
0184
0185 x*x, x*y, x*z,
0186 y*y, y*z,
0187 z*z,
0188
0189
0190 x*x*x, x*x*y, x*x*z, x*y*y, x*y*z, x*z*z,
0191 y*y*y, y*y*z, y*z*z,
0192 z*z*z;
0193
0194 return design_vector;
0195 }
0196
0197 void
0198 PHFieldInterpolated::cache_interpolation (
0199 Point_t const& point
0200 ) const {
0201 Indices_t indices = get_indices(point);
0202 if ((indices - m_buffered_indices).norm() == 0) { return; }
0203 m_buffered_indices = indices;
0204
0205
0206
0207 m_center = get_point(indices);
0208 for (int i = 0; i < 3; ++i) {
0209 m_center(i) += GRID_STEP / 2;
0210 }
0211
0212
0213
0214 Eigen::MatrixXf M(64, 20);
0215 std::array<Eigen::VectorXf, 3> solution_vectors {
0216 Eigen::VectorXf(64),
0217 Eigen::VectorXf(64),
0218 Eigen::VectorXf(64),
0219 };
0220
0221
0222
0223 for (int row = 0; row < 64; ++row) {
0224 indices = {
0225 m_buffered_indices(0) + (row / 16) - 1,
0226 m_buffered_indices(1) + ((row / 4) % 4) - 1,
0227 m_buffered_indices(2) + (row % 4) - 1,
0228 };
0229
0230
0231
0232 try {
0233 validate_indices(indices);
0234 } catch (std::exception const&) {
0235 m_buffered_indices = {-1, -1, -1};
0236 std::stringstream what;
0237 what
0238 << PHWHERE
0239 << " Point too close to edge of fieldmap "
0240 << " (at " << point << ")";
0241 throw std::runtime_error(what.str());
0242 }
0243
0244 M.row(row) = get_design_vector(get_point(indices));
0245 for (int i = 0; i < 3; ++i) {
0246 solution_vectors[i](row) = get_field(indices)(i);
0247 }
0248 }
0249
0250 for (int i = 0; i < 3; ++i) {
0251 m_coefficients[i] = M.bdcSvd (
0252 Eigen::ComputeThinU | Eigen::ComputeThinV
0253 ).solve (solution_vectors[i]);
0254 }
0255 }
0256
0257 PHFieldInterpolated::Field_t
0258 PHFieldInterpolated::get_interpolated (
0259 Point_t const& point
0260 ) const {
0261 cache_interpolation(point);
0262 return {
0263 get_design_vector(point).dot(m_coefficients[0]),
0264 get_design_vector(point).dot(m_coefficients[1]),
0265 get_design_vector(point).dot(m_coefficients[2]),
0266 };
0267 }
0268
0269 PHFieldInterpolated::Indices_t
0270 PHFieldInterpolated::get_indices (
0271 std::size_t const& index
0272 ) {
0273 return {
0274 ((int)index / (GRID_COUNT * GRID_COUNT)),
0275 ((int)index / GRID_COUNT) % GRID_COUNT,
0276 ((int)index) % GRID_COUNT,
0277 };
0278 }
0279
0280 std::size_t
0281 PHFieldInterpolated::get_index (
0282 Indices_t const& indices
0283 ) {
0284 validate_indices(indices);
0285 return
0286 (indices(0) * GRID_COUNT * GRID_COUNT) +
0287 (indices(1) * GRID_COUNT) +
0288 indices(2);
0289 }
0290
0291 PHFieldInterpolated::Indices_t
0292 PHFieldInterpolated::get_indices (
0293 Point_t const& point
0294 ) {
0295 validate_point(point);
0296 return {
0297 (int)std::floor(point(0) / GRID_STEP) + (GRID_COUNT / 2),
0298 (int)std::floor(point(1) / GRID_STEP) + (GRID_COUNT / 2),
0299 (int)std::floor(point(2) / GRID_STEP) + (GRID_COUNT / 2),
0300 };
0301 }
0302
0303 PHFieldInterpolated::Point_t
0304 PHFieldInterpolated::get_point (
0305 Indices_t const& indices
0306 ) {
0307 validate_indices(indices);
0308 return {
0309
0310 (indices(0) - (GRID_COUNT / 2)) * GRID_STEP,
0311
0312 (indices(1) - (GRID_COUNT / 2)) * GRID_STEP,
0313
0314 (indices(2) - (GRID_COUNT / 2)) * GRID_STEP,
0315 };
0316 }
0317
0318 void
0319 PHFieldInterpolated::validate_indices (
0320 Indices_t const& indices
0321 ) {
0322 for (int i = 0; i < 3; ++i) {
0323 if (indices(i) < 0 || indices(i) >= GRID_COUNT) {
0324 std::stringstream what;
0325 what
0326 << PHWHERE
0327 << " Component out of range"
0328 << " (at " << indices << ")";
0329 throw std::runtime_error(what.str());
0330 }
0331 }
0332 }
0333
0334 void
0335 PHFieldInterpolated::validate_point (
0336 Point_t const& point
0337 ) {
0338 for (int i = 0; i < 3; ++i) {
0339 if (GRID_MAX < point(i) || point(i) < GRID_MIN) {
0340 std::stringstream what;
0341 what
0342 << PHWHERE
0343 << " Component out of range"
0344 << " (at " << point << ")";
0345 throw std::runtime_error(what.str());
0346 }
0347 }
0348 }
0349
0350 void
0351 PHFieldInterpolated::print_map (
0352 ) const {
0353 for (std::size_t index = 0; index < m_field.size(); ++index) {
0354 Indices_t indices = get_indices(index);
0355 std::cout
0356 << " indices: " << indices.transpose()
0357 << " point: " << get_point(indices).transpose()
0358 << " field: " << get_field(indices).transpose()
0359 << std::endl;
0360 }
0361 }
0362
0363 void
0364 PHFieldInterpolated::print_coefficients (
0365 ) const {
0366 for (int i = 0; i < 3; ++i) {
0367 std::cout
0368 << m_coefficients[i].transpose()
0369 << std::endl;
0370 }
0371 }
0372