IncliArray
Loading...
Searching...
No Matches
NDArray.h
Go to the documentation of this file.
1
23#pragma once
24
25#include <functional>
26#include <string>
27#include <tuple>
28#include <unordered_set>
29#include <vector>
30
31class NDArray {
32private:
43 NDArray(std::vector<int> shape, std::vector<int> strides, float *data,
44 bool ownsData, std::string label = "", std::string op = "",
45 std::vector<std::reference_wrapper<NDArray>> prev = {});
46
54 void build_topo(std::unordered_set<NDArray *> &visited, NDArray *arr,
55 std::vector<std::reference_wrapper<NDArray>> &topo);
56
57public:
59 float *data;
61 std::vector<int> shape;
63 std::vector<int> strides;
65 int ndim;
67 int size = 0;
70 bool ownsData;
74 float *grad;
76 std::string op;
78 std::string label;
80 std::vector<std::reference_wrapper<NDArray>> prev;
82 enum class PrintType { Data, Grad };
84 std::function<void()> _backward;
93 NDArray(std::vector<int> shape, std::string label = "", std::string op = "",
94 std::vector<std::reference_wrapper<NDArray>> prev = {});
95
104 void metadata(bool shapeInfo = true, bool stridesInfo = false,
105 bool ndimInfo = false, bool sizeInfo = false,
106 bool ownsDataInfo = false);
107
115 float get(std::vector<int> indices, PrintType type = PrintType::Data) const;
116
129 float get(int index, PrintType type = PrintType::Data) const;
130
137 void set(std::vector<int> indices, float value);
138
150 void set(int index, float value);
151
165 NDArray slice(std::vector<std::tuple<int, int>> indices);
166
171 bool isContiguous() const;
172
183 void reshape(std::vector<int> newShape);
184
191 void print(PrintType type = PrintType::Data);
192
197 void fillSequential();
198
203 void fill(float value);
204
208 void zeros();
209
213 void ones();
214
221 void randint(int low, int high);
222
227 void rand();
228
234 void rand(float low, float high);
235
242 NDArray operator+(NDArray &other);
243
247 NDArray operator+(float value);
248
254 NDArray operator-(NDArray &other);
255
259 NDArray operator-(float value);
260
266 NDArray operator*(NDArray &other);
267
272 NDArray operator*(float value);
273
280 NDArray operator/(NDArray &other);
281
285 NDArray operator/(float value);
286
293 NDArray operator^(float value);
294
301
305 NDArray element_wise_multiply(float value);
306
313 NDArray sum();
314
325 NDArray sum(int axis);
326
334 void backward();
335
342 NDArray clone();
343};
Definition NDArray.h:31
void rand()
Fill with uniform real values in [0, 1).
Definition NDArray.cpp:307
void backward()
Reverse‑mode backprop: accumulate gradients into all reachable parents from this node.
Definition NDArray.cpp:855
void print(PrintType type=PrintType::Data)
Pretty‑print the data or gradients.
Definition NDArray.cpp:199
std::string op
Definition NDArray.h:76
void build_topo(std::unordered_set< NDArray * > &visited, NDArray *arr, std::vector< std::reference_wrapper< NDArray > > &topo)
Build a topological ordering of nodes reachable from arr.
Definition NDArray.cpp:844
NDArray operator/(NDArray &other)
Broadcasted element‑wise division (this / other).
Definition NDArray.cpp:615
std::string label
Definition NDArray.h:78
void randint(int low, int high)
Fill with uniform integer values in [low, high).
Definition NDArray.cpp:294
int ndim
Definition NDArray.h:65
float * data
Definition NDArray.h:59
NDArray sum()
Reduce all elements to a scalar sum.
Definition NDArray.cpp:870
NDArray operator*(NDArray &other)
2D matrix multiplication (no broadcasting).
Definition NDArray.cpp:529
void set(std::vector< int > indices, float value)
Write an element by multi‑dimensional indices.
Definition NDArray.cpp:138
std::function< void()> _backward
Definition NDArray.h:84
NDArray clone()
Materialize a contiguous, owning copy. Detached from autograd.
Definition NDArray.cpp:337
NDArray element_wise_multiply(NDArray &other)
Broadcasted element‑wise multiplication.
Definition NDArray.cpp:757
float get(std::vector< int > indices, PrintType type=PrintType::Data) const
Read an element by multi‑dimensional indices.
Definition NDArray.cpp:108
int size
Definition NDArray.h:67
std::vector< int > strides
Definition NDArray.h:63
NDArray slice(std::vector< std::tuple< int, int > > indices)
Return a non‑owning view restricted by per‑axis [start, stop) slices.
Definition NDArray.cpp:168
NDArray operator-(NDArray &other)
Broadcasted element‑wise subtraction (this - other).
Definition NDArray.cpp:445
std::vector< std::reference_wrapper< NDArray > > prev
Definition NDArray.h:80
void ones()
Set all elements to 1.
Definition NDArray.cpp:287
void reshape(std::vector< int > newShape)
Reshape this array to a new shape with the same number of elements.
Definition NDArray.cpp:236
void metadata(bool shapeInfo=true, bool stridesInfo=false, bool ndimInfo=false, bool sizeInfo=false, bool ownsDataInfo=false)
Print selected metadata fields.
Definition NDArray.cpp:72
bool ownsData
Definition NDArray.h:70
std::vector< int > shape
Definition NDArray.h:61
void fill(float value)
Fill with a constant value.
Definition NDArray.cpp:270
NDArray operator+(NDArray &other)
Broadcasted element‑wise addition (this + other).
Definition NDArray.cpp:360
NDArray operator^(float value)
Scalar element-wise power (this ^ value).
Definition NDArray.cpp:723
float * grad
Definition NDArray.h:74
void fillSequential()
Fill with sequential values 0, 1, 2, ... (for demos/testing).
Definition NDArray.cpp:260
bool isContiguous() const
Whether the logical layout matches standard row‑major contiguous strides for the current shape.
Definition NDArray.cpp:190
void zeros()
Set all elements to 0.
Definition NDArray.cpp:280
PrintType
Definition NDArray.h:82