Skip to content

Commit

Permalink
Merge pull request #2513 from alibaba/feature/sync
Browse files Browse the repository at this point in the history
[MNN:Sync] Sync Internal 2.6.2
  • Loading branch information
jxt1234 authored Jul 31, 2023
2 parents d8266f9 + 84d6bd7 commit 8697ec8
Show file tree
Hide file tree
Showing 148 changed files with 11,925 additions and 5,443 deletions.
71 changes: 47 additions & 24 deletions 3rd_party/OpenCLHeaders/CL/cl2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2213,9 +2213,13 @@ class Platform : public detail::Wrapper<cl_platform_id>
{
// If default wasn't passed ,generate one
// Otherwise set it
cl_uint n = 0;

cl_int err = ::clGetPlatformIDs(0, NULL, &n);
// Default only check first card info
cl_uint n = 1;
cl_int err = CL_SUCCESS;

// Ignore platform number acquire
#if 0
err = ::clGetPlatformIDs(0, NULL, &n);
if (err != CL_SUCCESS) {
default_error_ = err;
return;
Expand All @@ -2224,7 +2228,7 @@ class Platform : public detail::Wrapper<cl_platform_id>
default_error_ = CL_INVALID_PLATFORM;
return;
}

#endif
vector<cl_platform_id> ids(n);
err = ::clGetPlatformIDs(n, ids.data(), NULL);
if (err != CL_SUCCESS) {
Expand Down Expand Up @@ -2466,15 +2470,17 @@ class Platform : public detail::Wrapper<cl_platform_id>
* Wraps clGetPlatformIDs().
*/
static cl_int get(
vector<Platform>* platforms)
vector<Platform>* platforms, int platformSize = 0)
{
cl_uint n = 0;
cl_uint n = platformSize;

if( platforms == NULL ) {
return detail::errHandler(CL_INVALID_ARG_VALUE, __GET_PLATFORM_IDS_ERR);
}

cl_int err = ::clGetPlatformIDs(0, NULL, &n);
cl_int err = CL_SUCCESS;
if(n == 0) {
err = ::clGetPlatformIDs(0, NULL, &n);
}
if (err != CL_SUCCESS) {
return detail::errHandler(err, __GET_PLATFORM_IDS_ERR);
}
Expand All @@ -2486,40 +2492,57 @@ class Platform : public detail::Wrapper<cl_platform_id>
}

// more than one gpu card
#if defined(_WIN32) || defined(__linux__) // Windows or Linux
if (n > 1) {
// first select nvidia gpu as discrete card, if multi gpu cards are available, x86_64 platform
// first select nvidia gpu as discrete card
//const char* integrate_gpu = "Intel";
const char* discrete_gpu = "NVIDIA";
const char* discrete_gpu_0 = "NVIDIA";
bool hasFirstPriority = false;
for (cl_uint i = 0; i < n; ++i) {
// get the length of platform name
size_t platform_name_length = 0;
err = clGetPlatformInfo(ids[i], CL_PLATFORM_NAME, 0, 0, &platform_name_length);
vector<char> platform_name(10240);
err = clGetPlatformInfo(ids[i], CL_PLATFORM_NAME, 10240, platform_name.data(), 0);
if (err != CL_SUCCESS) {
return detail::errHandler(err, __GET_PLATFORM_INFO_ERR);
}
// get platform name
char* platform_name = new char[platform_name_length];
err = clGetPlatformInfo(ids[i], CL_PLATFORM_NAME, platform_name_length, platform_name, 0);
if (err != CL_SUCCESS) {
delete[] platform_name;
return detail::errHandler(err, __GET_PLATFORM_INFO_ERR);
}
// if nvidia card is detected, set it as default ids[0]
if (strstr(platform_name, discrete_gpu)) {
if (strstr(platform_name.data(), discrete_gpu_0)) {
hasFirstPriority = true;
if (i == 0) {
delete[] platform_name;
break;
}
// swap
cl_platform_id tmp = ids[0];
ids[0] = ids[i];
ids[i] = tmp;
delete[] platform_name;
break;
}
delete[] platform_name;
}

// second select amd gpu as discrete card
if(!hasFirstPriority) {
const char* discrete_gpu_1 = "AMD";
for (cl_uint i = 0; i < n; ++i) {
vector<char> platform_name(10240);
err = clGetPlatformInfo(ids[i], CL_PLATFORM_NAME, 10240, platform_name.data(), 0);
if (err != CL_SUCCESS) {
return detail::errHandler(err, __GET_PLATFORM_INFO_ERR);
}
// if amd card is detected, set it as default ids[0]
if (strstr(platform_name.data(), discrete_gpu_1)) {
hasFirstPriority = true;
if (i == 0) {
break;
}
// swap
cl_platform_id tmp = ids[0];
ids[0] = ids[i];
ids[i] = tmp;
break;
}
}
}
}
#endif

if (platforms) {
platforms->resize(ids.size());
Expand Down
35 changes: 24 additions & 11 deletions benchmark/benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <MNN/Interpreter.hpp>
#include <MNN/MNNDefine.h>
#include <MNN/Tensor.hpp>
#include <MNN/AutoTime.hpp>
#include "revertMNNModel.hpp"

/**
Expand Down Expand Up @@ -119,7 +120,6 @@ std::vector<float> doBench(Model& model, int loop, int warmup = 10, int forward
int numberThread = 4, int precision = 2, float sparsity = 0.0f, int sparseBlockOC = 1, bool testQuantModel=false) {
auto revertor = std::unique_ptr<Revert>(new Revert(model.model_file.c_str()));
if (testQuantModel) {
printf("Auto set sparsity=0 when test quantized model in benchmark...\n");
revertor->initialize(0, sparseBlockOC, false, true);
} else {
revertor->initialize(sparsity, sparseBlockOC);
Expand Down Expand Up @@ -168,19 +168,19 @@ std::vector<float> doBench(Model& model, int loop, int warmup = 10, int forward
}

for (int round = 0; round < loop; round++) {
auto timeBegin = getTimeInUs();
MNN::Timer _t;
void* host = input->map(MNN::Tensor::MAP_TENSOR_WRITE, input->getDimensionType());
input->unmap(MNN::Tensor::MAP_TENSOR_WRITE, input->getDimensionType(), host);
net->runSession(session);
host = outputTensor->map(MNN::Tensor::MAP_TENSOR_READ, outputTensor->getDimensionType());
outputTensor->unmap(MNN::Tensor::MAP_TENSOR_READ, outputTensor->getDimensionType(), host);
auto timeEnd = getTimeInUs();
costs.push_back((timeEnd - timeBegin) / 1000.0);
auto time = (float)_t.durationInUs() / 1000.0f;
costs.push_back(time);
}
return costs;
}

void displayStats(const std::string& name, const std::vector<float>& costs) {
void displayStats(const std::string& name, const std::vector<float>& costs, int quant = 0) {
float max = 0, min = FLT_MAX, sum = 0, avg;
for (auto v : costs) {
max = fmax(max, v);
Expand All @@ -189,7 +189,11 @@ void displayStats(const std::string& name, const std::vector<float>& costs) {
//printf("[ - ] cost:%f ms\n", v);
}
avg = costs.size() > 0 ? sum / costs.size() : 0;
printf("[ - ] %-24s max = %8.3f ms min = %8.3f ms avg = %8.3f ms\n", name.c_str(), max, avg == 0 ? 0 : min, avg);
std::string model = name;
if (quant == 1) {
model = "quant-" + name;
}
printf("[ - ] %-24s max = %8.3f ms min = %8.3f ms avg = %8.3f ms\n", model.c_str(), max, avg == 0 ? 0 : min, avg);
}
static inline std::string forwardType(MNNForwardType type) {
switch (type) {
Expand Down Expand Up @@ -417,22 +421,31 @@ int main(int argc, const char* argv[]) {
testQuantizedModel = atoi(argv[9]);
}

std::cout << "Forward type: **" << forwardType(forward) << "** thread=" << numberThread << "** precision=" <<precision << "** sparsity=" <<sparsity << "** sparseBlockOC=" << sparseBlockOC << "** testQuantizedModel=" << testQuantizedModel << std::endl;
std::cout << "Forward type: " << forwardType(forward) << " thread=" << numberThread << " precision=" <<precision << " sparsity=" <<sparsity << " sparseBlockOC=" << sparseBlockOC << " testQuantizedModel=" << testQuantizedModel << std::endl;
std::vector<Model> models = findModelFiles(argv[1]);

std::cout << "--------> Benchmarking... loop = " << argv[2] << ", warmup = " << warmup << std::endl;
std::string fpInfType = "precision!=2, use fp32 inference.";
if (precision == 2) {
fpInfType = "precision=2, use fp16 inference if your device supports and open MNN_ARM82=ON.";
}
MNN_PRINT("[-INFO-]: %s\n", fpInfType.c_str());
if (testQuantizedModel) {
MNN_PRINT("[-INFO-]: Auto set sparsity=0 when test quantized model in benchmark...\n");
}

/* not called yet */
// set_cpu_affinity();
if (testQuantizedModel) {
printf("Auto set sparsity=0 when test quantized model in benchmark...\n");
}

for (auto& m : models) {
printf("Float model test...\n");
std::vector<float> costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, false);
displayStats(m.name, costs);
displayStats(m.name.c_str(), costs, false);
if (testQuantizedModel) {
printf("Quantized model test...\n");
costs = doBench(m, loop, warmup, forward, false, numberThread, precision, sparsity, sparseBlockOC, true);
displayStats(m.name, costs);
displayStats(m.name, costs, 1);
}
}
}
Expand Down
18 changes: 17 additions & 1 deletion docs/start/demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,22 @@ $ python quant_aware_training.py --model_file quant_demo/mobilenet_v2_tfpb_trai
![android_demo.png](../_static/images/start/android_demo.jpg)
## iOS Demo
### 模型下载与转换:
首先编译(如果已编译可以跳过)`MNNConvert`,操作如下:
```
cd MNN
mkdir build && cd build
cmake -DMNN_BUILD_CONVERTER=ON ..
make -j8
```
然后下载并转换模型:
切到编译了 MNNConvert 的目录,如上为 build 目录,执行
```
sh ../tools/script/get_model.sh
```
### 工程编译
代码位置:`project/ios`
使用`xcode`打开`project/ios/MNN.xcodeproj`, `target`选择`demo`,既可编译运行。
Expand Down Expand Up @@ -223,4 +239,4 @@ $ python quant_aware_training.py --model_file quant_demo/mobilenet_v2_tfpb_trai
- [人脸追踪](https://github.com/qaz734913414/MNN_FaceTrack)
- [视频抠图](https://github.com/DefTruth/RobustVideoMatting.lite.ai.toolkit)
- [SuperGlue关键点匹配](https://github.com/Hanson0910/MNNSuperGlue)
- [OCR](https://github.com/DayBreak-u/chineseocr_lite/tree/onnx/android_projects/OcrLiteAndroidMNN)
- [OCR](https://github.com/DayBreak-u/chineseocr_lite/tree/onnx/android_projects/OcrLiteAndroidMNN)
5 changes: 3 additions & 2 deletions docs/tools/convert.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,10 @@ model_script.save('model_script.pt')
- 测试 pb / tflite :安装`tensorflow`(`pip install tensorflow`
- 测试 onnx : 安装`onnxruntime`(`pip install onnxruntime`
- 测试 torchscript:安装`torch`(`pip install torch`)
- MNN模型转换工具编译完成(编译完成产生`TestConvertResult`可执行文件)
- 【可选】MNN模型转换工具编译完成(编译完成产生`MNNConvert`可执行文件)
### 使用
- 使用:在MNN的`build`目录下(包含`TestConvertResult`)运行`python3 testMNNFromTf.py SRC.pb`(Onnx为`python3 testMNNFromOnnx.py SRC.onnx`,Tflite 类似),若最终结果为`TEST_SUCCESS`则表示 MNN 的模型转换与运行结果正确
- 使用:在MNN的`build`目录下(包含`MNNConvert`)运行`python3 testMNNFromTf.py SRC.pb`(Onnx为`python3 testMNNFromOnnx.py SRC.onnx`,Tflite 类似),若最终结果为`TEST_SUCCESS`则表示 MNN 的模型转换与运行结果正确
- 若路径下面没有编译好的 MNNConvert 可执行文件,脚本会使用 pymnn 去进行校验
- 由于 MNN 图优化会去除 Identity ,有可能出现 find var error ,这个时候可以打开原始模型文件,找到 identity 之前的一层(假设为 LAYER_NAME )校验,示例:`python3 ../tools/script/testMNNFromTF.py SRC.pb LAYER_NAME`
- 完整实例如下(以onnx为例):
- 成功执行,当结果中显示`TEST_SUCCESS`时,就表示模型转换与推理没有错误
Expand Down
17 changes: 16 additions & 1 deletion docs/tools/test.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ Model Version: < 2.0.0
- `numberThread:int` 线程数仅对CPU有效,可选,默认为`4`
- `precision_memory:int` 测试精度与内存模式,precision_memory % 16 为精度,有效输入为:0(Normal), 1(High), 2(Low), 3(Low_BF16),可选,默认为`2` ; precision_memory / 16 为内存设置,默认为 0 (memory_normal) 。例如测试 memory 为 low (2) ,precision 为 1 (high) 时,设置 precision_memory = 9 (2 * 4 + 1)
- `inputSize:str` 输入tensor的大小,输入格式为:`1x3x224x224`,可选,默认使用模型默认输入


### 默认输入与输出
只支持单一输入、单一输出。输入为运行目录下的input_0.txt;输出为推理完成后的第一个输出tensor,转换为文本后,输出到output.txt中。
### 示例
Expand Down Expand Up @@ -67,13 +69,26 @@ Avg= 5.570600 ms, OpSum = 7.059200 ms min= 3.863000 ms, max= 11.596001 ms
`./ModuleBasic.out model dir [runMask forwardType runLoops numberThread precision_memory cacheFile]`
- `model:str` 模型文件路径
- `dir:str` 输入输出信息文件夹,可使用 fastTestOnnx.py / fastTestTf.py / fastTestTflite.py 等脚本生成,参考模型转换的正确性校验部分。
- `runMask:int` 是否输出推理中间结果,0为不输出,1为只输出每个算子的输出结果({op_name}.txt),2为输出每个算子的输入(Input_{op_name}.txt)和输出({op_name}.txt)结果; 默认输出当前目录的output目录下(使用工具之前要自己建好output目录),可选,默认为`0`
- `runMask:int` 默认为 0 ,为一系列功能的开关,如需开启多个功能,可把对齐的 mask 值相加(不能叠加的情况另行说明),具体见下面的 runMask 参数解析
- `forwardType:int` 执行推理的计算设备,有效值为:0(CPU)、1(Metal)、2(CUDA)、3(OpenCL)、6(OpenGL),7(Vulkan) ,9 (TensorRT),可选,默认为`0`
- `runLoops:int` 性能测试的循环次数,可选,默认为`0`即不做性能测试
- `numberThread:int` GPU的线程数,可选,默认为`1`
- `precision_memory:int` 测试精度与内存模式,precision_memory % 16 为精度,有效输入为:0(Normal), 1(High), 2(Low), 3(Low_BF16),可选,默认为`2` ; precision_memory / 16 为内存设置,默认为 0 (memory_normal) 。例如测试 memory 为 2(low) ,precision 为 1 (high) 时,设置 precision_memory = 9 (2 * 4 + 1)


### 默认输出
在当前目录 output 文件夹下,依次打印输出为 0.txt , 1.txt , 2.txt , etc

### runMask 参数说明
- 1 : 输出推理中间结果,每个算子的输入存到(Input_{op_name}.txt),输出存为({op_name}.txt), 默认输出当前目录的output目录下(使用工具之前要自己建好output目录),不支持与 2 / 4 叠加
- 2 : 打印推理中间结果的统计值(最大值/最小值/平均值),只支持浮点类型的统计,不支持与 1 / 4 叠加
- 4 : 统计推理过程中各算子耗时,不支持与 1 / 2 叠加,仅在 runLoops 大于 0 时生效
- 8 : shapeMutable 设为 false (默认为 true)
- 16 : 适用于使用 GPU 的情况,由 MNN 优先选择 CPU 运行,并将 GPU 的 tuning 信息存到 cache 文件,所有算子 tuning 完成则启用 GPU
- 32 : rearrange 设为 true ,降低模型加载后的内存大小,但会增加模型加载的初始化时间
- 64 : 创建模型后,clone 出一个新的模型运行,用于测试 clone 功能(主要用于多并发推理)的正确性


### 示例
```bash
$ python ../tools/script/fastTestOnnx.py mobilenetv2-7.onnx
Expand Down
2 changes: 1 addition & 1 deletion include/MNN/MNNDefine.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
#define STR(x) STR_IMP(x)
#define MNN_VERSION_MAJOR 2
#define MNN_VERSION_MINOR 6
#define MNN_VERSION_PATCH 1
#define MNN_VERSION_PATCH 2
#define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
#endif /* MNNDefine_h */
7 changes: 6 additions & 1 deletion include/MNN/MNNSharedContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ MNN_PUBLIC int MNNMetalGetTensorContent(MNNMetalTensorContent* content, void* te
#ifdef MNN_USER_SET_DEVICE

struct MNNDeviceContext {
uint32_t deviceId;
// When one gpu card has multi devices, choose which device. set deviceId
uint32_t deviceId = 0;
// When has multi gpu cards, choose which card. set platformId
uint32_t platformId = 0;
// User set number of gpu cards
uint32_t platformSize = 0;
};

#endif
Expand Down
4 changes: 2 additions & 2 deletions project/android/demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ cmake -DMNN_BUILD_CONVERTER=ON ..
make -j8
```

然后下载模型,转换模型并将模型拷贝到资源文件夹下
然后下载模型,可以直接执行 sh ../tools/script/get_model.sh ,也可以按如下步骤自行下载与转换
#### MobileNet_v2
```
wget https://github.com/shicai/MobileNet-Caffe/blob/master/mobilenet_v2.caffemodel
Expand All @@ -40,4 +40,4 @@ mv Portrait.tflite.mnn ../resource/model/Portrait/

## 2. 编译运行

使用`Android Studio`打开`demo`目录,在`local.properties`中指定`sdk.dir``ndk.dir`,即可编译执行。
使用`Android Studio`打开`demo`目录,在`local.properties`中指定`sdk.dir``ndk.dir`,即可编译执行。
Loading

0 comments on commit 8697ec8

Please sign in to comment.