forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmerge_id_lists_op.h
85 lines (69 loc) · 2.54 KB
/
merge_id_lists_op.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#ifndef CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
#define CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
#include <set>
#include <vector>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/export_caffe2_op_to_c10.h"
#include <c10/util/irange.h>
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(MergeIdLists);
namespace caffe2 {
template <class Context>
class MergeIdListsOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_SIMPLE_CTOR_DTOR(MergeIdListsOp);
template <typename T>
bool DoRunWithType() {
auto& first_lengths = Input(0);
CAFFE_ENFORCE_EQ(first_lengths.dim(), 1, "LENGTHS should be 1-D");
const auto batch_size = first_lengths.numel();
auto* out_lengths = Output(0, first_lengths.sizes(), at::dtype<int32_t>());
auto* out_lengths_data = out_lengths->template mutable_data<int32_t>();
/**
* Loop to figure out how much space to reserve for output
* and perform checks.
*/
auto M = 0;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
for (size_t i = 0; i < InputSize(); i += 2) {
auto& lengths = Input(i);
CAFFE_ENFORCE_EQ(lengths.dim(), 1, "LENGTHS should be 1-D");
CAFFE_ENFORCE_EQ(lengths.numel(), batch_size, "LENGTHS should be equal");
auto& values = Input(i + 1);
CAFFE_ENFORCE_EQ(values.dim(), 1, "VALUES should be 1-D");
M += values.numel();
}
auto* out_values = Output(1, {M}, at::dtype<T>());
T* out_values_data = out_values->template mutable_data<T>();
auto pos = 0;
// TODO(badri): Use unordered_set if performance is an issue
std::set<T> deduped;
std::vector<int> offsets(InputSize(), 0);
for (const auto sample : c10::irange(batch_size)) {
for (size_t i = 0; i < InputSize(); i += 2) {
auto& lengths = Input(i);
const auto* lengths_data = lengths.template data<int32_t>();
auto& values = Input(i + 1);
const T* values_data = values.template data<T>();
const auto length = lengths_data[sample];
for (auto j = offsets[i]; j < offsets[i] + length; j++) {
deduped.insert(values_data[j]);
}
offsets[i] += length;
}
for (auto val : deduped) {
out_values_data[pos++] = val;
}
out_lengths_data[sample] = deduped.size();
deduped.clear();
}
out_values->Resize(pos);
return true;
}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(this, Input(1));
}
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_