Skip to content

Commit d110bb3

Browse files
yiming0416facebook-github-bot
authored andcommitted
[nativert] Move Weights to PyTorch core (pytorch#155156)
Summary: Pull Request resolved: pytorch#155156 Moves Weights class to PyTorch core Torch Native Runtime RFC: pytorch/rfcs#72 README: https://github.com/pytorch/pytorch/blob/main/torch/nativert/OVERVIEW.md Test Plan: buck2 run mode/dev-nosan caffe2/test/cpp/nativert:weights_test Differential Revision: D75973156
1 parent 3398d1d commit d110bb3

File tree

5 files changed

+678
-0
lines changed

5 files changed

+678
-0
lines changed

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,7 @@ libtorch_nativert_sources = [
596596
"torch/nativert/graph/TensorMeta.cpp",
597597
"torch/nativert/executor/Placement.cpp",
598598
"torch/nativert/executor/PlacementUtils.cpp",
599+
"torch/nativert/executor/Weights.cpp",
599600
"torch/nativert/common/FileUtil.cpp",
600601
]
601602

test/cpp/nativert/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ set(NATIVERT_TEST_SRCS
99
${TORCH_ROOT}/torch/nativert/graph/Graph.cpp
1010
${TORCH_ROOT}/torch/nativert/graph/GraphSignature.cpp
1111
${TORCH_ROOT}/torch/nativert/executor/PlacementUtils.cpp
12+
${TORCH_ROOT}/torch/nativert/executor/Weights.cpp
1213
${TORCH_ROOT}/torch/nativert/common/FileUtil.cpp
1314
)
1415

test/cpp/nativert/test_weights.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
#include <gtest/gtest.h>
2+
#include <torch/csrc/jit/serialization/pickle.h>
3+
#include <torch/custom_class.h>
4+
#include <torch/torch.h>
5+
#include <memory>
6+
7+
#include <torch/nativert/executor/Placement.h>
8+
#include <torch/nativert/executor/Weights.h>
9+
#include <torch/nativert/graph/Graph.h>
10+
11+
namespace torch::nativert {
12+
class WeightsTest : public ::testing::Test {
13+
protected:
14+
void SetUp() override {
15+
static constexpr std::string_view source =
16+
R"(graph(%foo, %bar, %baz):
17+
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
18+
return(%o2, %baz)
19+
)";
20+
graph = stringToGraph(source);
21+
placement = std::make_unique<Placement>(c10::Device(c10::DeviceType::CPU));
22+
}
23+
std::shared_ptr<Graph> graph;
24+
std::unique_ptr<Placement> placement;
25+
};
26+
TEST_F(WeightsTest, ConstructEmptyStateDict) {
27+
std::unordered_map<std::string, c10::IValue> stateDict;
28+
Weights weights(graph.get(), stateDict, *placement);
29+
// Check that weights are initialized correctly
30+
EXPECT_TRUE(weights.parameters().empty());
31+
EXPECT_TRUE(weights.buffers().empty());
32+
EXPECT_FALSE(weights.contains("non_existent_weight"));
33+
}
34+
TEST_F(WeightsTest, SetAndGetValue) {
35+
std::unordered_map<std::string, c10::IValue> stateDict;
36+
Weights weights(graph.get(), stateDict, *placement);
37+
at::Tensor tensor = at::ones({2, 2});
38+
weights.setValue("added_weight", tensor);
39+
EXPECT_TRUE(weights.contains("added_weight"));
40+
EXPECT_EQ(weights.at("added_weight").sizes(), tensor.sizes());
41+
}
42+
43+
} // namespace torch::nativert
44+
45+
using namespace ::testing;
46+
struct ContainsTensorDict : torch::CustomClassHolder {
47+
explicit ContainsTensorDict(at::Tensor t) : t_(t) {}
48+
49+
explicit ContainsTensorDict(c10::Dict<std::string, at::Tensor> dict) {
50+
t_ = dict.at(std::string("init_tensor"));
51+
}
52+
53+
c10::Dict<std::string, at::Tensor> serialize() const {
54+
c10::Dict<std::string, at::Tensor> dict;
55+
dict.insert(std::string("init_tensor"), t_);
56+
return dict;
57+
}
58+
59+
at::Tensor t_;
60+
};
61+
62+
static auto reg =
63+
torch::class_<ContainsTensorDict>("testing", "ContainsTensorDict")
64+
.def(torch::init<at::Tensor>())
65+
.def_pickle(
66+
// __getstate__
67+
[](const c10::intrusive_ptr<ContainsTensorDict>& self)
68+
-> c10::Dict<std::string, at::Tensor> {
69+
return self->serialize();
70+
},
71+
// __setstate__
72+
[](c10::Dict<std::string, at::Tensor> data)
73+
-> c10::intrusive_ptr<ContainsTensorDict> {
74+
return c10::make_intrusive<ContainsTensorDict>(std::move(data));
75+
});
76+
77+
TEST(CustomWeightsTest, TestCustomObjWithContainedTensor) {
78+
// Save
79+
auto customObj =
80+
c10::make_intrusive<ContainsTensorDict>(torch::tensor({1, 2, 3}));
81+
const auto bytes = torch::jit::pickle_save(c10::IValue(std::move(customObj)));
82+
83+
// Load
84+
const auto loadedCustomObj =
85+
torch::jit::pickle_load_obj(std::string{bytes.begin(), bytes.end()});
86+
EXPECT_TRUE(loadedCustomObj.isObject());
87+
EXPECT_EQ(
88+
loadedCustomObj.to<c10::intrusive_ptr<ContainsTensorDict>>()
89+
->t_[0]
90+
.item<int>(),
91+
1);
92+
}

0 commit comments

Comments
 (0)