Skip to content

Commit ced0cc4

Browse files
q10facebook-github-bot
authored andcommitted
TensorAccessor cleanup (#3973)
Summary: X-link: facebookresearch/FBGEMM#1059 - The existing `tensor_accessor.h` duplicates a lot of code from `ATen/core/TensorAccessor.h`. This diff removes the duplication and simplifies the class template specializations by using SFINAE methods instead. - Add unit tests to check that the index checking works. Reviewed By: sryap Differential Revision: D72940055
1 parent ce02bfd commit ced0cc4

File tree

3 files changed

+227
-1
lines changed

3 files changed

+227
-1
lines changed
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <ATen/ATen.h>
12+
#include <c10/core/ScalarType.h>
13+
#include <c10/macros/Macros.h>
14+
#include <c10/util/ArrayRef.h>
15+
#include <c10/util/Deprecated.h>
16+
#include <c10/util/Exception.h>
17+
#include <c10/util/irange.h>
18+
19+
#include <cstddef>
20+
#include <cstdint>
21+
22+
////////////////////////////////////////////////////////////////////////////////
23+
// Extended TensorAccessor
24+
//
25+
// This file contains TensorAccessor and PackedTensorAccessor implementations
26+
// that are used in FBGEMM_GPU for additional bounds checks that are not
27+
// available in the standard ATen implementation. Using the builder macro
28+
// MAKE_TA_WITH_NAME and MAKE_PTA_WITH_NAME, bounds checks can be enabled using
29+
// the FBGEMM_GPU_MEMCHECK flag.
30+
//
31+
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/core/TensorAccessor.h
32+
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/core/TensorBase.h
33+
////////////////////////////////////////////////////////////////////////////////
34+
35+
namespace fbgemm_gpu::utils {
36+
37+
template <typename T>
38+
using DefaultPtrTraits = at::DefaultPtrTraits<T>;
39+
40+
#if defined(__CUDACC__) || defined(__HIPCC__)
41+
template <typename T>
42+
using RestrictPtrTraits = at::RestrictPtrTraits<T>;
43+
#endif
44+
45+
static constexpr size_t NAME_MAX_LEN = 32;
46+
static constexpr size_t CONTEXT_MAX_LEN = 256;
47+
48+
C10_HOST_DEVICE inline void
49+
copy_str(char* dst, const char* src, const size_t max_len) {
50+
// If dst is nullptr, then skip.
51+
if (dst == nullptr) {
52+
return;
53+
}
54+
55+
// If src is nullptr or max_len is zero, then mark empty string and skip.
56+
if (src == nullptr || max_len == 0) {
57+
dst[0] = '\0';
58+
return;
59+
}
60+
61+
// Count src buffer length up to max_len
62+
size_t len = 0;
63+
for (len = 0; src[len] != 0 && len < max_len; len++) {
64+
// no action - calculating string length
65+
}
66+
len = len < (max_len - 1) ? len : (max_len - 1);
67+
68+
// Copy src to dst
69+
for (size_t i = 0; i < len; i++) {
70+
dst[i] = src[i];
71+
}
72+
dst[len] = '\0';
73+
}
74+
75+
////////////////////////////////////////////////////////////////////////////////
76+
// TensorAccessor
77+
//
78+
// This is an extension of at::TensorAccessorBase that consolidates some methods
79+
// defined in at::TensorAccessor.
80+
////////////////////////////////////////////////////////////////////////////////
81+
82+
template <
83+
typename T,
84+
size_t N,
85+
template <typename U> class PtrTraits = DefaultPtrTraits,
86+
typename index_t = int64_t>
87+
class TensorAccessor : public at::TensorAccessorBase<T, N, PtrTraits, index_t> {
88+
public:
89+
typedef typename PtrTraits<T>::PtrType PtrType;
90+
91+
C10_HOST_DEVICE TensorAccessor(
92+
const PtrType data_,
93+
const index_t* const sizes_,
94+
const index_t* const strides_,
95+
const char* const _name_,
96+
const char* const _context_)
97+
: at::TensorAccessorBase<T, N, PtrTraits, index_t>(
98+
data_,
99+
sizes_,
100+
strides_) {
101+
if (sizes_ && strides_) {
102+
numel_ = 1;
103+
for (size_t d = 0; d < N; d++) {
104+
numel_ += (sizes_[d] - 1) * strides_[d];
105+
}
106+
}
107+
108+
copy_str(name_, _name_, NAME_MAX_LEN);
109+
copy_str(context_, _context_, CONTEXT_MAX_LEN);
110+
}
111+
112+
template <size_t M = N>
113+
C10_HOST_DEVICE inline auto operator[](const index_t i)
114+
-> std::
115+
enable_if_t<(M > 1), TensorAccessor<T, N - 1, PtrTraits, index_t>> {
116+
return TensorAccessor<T, N - 1, PtrTraits, index_t>(
117+
this->data_ + this->strides_[0] * i,
118+
this->sizes_ + 1,
119+
this->strides_ + 1,
120+
this->name_,
121+
this->context_);
122+
}
123+
124+
template <size_t M = N>
125+
C10_HOST_DEVICE inline auto operator[](const index_t i) const
126+
-> std::enable_if_t<
127+
(M > 1),
128+
const TensorAccessor<T, N - 1, PtrTraits, index_t>> {
129+
return TensorAccessor<T, N - 1, PtrTraits, index_t>(
130+
this->data_ + this->strides_[0] * i,
131+
this->sizes_ + 1,
132+
this->strides_ + 1,
133+
this->name_,
134+
this->context_);
135+
}
136+
137+
template <size_t M = N>
138+
C10_HOST_DEVICE inline auto operator[](const index_t i)
139+
-> std::enable_if_t<(M == 1), T&> {
140+
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
141+
return this->at(this->strides_[0] * i);
142+
}
143+
144+
template <size_t M = N>
145+
C10_HOST_DEVICE inline auto operator[](const index_t i) const
146+
-> std::enable_if_t<(M == 1), const T&> {
147+
// NOLINTNEXTLINE(clang-analyzer-core.NullDereference)
148+
return this->at(this->strides_[0] * i);
149+
}
150+
151+
C10_HOST_DEVICE T& at(const index_t idx) const {
152+
if (idx < 0) {
153+
printf(
154+
"[%s][Tensor %s] ERROR: (idx=%ld) < 0\n",
155+
this->context_,
156+
this->name_,
157+
static_cast<int64_t>(idx));
158+
CUDA_KERNEL_ASSERT(idx >= 0);
159+
160+
} else if (idx >= numel_) {
161+
printf(
162+
"[%s][Tensor %s] ERROR: (idx=%ld) >= (numel=%ld)\n",
163+
this->context_,
164+
this->name_,
165+
static_cast<int64_t>(idx),
166+
static_cast<int64_t>(numel_));
167+
CUDA_KERNEL_ASSERT(idx < numel_);
168+
}
169+
170+
return this->data_[idx];
171+
}
172+
173+
protected:
174+
size_t numel_;
175+
char name_[NAME_MAX_LEN];
176+
char context_[CONTEXT_MAX_LEN];
177+
};
178+
179+
} // namespace fbgemm_gpu::utils

fbgemm_gpu/test/utils/kernel_launcher_test.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ TEST(KernelLauncherTest, kernel_launch_checks) {
307307
}
308308

309309
// NOTE: This test currently fails in fbcode CI for HIP with the following
310-
// error:
310+
// error (but runs without issues on both NVIDIA and AMD machines):
311311
//
312312
// void fbgemm_gpu::utils::always_fail_assertion_kernel(const int,
313313
// c10::hip::DeviceAssertionsData *const, uint32_t): Device-side assertion `(a
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <gtest/gtest.h>
11+
#include <torch/types.h> // @manual=//caffe2:torch-cpp-cpu
12+
13+
#include "fbgemm_gpu/utils/tensor_accessor2.h"
14+
15+
namespace fbgemm_gpu::utils {
16+
17+
TEST(TensorAccessorTest, tensor_access) {
18+
const auto tensor1 = torch::tensor(
19+
{{1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f},
20+
{1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f}},
21+
torch::kFloat32);
22+
23+
const auto tensor2 = torch::tensor(
24+
{{1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f},
25+
{1.0f, 1.1f, 1.2f, 1.3f, 42.0f, 1.5f, 1.6f, 1.7f}},
26+
torch::kFloat32);
27+
28+
auto accessor = TensorAccessor<float, 2, DefaultPtrTraits, int64_t>(
29+
static_cast<typename DefaultPtrTraits<float>::PtrType>(
30+
tensor1.data_ptr<float>()),
31+
tensor1.sizes().data(),
32+
tensor1.strides().data(),
33+
"tensor",
34+
"context");
35+
36+
// Accessor should work as expected
37+
accessor[1][4] = 42.0f;
38+
39+
EXPECT_TRUE(torch::equal(tensor1, tensor1))
40+
<< "tensor1 is not equal to tensor2";
41+
42+
#ifndef __HIPCC__
43+
EXPECT_DEATH({ accessor[10][20] = 3.14f; }, "idx < numel_");
44+
#endif
45+
}
46+
47+
} // namespace fbgemm_gpu::utils

0 commit comments

Comments
 (0)