File indexing completed on 2025-12-18 09:17:53
0001
0002 #ifndef MULTIARRAY_H
0003 #define MULTIARRAY_H
0004
0005 #include <cassert>
0006 #include <cstdlib> // for malloc
0007 #include <format>
0008 #include <iostream>
0009
0010 template <class T>
0011 class MultiArray
0012 {
0013
0014
0015 public:
0016 static const int MAX_DIM = 6;
0017 int dim;
0018 int n[6];
0019 long int length;
0020 T *field;
0021
0022 MultiArray(int a = 0, int b = 0, int c = 0, int d = 0, int e = 0, int f = 0)
0023 {
0024 int n_[6];
0025 for (int i = 0; i < MAX_DIM; i++)
0026 n[i] = 0;
0027 n_[0] = a;
0028 n_[1] = b;
0029 n_[2] = c;
0030 n_[3] = d;
0031 n_[4] = e;
0032 n_[5] = f;
0033 length = 1;
0034 dim = MAX_DIM;
0035 for (int i = 0; i < dim; i++)
0036 {
0037 if (n_[i] < 1)
0038 {
0039 dim = i;
0040 break;
0041 }
0042 n[i] = n_[i];
0043 length *= n[i];
0044 }
0045 field = static_cast<T *>(malloc(length * sizeof(T)));
0046
0047 }
0048
0049 explicit MultiArray(const MultiArray &) = delete;
0050 MultiArray &operator=(const MultiArray &) = delete;
0051
0052 ~MultiArray()
0053 {
0054 free(field);
0055 }
0056
0057 void Add(int a, int b, int c, T in)
0058 {
0059 Add(a, b, c, 0, 0, 0, in);
0060 return;
0061 };
0062
0063 void Add(int a, int b, int c, int d, int e, int f, T in)
0064 {
0065 int n_[6];
0066 n_[0] = a;
0067 n_[1] = b;
0068 n_[2] = c;
0069 n_[3] = d;
0070 n_[4] = e;
0071 n_[5] = f;
0072 long int index = n_[0];
0073 for (int i = 1; i < dim; i++)
0074 {
0075 index = (index * n[i]) + n_[i];
0076 }
0077 field[index] = field[index] + in;
0078 return;
0079 }
0080
0081 T Get(int a = 0, int b = 0, int c = 0, int d = 0, int e = 0, int f = 0)
0082 {
0083 int n_[6];
0084 n_[0] = a;
0085 n_[1] = b;
0086 n_[2] = c;
0087 n_[3] = d;
0088 n_[4] = e;
0089 n_[5] = f;
0090 long int index = 0;
0091 for (int i = 0; i < dim; i++)
0092 {
0093 if (n[i] <= n_[i] || n_[i] < 0)
0094 {
0095 std::cout << std::format("asking for el {} {} {} {} {} {}. {}th element is outside of bounds 0<x<{}", n_[0], n_[1], n_[2], n_[3], n_[4], n_[5], n_[i], n[i]) << std::endl;
0096 assert(false);
0097 }
0098 index = (index * n[i]) + n_[i];
0099 }
0100 return field[index];
0101 }
0102
0103 T *GetPtr(int a = 0, int b = 0, int c = 0, int d = 0, int e = 0, int f = 0)
0104 {
0105 int n_[6];
0106 n_[0] = a;
0107 n_[1] = b;
0108 n_[2] = c;
0109 n_[3] = d;
0110 n_[4] = e;
0111 n_[5] = f;
0112 long int index = n_[0];
0113 for (int i = 1; i < dim; i++)
0114 {
0115 index = (index * n[i]) + n_[i];
0116 }
0117 return &(field[index]);
0118 }
0119
0120 T *GetFlat(int a = 0)
0121 {
0122 if (a < 0 || a >= length)
0123 {
0124 std::cout << std::format("tried to seek element {} of multiarray, but bounds are 0<a<{}", a, length) << std::endl;
0125 assert(a < 0 || a >= length);
0126 }
0127 return &(field[a]);
0128 }
0129
0130 int Length()
0131 {
0132 return (int) length;
0133 }
0134
0135 void Set(int a, int b, int c, T in)
0136 {
0137 Set(a, b, c, 0, 0, 0, in);
0138 return;
0139 };
0140
0141 void Set(int a, int b, int c, int d, int e, int f, T in)
0142 {
0143 int n_[6];
0144 n_[0] = a;
0145 n_[1] = b;
0146 n_[2] = c;
0147 n_[3] = d;
0148 n_[4] = e;
0149 n_[5] = f;
0150 long int index = n_[0];
0151 for (int i = 1; i < dim; i++)
0152 {
0153 index = (index * n[i]) + n_[i];
0154 }
0155 field[index] = in;
0156 return;
0157 }
0158
0159 void SetAll(T in)
0160 {
0161
0162 for (long int i = 0; i < length; i++)
0163 {
0164 field[i] = in;
0165 }
0166 return;
0167 }
0168 };
0169 #endif