Skip to content

Commit

Permalink
Fix bug for ShapeTopKV2 don't mask 2 input, stridedslice don't add 4 …
Browse files Browse the repository at this point in the history
…for inputs, and GeometryPad's zero opt don't care about not mutable inputs
  • Loading branch information
xiaying committed Sep 5, 2023
1 parent db78514 commit 43b2406
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 25 deletions.
48 changes: 25 additions & 23 deletions source/geometry/GeometryCrop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,32 +210,34 @@ class GeometryPad : public GeometryComputer {
// Check Zero for inputs[2]
bool zero = false;
auto type = inputs[2]->getType();
switch (type.code) {
case halide_type_int:
{
if (type.bits == 8) {
zero = inputs[2]->host<int8_t>()[0] == 0;
} else if (type.bits == 32) {
zero = inputs[2]->host<int32_t>()[0] == 0;
if (!TensorUtils::getDescribe(inputs[2])->isMutable && inputs[2]->deviceId() == 0) {
switch (type.code) {
case halide_type_int:
{
if (type.bits == 8) {
zero = inputs[2]->host<int8_t>()[0] == 0;
} else if (type.bits == 32) {
zero = inputs[2]->host<int32_t>()[0] == 0;
}
}
}
break;
case halide_type_uint:
{
if (type.bits == 8) {
zero = inputs[2]->host<uint8_t>()[0] == 0;
} else if (type.bits == 32) {
zero = inputs[2]->host<uint32_t>()[0] == 0;
break;
case halide_type_uint:
{
if (type.bits == 8) {
zero = inputs[2]->host<uint8_t>()[0] == 0;
} else if (type.bits == 32) {
zero = inputs[2]->host<uint32_t>()[0] == 0;
}
}
break;
case halide_type_float:
{
zero = inputs[2]->host<float>()[0] == 0.0f;
}
break;
default:
break;
}
break;
case halide_type_float:
{
zero = inputs[2]->host<float>()[0] == 0.0f;
}
break;
default:
break;
}
if (zero) {
return true;
Expand Down
2 changes: 1 addition & 1 deletion source/shape/ShapeStridedSlice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,5 +294,5 @@ class StridedSliceComputer : public SizeComputer {
}
};

REGISTER_SHAPE_INPUTS(StridedSliceComputer, OpType_StridedSlice, (std::vector<int>{1,2,3}));
REGISTER_SHAPE_INPUTS(StridedSliceComputer, OpType_StridedSlice, (std::vector<int>{1,2,3,4}));
} // namespace MNN
2 changes: 1 addition & 1 deletion source/shape/ShapeTopKV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ class TopKV2SizeComputer : public SizeComputer {
}
};

REGISTER_SHAPE_INPUTS(TopKV2SizeComputer, OpType_TopKV2, {1});
REGISTER_SHAPE_INPUTS(TopKV2SizeComputer, OpType_TopKV2, (std::vector<int>{1,2}));
} // namespace MNN

0 comments on commit 43b2406

Please sign in to comment.