Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit ae30a1c

Browse filesBrowse files
yiming0416facebook-github-bot
authored andcommitted
[nativert] Move Weights to PyTorch core (#155156)
Summary: Pull Request resolved: #155156 Moves Weights class to PyTorch core Torch Native Runtime RFC: pytorch/rfcs#72 README: https://github.com/pytorch/pytorch/blob/main/torch/nativert/OVERVIEW.md Test Plan: ``` buck2 run mode/dev-nosan caffe2/test/cpp/nativert:weights_test ``` Reviewed By: zhxchen17 Differential Revision: D75973156
1 parent 9b4db09 commit ae30a1c
Copy full SHA for ae30a1c

File tree

Expand file treeCollapse file tree

5 files changed

+678
-0
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+678
-0
lines changed

‎build_variables.bzl

Copy file name to clipboardExpand all lines: build_variables.bzl
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,7 @@ libtorch_nativert_sources = [
596596
"torch/nativert/graph/TensorMeta.cpp",
597597
"torch/nativert/executor/Placement.cpp",
598598
"torch/nativert/executor/PlacementUtils.cpp",
599+
"torch/nativert/executor/Weights.cpp",
599600
"torch/nativert/common/FileUtil.cpp",
600601
]
601602

‎test/cpp/nativert/CMakeLists.txt

Copy file name to clipboardExpand all lines: test/cpp/nativert/CMakeLists.txt
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ set(NATIVERT_TEST_SRCS
99
${TORCH_ROOT}/torch/nativert/graph/Graph.cpp
1010
${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp
1111
${TORCH_ROOT}/torch/nativert/executor/PlacementUtils.cpp
12+
${TORCH_ROOT}/torch/nativert/executor/Weights.cpp
1213
${TORCH_ROOT}/torch/nativert/common/FileUtil.cpp
1314
)
1415

‎test/cpp/nativert/test_weights.cpp

Copy file name to clipboard
+92Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#include <gtest/gtest.h>
2+
#include <torch/csrc/jit/serialization/pickle.h>
3+
#include <torch/custom_class.h>
4+
#include <torch/torch.h>
5+
#include <memory>
6+
7+
#include <torch/nativert/executor/Placement.h>
8+
#include <torch/nativert/executor/Weights.h>
9+
#include <torch/nativert/graph/Graph.h>
10+
11+
namespace torch::nativert {
12+
class WeightsTest : public ::testing::Test {
13+
protected:
14+
void SetUp() override {
15+
static constexpr std::string_view source =
16+
R"(graph(%foo, %bar, %baz):
17+
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
18+
return(%o2, %baz)
19+
)";
20+
graph = stringToGraph(source);
21+
placement = std::make_unique<Placement>(c10::Device(c10::DeviceType::CPU));
22+
}
23+
std::shared_ptr<Graph> graph;
24+
std::unique_ptr<Placement> placement;
25+
};
26+
TEST_F(WeightsTest, ConstructEmptyStateDict) {
27+
std::unordered_map<std::string, c10::IValue> stateDict;
28+
Weights weights(graph.get(), stateDict, *placement);
29+
// Check that weights are initialized correctly
30+
EXPECT_TRUE(weights.parameters().empty());
31+
EXPECT_TRUE(weights.buffers().empty());
32+
EXPECT_FALSE(weights.contains("non_existent_weight"));
33+
}
34+
TEST_F(WeightsTest, SetAndGetValue) {
35+
std::unordered_map<std::string, c10::IValue> stateDict;
36+
Weights weights(graph.get(), stateDict, *placement);
37+
at::Tensor tensor = at::ones({2, 2});
38+
weights.setValue("added_weight", tensor);
39+
EXPECT_TRUE(weights.contains("added_weight"));
40+
EXPECT_EQ(weights.at("added_weight").sizes(), tensor.sizes());
41+
}
42+
43+
} // namespace torch::nativert
44+
45+
using namespace ::testing;
46+
struct ContainsTensorDict : torch::CustomClassHolder {
47+
explicit ContainsTensorDict(at::Tensor t) : t_(t) {}
48+
49+
explicit ContainsTensorDict(c10::Dict<std::string, at::Tensor> dict) {
50+
t_ = dict.at(std::string("init_tensor"));
51+
}
52+
53+
c10::Dict<std::string, at::Tensor> serialize() const {
54+
c10::Dict<std::string, at::Tensor> dict;
55+
dict.insert(std::string("init_tensor"), t_);
56+
return dict;
57+
}
58+
59+
at::Tensor t_;
60+
};
61+
62+
static auto reg =
63+
torch::class_<ContainsTensorDict>("testing", "ContainsTensorDict")
64+
.def(torch::init<at::Tensor>())
65+
.def_pickle(
66+
// __getstate__
67+
[](const c10::intrusive_ptr<ContainsTensorDict>& self)
68+
-> c10::Dict<std::string, at::Tensor> {
69+
return self->serialize();
70+
},
71+
// __setstate__
72+
[](c10::Dict<std::string, at::Tensor> data)
73+
-> c10::intrusive_ptr<ContainsTensorDict> {
74+
return c10::make_intrusive<ContainsTensorDict>(std::move(data));
75+
});
76+
77+
TEST(CustomWeightsTest, TestCustomObjWithContainedTensor) {
78+
// Save
79+
auto customObj =
80+
c10::make_intrusive<ContainsTensorDict>(torch::tensor({1, 2, 3}));
81+
const auto bytes = torch::jit::pickle_save(c10::IValue(std::move(customObj)));
82+
83+
// Load
84+
const auto loadedCustomObj =
85+
torch::jit::pickle_load_obj(std::string{bytes.begin(), bytes.end()});
86+
EXPECT_TRUE(loadedCustomObj.isObject());
87+
EXPECT_EQ(
88+
loadedCustomObj.to<c10::intrusive_ptr<ContainsTensorDict>>()
89+
->t_[0]
90+
.item<int>(),
91+
1);
92+
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.