File indexing completed on 2025-08-05 08:15:40
0001
0002 #ifndef MULTIARRAY_H
0003 #define MULTIARRAY_H
0004
0005 #include <cassert>
0006 #include <cstdio> // for printf
0007 #include <cstdlib> // for malloc
0008
0009 template <class T>
0010 class MultiArray
0011 {
0012
0013
0014 public:
0015 static const int MAX_DIM = 6;
0016 int dim;
0017 int n[6];
0018 long int length;
0019 T *field;
0020
0021 MultiArray(int a = 0, int b = 0, int c = 0, int d = 0, int e = 0, int f = 0)
0022 {
0023 int n_[6];
0024 for (int i = 0; i < MAX_DIM; i++)
0025 n[i] = 0;
0026 n_[0] = a;
0027 n_[1] = b;
0028 n_[2] = c;
0029 n_[3] = d;
0030 n_[4] = e;
0031 n_[5] = f;
0032 length = 1;
0033 dim = MAX_DIM;
0034 for (int i = 0; i < dim; i++)
0035 {
0036 if (n_[i] < 1)
0037 {
0038 dim = i;
0039 break;
0040 }
0041 n[i] = n_[i];
0042 length *= n[i];
0043 }
0044 field = static_cast<T *>(malloc(length * sizeof(T)));
0045
0046 }
0047
0048 explicit MultiArray(const MultiArray &) = delete;
0049 MultiArray &operator=(const MultiArray &) = delete;
0050
0051 ~MultiArray()
0052 {
0053 free(field);
0054 }
0055
0056 void Add(int a, int b, int c, T in)
0057 {
0058 Add(a, b, c, 0, 0, 0, in);
0059 return;
0060 };
0061
0062 void Add(int a, int b, int c, int d, int e, int f, T in)
0063 {
0064 int n_[6];
0065 n_[0] = a;
0066 n_[1] = b;
0067 n_[2] = c;
0068 n_[3] = d;
0069 n_[4] = e;
0070 n_[5] = f;
0071 long int index = n_[0];
0072 for (int i = 1; i < dim; i++)
0073 {
0074 index = (index * n[i]) + n_[i];
0075 }
0076 field[index] = field[index] + in;
0077 return;
0078 }
0079
0080 T Get(int a = 0, int b = 0, int c = 0, int d = 0, int e = 0, int f = 0)
0081 {
0082 int n_[6];
0083 n_[0] = a;
0084 n_[1] = b;
0085 n_[2] = c;
0086 n_[3] = d;
0087 n_[4] = e;
0088 n_[5] = f;
0089 long int index = 0;
0090 for (int i = 0; i < dim; i++)
0091 {
0092 if (n[i] <= n_[i] || n_[i] < 0)
0093 {
0094 printf("asking for el %d %d %d %d %d %d. %dth element is outside of bounds 0<x<%d\n", n_[0], n_[1], n_[2], n_[3], n_[4], n_[5], n_[i], n[i]);
0095 assert(false);
0096 }
0097 index = (index * n[i]) + n_[i];
0098 }
0099 return field[index];
0100 }
0101
0102 T *GetPtr(int a = 0, int b = 0, int c = 0, int d = 0, int e = 0, int f = 0)
0103 {
0104 int n_[6];
0105 n_[0] = a;
0106 n_[1] = b;
0107 n_[2] = c;
0108 n_[3] = d;
0109 n_[4] = e;
0110 n_[5] = f;
0111 long int index = n_[0];
0112 for (int i = 1; i < dim; i++)
0113 {
0114 index = (index * n[i]) + n_[i];
0115 }
0116 return &(field[index]);
0117 }
0118
0119 T *GetFlat(int a = 0)
0120 {
0121 if (a < 0 || a >= length)
0122 {
0123 printf("tried to seek element %d of multiarray, but bounds are 0<a<%ld\n", a, length);
0124 assert(a < 0 || a >= length);
0125 }
0126 return &(field[a]);
0127 }
0128
0129 int Length()
0130 {
0131 return (int) length;
0132 }
0133
0134 void Set(int a, int b, int c, T in)
0135 {
0136 Set(a, b, c, 0, 0, 0, in);
0137 return;
0138 };
0139
0140 void Set(int a, int b, int c, int d, int e, int f, T in)
0141 {
0142 int n_[6];
0143 n_[0] = a;
0144 n_[1] = b;
0145 n_[2] = c;
0146 n_[3] = d;
0147 n_[4] = e;
0148 n_[5] = f;
0149 long int index = n_[0];
0150 for (int i = 1; i < dim; i++)
0151 {
0152 index = (index * n[i]) + n_[i];
0153 }
0154 field[index] = in;
0155 return;
0156 }
0157
0158 void SetAll(T in)
0159 {
0160
0161 for (long int i = 0; i < length; i++)
0162 {
0163 field[i] = in;
0164 }
0165 return;
0166 }
0167 };
0168 #endif