This repository was archived by the owner on Jan 26, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathIO.cpp
More file actions
104 lines (88 loc) · 2.83 KB
/
IO.cpp
File metadata and controls
104 lines (88 loc) · 2.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
// SPDX-License-Identifier: BSD-3-Clause
/*
I/O ops.
*/
#include "sharpy/IO.hpp"
#include "sharpy/Factory.hpp"
#include "sharpy/NDArray.hpp"
#include "sharpy/PyTypes.hpp"
#include "sharpy/SetGetItem.hpp"
#include "sharpy/Transceiver.hpp"
#include "sharpy/TypeDispatch.hpp"
namespace SHARPY {
// ***************************************************************************
/// @brief form a FutureArray from local numpy arrays (inplace - no copy)
struct DeferredFromLocal : public Deferred {
py::array _npa;
DeferredFromLocal() = default;
DeferredFromLocal(py::array npa)
: Deferred(getDTypeId(npa.dtype()),
{npa.shape(), npa.shape() + npa.ndim()}, {}, 0),
_npa(npa) {}
// get our DTypeId from py::dtype
DTypeId getDTypeId(const py::dtype &dtype) {
auto bw = dtype.itemsize();
auto kind = dtype.kind();
switch (kind) {
case 'i':
switch (bw) {
case 1:
return INT8;
case 2:
return INT16;
case 4:
return INT32;
case 8:
return INT64;
};
case 'f':
switch (bw) {
case 4:
return FLOAT32;
case 8:
return FLOAT64;
};
};
throw std::invalid_argument("Unsupported dtype");
}
void run() override {
auto _strides = _npa.strides();
auto shape = _npa.shape();
auto data = _npa.mutable_data();
auto dtype = _npa.dtype();
auto ndim = _npa.ndim();
auto eSz = dtype.itemsize();
// py::array stores strides in bytes, not elements
std::vector<intptr_t> strides(ndim);
for (auto i = 0; i < ndim; ++i) {
strides[i] = _strides[i] / eSz;
}
auto res = mk_tnsr(this->guid(), getDTypeId(dtype), ndim, shape,
strides.data(), data, this->device(), this->team());
// make sure we do not delete numpy's memory before the numpy array is dead
// notice: py::objects have ref-counting)
res->set_base(new SharedBaseObject<py::object>(_npa));
set_value(std::move(res));
}
bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc,
jit::DepManager &dm) override {
return true;
}
FactoryId factory() const override { return F_FROMLOCALS; }
template <typename S> void serialize(S &ser) {}
};
GetItem::py_future_type IO::to_numpy(const FutureArray &a) {
if (getTransceiver()->is_cw() && getTransceiver()->rank() != 0) {
throw std::runtime_error(
"In c/w mode, to_numpy is only supported on rank 0");
}
return GetItem::gather(a, getTransceiver()->is_cw() ? 0 : REPLICATED);
}
FutureArray *IO::from_locals(const std::vector<py::array> &a) {
if (a.size() != 1) {
throw std::runtime_error("from_locals only supports a single local array");
}
return new FutureArray(defer<DeferredFromLocal>(a.front()));
}
FACTORY_INIT(DeferredFromLocal, F_FROMLOCALS);
} // namespace SHARPY