File indexing completed on 2025-08-06 08:11:12
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011
0012 #include "Acts/Plugins/Sycl/Seeding/detail/Types.hpp"
0013
0014 #include "../Utilities/Arrays.hpp"
0015 #include "SpacePointType.hpp"
0016
0017
0018 #include <CL/sycl.hpp>
0019
0020
0021 #include <cassert>
0022 #include <cstdint>
0023
0024
0025 #include "vecmem/containers/data/vector_view.hpp"
0026 #include "vecmem/containers/device_vector.hpp"
0027
0028 namespace Acts::Sycl::detail {
0029
0030
0031 template <SpacePointType OtherSPType>
0032 class LinearTransform {
0033
0034 static_assert((OtherSPType == SpacePointType::Bottom) ||
0035 (OtherSPType == SpacePointType::Top),
0036 "Class must be instantiated with either "
0037 "Acts::Sycl::detail::SpacePointType::Bottom or "
0038 "Acts::Sycl::detail::SpacePointType::Top");
0039
0040 public:
0041
0042 LinearTransform(
0043 vecmem::data::vector_view<const DeviceSpacePoint> middleSPs,
0044 vecmem::data::vector_view<const DeviceSpacePoint> otherSPs,
0045 vecmem::data::vector_view<uint32_t> middleIndexLUT,
0046 vecmem::data::vector_view<uint32_t> otherIndexLUT, uint32_t nEdges,
0047 vecmem::data::vector_view<detail::DeviceLinEqCircle> resultArray)
0048 : m_middleSPs(middleSPs),
0049 m_otherSPs(otherSPs),
0050 m_middleIndexLUT(middleIndexLUT),
0051 m_otherIndexLUT(otherIndexLUT),
0052 m_nEdges(nEdges),
0053 m_resultArray(resultArray) {}
0054
0055
0056 void operator()(cl::sycl::nd_item<1> item) const {
0057
0058 const auto idx = item.get_global_linear_id();
0059 if (idx >= m_nEdges) {
0060 return;
0061 }
0062
0063
0064
0065
0066
0067 vecmem::device_vector<uint32_t> middleIndexLUT(m_middleIndexLUT);
0068 const uint32_t middleIndex = middleIndexLUT[idx];
0069 assert(middleIndex < m_middleSPs.size());
0070 (void)m_middleSPs.size();
0071 vecmem::device_vector<uint32_t> otherIndexLUT(m_otherIndexLUT);
0072 const uint32_t otherIndex = otherIndexLUT[idx];
0073 assert(otherIndex < m_otherSPs.size());
0074 (void)m_otherSPs.size();
0075
0076
0077
0078
0079 const vecmem::device_vector<const DeviceSpacePoint> middleSPs(m_middleSPs);
0080 const DeviceSpacePoint middleSP = middleSPs[middleIndex];
0081 const vecmem::device_vector<const DeviceSpacePoint> otherSPs(m_otherSPs);
0082 const DeviceSpacePoint otherSP = otherSPs[otherIndex];
0083
0084
0085
0086 const float cosPhiM = middleSP.x / middleSP.r;
0087 const float sinPhiM = middleSP.y / middleSP.r;
0088
0089 const float deltaX = otherSP.x - middleSP.x;
0090 const float deltaY = otherSP.y - middleSP.y;
0091 const float deltaZ = otherSP.z - middleSP.z;
0092
0093 const float x = deltaX * cosPhiM + deltaY * sinPhiM;
0094 const float y = deltaY * cosPhiM - deltaX * sinPhiM;
0095 const float iDeltaR2 = 1.f / (deltaX * deltaX + deltaY * deltaY);
0096
0097
0098 DeviceLinEqCircle result;
0099 result.iDeltaR = cl::sycl::sqrt(iDeltaR2);
0100 result.cotTheta = deltaZ * result.iDeltaR;
0101 if constexpr (OtherSPType == SpacePointType::Bottom) {
0102 result.cotTheta = -(result.cotTheta);
0103 }
0104 result.zo = middleSP.z - middleSP.r * result.cotTheta;
0105 result.u = x * iDeltaR2;
0106 result.v = y * iDeltaR2;
0107 result.er =
0108 ((middleSP.varZ + otherSP.varZ) +
0109 (result.cotTheta * result.cotTheta) * (middleSP.varR + otherSP.varR)) *
0110 iDeltaR2;
0111
0112
0113 vecmem::device_vector<detail::DeviceLinEqCircle> resultArray(m_resultArray);
0114 resultArray[idx] = result;
0115 return;
0116 }
0117
0118 private:
0119
0120 vecmem::data::vector_view<const DeviceSpacePoint> m_middleSPs;
0121
0122 vecmem::data::vector_view<const DeviceSpacePoint> m_otherSPs;
0123
0124
0125 vecmem::data::vector_view<uint32_t> m_middleIndexLUT;
0126
0127 vecmem::data::vector_view<uint32_t> m_otherIndexLUT;
0128
0129
0130 uint32_t m_nEdges;
0131
0132
0133 vecmem::data::vector_view<detail::DeviceLinEqCircle> m_resultArray;
0134
0135 };
0136
0137 }