Skip to content

Commit 5713190

Browse files
seanxcwangdoxutx
authored andcommitted
fix layernorm and graph ir logic
1 parent 49ed37c commit 5713190

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

source/tnn/device/cuda/acc/cuda_layer_norm_layer_acc.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ template<typename T>
6363
__global__ void ln_mul_add_kernel(const T *input, T *output, const T *scale, const T *bias,
6464
const LNFloat2 *mean_var,
6565
const int count, const float eps) {
66-
int offset = blockIdx.y * blockDim.y + threadIdx.x;
66+
int offset = blockIdx.y * blockDim.x + threadIdx.x;
6767
int total_offset = blockIdx.x * count + offset;
6868
if (offset < count) {
6969
const float* mean_var_float = reinterpret_cast<const float*>(mean_var);

source/tnn/optimizer/graph_matcher/ir.cc

+15
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,21 @@ namespace TNN_NS {
995995
// 5. remove unused Nodes
996996
// NB. we need to keep the original graph output tensor names un-changed.
997997

998+
auto return_check = [&]() {
999+
std::set<std::string> graph_outputs;
1000+
for (auto &output: g->outputs())
1001+
graph_outputs.insert(output->name);
1002+
1003+
return std::any_of(anchor->nodes.begin(), anchor->nodes.end(), [&](const std::shared_ptr<Node> &node) {
1004+
return std::any_of(node->info->outputs.begin(), node->info->outputs.end(), [&](const std::string &output) {
1005+
return (graph_outputs.count(output));
1006+
});
1007+
});
1008+
};
1009+
1010+
if (return_check())
1011+
return;
1012+
9981013
std::set<std::string> tensor_names;
9991014
for(auto & p : tensor_map) tensor_names.insert(p.first);
10001015
for(auto &name : tensor_names) renameTensor(name, name_prefix + name);

0 commit comments

Comments
 (0)