File indexing completed on 2025-08-06 08:09:57
0001
0002
0003
0004
0005
0006
0007
0008
0009 #include <array>
0010 #include <vector>
0011
0012 #include <boost/pending/disjoint_sets.hpp>
0013
0014 namespace Acts::Ccl::internal {
0015
0016
0017
0018 template <typename, std::size_t, typename T = void>
0019 struct cellTypeHasRequiredFunctions : std::false_type {};
0020
0021 template <typename T>
0022 struct cellTypeHasRequiredFunctions<
0023 T, 2,
0024 std::void_t<decltype(getCellRow(std::declval<T>())),
0025 decltype(getCellColumn(std::declval<T>())),
0026 decltype(getCellLabel(std::declval<T&>()))>> : std::true_type {
0027 };
0028
0029 template <typename T>
0030 struct cellTypeHasRequiredFunctions<
0031 T, 1,
0032 std::void_t<decltype(getCellColumn(std::declval<T>())),
0033 decltype(getCellLabel(std::declval<T&>()))>> : std::true_type {
0034 };
0035
0036 template <typename, typename, typename T = void>
0037 struct clusterTypeHasRequiredFunctions : std::false_type {};
0038
0039 template <typename T, typename U>
0040 struct clusterTypeHasRequiredFunctions<
0041 T, U,
0042 std::void_t<decltype(clusterAddCell(std::declval<T>(), std::declval<U>()))>>
0043 : std::true_type {};
0044
0045 template <std::size_t GridDim>
0046 constexpr void staticCheckGridDim() {
0047 static_assert(
0048 GridDim == 1 || GridDim == 2,
0049 "mergeClusters is only defined for grid dimensions of 1 or 2. ");
0050 }
0051
0052 template <typename T, std::size_t GridDim>
0053 constexpr void staticCheckCellType() {
0054 constexpr bool hasFns = cellTypeHasRequiredFunctions<T, GridDim>();
0055 static_assert(hasFns,
0056 "Cell type should have the following functions: "
0057 "'int getCellRow(const Cell&)', "
0058 "'int getCellColumn(const Cell&)', "
0059 "'Label& getCellLabel(Cell&)'");
0060 }
0061
0062 template <typename T, typename U>
0063 constexpr void staticCheckClusterType() {
0064 constexpr bool hasFns = clusterTypeHasRequiredFunctions<T, U>();
0065 static_assert(hasFns,
0066 "Cluster type should have the following function: "
0067 "'void clusterAddCell(Cluster&, const Cell&)'");
0068 }
0069
0070 template <typename Cell, std::size_t GridDim>
0071 struct Compare {
0072 static_assert(GridDim != 1 && GridDim != 2,
0073 "Only grid dimensions of 1 or 2 are supported");
0074 };
0075
0076
0077
0078 template <typename Cell>
0079 struct Compare<Cell, 2> {
0080 bool operator()(const Cell& c0, const Cell& c1) const {
0081 int row0 = getCellRow(c0);
0082 int row1 = getCellRow(c1);
0083 int col0 = getCellColumn(c0);
0084 int col1 = getCellColumn(c1);
0085 return (col0 == col1) ? row0 < row1 : col0 < col1;
0086 }
0087 };
0088
0089
0090 template <typename Cell>
0091 struct Compare<Cell, 1> {
0092 bool operator()(const Cell& c0, const Cell& c1) const {
0093 int col0 = getCellColumn(c0);
0094 int col1 = getCellColumn(c1);
0095 return col0 < col1;
0096 }
0097 };
0098
0099
0100
0101
0102 class DisjointSets {
0103 public:
0104 explicit DisjointSets(std::size_t initial_size = 128)
0105 : m_size(initial_size),
0106 m_rank(m_size),
0107 m_parent(m_size),
0108 m_ds(&m_rank[0], &m_parent[0]) {}
0109
0110 Label makeSet() {
0111
0112
0113 while (m_globalId >= m_size) {
0114 m_size *= 2;
0115 m_rank.resize(m_size);
0116 m_parent.resize(m_size);
0117 m_ds = boost::disjoint_sets<std::size_t*, std::size_t*>(&m_rank[0],
0118 &m_parent[0]);
0119 }
0120 m_ds.make_set(m_globalId);
0121 return static_cast<Label>(m_globalId++);
0122 }
0123
0124 void unionSet(std::size_t x, std::size_t y) { m_ds.union_set(x, y); }
0125 Label findSet(std::size_t x) { return static_cast<Label>(m_ds.find_set(x)); }
0126
0127 private:
0128 std::size_t m_globalId = 1;
0129 std::size_t m_size;
0130 std::vector<std::size_t> m_rank;
0131 std::vector<std::size_t> m_parent;
0132 boost::disjoint_sets<std::size_t*, std::size_t*> m_ds;
0133 };
0134
0135 template <std::size_t BufSize>
0136 struct ConnectionsBase {
0137 std::size_t nconn{0};
0138 std::array<Label, BufSize> buf;
0139 ConnectionsBase() { std::fill(buf.begin(), buf.end(), NO_LABEL); }
0140 };
0141
0142 template <std::size_t GridDim>
0143 class Connections {};
0144
0145
0146 template <>
0147 struct Connections<1> : public ConnectionsBase<1> {
0148 using ConnectionsBase::ConnectionsBase;
0149 };
0150
0151
0152 template <>
0153 struct Connections<2> : public ConnectionsBase<4> {
0154 using ConnectionsBase::ConnectionsBase;
0155 };
0156
0157
0158 template <typename Cell, typename Connect, std::size_t GridDim>
0159 Connections<GridDim> getConnections(typename std::vector<Cell>::iterator it,
0160 std::vector<Cell>& set, Connect connect) {
0161 Connections<GridDim> seen;
0162 typename std::vector<Cell>::iterator it_2{it};
0163
0164 while (it_2 != set.begin()) {
0165 it_2 = std::prev(it_2);
0166
0167 ConnectResult cr = connect(*it, *it_2);
0168 if (cr == ConnectResult::eNoConnStop) {
0169 break;
0170 }
0171 if (cr == ConnectResult::eNoConn) {
0172 continue;
0173 }
0174 if (cr == ConnectResult::eConn) {
0175 seen.buf[seen.nconn] = getCellLabel(*it_2);
0176 seen.nconn += 1;
0177 if (seen.nconn == seen.buf.size()) {
0178 break;
0179 }
0180 }
0181 }
0182 return seen;
0183 }
0184
0185 template <typename CellCollection, typename ClusterCollection>
0186 ClusterCollection mergeClustersImpl(CellCollection& cells) {
0187 using Cluster = typename ClusterCollection::value_type;
0188
0189 if (cells.empty()) {
0190 return {};
0191 }
0192
0193
0194 ClusterCollection outv;
0195 Cluster cl;
0196 int lbl = getCellLabel(cells.front());
0197 for (auto& cell : cells) {
0198 if (getCellLabel(cell) != lbl) {
0199
0200 outv.push_back(std::move(cl));
0201 cl = Cluster();
0202 lbl = getCellLabel(cell);
0203 }
0204 clusterAddCell(cl, cell);
0205 }
0206
0207 outv.push_back(std::move(cl));
0208
0209 return outv;
0210 }
0211
0212 }
0213
0214 namespace Acts::Ccl {
0215
0216 template <typename Cell>
0217 ConnectResult Connect2D<Cell>::operator()(const Cell& ref,
0218 const Cell& iter) const {
0219 int deltaRow = std::abs(getCellRow(ref) - getCellRow(iter));
0220 int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
0221
0222
0223 if (deltaCol > 1) {
0224 return ConnectResult::eNoConnStop;
0225 }
0226
0227
0228 if (deltaRow > 1) {
0229 return ConnectResult::eNoConn;
0230 }
0231
0232
0233 if ((deltaRow + deltaCol) <= (conn8 ? 2 : 1)) {
0234 return ConnectResult::eConn;
0235 }
0236 return ConnectResult::eNoConn;
0237 }
0238
0239 template <typename Cell>
0240 ConnectResult Connect1D<Cell>::operator()(const Cell& ref,
0241 const Cell& iter) const {
0242 int deltaCol = std::abs(getCellColumn(ref) - getCellColumn(iter));
0243 return deltaCol == 1 ? ConnectResult::eConn : ConnectResult::eNoConnStop;
0244 }
0245
0246 template <std::size_t GridDim>
0247 void recordEquivalences(const internal::Connections<GridDim> seen,
0248 internal::DisjointSets& ds) {
0249
0250
0251 if (seen.nconn > 0 && seen.buf[0] == NO_LABEL) {
0252 throw std::logic_error("seen.nconn > 0 but seen.buf[0] == NO_LABEL");
0253 }
0254 for (std::size_t i = 1; i < seen.nconn; i++) {
0255
0256
0257
0258 if (seen.buf[i] == NO_LABEL) {
0259 throw std::logic_error("i < seen.nconn but see.buf[i] == NO_LABEL");
0260 }
0261
0262 if (seen.buf[0] != seen.buf[i]) {
0263 ds.unionSet(seen.buf[0], seen.buf[i]);
0264 }
0265 }
0266 }
0267
0268 template <typename CellCollection, std::size_t GridDim, typename Connect>
0269 void labelClusters(CellCollection& cells, Connect connect) {
0270 using Cell = typename CellCollection::value_type;
0271 internal::staticCheckCellType<Cell, GridDim>();
0272
0273 internal::DisjointSets ds{};
0274
0275
0276 std::sort(cells.begin(), cells.end(), internal::Compare<Cell, GridDim>());
0277
0278
0279 for (auto it = cells.begin(); it != cells.end(); ++it) {
0280 const internal::Connections<GridDim> seen =
0281 internal::getConnections<Cell, Connect, GridDim>(it, cells, connect);
0282 if (seen.nconn == 0) {
0283
0284 getCellLabel(*it) = ds.makeSet();
0285 } else {
0286 recordEquivalences(seen, ds);
0287
0288 getCellLabel(*it) = seen.buf[0];
0289 }
0290 }
0291
0292
0293 for (auto& cell : cells) {
0294 Label& lbl = getCellLabel(cell);
0295 lbl = ds.findSet(lbl);
0296 }
0297 }
0298
0299 template <typename CellCollection, typename ClusterCollection,
0300 std::size_t GridDim = 2>
0301 ClusterCollection mergeClusters(CellCollection& cells) {
0302 using Cell = typename CellCollection::value_type;
0303 using Cluster = typename ClusterCollection::value_type;
0304 internal::staticCheckGridDim<GridDim>();
0305 internal::staticCheckCellType<Cell, GridDim>();
0306 internal::staticCheckClusterType<Cluster&, const Cell&>();
0307
0308 if constexpr (GridDim > 1) {
0309
0310
0311 std::sort(cells.begin(), cells.end(), [](Cell& lhs, Cell& rhs) {
0312 return getCellLabel(lhs) < getCellLabel(rhs);
0313 });
0314 }
0315
0316 return internal::mergeClustersImpl<CellCollection, ClusterCollection>(cells);
0317 }
0318
0319 template <typename CellCollection, typename ClusterCollection,
0320 std::size_t GridDim, typename Connect>
0321 ClusterCollection createClusters(CellCollection& cells, Connect connect) {
0322 using Cell = typename CellCollection::value_type;
0323 using Cluster = typename ClusterCollection::value_type;
0324 internal::staticCheckCellType<Cell, GridDim>();
0325 internal::staticCheckClusterType<Cluster&, const Cell&>();
0326 labelClusters<CellCollection, GridDim, Connect>(cells, connect);
0327 return mergeClusters<CellCollection, ClusterCollection, GridDim>(cells);
0328 }
0329
0330 }