diff --git a/3rd_party/OpenCLHeaders/CL/cl2.hpp b/3rd_party/OpenCLHeaders/CL/cl2.hpp index 4db4f7cf6..305e88f30 100644 --- a/3rd_party/OpenCLHeaders/CL/cl2.hpp +++ b/3rd_party/OpenCLHeaders/CL/cl2.hpp @@ -403,10 +403,6 @@ # pragma message("cl2.hpp: USE_CL_DEVICE_FISSION is deprecated. Define CL_HPP_USE_CL_DEVICE_FISSION instead") # define CL_HPP_USE_CL_DEVICE_FISSION #endif -#if !defined(CL_HPP_ENABLE_EXCEPTIONS) && defined(__CL_ENABLE_EXCEPTIONS) -# pragma message("cl2.hpp: __CL_ENABLE_EXCEPTIONS is deprecated. Define CL_HPP_ENABLE_EXCEPTIONS instead") -# define CL_HPP_ENABLE_EXCEPTIONS -#endif #if !defined(CL_HPP_NO_STD_VECTOR) && defined(__NO_STD_VECTOR) # pragma message("cl2.hpp: __NO_STD_VECTOR is deprecated. Define CL_HPP_NO_STD_VECTOR instead") # define CL_HPP_NO_STD_VECTOR diff --git a/CMakeLists.txt b/CMakeLists.txt index a893e0854..7b940476e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,6 +63,7 @@ option(MNN_INTERNAL "Build with MNN internal features, such as model authenticat option(MNN_JNI "Build MNN Jni for java to use" OFF) option(MNN_SUPPORT_BF16 "Enable MNN's bf16 op" OFF) option(MNN_LOW_MEMORY "Build MNN support low memory for weight quant model." OFF) +option(MNN_CPU_WEIGHT_DEQUANT_GEMM "Build MNN CPU weight dequant related gemm kernels." OFF) IF (OHOS) include($ENV{NODE_PATH}/@ali/tcpkg/tcpkg.cmake) diff --git a/codegen/OpFuse.cpp b/codegen/OpFuse.cpp index 480825470..5e1ccbb91 100644 --- a/codegen/OpFuse.cpp +++ b/codegen/OpFuse.cpp @@ -275,7 +275,7 @@ bool codegen(std::vector& infos, std::vector cmdPlugin; + std::shared_ptr cmdPlugin; { auto sourceCode = fuseModule.codegen(); if(mapKernelSources.find(sourceCode) == mapKernelSources.end()) { diff --git a/docs/compile/cmake.md b/docs/compile/cmake.md index 092c9d1ec..95f9d5760 100644 --- a/docs/compile/cmake.md +++ b/docs/compile/cmake.md @@ -80,6 +80,7 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下: | MNN_OPENCV_BENCH | 构建MNN的OpenCV功能是否开启性能benchmark,默认为`OFF` | | MNN_VULKAN_IMAGE | 构建MNN的Vulkan后端时采用Image内存模式,以便支持FP16和部分移动端上GPU的加速,默认为`ON` | | MNN_LOW_MEMORY | 是否支持低内存模式,支持低内存模式使用权值量化模型并设置`low_memory`则会使用计算时反量化,默认为`OFF` | +| MNN_CPU_WEIGHT_DEQUANT_GEMM | 是否编译CPU权重反量化的矩阵乘Kernel, 如果打开该编译宏并且在CPU推理时设置MNN::BackendConfig::MemoryMode=Memory_Normal,就会使用权重反量化算子进行权重量化模型的推理,默认为`OFF` | | MNN_SUPPORT_RENDER | 是否支持图形渲染相关算子实现,默认为 `OFF` | | MNN_SUPPORT_TRANSFORMER_FUSE | 是否支持Fuse Transformer相关OP实现,默认为 `OFF` | | MNN_BUILD_LLM | 是否构建基于MNN的llm库和demo,默认为`OFF` | diff --git a/docs/contribute/backend.md b/docs/contribute/backend.md index b54f177f6..caa10ee2a 100644 --- a/docs/contribute/backend.md +++ b/docs/contribute/backend.md @@ -1,5 +1,7 @@ # 自定义后端 -Backend是MNN对计算设备的抽象。MNN当前已经支持CPU、Vulkan、OpenCL、Metal等Backend,**只在计算设备暂未支持时新增Backend**,新增Op,请参阅[新增Op文档](customize_op)。 +Runtime-Backend是MNN对计算设备的抽象。MNN当前已经支持CPU、Vulkan、OpenCL、Metal、CUDA等Backend,**只在计算设备暂未支持时新增Backend**,新增Op,请参阅[新增Op文档](op)。 + + ## 声明 所有新增Backend都需继承`Backend`类,并实现所有纯虚函数。 @@ -10,8 +12,10 @@ class XPUBackend final : public Backend { virtual Execution* onCreate(const std::vector& inputs, const std::vector& outputs, const MNN::Op* op) override; virtual void onExecuteBegin() const override; virtual void onExecuteEnd() const override; - virtual bool onAcquireBuffer(const Tensor* tensor, StorageType storageType) override; - virtual bool onReleaseBuffer(const Tensor* tensor, StorageType storageType) override; + virtual void onResizeBegin() override; + virtual ErrorCode onResizeEnd() override; + + virtual MemObj* onAcquire(const Tensor* tensor, StorageType storageType) override; virtual bool onClearBuffer() override; virtual void onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor) const override; } @@ -91,7 +95,7 @@ static XPUCreatorRegister __reg(OpType_Pooling); ``` ## 内存管理 -Backend通过`onAcquireBuffer`为tensor分配内存,通过`onReleaseBuffer`为tensor释放内存。内存有三种存储模式:`STATIC`内存不复用,一般用于op常量存储;`DYNAMIC`内存可复用,一般用于变量存储;`DYNAMIC_SEPERATE`内存在pipeline间可复用,一般用于pipeline常量存储。`_onAcquireBuffer_`_和_`_onReleaseBuffer_`_中可以不实际分配/释放内存,只记录内存用量变更,在_`_onAllocateBuffer_`_调用时,再根据用量计算出优化方案,一次性完成分配/释放。_ +Backend通过`onAcquire`创建`MemObj`内存对象,定义其析构函数以便为tensor释放内存。内存有三种存储模式:`STATIC`内存不复用,一般用于op常量存储;`DYNAMIC`内存可复用,一般用于变量存储;`DYNAMIC_SEPERATE`内存在pipeline间可复用,一般用于pipeline常量存储。 ```cpp /** backend buffer storage type */ @@ -118,31 +122,13 @@ enum StorageType { */ DYNAMIC_SEPERATE }; -/** - * @brief allocate buffer of tensor for given storage type. - * @param tensor buffer provider. - * @param storageType buffer storage type. - * @return success or not. - */ -virtual bool onAcquireBuffer(const Tensor* tensor, StorageType storageType) = 0; -/** - * @brief release buffer of tensor for given storage type. - * @param tensor buffer provider. - * @param storageType buffer storage type. - * @return success or not. - */ -virtual bool onReleaseBuffer(const Tensor* tensor, StorageType storageType) = 0; -``` - -在所有内存都分配完成后,backend会收到`onAllocateBuffer`回调: -```cpp -/** - * @brief callback after all buffers needed by backend ops were allocated. - * @return success or not. (result not used currently) - */ -virtual bool onAllocateBuffer() { - return true; -} + /** + * @brief allocate buffer of tensor for given storage type. + * @param tensor buffer provider. + * @param storageType buffer storage type. + * @return MemObj for release, if failed, return nullptr. + */ + virtual MemObj* onAcquire(const Tensor* tensor, StorageType storageType) = 0; ``` Backend在调用`onClearBuffer`时,需要释放所有`DYNAMIC`和`DYNAMIC_SEPERATE`存储模式的内存: @@ -189,17 +175,47 @@ virtual void onExecuteEnd() const = 0; ``` -## 注册Backend -最后,定义Backend Creator,注册方法中调用`MNNInsertExtraBackendCreator`就可以完成Backend的注册,这里的注册方法需要在BackendRegister.cpp中声明并调用: +## Runtime(运行时) +对于使用同一种后端,且存在先后顺序,不会同时运行的模型,MNN提供机制使其共享部分计算资源,比如线程池,内存池等等。 +这部分计算资源使用Runtime存储。而Backend则由Runtime创建 + +### 实现Runtime +Runtime主要实现如下接口: + +``` + virtual Backend* onCreate(const BackendConfig* config = nullptr, Backend* origin = nullptr) const = 0; + + /** + @brief reset runtime + */ + virtual void onReset(int numberThread, const BackendConfig* config, bool full) { + // Do nothing + } + + /** + @brief clear unuseful resource + @param level clear level: 0 - 100, bigger mean clear more, smaller mean cache more + */ + virtual void onGabageCollect(int level) = 0; + +``` + +- onCreate :创建 Backend +- onReset :重设默认配置 +- onGabageCollect :清理资源以节省内存 + + +### 注册Runtime +注册方法中调用`MNNInsertExtraRuntimeCreator`就可以完成Runtime的注册,这里的注册方法需要在Backend.cpp中声明并调用: ```cpp -class XPUBackendCreator : public BackendCreator { - virtual Backend *onCreate(const Backend::Info &info) const { - return new MetalBackend; +class XPURuntimeCreator : public RuntimeCreator { + virtual Runtime* onCreate(const Backend::Info &info) const { + return new XPURuntime; } }; -void registerCPUBackendCreator() { - MNNInsertExtraBackendCreator(MNN_FORWARD_CPU, new CPUBackendCreator); +void registerXPURuntimeCreator() { + MNNInsertExtraBackendCreator(MNN_FORWARD_XPU, new XPURuntimeCreator); }; ``` -使用cmake编译时,完成代码修改后,也需要相应修改CMakeLists.txt。 \ No newline at end of file +使用cmake编译时,完成代码修改后,也需要相应修改CMakeLists.txt。 diff --git a/docs/contribute/op.md b/docs/contribute/op.md index 059a84d25..7a28397de 100644 --- a/docs/contribute/op.md +++ b/docs/contribute/op.md @@ -1,6 +1,14 @@ # 自定义算子 ## 概述 -在添加自定义算子前,请参阅[算子列表](../en/ops),避免不必要的重复。 +在添加自定义算子前,请查看算子列表,避免不必要的重复。 + +```bash +./MNNConvert -f CAFFE --OP +./MNNConvert -f TF --OP +./MNNConvert -f ONNX --OP +./MNNConvert -f TORCH --OP +``` + ### MNN 算子转换与实现结构 MNN 的算子转换与实现如下图, - 模型转换包括以下步骤,二选一: diff --git a/docs/faq.md b/docs/faq.md index c9b6344c0..db7241f12 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -250,7 +250,7 @@ OpenCL / Vulkan 采用静态变量自注册的方式往 MNN 主库注册后端. ## 性能相关 -### 使用 GPU 时,调用 copyToHostTensor / copyFromHostTensor 非常慢 +### 使用 GPU 时,调用 copyToHostTensor / readMap 非常慢 GPU 后端调用 copy 的时间包含两个部分 - 异构数据拷贝 @@ -258,7 +258,7 @@ GPU 后端调用 copy 的时间包含两个部分 对 GPU 后端而言,在数据被要求对用户可见(比如复制 output tensor 数据出来)之前,是允许异步执行的。 在数据被用户要求可见之时,会等待相应的异步操作完成。 -因此有可能 复制 output tensor 的过程包括了等待 GPU 算子异步执行完成,导致缓慢。 +因此有可能 复制 output tensor 的过程包括了等待 GPU 算子异步执行完成,导致看上去缓慢。 ### GPU 为什么比 CPU 跑得慢? 有如下原因: diff --git a/docs/index.rst b/docs/index.rst index 827a85235..174a53cf3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -72,7 +72,7 @@ .. toctree:: :maxdepth: 1 - :caption: 测试工具 + :caption: 工具集 :name: tools tools/convert diff --git a/docs/inference/module.md b/docs/inference/module.md index 7ec90a8a4..3fea3589a 100644 --- a/docs/inference/module.md +++ b/docs/inference/module.md @@ -5,19 +5,25 @@ - 模型推理与`Session`的区别是不需要用户显式resize,支持控制流,所以当模型中有`if`或`while`时必须使用`Module`推理 ### 相关数据结构 - `Module` Module接口的核心类,表示一个模型的虚类;实际加载模型时会创建其子类 -- `Executor` 包含若干个`RuntimeManager`,提供内存管理接口,每个`Executor`必须在单线程环境下运行。默认提供全局 `Executor`,需要并发执行时,可自行创建。 -- `ExecutorScope` 用于在子线程中绑定`Executor`,多线程并发必需 -- `VARP` 作为`Module`的输入输出,也是[Expr API](expr.md)中的基础数据结构 +- `Executor` 提供内存管理和后端资源管理能力,每个`Executor`必须在单线程环境下运行。同一个`Executor`可以用于多个顺序执行的`Module` +- `ExecutorScope` 用于在子线程中绑定`Executor`,多线程并发必需。默认在创建`Module`时使用全局 `Executor`,如果有多个Module在不同线程并发执行时,需要各自创建`Executor`,并用`ExecutorScope`绑定。 +- `VARP` 是`Module`的输入输出,也是[Expr API](expr.md)中的基础数据结构 ## 工作流程 -配置Executor(可选) -> 创建 RuntimeManager(可选) -> 创建Module -> 创建输入VARP -> 使用Module::forwad推理 -> 使用输出VARP -> 销毁Module -### (可选)配置Executor -`Executor`给用户提供接口来配置推理后端、线程数等属性,以及做性能统计、算子执行的回调函数、内存回收等功能。 提供一个全局的Exector对象,用户不用创建或持有对象即可直接使用。 +创建和配置Executor -> 创建 RuntimeManager(可选) -> 创建Module -> 创建输入VARP -> 使用Module::forwad推理 -> 使用输出VARP -> 销毁Module -> 销毁Executor +### 创建和配置Executor +`Executor`给用户提供接口来配置推理后端、线程数等属性,以及做性能统计、算子执行的回调函数、内存回收等功能。 推荐针对自身模块创建单独的Executor ,若使用全局的Exector对象,对于多个模块在不同线程运行时可能会发生冲突。 ```cpp -// 配置默认全局Exector -MNN::BackendConfig backend_config; // default backend config +// 创建Exector +MNN::BackendConfig backendConfig; // default backend config +std::shared_ptr executor = MNN::Express::Executor::newExecutor(MNN_FORWARD_CPU, backendConfig, 1); + // 设置使用4线程+CPU -MNN::Express::Executor::getGlobalExecutor()->setGlobalExecutorConfig(MNN_FORWARD_CPU, backend_config, 4); +executor->setGlobalExecutorConfig(MNN_FORWARD_CPU, backend_config, 4); + +// 绑定Executor,在创建/销毁/使用Module或进行表达式计算之前都需要绑定 +MNN::Express::ExecutorScope _s(executor); + ``` ### (可选)创建 RuntimeManager @@ -39,6 +45,68 @@ std::shared_ptr rtmgr(MNN::Express::Exec rtmgr->setCache(".cachefile"); ``` +RuntimeManager 可以设置 hint , mode , cache, externalpath ,以支持扩展功能。 + +``` +void setCache(std::string cacheName); +void updateCache(); +void setMode(Interpreter::SessionMode mode); +void setHint(Interpreter::HintMode mode, int value); +void setExternalPath(std::string path, int type); +bool getInfo(Interpreter::SessionInfoCode code, void* ptr); +``` + +#### cache 设置 +对于GPU后端(Metal/OpenCL等),可以设置缓存文件路径,存储AutoTuning结果和Program编译结果,以加速第二次之后的Module load 过程。 + +``` + std::shared_ptr rtmgr(Executor::RuntimeManager::createRuntimeManager(config)); + rtmgr->setCache(cacheFileName); + + std::shared_ptr module(Module::load(inputNames, outputNames, modelName.c_str(), rtmgr, mdConfig)); + /*... Make Inputs*/ + auto outputs = module->onForward(inputs); + + // Update cache file + rtmgr->updateCache(); +``` + +#### mode 设置 +可以通过设置mode开启/关闭一些功能,示例: + +``` +// 创建出来的 Module 支持插入回调函数 +rtmgr->setMode(Interpreter::Session_Debug); +``` + +并非所有枚举都适用 Module 的创建,有效值如下: + +- Interpreter::SessionMode::Session_Debug : 支持逐算子调试 +- Interpreter::SessionMode::Session_Release : 关闭逐算子调试功能,可以轻微提升性能【默认选项】 +- Interpreter::SessionMode::Session_Backend_Fix : 固定使用用户设置的后端【默认选项】 +- Interpreter::SessionMode::Session_Backend_Auto : MNN根据用户倾向,预估load Module耗时,如果耗时较短则使用用户设置的后端,否则使用CPU + + +#### hint 设置 +通过 hint 设置,可以在后端支持的情况下设置相应属性,有效值如下: + +- Interpreter::HintMode::WINOGRAD_MEMORY_LEVEL :使用 Winograd 算法优化卷积时,内存占用倾向,默认为 3 ,若希望降低内存占用可设为 0 +- Interpreter::HintMode::GEOMETRY_COMPUTE_MASK :几何计算相关优化开关,1为区域合并,2为复合区域合并,4为使用loop算子,8为支持几何计算重计算,需要多个功能开启时把对应值叠加。默认为功能全开。 +- Interpreter::HintMode::DYNAMIC_QUANT_OPTIONS :动态量化选项,1为 Per Batch,2为Per Tensor 。默认为2。 +- Interpreter::HintMode::CPU_LITTLECORE_DECREASE_RATE :对于 Android 设备存在大中小核的情况,大核算力到中核算力的衰减比例。默认为50(中核算力为大核的50%) + + +#### ExternalPath +在设备可能出现内存不足时,可以通过 setExternalPath 指定路径,让MNN把部分内存用mmap分配。这样操作系统可在内存不足时会将其转换为读写文件,避免内存不足程序闪退。示例: + +``` +runtime_manager_->setExternalPath("tmp", MNN::Interpreter::EXTERNAL_WEIGHT_DIR); +runtime_manager_->setExternalPath("tmp", MNN::Interpreter::EXTERNAL_FEATUREMAP_DIR); +``` + +- MNN::Interpreter::EXTERNAL_WEIGHT_DIR : 权重重排后的内存转换为文件存储 +- MNN::Interpreter::EXTERNAL_FEATUREMAP_DIR : 中间内存转换为文件存储 + ### 创建Module `Module`可以通过指定模型,输入输出的名称,配置文件创建 ```cpp diff --git a/docs/start/overall.md b/docs/start/overall.md index 02203b20e..de13f21e0 100644 --- a/docs/start/overall.md +++ b/docs/start/overall.md @@ -6,6 +6,6 @@ ### 训练 在训练框架上,根据训练数据训练出模型的阶段。虽然当前MNN也提供了[训练模型的能力](../train/expr.md),但主要用于端侧训练或模型调优。在数据量较大时,依然建议使用成熟的训练框架,如TensorFlow、PyTorch等。除了自行训练外,也可以直接利用开源的预训练模型。 ### 转换 -将其他训练框架模型转换为MNN模型的阶段。MNN当前支持Tensorflow(Lite)、Caffe、ONNX和TorchScript的模型转换。模型转换工具可以参考[编译文档](../compile/tools.html#id2)和[使用说明](../tools/convert.md)。支持转换的算子,可以参考[算子列表文档](../tools/convert.html#id7);在遇到不支持的算子时,可以尝试[自定义算子](../contribute/op.md),或在Github上给我们[提交issue](https://github.com/alibaba/MNN/issues/74)。此外,[模型打印工具](../tools/convert.html#id8)可以用于输出模型结构,辅助调试。除模型转换外,MNN也提供了[模型量化工具](../tools/quant.md),可以对浮点模型进行量化压缩。 +将其他训练框架模型转换为MNN模型的阶段。MNN当前支持Tensorflow(Lite)、Caffe、ONNX和TorchScript的模型转换。模型转换工具可以参考[使用说明](../tools/convert.md)。支持转换的算子,可以参考[算子列表文档](../tools/convert.html#id7);在遇到不支持的算子时,可以尝试[自定义算子](../contribute/op.md),或在Github上给我们[提交issue](https://github.com/alibaba/MNN/issues/74)。此外,[模型打印工具](../tools/convert.html#id8)可以用于输出模型结构,辅助调试。除模型转换外,MNN也提供了[模型量化工具](../tools/quant.md),可以对浮点模型进行量化压缩。 ### 推理 -在端侧加载MNN模型进行推理的阶段。端侧运行库的编译请参考各平台的编译文档:[iOS](../compile/engine.html#ios)、[Android](../compile/engine.html#android)、[Linux/macOS/Ubuntu](../compile/engine.html#linux-macos)、[Windows](../compile/engine.html#windows)。我们提供了[API接口文档](https://github.com/alibaba/MNN/tree/master/doc/API),也详细说明了[会话创建](../inference/session.html#id1)、[数据输入](../inference/session.html#id8)、[执行推理](../inference/session.html#id17)、[数据输出](../inference/session.html#id21)相关的接口和参数。`demo/exec`下提供了使用示例,如图像识别 `demo/exec/pictureRecognition.cpp` ,图像实例分割(人像分割)`demo/exec/segment.cpp`,[更多demo](demo.md)。此外,[测试工具](../tools/test.md)和[benchmark工具](../tools/benchmark.md)也可以用于问题定位。 \ No newline at end of file +在端侧加载MNN模型进行推理的阶段。端侧运行库的编译请参考各平台的编译文档:[iOS](../compile/engine.html#ios)、[Android](../compile/engine.html#android)、[Linux/macOS/Ubuntu](../compile/engine.html#linux-macos)、[Windows](../compile/engine.html#windows)。我们提供了[API接口文档](https://github.com/alibaba/MNN/tree/master/doc/API),也详细说明了[会话创建](../inference/session.html#id1)、[数据输入](../inference/session.html#id8)、[执行推理](../inference/session.html#id17)、[数据输出](../inference/session.html#id21)相关的接口和参数。`demo/exec`下提供了使用示例,如图像识别 `demo/exec/pictureRecognition.cpp` ,图像实例分割(人像分割)`demo/exec/segment.cpp`,[更多demo](demo.md)。此外,[测试工具](../tools/test.md)和[benchmark工具](../tools/benchmark.md)也可以用于问题定位。 diff --git a/docs/tools/convert.md b/docs/tools/convert.md index fdc707bc1..b815405bf 100644 --- a/docs/tools/convert.md +++ b/docs/tools/convert.md @@ -1,5 +1,4 @@ # 模型转换工具 -[从源码编译](../compile/tools.html#id2) ## 参数说明 ```bash Usage: diff --git a/docs/tools/quant.md b/docs/tools/quant.md index c2a26d1d5..1a66b6e1b 100644 --- a/docs/tools/quant.md +++ b/docs/tools/quant.md @@ -1,7 +1,7 @@ # 单输入模型离线量化工具 `./quantized.out origin.mnn quan.mnn imageInputConfig.json` -通用(任意输入个数、维度、类型)模型离线量化请看[说明](https://mnn-docs.readthedocs.io/en/latest/tools/compress.html#id10) +MNN quantized.out工具已支持通用(任意输入个数、维度、类型)模型离线量化, 但这里的多输入模型仅仅支持非图片输入类模型。 MNN现已推出基于TensorFlow/Pytorch的模型压缩工具mnncompress,请查看[文档](https://mnn-docs.readthedocs.io/en/latest/tools/compress.html)选择使用 @@ -38,6 +38,10 @@ MNN现已推出基于TensorFlow/Pytorch的模型压缩工具mnncompress,请查 | MAX_ABS | 使用权值的绝对值的最大值进行对称量化 | | ADMM | 使用ADMM方法进行权值量化 | +## 多输入模型的参数设置的特别说明(MNN现阶段仅支持输入数据类型是非图片的多输入模型) +| input_type | `str` | 输入数据的类型,"sequence" | +| path | `str` | 存放校正特征量化系数的输入数据目录 |,例如该目录下包含2个输入数据集input_0和input_1,子目录input_0和input_1中包含模型的输入数据和一个input.json文件。input_0和input_1分别是两个输入输出信息文件夹,可使用 testMNNFromOnnx.py 等脚本生成,参考模型转换的正确性校验部分。 + ## 量化模型的使用 和浮点模型同样使用方法,输入输出仍然为浮点类型 ## 参考资料 diff --git a/docs/tools/test.md b/docs/tools/test.md index 532877f9e..02c2d3df0 100644 --- a/docs/tools/test.md +++ b/docs/tools/test.md @@ -1,5 +1,5 @@ # 测试工具 -[从源码编译](../compile/tools.html#id4)使用cmake编译时,build目录下的产物也包含测试使用的工具集,下面逐项说明。 +使用cmake编译时,默认打开 MNN_BUILD_TOOLS 编译宏,对应build目录下的产物也包含测试使用的工具集,下面逐项说明。 ## GetMNNInfo ### 功能 @@ -95,6 +95,7 @@ Avg= 5.570600 ms, OpSum = 7.059200 ms min= 3.863000 ms, max= 11.596001 ms - 128 : 使用文件夹下面的 input.mnn 和 output.mnn 做为输入和对比输出,对于数据量较大的情况宜用此方案 - 512 : 开启使用Winograd算法计算卷积时的内存优化,开启后模型的运行时内存会降低,但可能导致性能损失。 - 1024: 使用动态量化推理时,对输入数据分batch量化以提高模型的推理精度 +- 2048: 使用mmap方式,使用文件存储中间内存。存储文件的目录为当前目录/tmp,需要先建tmp文件夹 ### 示例 @@ -262,19 +263,10 @@ stopOp.c_str()=s in main, 278 Correct ! Run second pass Correct ! ``` -### 在Android中使用 -先编译相关的库和可执行文件,然后push到Android手机上,用adb执行命令,参考`project/android/testCommon.sh` -```bash -cd project/android -mkdir build_64 -cd build_64 && ../build_64.sh -../updateTest.sh -../testCommon.sh ./backendTest.out temp.mnn 3 0.15 1 -``` ## getPerformance ### 功能 -获取当前设备的CPU性能,打印出每个CPU核心的频率;在Android设备上还会打印该设备CPU的浮点计算能力(GFLOPS) +获取当前设备的CPU性能和内存访问性能,打印出每个CPU核心的频率;在Android设备上还会打印该设备CPU的浮点计算能力(GFLOPS) *各核心频率仅在Linux/Android环境中有效,计算能力仅在Android中有效* ### 参数 @@ -475,6 +467,7 @@ Matrix: ### 示例 ```bash $ ./fuseTest user.spirv user.json +``` ## GpuInterTest.out ### 功能 @@ -488,3 +481,22 @@ GPU 内存输入测试用例 - `forwardType:int` 执行推理的计算设备,有效值为:0(CPU)、1(Metal)、2(CUDA)、3(OpenCL)、6(OpenGL),7(Vulkan) ,9 (TensorRT),可选,默认为`0` - `numberThread:int` GPU的线程数,可选,默认为`1` - `precision_memory:int` 测试精度与内存模式,precision_memory % 16 为精度,有效输入为:0(Normal), 1(High), 2(Low), 3(Low_BF16),可选,默认为`2` ; precision_memory / 16 为内存设置,默认为 0 (memory_normal) 。例如测试 memory 为 2(low) ,precision 为 1 (high) 时,设置 precision_memory = 9 (2 * 4 + 1) + + +## 在Android中使用测试工具 +- project/android/updateTest.sh 可以把编译好的库和可执行文件 push 到Android手机的/data/local/tmp/MNN 目录 +- project/android/testCommon.sh 可以在 /data/local/tmp/MNN 目录下执行可执行程序 + +其他的资源文件需要自行使用 adb push ,将其放到手机的 /data/local/tmp/MNN 目录下,比如 adb push temp.mnn /data/local/tmp/MNN/temp.mnn + +如下例子是在Android设备上使用 backendTest.out ,其中 temp.mnn 路径为 /data/local/tmp/MNN/temp.mnn + +```bash +cd project/android +mkdir build_64 +cd build_64 && ../build_64.sh +../updateTest.sh +../testCommon.sh ./backendTest.out temp.mnn 3 0.15 1 +``` + + diff --git a/docs/transformers/diffusion.md b/docs/transformers/diffusion.md index 70e64766b..5c6d341fb 100644 --- a/docs/transformers/diffusion.md +++ b/docs/transformers/diffusion.md @@ -17,8 +17,8 @@ https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/tree/mai ## 模型转换 ### 将Huggingface的Stable Diffusion模型 转为onnx模型 ```sh -cd mnn_path/transformers/diffusion/ -python export/onnx_export.py \ +cd mnn_path/transformers/diffusion/export +python onnx_export.py \ --model_path hf_sd_load_path \ --output_path onnx_save_path ``` @@ -30,20 +30,19 @@ conda activate ldm 在conda环境中执行模型转换脚本 ### 将onnx模型转为mnn模型 -新建diffusion mnn模型文件夹,将转好的mnn文件放在该文件夹下。 -1. 实现encoder从onnx模型 -> mnn模型 -``` -./MNNConvert -f ONNX --modelFile onnx_save_path/text_encoder/model.onnx --MNNModel mnn_save_path/text_encoder.mnn --weightQuantBits 8 --bizCode biz -``` -2. 实现denoiser unet从onnx模型 -> mnn模型 +新建diffusion mnn模型文件夹 mnn_save_path ,将转好的mnn文件放在该文件夹下。 + +执行脚本 ``` -./MNNConvert -f ONNX --modelFile onnx_save_path/unet/model.onnx --MNNModel mnn_save_path/unet.mnn --transformerFuse --weightQuantBits 8 --bizCode biz -注意:对于非OpenCL后端推理,需要去掉--transformerFuse。 +python3 convert_mnn.py ../onnx ~/alicnn/AliNNPrivate/build/diffusion "--weightQuantBits=8" ``` -3. 实现decoder从onnx模型 -> mnn模型 + +若希望在OpenCL后端上进一步加速,可加上--transformerFuse: ``` -./MNNConvert -f ONNX --modelFile onnx_save_path/vae_decoder/model.onnx --keepInputFormat --MNNModel mnn_save_path/vae_decoder.mnn --weightQuantBits 8 --bizCode biz +# 适用OpenCL 后端推理 +python3 convert_mnn.py onnx_path mnn_save_path "--weightQuantBits=8 --transformerFuse" ``` + ## 编译Diffusion Demo ### Linux/MAC/Windows上 ``` diff --git a/docs/transformers/llm.md b/docs/transformers/llm.md index 5e77ab0cb..0d00de862 100644 --- a/docs/transformers/llm.md +++ b/docs/transformers/llm.md @@ -6,68 +6,59 @@ ## 模型导出 -`llm_export`是一个llm模型导出工具,能够将llm模型导出为onnx和mnn模型。 +`llmexport`是一个llm模型导出工具,能够将llm模型导出为onnx和mnn模型。 ### 用法 1. 将需要导出的LLM项目clone到本地,如:Qwen2-0.5B-Instruct ```sh git clone https://www.modelscope.cn/qwen/Qwen2-0.5B-Instruct.git ``` -3. 执行`llm_export.py`导出模型 +3. 执行`llmexport.py`导出模型 ```sh cd ./transformers/llm/export # 导出模型,tokenizer和embedding,并导出对应的mnn模型 -python llm_export.py \ - --type Qwen2-0_5B-Instruct \ +python llmexport.py \ --path /path/to/Qwen2-0.5B-Instruct \ - --export \ - --export_token \ - --export_embed --embed_bin \ - --export_mnn + --export mnn ``` 4. 导出产物 导出产物为: -1. `embeddings_bf16.bin`: 模型的embedding权重二进制文件,推理时使用; -2. `llm_config.json`: 模型的配置信息,推理时使用; -3. `llm.onnx`: 模型的onnx文件,推理时不使用; -4. `tokenizer.txt`: 模型的tokenzier文件,推理时使用; -5. `llm.mnn`: 模型的mnn文件,推理时使用; -6. `llm.mnn.weight`: 模型的mnn权重,推理时使用; +1. `config.json`: 模型运行时的配置,可手动修改; +2. `embeddings_bf16.bin`: 模型的embedding权重二进制文件,推理时使用; +3. `llm.mnn`: 模型的mnn文件,推理时使用; +4. `llm.mnn.json`: mnn模型对应的json文件,apply_lora或者gptq量化权重时使用; +5. `llm.mnn.weight`: 模型的mnn权重,推理时使用; +6. `llm.onnx`: 模型的onnx文件,不包含权重,推理时不使用; +7. `llm_config.json`: 模型的配置信息,推理时使用; +8. `tokenizer.txt`: 模型的tokenzier文件,推理时使用; 目录结构如下所示: ``` . -├── onnx -| ├── embeddings_bf16.bin -| ├── llm_config.json -| ├── llm.onnx -| └── tokenizer.txt -└── mnn +└── model + ├── config.json + ├── embeddings_bf16.bin ├── llm.mnn - └── llm.mnn.weight + ├── llm.mnn.json + ├── llm.mnn.weight + ├── llm.onnx + ├── llm_config.json + └── tokenizer.txt ``` ### 功能 -- 支持将模型完整导出为一个onnx模型,使用`--export` -- 支持将模型分段导出为多个模型,使用`--export_split` -- 支持导出模型的词表到一个文本文件,每行代表一个token;其中token使用base64编码;使用`--export_verbose` -- 支持导出模型的Embedding层为一个onnx模型,使用`--export_embed`,同时支持bf16格式,使用`--embed_bf16` -- 支持分层导出模型的block,使用`--export_blocks`导出全部层;使用`--export_block $id`导出指定层 -- 支持导出模型的lm_head层为一个onnx模型,使用`--export_lm` -- 支持导出多模态模型的visual模型为一个onnx模型,使用`--export_visual` +- 支持将模型为onnx或mnn模型,使用`--export onnx`或`--export mnn` - 支持对模型进行对话测试,使用`--test $query`会返回llm的回复内容 -- 支持在导出onnx模型后使用onnxruntime对结果一致性进行校验,使用`--export_test` -- 支持将tokenizer导出为文本文件,使用`--export_token` -- 支持将导出的onnx模型转换为mnn模型,默认转换为非对称4bit量化,使用`--export_mnn` -- 指定导出路径使用`--onnx_path`和`--mnn_path` - 默认会使用onnx-slim对onnx模型进行优化,跳过该步骤使用`--skip_slim` - 支持合并lora权重后导出,指定lora权重的目录使用`--lora_path` +- 制定量化bit数使用`--quant_bit`;量化的block大小使用`--quant_block` +- 使用`--lm_quant_bit`来制定lm_head层权重的量化bit数,不指定则使用`--quant_bit`的量化bit数 +- 支持使用自己编译的`MNNConvert`,使用`--mnnconvert` ### 参数 ``` -usage: llm_export.py [-h] --path PATH - [--type {chatglm-6b,chatglm2-6b,chatglm3-6b,codegeex2-6b,Qwen-7B-Chat,Qwen-1_8B-Chat,Qwen-1_8B,Qwen-VL-Chat,Qwen1_5-0_5B-Chat,Qwen1_5-1_8B-Chat,Qwen1_5-4B-Chat,Qwen1_5-7B-Chat,Qwen2-1_5B-Instruct,Baichuan2-7B-Chat,Llama-2-7b-chat-ms,Llama-3-8B-Instruct,internlm-chat-7b,TinyLlama-1_1B-Chat,Yi-6B-Chat,deepseek-llm-7b-chat,phi-2,bge-large-zh,lora}] - [--lora_path LORA_PATH] [--onnx_path ONNX_PATH] [--mnn_path MNN_PATH] [--export_mnn] [--export_verbose] [--export_test] [--test TEST] [--export] [--export_split] [--export_token] - [--export_embed] [--export_visual] [--export_lm] [--export_block EXPORT_BLOCK] [--export_blocks] [--embed_bin] [--embed_bf16] [--skip_slim] +usage: llmexport.py [-h] --path PATH [--type TYPE] [--lora_path LORA_PATH] [--dst_path DST_PATH] [--test TEST] [--export EXPORT] + [--skip_slim] [--quant_bit QUANT_BIT] [--quant_block QUANT_BLOCK] [--lm_quant_bit LM_QUANT_BIT] + [--mnnconvert MNNCONVERT] llm_exporter @@ -77,33 +68,22 @@ options: Can be either: - A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO] - A path to a *directory* clone from repo like `../chatglm-6b`. - --type {chatglm-6b,chatglm2-6b,chatglm3-6b,codegeex2-6b,Qwen-7B-Chat,Qwen-1_8B-Chat,Qwen-1_8B,Qwen-VL-Chat,Qwen1_5-0_5B-Chat,Qwen1_5-1_8B-Chat,Qwen1_5-4B-Chat,Qwen1_5-7B-Chat,Qwen2-1_5B-Instruct,Baichuan2-7B-Chat,Llama-2-7b-chat-ms,Llama-3-8B-Instruct,internlm-chat-7b,TinyLlama-1_1B-Chat,Yi-6B-Chat,deepseek-llm-7b-chat,phi-2,bge-large-zh,lora} - type(`str`, *optional*): + --type TYPE type(`str`, *optional*): The pretrain llm model type. --lora_path LORA_PATH lora path, defaut is `None` mean not apply lora. - --onnx_path ONNX_PATH - export onnx model path, defaut is `./onnx`. - --mnn_path MNN_PATH export mnn model path, defaut is `./mnn`. - --export_mnn Whether or not to export mnn model after onnx. - --export_verbose Whether or not to export onnx with verbose. - --export_test Whether or not to export onnx with test using onnxruntime. + --dst_path DST_PATH export onnx/mnn model to path, defaut is `./model`. --test TEST test model inference with query `TEST`. - --export export model to an `onnx` model. - --export_split export model split to some `onnx` models: - - embedding model. - - block models. - - lm_head model. - --export_token export llm tokenizer to a txt file. - --export_embed export llm embedding to an `onnx` model. - --export_visual export llm visual model to an `onnx` model. - --export_lm export llm lm_head to an `onnx` model. - --export_block EXPORT_BLOCK - export llm block [id] to an `onnx` model. - --export_blocks export llm all blocks to `onnx` models. - --embed_bin export embedding weight as bin file with dtype `bfloat16` - --embed_bf16 using `bfloat16` replace `float32` in embedding. + --export EXPORT export model to an onnx/mnn model. --skip_slim Whether or not to skip onnx-slim. + --quant_bit QUANT_BIT + mnn quant bit, 4 or 8, default is 4. + --quant_block QUANT_BLOCK + mnn quant block, default is 0 mean channle-wise. + --lm_quant_bit LM_QUANT_BIT + mnn lm_head quant bit, 4 or 8, default is `quant_bit`. + --mnnconvert MNNCONVERT + local mnnconvert path, if invalid, using pymnn. ``` ## 模型推理 @@ -111,6 +91,37 @@ options: ### 编译 [从源码编译](../compile/other.html#id4) +在原有编译过程中增加必需编译宏即可: -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true + +- mac / linux / windows + +以 mac / linux 为例 : +``` +make build +cd build +cmake ../ -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true +make -j16 +``` + +x86架构额外加 MNN_AVX512 的宏: +``` +make build +cd build +cmake ../ -DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_AVX512=true +make -j16 +``` + +- Android:额外增加 MNN_ARM82 的宏 +``` +cd project/android +mkdir build_64 +../build_64.sh "-DMNN_LOW_MEMORY=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true -DMNN_BUILD_LLM=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_ARM82=true" +``` + +- iOS: 参考 transformers/llm/engine/ios/README.md +``` +sh package_scripts/ios/buildiOS.sh "-DMNN_ARM82=true -DMNN_LOW_MEMORY=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_BUILD_LLM=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true" +``` ### 使用 #### 运行时配置 @@ -144,11 +155,16 @@ options: - 推理配置 - max_new_tokens: 生成时最大token数,默认为`512` - reuse_kv: 多轮对话时是否复用之前对话的`kv cache`,默认为`false` - - quant_kv: 存储`kv cache`时是否量化,可选为:`0, 1, 2, 3`,默认为`0`,含义如下: + - quant_qkv: CPU attention 算子中`query, key, value`是否量化,可选为:`0, 1, 2, 3, 4`,默认为`0`,含义如下: - 0: key和value都不量化 - 1: 使用非对称8bit量化存储key - - 2: 使用fp8格式寸处value - - 3: 使用非对称8bit量化存储key,使用fp8格式寸处value + - 2: 使用fp8格式量化存储value + - 3: 使用非对称8bit量化存储key,使用fp8格式量化存储value + - 4: 量化kv的同时使用非对称8bit量化query,并使用int8矩阵乘计算Q*K + - use_mmap: 是否使用mmap方式,在内存不足时将权重写入磁盘,避免溢出,默认为false,手机上建议设成true + - kvcache_mmap: 是否使用mmap方式,在内存不足时将在KV Cache 写入磁盘,避免溢出,默认为false + - tmp_path: 启用 mmap 相关功能时,写入磁盘的缓存目录 + - iOS 上可用如下语句创建临时目录并设置:`NSString *tempDirectory = NSTemporaryDirectory();llm->set_config("{\"tmp_path\":\"" + std::string([tempDirectory UTF8String]) + "\"}")` - 硬件配置 - backend_type: 推理使用硬件后端类型,默认为:`"cpu"` - thread_num: CPU推理使用硬件线程数,默认为:`4`; OpenCL推理时使用`68` @@ -266,4 +282,4 @@ options: thread1.join(); thread2.join(); } - ``` \ No newline at end of file + ``` diff --git a/express/Executor.cpp b/express/Executor.cpp index 437d72df6..5f6a6dd48 100644 --- a/express/Executor.cpp +++ b/express/Executor.cpp @@ -154,9 +154,8 @@ std::shared_ptr Executor::getGlobalExecutor() { RuntimeHint hint; hint.memoryAllocatorType = 0;// Defer bn->setRuntimeHint(hint); - static std::shared_ptr executorStatic; - executorStatic.reset(new Executor(bn, MNN_FORWARD_CPU, 1)); - gExecutor = &executorStatic; + gExecutor = new std::shared_ptr; + gExecutor->reset(new Executor(bn, MNN_FORWARD_CPU, 1)); }); return *gExecutor; } diff --git a/express/module/Module.cpp b/express/module/Module.cpp index 4ba49c27a..a0976bd67 100644 --- a/express/module/Module.cpp +++ b/express/module/Module.cpp @@ -330,11 +330,17 @@ Module* Module::load(const std::vector& inputs, const std::vectorgetInside()->mExternalFile.empty()) { // Set Default externalFile rtMgr->setExternalFile(std::string(fileName) + ".weight"); + needReset = true; } - return loadInternal(inputs, outputs, buffer.get(), buffer.size(), rtMgr, config); + auto res = loadInternal(inputs, outputs, buffer.get(), buffer.size(), rtMgr, config); + if (needReset) { + rtMgr->setExternalFile(""); + } + return res; } Module* Module::load(const std::vector& inputs, const std::vector& outputs, const uint8_t* buffer, size_t length, const std::shared_ptr _rtMgr, const Module::Config* config) { diff --git a/express/module/StaticModule.cpp b/express/module/StaticModule.cpp index 986185534..31a07c632 100644 --- a/express/module/StaticModule.cpp +++ b/express/module/StaticModule.cpp @@ -33,7 +33,7 @@ static const StaticModule* getStaticModule(const Module* m) { } static std::vector> preRearrangeWeights( // NOLINT - Schedule::ScheduleInfo& scheduleInfo, Backend* backend, Backend* backupBackend, const Module* base = nullptr) { + Schedule::ScheduleInfo& scheduleInfo, Backend* firstbackend, Backend* backupBackend, const Module* base = nullptr) { std::map> base_executions; if (base != nullptr) { // has base module @@ -59,6 +59,10 @@ static std::vector> preRearrangeWeights( // NOLIN auto op = pipelineInfo[i].op; std::unique_ptr op_table(op->UnPack()); std::shared_ptr exe; + Backend* backend = firstbackend; + if (info.type == Schedule::CONSTANT) { + backend = backupBackend; + } switch (op->type()) { case MNN::OpType_DepthwiseConvInt8: case MNN::OpType_ConvInt8: @@ -304,20 +308,8 @@ StaticModule::StaticModule(std::vector inputs, std::map, DataType>> exeCache; MNN_ASSERT(1 == scheduleInfo.pipelineInfo.size()); auto& bnCache = scheduleInfo.pipelineInfo[0].first; - bnCache.cache.first.reset(rt.first[bnCache.info.type]->onCreate(bnCache.info.user)); - if (bnCache.cache.first->type() == MNN_FORWARD_CPU) { - bnCache.cache.second = bnCache.cache.first; - } else { - // Use Multi-thread if user has set numberthread > 1 - BackendConfig defaultConfig; - defaultConfig.flags = 4; - auto cpurt = rt.first.find(MNN_FORWARD_CPU); - if (cpurt != rt.first.end()) { - bnCache.cache.second.reset(cpurt->second->onCreate(&defaultConfig)); - } else { - bnCache.cache.second.reset(rt.second->onCreate(&defaultConfig)); - } - } + // Create Backend for prearrange + Session::createPipelineBackend(scheduleInfo.pipelineInfo[0], rt); if (config.rearrange) { mResource->mBuffer = preRearrangeWeights(scheduleInfo, bnCache.cache.first.get(), bnCache.cache.second.get(), config.base); } else { diff --git a/include/MNN/Interpreter.hpp b/include/MNN/Interpreter.hpp index bac8fb341..edeceb296 100644 --- a/include/MNN/Interpreter.hpp +++ b/include/MNN/Interpreter.hpp @@ -224,11 +224,12 @@ class MNN_PUBLIC Interpreter { // Default is 50 CPU_LITTLECORE_DECREASE_RATE = 6, - // 0: Do not quantize kvcache, just store float - // 1: Only quantize key cache, use int8 asymmetric quantization - // 2: Only quantize value cache, use fp8 quantization - // 3: quantize both key and value cache as described above - KVCACHE_QUANT_OPTIONS = 7, + // 0: Do not quantize + // 1: Only quantize key, use int8 asymmetric quantization + // 2: Only quantize value, use fp8 quantization + // 3: quantize both key and value + // 4: quantize query, key and value, and use gemm int8 kernel to compute K*V + QKV_QUANT_OPTIONS = 7, // size limit of kvcache in memory (for a single layer) // if the size of kvcache exceeds the limit, it will be moved to disk @@ -238,6 +239,12 @@ class MNN_PUBLIC Interpreter { enum ExternalPathType { // Path of the kvcache directory EXTERNAL_PATH_KVCACHE_DIR = 0, + + // Mid Buffer Cache File + EXTERNAL_FEATUREMAP_DIR = 1, + + // Weight Buffer Cache File + EXTERNAL_WEIGHT_DIR = 2, // Other types ... }; diff --git a/include/MNN/MNNDefine.h b/include/MNN/MNNDefine.h index 215939a99..8a0af32de 100644 --- a/include/MNN/MNNDefine.h +++ b/include/MNN/MNNDefine.h @@ -69,6 +69,6 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \ #define STR(x) STR_IMP(x) #define MNN_VERSION_MAJOR 2 #define MNN_VERSION_MINOR 9 -#define MNN_VERSION_PATCH 4 +#define MNN_VERSION_PATCH 5 #define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH) #endif /* MNNDefine_h */ diff --git a/package_scripts/ios/buildiOS.sh b/package_scripts/ios/buildiOS.sh index 3722f1f61..0f0942d31 100755 --- a/package_scripts/ios/buildiOS.sh +++ b/package_scripts/ios/buildiOS.sh @@ -12,31 +12,12 @@ cd Static rm -rf ios_64 mkdir ios_64 cd ios_64 -cmake ../../../ -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../../../cmake/ios.toolchain.cmake -DMNN_METAL=ON -DARCHS="arm64" -DENABLE_BITCODE=0 -DMNN_AAPL_FMWK=1 -DMNN_SEP_BUILD=0 -DMNN_ARM82=true -DMNN_BUILD_SHARED_LIBS=false $1 +cmake ../../../ -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../../../cmake/ios.toolchain.cmake -DMNN_METAL=ON -DARCHS="arm64" -DENABLE_BITCODE=0 -DMNN_AAPL_FMWK=1 -DMNN_SEP_BUILD=0 -DMNN_ARM82=true -DMNN_BUILD_SHARED_LIBS=false -DMNN_USE_THREAD_POOL=OFF $1 echo "Building AArch64" make MNN -j16 echo "End Building AArch64" cd ../ -rm -rf ios_32 -mkdir ios_32 -cd ios_32 -cmake ../../../ -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../../../cmake/ios.toolchain.cmake -DMNN_METAL=ON -DARCHS="armv7;armv7s" -DENABLE_BITCODE=0 -DMNN_AAPL_FMWK=1 -DMNN_SEP_BUILD=0 -DMNN_BUILD_SHARED_LIBS=false $1 -echo "Building AArch32" -make MNN -j16 -echo "End Building AArch32" -cd ../ - -find ios_32 -name "MNN*framework" -find ios_64 -name "MNN*framework" - -mv ios_32/MNN.framework/MNN ios_32/MNN.framework/MNN_32 +mv ios_64/MNN.framework MNN.framework -echo "Creating Fat Binary" -lipo -create ios_32/MNN.framework/MNN_32 ios_64/MNN.framework/MNN -output ios_32/MNN.framework/MNN -rm ios_32/MNN.framework/MNN_32 -echo "Patching Framework Headers" -rm -rf ./MNN.framework -cp -R ios_32/MNN.framework ./MNN.framework -rm -rf ios_32 rm -rf ios_64 diff --git a/package_scripts/ios/buildiOS_with_armv7.sh b/package_scripts/ios/buildiOS_with_armv7.sh new file mode 100755 index 000000000..ea5851791 --- /dev/null +++ b/package_scripts/ios/buildiOS_with_armv7.sh @@ -0,0 +1,42 @@ +#!/bin/sh +echo "Change directory to MNN_SOURCE_ROOT/project/ios before running this script" +echo "Current PWD: ${PWD}" + +rm -rf MNN-iOS-CPU-GPU +mkdir MNN-iOS-CPU-GPU +cd MNN-iOS-CPU-GPU +# Static Begin +mkdir Static +cd Static + +rm -rf ios_64 +mkdir ios_64 +cd ios_64 +cmake ../../../ -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../../../cmake/ios.toolchain.cmake -DMNN_METAL=ON -DARCHS="arm64" -DENABLE_BITCODE=0 -DMNN_AAPL_FMWK=1 -DMNN_SEP_BUILD=0 -DMNN_ARM82=true -DMNN_BUILD_SHARED_LIBS=false -DMNN_USE_THREAD_POOL=OFF $1 +echo "Building AArch64" +make MNN -j16 +echo "End Building AArch64" +cd ../ + +rm -rf ios_32 +mkdir ios_32 +cd ios_32 +cmake ../../../ -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../../../cmake/ios.toolchain.cmake -DMNN_METAL=ON -DARCHS="armv7;armv7s" -DENABLE_BITCODE=0 -DMNN_AAPL_FMWK=1 -DMNN_SEP_BUILD=0 -DMNN_BUILD_SHARED_LIBS=false -DMNN_USE_THREAD_POOL=OFF $1 +echo "Building AArch32" +make MNN -j16 +echo "End Building AArch32" +cd ../ + +find ios_32 -name "MNN*framework" +find ios_64 -name "MNN*framework" + +mv ios_32/MNN.framework/MNN ios_32/MNN.framework/MNN_32 + +echo "Creating Fat Binary" +lipo -create ios_32/MNN.framework/MNN_32 ios_64/MNN.framework/MNN -output ios_32/MNN.framework/MNN +rm ios_32/MNN.framework/MNN_32 +echo "Patching Framework Headers" +rm -rf ./MNN.framework +cp -R ios_32/MNN.framework ./MNN.framework +rm -rf ios_32 +rm -rf ios_64 diff --git a/package_scripts/mac/buildFrameWork.sh b/package_scripts/mac/buildFrameWork.sh index e0810e07e..7f955ed89 100755 --- a/package_scripts/mac/buildFrameWork.sh +++ b/package_scripts/mac/buildFrameWork.sh @@ -18,7 +18,7 @@ cd Static # ARM mkdir mac_a64 cd mac_a64 -cmake ../../../ -DMNN_USE_SSE=OFF -DCMAKE_BUILD_TYPE=Release -DMNN_OPENCL=ON -DMNN_METAL=ON -DARCHS="arm64" -DMNN_AAPL_FMWK=ON -DMNN_SEP_BUILD=OFF -DMNN_ARM82=ON -DCMAKE_OSX_ARCHITECTURES=arm64 -DMNN_BUILD_SHARED_LIBS=OFF $1 +cmake ../../../ -DMNN_USE_SSE=OFF -DCMAKE_BUILD_TYPE=Release -DMNN_OPENCL=ON -DMNN_USE_THREAD_POOL=OFF -DMNN_METAL=ON -DARCHS="arm64" -DMNN_AAPL_FMWK=ON -DMNN_SEP_BUILD=OFF -DMNN_ARM82=ON -DCMAKE_OSX_ARCHITECTURES=arm64 -DMNN_BUILD_SHARED_LIBS=OFF $1 echo "Building ARM64" make MNN -j16 echo "End Building ARM64" @@ -27,7 +27,7 @@ cd ../ # X86 mkdir mac_x64 cd mac_x64 -cmake ../../../ -DCMAKE_BUILD_TYPE=Release -DMNN_OPENCL=ON -DMNN_METAL=ON -DARCHS="x86_64" -DMNN_AAPL_FMWK=ON -DMNN_SEP_BUILD=OFF -DCMAKE_OSX_ARCHITECTURES=x86_64 -DMNN_BUILD_SHARED_LIBS=OFF $1 +cmake ../../../ -DCMAKE_BUILD_TYPE=Release -DMNN_OPENCL=ON -DMNN_USE_THREAD_POOL=OFF -DMNN_METAL=ON -DARCHS="x86_64" -DMNN_AAPL_FMWK=ON -DMNN_SEP_BUILD=OFF -DCMAKE_OSX_ARCHITECTURES=x86_64 -DMNN_BUILD_SHARED_LIBS=OFF $1 echo "Building x86" make MNN -j16 echo "End Building x86" @@ -52,7 +52,7 @@ cd Dynamic # ARM mkdir mac_a64 cd mac_a64 -cmake ../../../ -DMNN_USE_SSE=OFF -DCMAKE_BUILD_TYPE=Release -DMNN_OPENCL=ON -DMNN_METAL=ON -DARCHS="arm64" -DMNN_AAPL_FMWK=ON -DMNN_SEP_BUILD=OFF -DMNN_ARM82=ON -DCMAKE_OSX_ARCHITECTURES=arm64 $1 +cmake ../../../ -DMNN_USE_SSE=OFF -DCMAKE_BUILD_TYPE=Release -DMNN_OPENCL=ON -DMNN_USE_THREAD_POOL=OFF -DMNN_METAL=ON -DARCHS="arm64" -DMNN_AAPL_FMWK=ON -DMNN_SEP_BUILD=OFF -DMNN_ARM82=ON -DCMAKE_OSX_ARCHITECTURES=arm64 $1 echo "Building ARM64" make MNN -j16 echo "End Building ARM64" @@ -61,7 +61,7 @@ cd ../ # X86 mkdir mac_x64 cd mac_x64 -cmake ../../../ -DCMAKE_BUILD_TYPE=Release -DMNN_OPENCL=ON -DMNN_METAL=ON -DARCHS="x86_64" -DMNN_AAPL_FMWK=ON -DMNN_SEP_BUILD=OFF -DCMAKE_OSX_ARCHITECTURES=x86_64 $1 +cmake ../../../ -DCMAKE_BUILD_TYPE=Release -DMNN_OPENCL=ON -DMNN_USE_THREAD_POOL=OFF -DMNN_METAL=ON -DARCHS="x86_64" -DMNN_AAPL_FMWK=ON -DMNN_SEP_BUILD=OFF -DCMAKE_OSX_ARCHITECTURES=x86_64 $1 echo "Building x86" make MNN -j16 echo "End Building x86" diff --git a/project/android/build_32.sh b/project/android/build_32.sh index e83655009..24f0eb8cc 100755 --- a/project/android/build_32.sh +++ b/project/android/build_32.sh @@ -4,7 +4,6 @@ cmake ../../../ \ -DCMAKE_BUILD_TYPE=Release \ -DANDROID_ABI="armeabi-v7a" \ -DANDROID_STL=c++_static \ --DCMAKE_BUILD_TYPE=Release \ -DANDROID_NATIVE_API_LEVEL=android-14 \ -DANDROID_TOOLCHAIN=clang \ -DMNN_USE_LOGCAT=false \ diff --git a/project/ios/MNN.xcodeproj/project.pbxproj b/project/ios/MNN.xcodeproj/project.pbxproj index f576703bf..535f50d27 100644 --- a/project/ios/MNN.xcodeproj/project.pbxproj +++ b/project/ios/MNN.xcodeproj/project.pbxproj @@ -771,6 +771,25 @@ C4F906B327688C3A0026B847 /* NMSModule.hpp in Headers */ = {isa = PBXBuildFile; fileRef = C4F906B127688C3A0026B847 /* NMSModule.hpp */; }; C4F906B427688C3A0026B847 /* NMSModule.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C4F906B227688C3A0026B847 /* NMSModule.cpp */; }; C4FB6CB22769DF0800963B07 /* GeometryCumSum.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C4FB6CB12769DF0800963B07 /* GeometryCumSum.cpp */; }; + CE072A132C91AEE700F190FD /* MNNBGRToBGR555.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A032C91AEE700F190FD /* MNNBGRToBGR555.S */; }; + CE072A142C91AEE700F190FD /* MNNBGRAToGRAY.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A042C91AEE700F190FD /* MNNBGRAToGRAY.S */; }; + CE072A152C91AEE700F190FD /* MNNRGBAToGRAYFast.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A052C91AEE700F190FD /* MNNRGBAToGRAYFast.S */; }; + CE072A162C91AEE700F190FD /* MNNBGRAToBGR.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A062C91AEE700F190FD /* MNNBGRAToBGR.S */; }; + CE072A172C91AEE700F190FD /* MNNSamplerC3BilinearOpt.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A072C91AEE700F190FD /* MNNSamplerC3BilinearOpt.S */; }; + CE072A182C91AEE700F190FD /* MNNGRAYToC4Fast.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A082C91AEE700F190FD /* MNNGRAYToC4Fast.S */; }; + CE072A192C91AEE700F190FD /* MNNBGRToGRAY.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A092C91AEE700F190FD /* MNNBGRToGRAY.S */; }; + CE072A1A2C91AEE700F190FD /* MNNRGBToGRAYFast.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A0A2C91AEE700F190FD /* MNNRGBToGRAYFast.S */; }; + CE072A1B2C91AEE700F190FD /* MNNBGRToBGR565.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A0B2C91AEE700F190FD /* MNNBGRToBGR565.S */; }; + CE072A1C2C91AEE700F190FD /* MNNRGBAToBGRFast.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A0C2C91AEE700F190FD /* MNNRGBAToBGRFast.S */; }; + CE072A1D2C91AEE700F190FD /* MNNRGBAToBGRAFast.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A0D2C91AEE700F190FD /* MNNRGBAToBGRAFast.S */; }; + CE072A1E2C91AEE700F190FD /* MNNRGBToBGR555.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A0E2C91AEE700F190FD /* MNNRGBToBGR555.S */; }; + CE072A1F2C91AEE700F190FD /* MNNRGBToBGR.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A0F2C91AEE700F190FD /* MNNRGBToBGR.S */; }; + CE072A202C91AEE700F190FD /* MNNGRAYToC3Fast.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A102C91AEE700F190FD /* MNNGRAYToC3Fast.S */; }; + CE072A212C91AEE700F190FD /* MNNRGBToBGR565.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A112C91AEE700F190FD /* MNNRGBToBGR565.S */; }; + CE072A222C91AEE700F190FD /* MNNPackC2.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A122C91AEE700F190FD /* MNNPackC2.S */; }; + CE072A262C91AF0700F190FD /* MNNC3ToYUVFast.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A232C91AF0700F190FD /* MNNC3ToYUVFast.S */; }; + CE072A272C91AF0700F190FD /* MNNC3ToC4Fast.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A242C91AF0700F190FD /* MNNC3ToC4Fast.S */; }; + CE072A282C91AF0700F190FD /* MNNC3ToXYZFast.S in Sources */ = {isa = PBXBuildFile; fileRef = CE072A252C91AF0700F190FD /* MNNC3ToXYZFast.S */; }; CE125CC82A52BF6B003698C9 /* MNNBilinearSampleC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CE125CC62A52BF6B003698C9 /* MNNBilinearSampleC8.S */; }; CE125CC92A52BF6B003698C9 /* MNNBilinearLineC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CE125CC72A52BF6B003698C9 /* MNNBilinearLineC8.S */; }; CE7DC00028E2DE6B00797689 /* ShapeConvTranspose3D.cpp in Sources */ = {isa = PBXBuildFile; fileRef = CE7DBFFF28E2DE6B00797689 /* ShapeConvTranspose3D.cpp */; }; @@ -805,6 +824,8 @@ CEE9B95B2A3AA4D4006438F2 /* MNNBilinearLineC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CEE9B9572A3AA4D4006438F2 /* MNNBilinearLineC8.S */; }; CEE9B95C2A3AA4D4006438F2 /* MNNBilinearSampleC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CEE9B9582A3AA4D4006438F2 /* MNNBilinearSampleC8.S */; }; CEE9B95D2A3AA4D4006438F2 /* MNNCubicSampleC16.S in Sources */ = {isa = PBXBuildFile; fileRef = CEE9B9592A3AA4D4006438F2 /* MNNCubicSampleC16.S */; }; + CEEDB5542C7475A100FED0DC /* MNNFileUtils.h in Headers */ = {isa = PBXBuildFile; fileRef = CEEDB5522C7475A100FED0DC /* MNNFileUtils.h */; }; + CEEDB5552C7475A100FED0DC /* MNNFileUtils.cpp in Sources */ = {isa = PBXBuildFile; fileRef = CEEDB5532C7475A100FED0DC /* MNNFileUtils.cpp */; }; EB45C774244D7C4F00E28F44 /* MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S in Sources */ = {isa = PBXBuildFile; fileRef = EB45C773244D7C4F00E28F44 /* MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S */; }; EB45C776244D7C6600E28F44 /* MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S in Sources */ = {isa = PBXBuildFile; fileRef = EB45C775244D7C6600E28F44 /* MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S */; }; EB8D2ABE246A4975009948D1 /* Arm82OpRegister.cpp in Sources */ = {isa = PBXBuildFile; fileRef = EB8D2ABD246A4975009948D1 /* Arm82OpRegister.cpp */; }; @@ -1607,6 +1628,25 @@ C4F906B127688C3A0026B847 /* NMSModule.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = NMSModule.hpp; sourceTree = ""; }; C4F906B227688C3A0026B847 /* NMSModule.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = NMSModule.cpp; sourceTree = ""; }; C4FB6CB12769DF0800963B07 /* GeometryCumSum.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = GeometryCumSum.cpp; sourceTree = ""; }; + CE072A032C91AEE700F190FD /* MNNBGRToBGR555.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNBGRToBGR555.S; path = arm/arm64/MNNBGRToBGR555.S; sourceTree = ""; }; + CE072A042C91AEE700F190FD /* MNNBGRAToGRAY.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNBGRAToGRAY.S; path = arm/arm64/MNNBGRAToGRAY.S; sourceTree = ""; }; + CE072A052C91AEE700F190FD /* MNNRGBAToGRAYFast.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNRGBAToGRAYFast.S; path = arm/arm64/MNNRGBAToGRAYFast.S; sourceTree = ""; }; + CE072A062C91AEE700F190FD /* MNNBGRAToBGR.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNBGRAToBGR.S; path = arm/arm64/MNNBGRAToBGR.S; sourceTree = ""; }; + CE072A072C91AEE700F190FD /* MNNSamplerC3BilinearOpt.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNSamplerC3BilinearOpt.S; path = arm/arm64/MNNSamplerC3BilinearOpt.S; sourceTree = ""; }; + CE072A082C91AEE700F190FD /* MNNGRAYToC4Fast.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNGRAYToC4Fast.S; path = arm/arm64/MNNGRAYToC4Fast.S; sourceTree = ""; }; + CE072A092C91AEE700F190FD /* MNNBGRToGRAY.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNBGRToGRAY.S; path = arm/arm64/MNNBGRToGRAY.S; sourceTree = ""; }; + CE072A0A2C91AEE700F190FD /* MNNRGBToGRAYFast.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNRGBToGRAYFast.S; path = arm/arm64/MNNRGBToGRAYFast.S; sourceTree = ""; }; + CE072A0B2C91AEE700F190FD /* MNNBGRToBGR565.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNBGRToBGR565.S; path = arm/arm64/MNNBGRToBGR565.S; sourceTree = ""; }; + CE072A0C2C91AEE700F190FD /* MNNRGBAToBGRFast.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNRGBAToBGRFast.S; path = arm/arm64/MNNRGBAToBGRFast.S; sourceTree = ""; }; + CE072A0D2C91AEE700F190FD /* MNNRGBAToBGRAFast.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNRGBAToBGRAFast.S; path = arm/arm64/MNNRGBAToBGRAFast.S; sourceTree = ""; }; + CE072A0E2C91AEE700F190FD /* MNNRGBToBGR555.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNRGBToBGR555.S; path = arm/arm64/MNNRGBToBGR555.S; sourceTree = ""; }; + CE072A0F2C91AEE700F190FD /* MNNRGBToBGR.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNRGBToBGR.S; path = arm/arm64/MNNRGBToBGR.S; sourceTree = ""; }; + CE072A102C91AEE700F190FD /* MNNGRAYToC3Fast.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNGRAYToC3Fast.S; path = arm/arm64/MNNGRAYToC3Fast.S; sourceTree = ""; }; + CE072A112C91AEE700F190FD /* MNNRGBToBGR565.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNRGBToBGR565.S; path = arm/arm64/MNNRGBToBGR565.S; sourceTree = ""; }; + CE072A122C91AEE700F190FD /* MNNPackC2.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNPackC2.S; path = arm/arm64/MNNPackC2.S; sourceTree = ""; }; + CE072A232C91AF0700F190FD /* MNNC3ToYUVFast.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNC3ToYUVFast.S; path = arm/arm64/MNNC3ToYUVFast.S; sourceTree = ""; }; + CE072A242C91AF0700F190FD /* MNNC3ToC4Fast.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNC3ToC4Fast.S; path = arm/arm64/MNNC3ToC4Fast.S; sourceTree = ""; }; + CE072A252C91AF0700F190FD /* MNNC3ToXYZFast.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; name = MNNC3ToXYZFast.S; path = arm/arm64/MNNC3ToXYZFast.S; sourceTree = ""; }; CE125CC62A52BF6B003698C9 /* MNNBilinearSampleC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearSampleC8.S; sourceTree = ""; }; CE125CC72A52BF6B003698C9 /* MNNBilinearLineC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearLineC8.S; sourceTree = ""; }; CE7DBFFF28E2DE6B00797689 /* ShapeConvTranspose3D.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ShapeConvTranspose3D.cpp; sourceTree = ""; }; @@ -1643,6 +1683,8 @@ CEE9B9572A3AA4D4006438F2 /* MNNBilinearLineC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearLineC8.S; sourceTree = ""; }; CEE9B9582A3AA4D4006438F2 /* MNNBilinearSampleC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearSampleC8.S; sourceTree = ""; }; CEE9B9592A3AA4D4006438F2 /* MNNCubicSampleC16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNCubicSampleC16.S; sourceTree = ""; }; + CEEDB5522C7475A100FED0DC /* MNNFileUtils.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = MNNFileUtils.h; sourceTree = ""; }; + CEEDB5532C7475A100FED0DC /* MNNFileUtils.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = MNNFileUtils.cpp; sourceTree = ""; }; EB45C773244D7C4F00E28F44 /* MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S; sourceTree = ""; }; EB45C775244D7C6600E28F44 /* MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S; sourceTree = ""; }; EB8D2ABD246A4975009948D1 /* Arm82OpRegister.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = Arm82OpRegister.cpp; path = ../arm82/Arm82OpRegister.cpp; sourceTree = ""; }; @@ -1878,6 +1920,8 @@ 488873AC215B639D0079B12E /* core */ = { isa = PBXGroup; children = ( + CEEDB5532C7475A100FED0DC /* MNNFileUtils.cpp */, + CEEDB5522C7475A100FED0DC /* MNNFileUtils.h */, 48C84B9B250F722B00EE7666 /* Command.hpp */, 4819FB1524C138DF0050BD09 /* GeometryConvUtils.cpp */, 4819FB1324C138DF0050BD09 /* GeometryConvUtils.hpp */, @@ -1921,6 +1965,25 @@ 48887410215B639D0079B12E /* cpu */ = { isa = PBXGroup; children = ( + CE072A242C91AF0700F190FD /* MNNC3ToC4Fast.S */, + CE072A252C91AF0700F190FD /* MNNC3ToXYZFast.S */, + CE072A232C91AF0700F190FD /* MNNC3ToYUVFast.S */, + CE072A062C91AEE700F190FD /* MNNBGRAToBGR.S */, + CE072A042C91AEE700F190FD /* MNNBGRAToGRAY.S */, + CE072A032C91AEE700F190FD /* MNNBGRToBGR555.S */, + CE072A0B2C91AEE700F190FD /* MNNBGRToBGR565.S */, + CE072A092C91AEE700F190FD /* MNNBGRToGRAY.S */, + CE072A102C91AEE700F190FD /* MNNGRAYToC3Fast.S */, + CE072A082C91AEE700F190FD /* MNNGRAYToC4Fast.S */, + CE072A122C91AEE700F190FD /* MNNPackC2.S */, + CE072A0D2C91AEE700F190FD /* MNNRGBAToBGRAFast.S */, + CE072A0C2C91AEE700F190FD /* MNNRGBAToBGRFast.S */, + CE072A052C91AEE700F190FD /* MNNRGBAToGRAYFast.S */, + CE072A0F2C91AEE700F190FD /* MNNRGBToBGR.S */, + CE072A0E2C91AEE700F190FD /* MNNRGBToBGR555.S */, + CE072A112C91AEE700F190FD /* MNNRGBToBGR565.S */, + CE072A0A2C91AEE700F190FD /* MNNRGBToGRAYFast.S */, + CE072A072C91AEE700F190FD /* MNNSamplerC3BilinearOpt.S */, CEE4566A2BC0E23D00F062C1 /* CPUExternalConst.cpp */, 95278CE62B9F0999009E9B29 /* CPUDynamicQuant.cpp */, 95278CE52B9F0999009E9B29 /* CPUDynamicQuant.hpp */, @@ -2969,6 +3032,7 @@ 489D7A982550FDC900AD896A /* MNNMetalContext.h in Headers */, 952298B82B4D4CC80043978B /* coreMLLayerNorm.hpp in Headers */, 92FF029323AA0B5A00AC97F6 /* CPURange.hpp in Headers */, + CEEDB5542C7475A100FED0DC /* MNNFileUtils.h in Headers */, 4D9A937526255BDA00F9B43C /* CoreMLCommonExecution.hpp in Headers */, 4DF87C522887D3F20003E2D4 /* CPUSvd.hpp in Headers */, 48747D4B245D9D24000B9709 /* RuntimeFactory.hpp in Headers */, @@ -3260,6 +3324,8 @@ 950B29002A0C9B4D0002F454 /* MNNScaleAndAddBiasInt8.S in Sources */, 92FF04BD23AA0BFB00AC97F6 /* Execution.cpp in Sources */, 92FF030A23AA0B5A00AC97F6 /* MNNLineDepthWiseInt8AddBiasScaleUnit.S in Sources */, + CE072A212C91AEE700F190FD /* MNNRGBToBGR565.S in Sources */, + CE072A282C91AF0700F190FD /* MNNC3ToXYZFast.S in Sources */, 92FF03B023AA0B5A00AC97F6 /* ConvolutionGroup.cpp in Sources */, 48FA474623AA127B00172C3B /* NeuralNetWorkOp.cpp in Sources */, 4D9A936E26255BDA00F9B43C /* CoreMLArgMax.cpp in Sources */, @@ -3270,6 +3336,7 @@ 48747D63245D9E33000B9709 /* GeometryPermute.cpp in Sources */, 92FF032C23AA0B5A00AC97F6 /* MNNWinogradMatrixProductRight.S in Sources */, 48BB6EF625220AA80056E195 /* MNNTranspose32Bit4x4.S in Sources */, + CE072A1C2C91AEE700F190FD /* MNNRGBAToBGRFast.S in Sources */, CEE9B95C2A3AA4D4006438F2 /* MNNBilinearSampleC8.S in Sources */, 48BB6EF025220A930056E195 /* MNNTranspose32Bit4x4.S in Sources */, 92FF031223AA0B5A00AC97F6 /* MNNMaxFloat.S in Sources */, @@ -3296,6 +3363,7 @@ 4D9A935F26255BDA00F9B43C /* NeuralNetwork.pb-c.c in Sources */, 4D0C80E32862FC4100C7CAD6 /* CoreMLOPRegister.cpp in Sources */, 92FF02BE23AA0B5A00AC97F6 /* MNNFloat2Int8.S in Sources */, + CE072A1A2C91AEE700F190FD /* MNNRGBToGRAYFast.S in Sources */, 4A224A0B27D0C2D9000A9260 /* ConvolutionPackFreeWinograd.cpp in Sources */, 48608B52250632EC00CB1D71 /* GeometryComputerUtils.cpp in Sources */, 489D7A8A2550FDC900AD896A /* MetalConvolutionDepthwise.mm in Sources */, @@ -3330,6 +3398,7 @@ 92FF042323AA0B7100AC97F6 /* ShapeScatterNd.cpp in Sources */, 92FF045A23AA0B7100AC97F6 /* ShapeBinaryOp.cpp in Sources */, 92FF02E523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */, + CE072A192C91AEE700F190FD /* MNNBGRToGRAY.S in Sources */, EBECA37B24643D110062C7A3 /* MNNGemmInt8AddBiasScale_ARMV82_Unit.S in Sources */, 481C2DF525FE2CD6001ED6DF /* Arm82OptFunc.cpp in Sources */, 92FF033623AA0B5A00AC97F6 /* MNNConvRunForUnitDepthWise.S in Sources */, @@ -3353,6 +3422,7 @@ 48747D6F245D9E33000B9709 /* GeometryConcat.cpp in Sources */, 4819FB3224C1396A0050BD09 /* GeometryReduce.cpp in Sources */, 950B28EF29F627F70002F454 /* MNNBinaryMaxInt8.S in Sources */, + CE072A132C91AEE700F190FD /* MNNBGRToBGR555.S in Sources */, 92FF02B023AA0B5A00AC97F6 /* CPUDequantize.cpp in Sources */, 92FF04C223AA0BFB00AC97F6 /* Pipeline.cpp in Sources */, 92FF04C423AA0BFB00AC97F6 /* Session.cpp in Sources */, @@ -3395,6 +3465,7 @@ 48958783268EBA7C00EA01A7 /* ShapeSegmentMean.cpp in Sources */, 48747D61245D9E33000B9709 /* ConvertUtils.cpp in Sources */, 92FF043B23AA0B7100AC97F6 /* ShapeDetectionPostProcess.cpp in Sources */, + CE072A1B2C91AEE700F190FD /* MNNBGRToBGR565.S in Sources */, 48417FF124D13BF50056D9A7 /* GeometryELU.cpp in Sources */, 48C84B9A250F720C00EE7666 /* CPULayerNorm.cpp in Sources */, 4DF87C4A2887D3560003E2D4 /* calib3d.cpp in Sources */, @@ -3449,6 +3520,7 @@ 92FF034223AA0B5A00AC97F6 /* CPUReduction.cpp in Sources */, 92FF02CF23AA0B5A00AC97F6 /* MNNMinFloat.S in Sources */, C4F906B0276886040026B847 /* GeometryTopK.cpp in Sources */, + CEEDB5552C7475A100FED0DC /* MNNFileUtils.cpp in Sources */, 48CA2F572681844C003A1796 /* MNNUnpackC8FP16.S in Sources */, 92FF030E23AA0B5A00AC97F6 /* MNNNV21ToRGBUnit.S in Sources */, 4837147225A599EC004DBDED /* Arm82Binary.cpp in Sources */, @@ -3473,6 +3545,7 @@ 4D9A936726255BDA00F9B43C /* CoreMLReduction.cpp in Sources */, 48F5881324DEA3F000C484A2 /* GeometryConv3D.cpp in Sources */, 4882C8BA241A22B800DAC168 /* OpCommonUtils.cpp in Sources */, + CE072A202C91AEE700F190FD /* MNNGRAYToC3Fast.S in Sources */, 92FF02B523AA0B5A00AC97F6 /* CPUTopKV2.cpp in Sources */, 92FF02BD23AA0B5A00AC97F6 /* MNNMatrixProd.S in Sources */, 489D7A872550FDC900AD896A /* MetalOPRegister.mm in Sources */, @@ -3536,17 +3609,21 @@ 4D759B2C25FF89EE0037B0B6 /* GeometryShape.cpp in Sources */, 11A01A07258785EA00745FA7 /* MNNVectorTop1Float.S in Sources */, 48747D6E245D9E33000B9709 /* GeometrySlice.cpp in Sources */, + CE072A272C91AF0700F190FD /* MNNC3ToC4Fast.S in Sources */, CECF8C7D299CAD9400D3875B /* md5.c in Sources */, 92FF041923AA0B7100AC97F6 /* ShapeQuantizedMaxPool.cpp in Sources */, 92FF038A23AA0B5A00AC97F6 /* CPURange.cpp in Sources */, + CE072A182C91AEE700F190FD /* MNNGRAYToC4Fast.S in Sources */, CE125CC92A52BF6B003698C9 /* MNNBilinearLineC8.S in Sources */, 92FF03A123AA0B5A00AC97F6 /* Int8FunctionsOpt.cpp in Sources */, + CE072A222C91AEE700F190FD /* MNNPackC2.S in Sources */, 92FF026523AA0B5A00AC97F6 /* CPUQuantizedAvgPool.cpp in Sources */, 92FF029423AA0B5A00AC97F6 /* CPUMatMul.cpp in Sources */, 48747D62245D9E33000B9709 /* GeometryOPRegister.cpp in Sources */, 4838EA8B2611C1310027232C /* ShapeGridSample.cpp in Sources */, 92FF03A323AA0B5A00AC97F6 /* ConvOpt.cpp in Sources */, 92FF02CD23AA0B5A00AC97F6 /* MNNNV21ToRGBUnit.S in Sources */, + CE072A172C91AEE700F190FD /* MNNSamplerC3BilinearOpt.S in Sources */, 92FF029A23AA0B5A00AC97F6 /* CPUQuantizedMaxPool.cpp in Sources */, 48F5881124DEA3F000C484A2 /* GeometryPooling3D.cpp in Sources */, 92FF042423AA0B7100AC97F6 /* ShapeROIPooling.cpp in Sources */, @@ -3569,11 +3646,13 @@ 92FF02B123AA0B5A00AC97F6 /* CPUBackend.cpp in Sources */, 4D9A936226255BDA00F9B43C /* FeatureTypes.pb-c.c in Sources */, 486E1A9924F5078D00C16006 /* CPURandomUniform.cpp in Sources */, + CE072A1F2C91AEE700F190FD /* MNNRGBToBGR.S in Sources */, 92FF02C823AA0B5A00AC97F6 /* MNNNV21ToBGRUnit.S in Sources */, 92FF045C23AA0B7100AC97F6 /* ShapeBroadcastTo.cpp in Sources */, 48747D49245D9D24000B9709 /* RuntimeFactory.cpp in Sources */, 92FF02AE23AA0B5A00AC97F6 /* CPUProposal.cpp in Sources */, 92FF042723AA0B7100AC97F6 /* ShapeMatMul.cpp in Sources */, + CE072A262C91AF0700F190FD /* MNNC3ToYUVFast.S in Sources */, 92FF042823AA0B7100AC97F6 /* ShapeInterp.cpp in Sources */, 92FF02D623AA0B5A00AC97F6 /* MNNConvRunForLineDepthWiseInt8.S in Sources */, 48FB9DCA24A848D0008E1A2D /* MNNAxByClampBroadcastC4.S in Sources */, @@ -3610,6 +3689,7 @@ CECF8C64299CAD8400D3875B /* LogHelper.mm in Sources */, 48FA474523AA127B00172C3B /* Executor.cpp in Sources */, 92FF02EA23AA0B5A00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S in Sources */, + CE072A162C91AEE700F190FD /* MNNBGRAToBGR.S in Sources */, 48A8A61A21D101DE00C2B9A7 /* Matrix_CV.cpp in Sources */, 4DDD8E102B1D70C1005065D1 /* MNNTranspose16Bit8x8.S in Sources */, 489D7A8C2550FDC900AD896A /* MetalDeconvolution.mm in Sources */, @@ -3659,6 +3739,7 @@ 48F9E54C2493511200E46522 /* MNNPackedMatMul.S in Sources */, C4D4824327BA67DE0021C2B9 /* GeometryDet.cpp in Sources */, 92FF026F23AA0B5A00AC97F6 /* CPUInt8ToFloat.cpp in Sources */, + CE072A142C91AEE700F190FD /* MNNBGRAToGRAY.S in Sources */, 92FF037E23AA0B5A00AC97F6 /* CPUDetectionPostProcess.cpp in Sources */, 4D4CF4682760946500A36D9F /* geometric.cpp in Sources */, 92FF045023AA0B7100AC97F6 /* ShapeCropAndResize.cpp in Sources */, @@ -3671,6 +3752,7 @@ 92FF032723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */, CE7DC00028E2DE6B00797689 /* ShapeConvTranspose3D.cpp in Sources */, CECF8C78299CAD9400D3875B /* log_util_imp.cpp in Sources */, + CE072A152C91AEE700F190FD /* MNNRGBAToGRAYFast.S in Sources */, 92FF02CA23AA0B5A00AC97F6 /* MNNUnPackC4.S in Sources */, 952298B22B4D39050043978B /* MetalLoop.mm in Sources */, 48925F372744AC2A00919B37 /* ShapeROIAlign.cpp in Sources */, @@ -3691,6 +3773,7 @@ 92FF045423AA0B7100AC97F6 /* ShapeRNNSequenceGRU.cpp in Sources */, 4896D37C25FE2A6B00717702 /* MNNConvDwF23SourceTransUnitFP16.S in Sources */, EB8D2ABE246A4975009948D1 /* Arm82OpRegister.cpp in Sources */, + CE072A1E2C91AEE700F190FD /* MNNRGBToBGR555.S in Sources */, 48C84B87250F711700EE7666 /* WhileModule.cpp in Sources */, 48608B51250632EC00CB1D71 /* GeometryComputer.cpp in Sources */, 92FF02FF23AA0B5A00AC97F6 /* MNNFloat2Int8.S in Sources */, @@ -3720,6 +3803,7 @@ 92FF03AD23AA0B5A00AC97F6 /* ConvolutionDepthwise3x3.cpp in Sources */, 92FF031723AA0B5A00AC97F6 /* MNNConvRunForLineDepthWiseInt8.S in Sources */, 4DD1793A2694076700B0098F /* MNNSoftmax.S in Sources */, + CE072A1D2C91AEE700F190FD /* MNNRGBAToBGRAFast.S in Sources */, 489D7A762550FDC800AD896A /* MetalReduction.mm in Sources */, 92FF032023AA0B5A00AC97F6 /* MNNMatrixSub.S in Sources */, C43C81FF251894BD00A0FF84 /* ThreadPool.cpp in Sources */, @@ -4101,7 +4185,7 @@ CODE_SIGN_STYLE = Automatic; DEAD_CODE_STRIPPING = YES; DEFINES_MODULE = YES; - DEVELOPMENT_TEAM = Q48UX93J22; + DEVELOPMENT_TEAM = 6G7464HHUS; DYLIB_COMPATIBILITY_VERSION = 1; DYLIB_CURRENT_VERSION = 1; DYLIB_INSTALL_NAME_BASE = "@rpath"; @@ -4188,7 +4272,7 @@ ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage; CODE_SIGN_STYLE = Automatic; - DEVELOPMENT_TEAM = Q48UX93J22; + DEVELOPMENT_TEAM = 6G7464HHUS; GCC_ENABLE_CPP_EXCEPTIONS = NO; GCC_ENABLE_CPP_RTTI = NO; HEADER_SEARCH_PATHS = ( diff --git a/pymnn/test/model_test.py b/pymnn/test/model_test.py index 1da4e85f0..011cedb71 100644 --- a/pymnn/test/model_test.py +++ b/pymnn/test/model_test.py @@ -80,18 +80,21 @@ def MNNDataType2NumpyDataType(data_type): else: return np.float32 -def createTensor(tensor, file=''): +def createTensor(tensor, file='', empty=False): shape = tensor.getShape() data_type = tensor.getDataType() dtype = MNNDataType2NumpyDataType(data_type) if file == '': - data = np.ones(shape, dtype=dtype) + if empty: + data = np.zeros(shape, dtype=dtype) + else: + data = np.ones(shape, dtype=dtype) else: data = loadtxt(file, shape, dtype) - return MNN.Tensor(shape, tensor.getDataType(), data, tensor.getDimensionType()) + return MNN.Tensor(shape, tensor.getDataType(), data.copy(), tensor.getDimensionType()) def compareTensor(tensor, file, tolerance=5e-2): - outputNumpyData = tensor.getNumpyData() + outputNumpyData = tensor.getNumpyData().copy() expectNumpyData = loadtxt(file, tensor.getShape()) max_abs_dif = np.abs(outputNumpyData - expectNumpyData).max() max_exp_val = np.abs(expectNumpyData).max() @@ -117,6 +120,11 @@ def modelTest(modelPath, givenName, expectName): net = MNN.Interpreter(modelPath) session = net.createSession() allInput = net.getSessionInputAll(session) + # zero for all inputs + for name in allInput: + inputTensor = allInput[name] + inputHost = createTensor(inputTensor, givenName, True) + inputTensor.copyFrom(inputHost) # input inputTensor = net.getSessionInput(session) inputHost = createTensor(inputTensor, givenName) diff --git a/source/backend/arm82/Arm82Backend.cpp b/source/backend/arm82/Arm82Backend.cpp index 7b13b852b..377243388 100644 --- a/source/backend/arm82/Arm82Backend.cpp +++ b/source/backend/arm82/Arm82Backend.cpp @@ -118,6 +118,7 @@ void Arm82Backend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor CPUBackend::onCopyBuffer(srcTensor, dstTensor); return; } + _resetDynamicMemory(); auto source = TensorUtils::getDescribe(srcTensor)->dimensionFormat; auto dest = TensorUtils::getDescribe(dstTensor)->dimensionFormat; auto srcType = MNN_FORWARD_CPU; diff --git a/source/backend/arm82/Arm82Functions.cpp b/source/backend/arm82/Arm82Functions.cpp index 2e4e9dc6b..92749c426 100644 --- a/source/backend/arm82/Arm82Functions.cpp +++ b/source/backend/arm82/Arm82Functions.cpp @@ -35,12 +35,14 @@ void MNNPackedMatMulFP16(float* C, const float* A, const float* B, const size_t* // parameter: [aStride, l, h, cStride, bExtraStride] void MNNPackedMatMulRemainFP16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); -#ifdef MNN_LOW_MEMORY +#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM void MNNPackedMatMulFP16_int4(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void MNNPackedMatMulRemainFP16_int4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void MNNPackedMatMulFP16_int8(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void MNNPackedMatMulRemainFP16_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +#endif +#ifdef MNN_LOW_MEMORY void MNNAbsMaxFP16(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack); void MNNQuantScaleFP16(float* sum, float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch); void MNNDynamicQuantFP16(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack); @@ -48,8 +50,6 @@ void MNNQuantSumFP16(float* sum, const float* dequant_scale, size_t thread, size #endif #if defined(__aarch64__) void CountMinMaxValue_FP16(float* source, float* minVal, float* maxVal, size_t sizeQuad); -void MNNSumByAxisLForMatmul_A_ARM86(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams); -void MNNSumByAxisLForMatmul_A_ARM82(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams); #endif void MNNConvDwF23MulTransUnitFP16(FLOAT16 **cacheLine, const FLOAT16 *weight, FLOAT16 *dest, size_t ow); @@ -735,29 +735,25 @@ bool Arm82Functions::init() { FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul, MNNPackedMatMulFP16); FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16); #if defined(__aarch64__) -#ifdef MNN_LOW_MEMORY + gInstance->supportFp16arith = origin->supportFp16arith; + gInstance->supportSDot = origin->supportSDot; + gInstance->supportI8mm = origin->supportI8mm; +#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM // Weight Dequant Gemm Kernels FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul_int4, MNNPackedMatMulFP16_int4); FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain_int4, MNNPackedMatMulRemainFP16_int4); FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul_int8, MNNPackedMatMulFP16_int8); FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain_int8, MNNPackedMatMulRemainFP16_int8); +#endif +#ifdef MNN_LOW_MEMORY // Dynamic Qaunt Helper Functions FUNC_PTR_ASSIGN(gInstance->MNNAbsMax, MNNAbsMaxFP16); FUNC_PTR_ASSIGN(gInstance->MNNQuantScale, MNNQuantScaleFP16); FUNC_PTR_ASSIGN(gInstance->MNNDynamicQuant, MNNDynamicQuantFP16); FUNC_PTR_ASSIGN(gInstance->MNNQuantSum, MNNQuantSumFP16); FUNC_PTR_ASSIGN(gInstance->MNNCountMaxMinValue, ARM82CountMinMaxValue); - // Dynamic Quant Gemm Kernels. - gInstance->supportFp16arith = origin->supportFp16arith; - gInstance->supportSDot = origin->supportSDot; - gInstance->supportI8mm = origin->supportI8mm; #endif - if (gInstance->supportSDot) { - FUNC_PTR_ASSIGN(gInstance->MNNSumByAxisLForMatmul_A, MNNSumByAxisLForMatmul_A_ARM82); - } - if (gInstance->supportI8mm) { - FUNC_PTR_ASSIGN(gInstance->MNNSumByAxisLForMatmul_A, MNNSumByAxisLForMatmul_A_ARM86); - } + FUNC_PTR_ASSIGN(gInstance->MNNSumByAxisLForMatmul_A, origin->MNNSumByAxisLForMatmul_A); #endif FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Arm82MNNPackForMatMul_A); FUNC_PTR_ASSIGN(gInstance->MNNGetMatMulPackMode, Arm82MNNGetMatMulPackMode); diff --git a/source/backend/arm82/CMakeLists.txt b/source/backend/arm82/CMakeLists.txt index cc9fc0ab7..afbe55dbb 100644 --- a/source/backend/arm82/CMakeLists.txt +++ b/source/backend/arm82/CMakeLists.txt @@ -10,10 +10,17 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64") if (MNN_LOW_MEMORY) file(GLOB MNN_ARM82_SRCS_ASM ${MNN_ARM82_SRCS_ASM} ${CMAKE_CURRENT_LIST_DIR}/asm/arm64/low_memory/*) endif() + if (MNN_CPU_WEIGHT_DEQUANT_GEMM) + file(GLOB MNN_ARM82_SRCS_ASM ${MNN_ARM82_SRCS_ASM} ${CMAKE_CURRENT_LIST_DIR}/asm/arm64/normal_memory/*) + endif() add_library(MNN_Arm82 OBJECT ${MNN_ARM82_SRCS} ${MNN_ARM82_SRCS_ASM}) if (MNN_LOW_MEMORY) target_compile_options(MNN_Arm82 PRIVATE -DMNN_LOW_MEMORY) endif() + + if (MNN_CPU_WEIGHT_DEQUANT_GEMM) + target_compile_options(MNN_Arm82 PRIVATE -DMNN_CPU_WEIGHT_DEQUANT_GEMM) + endif() target_compile_options(MNN_Arm82 PRIVATE -march=armv8.2-a+fp16 -DENABLE_ARMV82) else() # Building fat binary requires multiple separate builds and lipo-by-hand under CMake's design diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuanInput_ARM82.S b/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuanInput_ARM82.S index 22919922f..0812f45c4 100644 --- a/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuanInput_ARM82.S +++ b/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuanInput_ARM82.S @@ -90,21 +90,31 @@ Note: Only used in dynamic quant,so do not need compare min max! */ asm_function DynamicQuanInput_ARM82 -//void DynamicQuanInput_ARM82(const float* src, int8_t* dst, size_t sizeQuad, float* scale, size_t aMin, size_t aMax, size_t zeroPoint); -//x0:src, x1:dst, x2:sizeQuad, x3:scale, x4:aMin, x5:aMax, x6:zeroPoint +//void DynamicQuanInput_ARM82(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, ssize_t maxValue, float* zeroPoint, ssize_t quanParamVec); +//x0:src, x1:dst, x2:sizeQuad, x3:scale, x4:aMin, x5:aMax, x6:zeroPoint, x7:quanParamVec stp d14, d15, [sp, #-64]! stp d12, d13, [sp, #16] stp d10, d11, [sp, #32] stp d8, d9, [sp, #48] ld1 {v29.s}[0], [x3] // Load scale -// copy zero point -dup v30.4s, w6 -fcvtn v31.4h, v29.4s -scvtf v30.4s, v30.4s +ld1 {v30.s}[0], [x6] // Load zero + +and x8, x7, #1 // if load vector scale +and x9, x7, #2 // if load vector zero +cbz x8, LOAD_VECTOR_ZERO +ld1 {v29.4s}, [x3] // scale + +LOAD_VECTOR_ZERO: +cbz x9, START +ld1 {v30.4s}, [x6] // zero + +START: +// copy zero point +fcvtn v31.4h, v29.4s // fp16 scale +fcvtn v30.4h, v30.4s // fp16 zero dup v31.8h, v31.h[0] -fcvtn v30.4h, v30.4s dup v30.8h, v30.h[0] FL28: diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuantAndReorder_ARM82.S b/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuantAndReorder_ARM82.S index 44e3568f1..5a8381765 100644 --- a/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuantAndReorder_ARM82.S +++ b/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuantAndReorder_ARM82.S @@ -1,5 +1,5 @@ // -// DynamicQuanInput_ARM82.S +// DynamicQuanInputAndReorder_ARM82.S // MNN // // Created by MNN on 2019/01/22. @@ -101,15 +101,12 @@ stp d10, d11, [sp, #32] stp d8, d9, [sp, #48] ld1 {v29.s}[0], [x3] // Load scale -// copy zero point -dup v30.4s, w6 +ld1 {v30.s}[0], [x6] // Load zero point fcvtn v31.4h, v29.4s -scvtf v30.4s, v30.4s - +fcvtn v30.4h, v30.4s add x13, x8, x8 dup v31.8h, v31.h[0] -fcvtn v30.4h, v30.4s dup v30.8h, v30.h[0] mov x9, x1 // first N*4 diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S index 143ec060a..ad9313244 100644 --- a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S +++ b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S @@ -115,19 +115,16 @@ ldr x27, [x6, #64] // blockNum mov x21, #16 // sizeof(float16_t) * PACK mul x27, x27, x3 Start: -lsl x15, x27, #4 // x15 = src_depth_quad * UNIT * SRC_UNIT +lsl x15, x27, #5 // x15 = src_depth_quad * UNIT * SRC_UNIT mov x22, #48 // src_steps -add x24, x15, x15 ldr x27, [x6, #80] // extra scale TILE_12: cmp x7, #12 blt TILE_8 L8LoopDz_TILE_12: - // ld1 {v0.4s, v1.4s}, [x9], #32 // bias mov x11, x1 mov x13, x3 - // Init 0 SET_BIAS v8, v9, v10, v11 SET_BIAS v12, v13, v14, v15 SET_BIAS v16, v17, v18, v19 @@ -137,13 +134,13 @@ L8LoopDz_TILE_12: mov x28, x2 L8LoopSz_TILE_12: - ld1 {v3.16b}, [x2], x15 // weight + ld1 {v3.16b, v4.16b}, [x2], #32 // weight ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] - ld1 {v4.16b}, [x2], #16 + .inst 0x4f81e06c // sdot v12.4s, v3.16b, v1.4b[0] .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] @@ -156,7 +153,7 @@ L8LoopDz_TILE_12: .inst 0x4fa0e095 // sdot v21.4s, v4.16b, v0.4b[1] .inst 0x4f80e896 // sdot v22.4s, v4.16b, v0.4b[2] .inst 0x4fa0e897 // sdot v23.4s, v4.16b, v0.4b[3] - sub x2, x2, x15 + .inst 0x4f81e098 // sdot v24.4s, v4.16b, v1.4b[0] .inst 0x4fa1e099 // sdot v25.4s, v4.16b, v1.4b[1] .inst 0x4f81e89a // sdot v26.4s, v4.16b, v1.4b[2] @@ -169,9 +166,7 @@ L8LoopDz_TILE_12: bne L8LoopSz_TILE_12 L8LoopSzEnd_TILE_12: - //add x2, x2, x15 - //add x24, x15, x15 - add x2, x28, x24 + add x2, x28, x15 sub x5, x5, #1 L8Tile12Quan: @@ -217,8 +212,6 @@ L8LoopDz_TILE_12: MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3 MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 - //ld1r {v0.4s}, [x23] // f32 min - //ld1r {v1.4s}, [x24] // f32 max MLA_WEIGHTZERO v20, v2, v6, 0 // tile:0, oc:4-7 MLA_WEIGHTZERO v21, v2, v6, 1 // tile:1, oc:4-7 MLA_WEIGHTZERO v22, v2, v6, 2 // tile:2, oc:4-7 @@ -297,8 +290,6 @@ L8LoopDz_TILE_12: blt End TILE_8: - //ld1r {v26.4s}, [x23] // f32 min - //ld1r {v27.4s}, [x24] // f32 max cmp x7, #8 blt TILE_4 mov x10, x0 @@ -319,18 +310,18 @@ L8LoopDz_TILE_8: SET_BIAS v20, v21, v22, v23 mov x28, x12 L8LoopSz_TILE_8: - ld1 {v3.16b}, [x12], x15 // weight + ld1 {v3.16b, v4.16b}, [x12], #32 // weight ld1 {v0.16b, v1.16b}, [x11], x22 // src .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] - ld1 {v4.16b}, [x12], #16 + .inst 0x4f81e06c // sdot v12.4s, v3.16b, v1.4b[0] .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] - sub x12, x12, x15 + .inst 0x4f80e090 // sdot v16.4s, v4.16b, v0.4b[0] .inst 0x4fa0e091 // sdot v17.4s, v4.16b, v0.4b[1] .inst 0x4f80e892 // sdot v18.4s, v4.16b, v0.4b[2] @@ -343,9 +334,7 @@ L8LoopDz_TILE_8: bne L8LoopSz_TILE_8 L8LoopSzEnd_TILE_8: - //add x12, x12, x15 - //add x24, x15, x15 - add x12, x28, x24 + add x12, x28, x15 sub x14, x14, #1 L8Tile8Quan: @@ -468,15 +457,13 @@ L8LoopDz_TILE_4: mov x28, x12 L8LoopSz_TILE_4: - ld1 {v3.16b}, [x12], x15 // weight + ld1 {v3.16b, v4.16b}, [x12], #32 // weight ld1 {v0.16b}, [x11], x22 // src - ld1 {v4.16b}, [x12], #16 // weight .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] subs x13, x13, #1 - sub x12, x12, x15 .inst 0x4f80e08c // sdot v12.4s, v4.16b, v0.4b[0] .inst 0x4fa0e08d // sdot v13.4s, v4.16b, v0.4b[1] .inst 0x4f80e88e // sdot v14.4s, v4.16b, v0.4b[2] @@ -484,9 +471,7 @@ L8LoopDz_TILE_4: bne L8LoopSz_TILE_4 L8LoopSzEnd_TILE_4: - //add x12, x12, x15 - //add x24, x15, x15 - add x12, x28, x24 + add x12, x28, x15 sub x14, x14, #1 L8Tile4Quan: @@ -571,23 +556,17 @@ L8LoopDz_TILE_1: movi v8.16b, #0 movi v9.16b, #0 - //mov v8.16b, v0.16b - //mov v9.16b, v1.16b mov x28, x12 L8LoopSz_TILE_1: - ld1 {v3.16b}, [x12], x15 // weight + ld1 {v3.16b, v4.16b}, [x12], #32 // weight ld1 {v0.s}[0], [x11], x22 // src - ld1 {v4.16b}, [x12], #16 // weight .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] subs x13, x13, #1 - sub x12, x12, x15 .inst 0x4f80e089 // sdot v9.4s, v4.16b, v0.4b[0] bne L8LoopSz_TILE_1 L8LoopSzEnd_TILE_1: - //add x12, x12, x15 - //add x24, x15, x15 - add x12, x28, x24 + add x12, x28, x15 sub x14, x14, #1 L8Tile1Quan: @@ -630,11 +609,7 @@ L8LoopDz_TILE_1: sub x23, x23, #2 fmax v0.8h, v24.8h, v0.8h fmin v0.8h, v25.8h, v0.8h - // st1 {v8.4s}, [x10], x4 - // st1 {v9.4s}, [x10], x4 - //fcvtn v0.4h, v8.4s - //fcvtn2 v0.8h, v9.4s TILE1_STORE: st1 {v0.8h}, [x10], x4 diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S index 5d92ae056..dd893b292 100644 --- a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S +++ b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S @@ -114,16 +114,14 @@ ldr x27, [x6, #64] // blockNum mov x21, #16 // sizeof(float16_t) * PACK mul x27, x27, x3 Start: -lsl x15, x27, #3 // x15 = src_depth_quad * UNIT * SRC_UNIT * sizeof(int4_t) +lsl x15, x27, #4 // x15 = src_depth_quad * UNIT * SRC_UNIT * sizeof(int4_t) mov x22, #48 // src_steps -add x24, x15, x15 ldr x27, [x6, #80] // extra scale TILE_12: cmp x7, #12 blt TILE_8 L8LoopDz_TILE_12: - // ld1 {v0.4s, v1.4s}, [x9], #32 // bias mov x11, x1 mov x13, x3 movi v7.16b, #15 @@ -138,13 +136,11 @@ L8LoopDz_TILE_12: mov x28, x2 L8LoopSz_TILE_12: - ld1 {v3.d}[0], [x2], x15 // weight - ld1 {v4.d}[0], [x2], #8 + ld1 {v5.16b}, [x2], #16 // weight ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src // int4->int8 - ushr v5.16b, v3.16b, #4 - and v6.16b, v3.16b, v7.16b - zip1 v3.16b, v5.16b, v6.16b + ushr v3.16b, v5.16b, #4 + and v4.16b, v5.16b, v7.16b .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] @@ -155,10 +151,6 @@ L8LoopDz_TILE_12: .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] - // int4->int8 - ushr v5.16b, v4.16b, #4 - and v6.16b, v4.16b, v7.16b - zip1 v4.16b, v5.16b, v6.16b .inst 0x4f82e070 // sdot v16.4s, v3.16b, v2.4b[0] .inst 0x4fa2e071 // sdot v17.4s, v3.16b, v2.4b[1] @@ -168,7 +160,7 @@ L8LoopDz_TILE_12: .inst 0x4fa0e095 // sdot v21.4s, v4.16b, v0.4b[1] .inst 0x4f80e896 // sdot v22.4s, v4.16b, v0.4b[2] .inst 0x4fa0e897 // sdot v23.4s, v4.16b, v0.4b[3] - sub x2, x2, x15 + .inst 0x4f81e098 // sdot v24.4s, v4.16b, v1.4b[0] .inst 0x4fa1e099 // sdot v25.4s, v4.16b, v1.4b[1] .inst 0x4f81e89a // sdot v26.4s, v4.16b, v1.4b[2] @@ -181,7 +173,7 @@ L8LoopDz_TILE_12: bne L8LoopSz_TILE_12 L8LoopSzEnd_TILE_12: - add x2, x28, x24 + add x2, x28, x15 sub x5, x5, #1 L8Tile12Quan: @@ -227,8 +219,6 @@ L8LoopDz_TILE_12: MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3 MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 - //ld1r {v0.4s}, [x23] // f32 min - //ld1r {v1.4s}, [x24] // f32 max MLA_WEIGHTZERO v20, v2, v6, 0 // tile:0, oc:4-7 MLA_WEIGHTZERO v21, v2, v6, 1 // tile:1, oc:4-7 MLA_WEIGHTZERO v22, v2, v6, 2 // tile:2, oc:4-7 @@ -304,7 +294,7 @@ L8LoopDz_TILE_12: L8Tile12LoopCheck: cmp x5, #1 bge L8LoopDz_TILE_12 - blt End + b End TILE_8: cmp x7, #8 @@ -327,27 +317,24 @@ L8LoopDz_TILE_8: SET_BIAS v20, v21, v22, v23 mov x28, x12 L8LoopSz_TILE_8: - ld1 {v3.d}[0], [x12], x15 // weight - ld1 {v4.d}[0], [x12], #8 + ld1 {v5.16b}, [x12], #16 // weight ld1 {v0.16b, v1.16b}, [x11], x22 // src // int4->int8 - ushr v5.16b, v3.16b, #4 - and v6.16b, v3.16b, v7.16b - zip1 v3.16b, v5.16b, v6.16b + ushr v3.16b, v5.16b, #4 + and v4.16b, v5.16b, v7.16b + //zip1 v3.16b, v5.16b, v6.16b + //zip2 v4.16b, v5.16b, v6.16b .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] - // int4->int8 - ushr v5.16b, v4.16b, #4 - and v6.16b, v4.16b, v7.16b - zip1 v4.16b, v5.16b, v6.16b + .inst 0x4f81e06c // sdot v12.4s, v3.16b, v1.4b[0] .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] - sub x12, x12, x15 + .inst 0x4f80e090 // sdot v16.4s, v4.16b, v0.4b[0] .inst 0x4fa0e091 // sdot v17.4s, v4.16b, v0.4b[1] .inst 0x4f80e892 // sdot v18.4s, v4.16b, v0.4b[2] @@ -360,7 +347,7 @@ L8LoopDz_TILE_8: bne L8LoopSz_TILE_8 L8LoopSzEnd_TILE_8: - add x12, x28, x24 + add x12, x28, x15 sub x14, x14, #1 L8Tile8Quan: @@ -446,10 +433,6 @@ L8LoopDz_TILE_8: st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], x4 - //st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], #64 - //st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 - //st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64 - //st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], x4 add x4, x4, #64 L8Tile8LoopCheck: @@ -483,24 +466,20 @@ L8LoopDz_TILE_4: mov x28, x12 L8LoopSz_TILE_4: - ld1 {v3.d}[0], [x12], x15 // weight + ld1 {v5.16b}, [x12], #16 // weight ld1 {v0.16b}, [x11], x22 // src - ld1 {v4.d}[0], [x12], #8 // weight // int4->int8 - ushr v5.16b, v3.16b, #4 - and v6.16b, v3.16b, v7.16b - zip1 v3.16b, v5.16b, v6.16b + ushr v3.16b, v5.16b, #4 + and v4.16b, v5.16b, v7.16b + //zip1 v3.16b, v5.16b, v6.16b + //zip2 v4.16b, v5.16b, v6.16b .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] - // int4->int8 - ushr v5.16b, v4.16b, #4 - and v6.16b, v4.16b, v7.16b - zip1 v4.16b, v5.16b, v6.16b + subs x13, x13, #1 - sub x12, x12, x15 .inst 0x4f80e08c // sdot v12.4s, v4.16b, v0.4b[0] .inst 0x4fa0e08d // sdot v13.4s, v4.16b, v0.4b[1] .inst 0x4f80e88e // sdot v14.4s, v4.16b, v0.4b[2] @@ -508,7 +487,7 @@ L8LoopDz_TILE_4: bne L8LoopSz_TILE_4 L8LoopSzEnd_TILE_4: - add x12, x28, x24 + add x12, x28, x15 sub x14, x14, #1 L8Tile4Quan: @@ -593,29 +572,61 @@ L8LoopDz_TILE_1: movi v8.16b, #0 movi v9.16b, #0 + mov x28, x12 - L8LoopSz_TILE_1: - ld1 {v3.d}[0], [x12], x15 // weight + cmp x13, #4 + blt L8LoopSz_TILE_1_lu1 + + L8LoopSz_TILE_1_lu4: + ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x12], #64 // weight: hu=0,1,2,3,pack=0~7 ld1 {v0.s}[0], [x11], x22 // src - ld1 {v4.d}[0], [x12], #8 // weight - // int4->int8 - ushr v5.16b, v3.16b, #4 - and v6.16b, v3.16b, v7.16b - zip1 v3.16b, v5.16b, v6.16b + ld1 {v0.s}[1], [x11], x22 + ld1 {v0.s}[2], [x11], x22 + ld1 {v0.s}[3], [x11], x22 - .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + sub x13, x13, #4 + // int4->int8 + ushr v12.16b, v3.16b, #4 + and v22.16b, v3.16b, v7.16b + + ushr v15.16b, v4.16b, #4 + and v23.16b, v4.16b, v7.16b + + ushr v18.16b, v5.16b, #4 + and v24.16b, v5.16b, v7.16b + + ushr v21.16b, v6.16b, #4 + and v25.16b, v6.16b, v7.16b + + cmp x13, #4 + //sub x12, x12, x15 + .inst 0x4f80e188 // sdot v8.4s, v12.16b, v0.4b[0] + .inst 0x4f80e2c9 // sdot v9.4s, v22.16b, v0.4b[0] + .inst 0x4fa0e1e8 // sdot v8.4s, v15.16b, v0.4b[1] + .inst 0x4fa0e2e9 // sdot v9.4s, v23.16b, v0.4b[1] + .inst 0x4f80ea48 // sdot v8.4s, v18.16b, v0.4b[2] + .inst 0x4f80eb09 // sdot v9.4s, v24.16b, v0.4b[2] + .inst 0x4fa0eaa8 // sdot v8.4s, v21.16b, v0.4b[3] + .inst 0x4fa0eb29 // sdot v9.4s, v25.16b, v0.4b[3] + bge L8LoopSz_TILE_1_lu4 + + cbz x13, L8LoopSzEnd_TILE_1 + + L8LoopSz_TILE_1_lu1: + ld1 {v4.16b}, [x12], #16 // weight + ld1 {v0.s}[0], [x11], x22 // src + //ld1 {v4.d}[0], [x12], #8 // weight subs x13, x13, #1 // int4->int8 - ushr v5.16b, v4.16b, #4 - and v6.16b, v4.16b, v7.16b - zip1 v4.16b, v5.16b, v6.16b - sub x12, x12, x15 + ushr v3.16b, v4.16b, #4 + and v12.16b, v4.16b, v7.16b - .inst 0x4f80e089 // sdot v9.4s, v4.16b, v0.4b[0] - bne L8LoopSz_TILE_1 + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + .inst 0x4f80e189 // sdot v9.4s, v12.16b, v0.4b[0] + bne L8LoopSz_TILE_1_lu1 L8LoopSzEnd_TILE_1: - add x12, x28, x24 + add x12, x28, x15 sub x14, x14, #1 L8Tile1Quan: @@ -658,8 +669,6 @@ L8LoopDz_TILE_1: sub x23, x23, #2 fmax v0.8h, v24.8h, v0.8h fmin v0.8h, v25.8h, v0.8h - // st1 {v8.4s}, [x10], x4 - // st1 {v9.4s}, [x10], x4 TILE1_STORE: st1 {v0.8h}, [x10], x4 diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16.S b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16.S index 7022af3a1..f6f6625d7 100644 --- a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16.S +++ b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16.S @@ -150,6 +150,7 @@ LoopDz_TILE_10: mov x12, x2 // weight mov x13, x3 // src_depth_quad mov x10, x0 // tag dst address + movi v2.16b, #15 SET_0_5 v12, v16, v20, v24, v28 // oc:0,1,0,1 SET_0_5 v13, v17, v21, v25, v29 // oc:2,3,2,3 @@ -158,7 +159,6 @@ LoopDz_TILE_10: LoopSz_TILE_10: ld1 {v0.16b, v1.16b}, [x12], #32 // weight - movi v2.16b, #15 ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9 ld1 {v7.16b}, [x11], #16 // int4->int8 @@ -763,50 +763,88 @@ TILE_1: mov x20, x9 // bias mov x6, x28 // weightQuanBias LoopDz_TILE_1: - //ld1 {v7.4s, v8.4s}, [x20], #32 // bias mov x11, x1 // src mov x12, x25 // weight mov x13, x3 // src_depth_quad mov x10, x26 - //dup v16.2d, v7.d[0] // oc:0,1,0,1 - //dup v17.2d, v7.d[1] // oc:2,3,2,3 - //dup v18.2d, v8.d[0] // oc:4,5,4,5 - //dup v19.2d, v8.d[1] // oc:6,7,6,7 movi v16.4s, #0 // oc:0,1,0,1 movi v17.4s, #0 // oc:2,3,2,3 movi v18.4s, #0 // oc:4,5,4,5 movi v19.4s, #0 // oc:6,7,6,7 - //movi v22.4s, #0 // oc:0,1,0,1 - //movi v23.4s, #0 // oc:2,3,2,3 - //movi v24.4s, #0 // oc:4,5,4,5 - //movi v25.4s, #0 // oc:6,7,6,7 + cmp x13, #4 + blt LoopSz1_TILE_1_lu1 +LoopSz1_TILE_1_lu4: + ld1 {v5.16b, v6.16b, v7.16b, v8.16b}, [x12], #64 // weight + ld1 {v9.16b, v10.16b, v11.16b, v12.16b}, [x12], #64 + ld1 {v0.8b}, [x11], x22 // src + ld1 {v1.8b}, [x11], x22 + ld1 {v2.8b}, [x11], x22 + ld1 {v3.8b}, [x11], x22 + + // int4->int8 + ushr v4.16b, v5.16b, #4 + ushr v14.16b, v6.16b, #4 + and v13.16b, v5.16b, v30.16b + and v15.16b, v6.16b, v30.16b + + ushr v20.16b, v7.16b, #4 + ushr v21.16b, v8.16b, #4 + and v22.16b, v7.16b, v30.16b + and v23.16b, v8.16b, v30.16b + + ushr v24.16b, v9.16b, #4 + ushr v25.16b, v10.16b, #4 + and v26.16b, v9.16b, v30.16b + and v27.16b, v10.16b, v30.16b + + ushr v5.16b, v11.16b, #4 + ushr v6.16b, v12.16b, #4 + and v7.16b, v11.16b, v30.16b + and v8.16b, v12.16b, v30.16b + + sub x13, x13, #4 + + .inst 0x4e84a410 // smmla v16.4s, v0.16b, v4.16b + .inst 0x4e8ea411 // smmla v17.4s, v0.16b, v14.16b + .inst 0x4e8da412 // smmla v18.4s, v0.16b, v13.16b + .inst 0x4e8fa413 // smmla v19.4s, v0.16b, v15.16b + + .inst 0x4e94a430 // smmla v16.4s, v1.16b, v20.16b + .inst 0x4e95a431 // smmla v17.4s, v1.16b, v21.16b + .inst 0x4e96a432 // smmla v18.4s, v1.16b, v22.16b + .inst 0x4e97a433 // smmla v19.4s, v1.16b, v23.16b + cmp x13, #4 + .inst 0x4e98a450 // smmla v16.4s, v2.16b, v24.16b + .inst 0x4e99a451 // smmla v17.4s, v2.16b, v25.16b + .inst 0x4e9aa452 // smmla v18.4s, v2.16b, v26.16b + .inst 0x4e9ba453 // smmla v19.4s, v2.16b, v27.16b + + .inst 0x4e85a470 // smmla v16.4s, v3.16b, v5.16b + .inst 0x4e86a471 // smmla v17.4s, v3.16b, v6.16b + .inst 0x4e87a472 // smmla v18.4s, v3.16b, v7.16b + .inst 0x4e88a473 // smmla v19.4s, v3.16b, v8.16b + + bge LoopSz1_TILE_1_lu4 + cbz x13, LoopSzEnd_TILE_1 -LoopSz1_TILE_1: - // src : 1 x [1 x 8] : v2 - // weight : 2 x [2 x 8] : v0-1 - // dst : 1 x 2 x [2] : v30-v31 +LoopSz1_TILE_1_lu1: ld1 {v13.16b, v14.16b}, [x12], #32 // weight - ld1 {v2.8b}, [x11], x22 // src + ld1 {v2.8b}, [x11], x22 // src // int4->int8 ushr v0.16b, v13.16b, #4 and v3.16b, v13.16b, v30.16b ushr v1.16b, v14.16b, #4 and v4.16b, v14.16b, v30.16b + subs x13, x13, #1 .inst 0x4e80a450 // smmla v16.4s, v2.16b, v0.16b .inst 0x4e81a451 // smmla v17.4s, v2.16b, v1.16b .inst 0x4e83a452 // smmla v18.4s, v2.16b, v3.16b .inst 0x4e84a453 // smmla v19.4s, v2.16b, v4.16b - subs x13, x13, #1 - bne LoopSz1_TILE_1 - - LoopSz_TILE_1_ADD: - //add v16.4s, v16.4s, v22.4s - //add v17.4s, v17.4s, v23.4s - //add v18.4s, v18.4s, v24.4s - //add v19.4s, v19.4s, v25.4s + + bne LoopSz1_TILE_1_lu1 LoopSzEnd_TILE_1: add x25, x25, x15 diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulFP16_int4.S b/source/backend/arm82/asm/arm64/normal_memory/MNNPackedMatMulFP16_int4.S similarity index 100% rename from source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulFP16_int4.S rename to source/backend/arm82/asm/arm64/normal_memory/MNNPackedMatMulFP16_int4.S diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulFP16_int8.S b/source/backend/arm82/asm/arm64/normal_memory/MNNPackedMatMulFP16_int8.S similarity index 100% rename from source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulFP16_int8.S rename to source/backend/arm82/asm/arm64/normal_memory/MNNPackedMatMulFP16_int8.S diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulRemainFP16_int4.S b/source/backend/arm82/asm/arm64/normal_memory/MNNPackedMatMulRemainFP16_int4.S similarity index 100% rename from source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulRemainFP16_int4.S rename to source/backend/arm82/asm/arm64/normal_memory/MNNPackedMatMulRemainFP16_int4.S diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulRemainFP16_int8.S b/source/backend/arm82/asm/arm64/normal_memory/MNNPackedMatMulRemainFP16_int8.S similarity index 100% rename from source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulRemainFP16_int8.S rename to source/backend/arm82/asm/arm64/normal_memory/MNNPackedMatMulRemainFP16_int8.S diff --git a/source/backend/coreml/backend/CoreMLBackend.cpp b/source/backend/coreml/backend/CoreMLBackend.cpp index 0a8ee125e..8342e68dd 100644 --- a/source/backend/coreml/backend/CoreMLBackend.cpp +++ b/source/backend/coreml/backend/CoreMLBackend.cpp @@ -300,7 +300,7 @@ namespace MNN { CoreMLRuntime::~CoreMLRuntime() {} - Backend* CoreMLRuntime::onCreate(const BackendConfig* config) const { + Backend* CoreMLRuntime::onCreate(const BackendConfig* config, Backend* origin) const { return new CoreMLBackend(this); } diff --git a/source/backend/coreml/backend/CoreMLBackend.hpp b/source/backend/coreml/backend/CoreMLBackend.hpp index 121b18192..b9136690b 100644 --- a/source/backend/coreml/backend/CoreMLBackend.hpp +++ b/source/backend/coreml/backend/CoreMLBackend.hpp @@ -26,7 +26,7 @@ namespace MNN { CoreMLRuntime(const Backend::Info& info); virtual ~CoreMLRuntime(); virtual CompilerType onGetCompilerType() const override; - virtual Backend* onCreate(const BackendConfig* conf) const override; + virtual Backend* onCreate(const BackendConfig* conf, Backend* origin) const override; virtual void onGabageCollect(int level) override; virtual std::pair onGetCache() override { return std::make_pair(mCacheBuffer, mCacheSize); diff --git a/source/backend/cpu/CMakeLists.txt b/source/backend/cpu/CMakeLists.txt index 41426c66c..e37ae3e55 100644 --- a/source/backend/cpu/CMakeLists.txt +++ b/source/backend/cpu/CMakeLists.txt @@ -24,6 +24,10 @@ if(MNN_LOW_MEMORY) target_compile_options(MNNCPU PRIVATE -DMNN_LOW_MEMORY) endif() +if(MNN_CPU_WEIGHT_DEQUANT_GEMM) + target_compile_options(MNNCPU PRIVATE -DMNN_CPU_WEIGHT_DEQUANT_GEMM) +endif() + # X86_64 AVX/SSE if (MNN_USE_SSE) include(${CMAKE_CURRENT_LIST_DIR}/x86_x64/CMakeLists.txt) diff --git a/source/backend/cpu/CPUAttention.cpp b/source/backend/cpu/CPUAttention.cpp index 8a5a89ec3..7f4c6ff44 100644 --- a/source/backend/cpu/CPUAttention.cpp +++ b/source/backend/cpu/CPUAttention.cpp @@ -30,22 +30,51 @@ namespace MNN { template -static void pack_query(Tensor* query, char* pack_q, int mNumHead, int mHeadDim, int eP, int seq_len, int h, float q_scale) { - T * query_src = query->host(); - T * query_dst = reinterpret_cast(pack_q); - for (int i = 0; i < seq_len; i++) { - int out_index = i / eP; - int in_index = i % eP; - for (int j = 0; j < mHeadDim; j++) { - query_dst[out_index * mHeadDim * eP + j * eP + in_index] = query_src[i * mNumHead * mHeadDim + h * mHeadDim + j] * q_scale; +void CPUAttention::pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_len, int h, float q_scale) { + if (mUseGemmInt8) { // Shape of Query: numhead, [seqlen/eP8, headdim/lP8, eP8, lP8] + mMinQ[h] = query->host()[h * mHeadDim]; + mMaxQ[h] = query->host()[h * mHeadDim]; + for (int i = 0; i < seq_len; i++) { + T * query_src = query->host() + i * mNumHead * mHeadDim + h * mHeadDim; + for (int j = 0; j < mHeadDim; j++) { + mMinQ[h] = ALIMIN(mMinQ[h], query_src[j]); + mMaxQ[h] = ALIMAX(mMaxQ[h], query_src[j]); + } + } + mQueryScale[h] = (mMaxQ[h] - mMinQ[h]) / 255.0f; + mQueryZeroPoint[h] = -255.0f * mMinQ[h] / (mMaxQ[h] - mMinQ[h]) - 128.0; + for (int i = 0; i < seq_len; i++) { + T * query_src = query->host() + i * mNumHead * mHeadDim + h * mHeadDim; + float sumQ = 0; + int out_index = i / eP8; + int in_index = i % eP8; + for (int j = 0; j < mHeadDim; j++) { + int a = j / lP8; + int b = j % lP8; + int quant_res = (int)roundf(query_src[j] / mQueryScale[h] + mQueryZeroPoint[h]); + sumQ += quant_res; + *((int8_t*)pack_q + out_index * UP_DIV(mHeadDim, lP8) * eP8 * lP8 + a * eP8 * lP8 + in_index * lP8 + b) = quant_res; + } + *((float*)sum_q + out_index * eP8 + in_index) = sumQ * mQueryScale[h]; + } + } + else { + T * query_src = query->host(); + T * query_dst = reinterpret_cast(pack_q); + for (int i = 0; i < seq_len; i++) { + int out_index = i / eP; + int in_index = i % eP; + for (int j = 0; j < mHeadDim; j++) { + query_dst[out_index * mHeadDim * eP + j * eP + in_index] = query_src[i * mNumHead * mHeadDim + h * mHeadDim + j] * q_scale; + } } } } template -static void unpack_QK(float * unpack_qk_dst, char * pack_qk_src, int seq_len, int kv_seq_len, int unit) { +void CPUAttention::unpack_QK(float * unpack_qk_dst, char * pack_qk_src, int seq_len, int kv_seq_len) { float * dst = unpack_qk_dst; - T * src = (T *)(pack_qk_src); + T * src = (T *)(pack_qk_src); // [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len] for (int i = 0; i < seq_len; i++) { for (int j = 0; j < kv_seq_len; j++) { @@ -119,6 +148,11 @@ ErrorCode CPUAttention::onResize(const std::vector& inputs, const std:: mThreadNum = ((CPUBackend *)backend())->threadNumber(); unit = core->pack; bytes = core->bytes; + int qkvQuantOptions = static_cast(backend())->getRuntime()->hint().qkvQuantOption; + mUseGemmInt8 = (qkvQuantOptions == 4); + if (mUseGemmInt8) { + static_cast(backend())->int8Functions()->MNNGetGemmUnit(&hP8, &lP8, &eP8); + } auto query = inputs[0]; auto key = inputs[1]; int seq_len = query->shape()[1]; @@ -126,12 +160,28 @@ ErrorCode CPUAttention::onResize(const std::vector& inputs, const std:: mHeadDim = query->shape()[3]; mKvNumHead = key->shape()[2]; mKVCacheManager->onResize(mKvNumHead, mHeadDim); - mPackQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP), mHeadDim, eP})); - mPackQKV.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit})); - backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC); - backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC); - backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC); - backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC); + if (mUseGemmInt8) { + mPackQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP8), UP_DIV(mHeadDim, lP8), eP8 * lP8})); + mSumQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP8), eP8})); + mPackQKV.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit})); + backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC); + backend()->onAcquireBuffer(mSumQ.get(), Backend::DYNAMIC); + backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mSumQ.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC); + mMinQ.resize(mNumHead); + mMaxQ.resize(mNumHead); + mQueryScale.resize(mNumHead); + mQueryZeroPoint.resize(mNumHead); + } else { + mPackQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP), mHeadDim, eP})); + mPackQKV.reset(Tensor::createDevice({mThreadNum, UP_DIV(mHeadDim, unit), seq_len, unit})); + backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC); + backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mPackQ.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mPackQKV.get(), Backend::DYNAMIC); + } return NO_ERROR; } @@ -179,12 +229,12 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: // Temporary tensors for intermediate results std::shared_ptr packQK(Tensor::createDevice({mThreadNum, UP_DIV(kv_seq_len, unit), seq_len, unit})); std::shared_ptr unpackQK(Tensor::createDevice({mThreadNum, seq_len, kv_seq_len})); - std::shared_ptr softmaxQK(Tensor::createDevice({mThreadNum, seq_len, kv_seq_len})); + std::shared_ptr softmMaxQ(Tensor::createDevice({mThreadNum, seq_len, kv_seq_len})); std::shared_ptr newPackQK(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP), kv_seq_len, eP})); std::shared_ptr dequantV(Tensor::createDevice({mKvNumHead, UP_DIV(mHeadDim, hP), kv_seq_len, hP})); backend()->onAcquireBuffer(packQK.get(), Backend::STATIC); backend()->onAcquireBuffer(unpackQK.get(), Backend::STATIC); - backend()->onAcquireBuffer(softmaxQK.get(), Backend::STATIC); + backend()->onAcquireBuffer(softmMaxQ.get(), Backend::STATIC); backend()->onAcquireBuffer(newPackQK.get(), Backend::STATIC); if (quant_value) { backend()->onAcquireBuffer(dequantV.get(), Backend::STATIC); @@ -194,48 +244,100 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: std::function mCompute = [=](int tId) { auto pack_q = mPackQ->host() + tId * UP_DIV(seq_len, eP) * mHeadDim * eP * bytes; auto pack_qk = packQK->host() + tId * UP_DIV(kv_seq_len, unit) * seq_len * unit * bytes; + char * sum_q = nullptr; auto unpack_qk = unpackQK->host() + tId * seq_len * kv_seq_len; - auto softmax_qk = softmaxQK->host() + tId * seq_len * kv_seq_len; + auto softmax_qk = softmMaxQ->host() + tId * seq_len * kv_seq_len; auto new_pack_qk = newPackQK->host() + tId * UP_DIV(seq_len, eP) * kv_seq_len * eP * bytes; auto pack_qkv = mPackQKV->host() + tId * UP_DIV(mHeadDim, unit) * seq_len * unit * bytes; auto QxK = quant_key ? core->MNNPackedMatMul_int8 : core->MNNPackedMatMul; auto QxK_remain = quant_key ? core->MNNPackedMatMulRemain_int8 : core->MNNPackedMatMulRemain; int head_index = tId * tileCount; + if (mUseGemmInt8) { + pack_q = mPackQ->host() + tId * UP_DIV(seq_len, eP8) * UP_DIV(mHeadDim, lP8) * eP8 * lP8; + sum_q = mSumQ->host() + tId * UP_DIV(seq_len, eP8) * eP8 * 4; + } for (int h = head_index; h < head_index + tileCount && h < mNumHead; h++) { int kv_h = h / group_size; char * key_addr = mKVCacheManager->addrOfKey(kv_h); - char * scale_addr = quant_key ? mKVCacheManager->addrOfScale(kv_h) : nullptr; - char * zero_point_addr = quant_key ? mKVCacheManager->addrOfZeroPoint(kv_h) : nullptr; - char * value_addr = quant_value ? dequantV->host() + kv_h * UP_DIV(mHeadDim, hP) * kv_seq_len * hP * bytes : mKVCacheManager->addrOfValue(kv_h); + char * scale_addr = mKVCacheManager->addrOfScale(kv_h); + char * zero_point_addr = mKVCacheManager->addrOfZeroPoint(kv_h); + char * key_sum_addr = mKVCacheManager->addrOfKeySum(kv_h); + char * value_addr = quant_value ? (dequantV->host() + kv_h * UP_DIV(mHeadDim, hP) * kv_seq_len * hP * bytes) : mKVCacheManager->addrOfValue(kv_h); if (bytes == 2) { - pack_query(query, pack_q, mNumHead, mHeadDim, eP, seq_len, h, q_scale); + pack_query(query, pack_q, sum_q, seq_len, h, q_scale); } else { - pack_query(query, pack_q, mNumHead, mHeadDim, eP, seq_len, h, q_scale); + pack_query(query, pack_q, sum_q, seq_len, h, q_scale); } // query @ key - int loop_e = seq_len / eP; - int remain = seq_len % eP; - size_t shapeParameters[7] = {(size_t)eP * bytes, (size_t)mHeadDim, (size_t)kv_seq_len, (size_t)seq_len * unit * bytes, 0, 0, 0}; - for (int i = 0 ; i < loop_e; i++) { - QxK((float*)(pack_qk + (i * eP * unit) * bytes), (float*)(pack_q + (i * mHeadDim * eP) * bytes), (float*)key_addr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr); + if (mUseGemmInt8) { + auto GemmInt8Kernel = static_cast(backend())->int8Functions()->Int8GemmKernel; + if (bytes == 2 && unit == 8) { + GemmInt8Kernel = static_cast(backend())->int8Functions()->MNNGemmInt8AddBiasScale_Unit_FP16; + } + std::vector postScale(ROUND_UP(kv_seq_len, hP8), 0.0f); + for (int i = 0; i < kv_seq_len; i++) { + postScale[i] = ((float*)scale_addr)[i] * mQueryScale[h] * q_scale; + } + std::vector weightQuantBias(ROUND_UP(kv_seq_len, hP8), 0.0f); + for (int i = 0; i < kv_seq_len; i++) { + weightQuantBias[i] = -((float*)scale_addr)[i] * ((float*)zero_point_addr)[i] * q_scale; + } + std::vector biasFloat(ROUND_UP(kv_seq_len, hP8), 0.0f); + for (int i = 0; i < kv_seq_len; i++) { + biasFloat[i] = -mQueryScale[h] * mQueryZeroPoint[h] * ((float*)key_sum_addr)[i] * q_scale; + } + QuanPostTreatParameters post; + post.bias = nullptr; + post.biasFloat = biasFloat.data(); + post.blockNum = 1; + post.extraBias = nullptr; + post.extraScale = nullptr; + post.fp32minmax = nullptr; + post.scale = postScale.data(); + post.useInt8 = false; + post.weightQuanBias = weightQuantBias.data(); + int N = UP_DIV(seq_len, eP8); + for (int i = 0; i < N; i++) { + int realcount = ALIMIN(eP8, seq_len - i * eP8); + post.srcKernelSum = (float*)((char*)sum_q + i * eP8 * 4); + GemmInt8Kernel( + (int8_t*)pack_qk + i * eP8 * unit * bytes, + (int8_t*)pack_q + i * ROUND_UP(mHeadDim, lP8) * eP8, + (int8_t*)key_addr, + UP_DIV(mHeadDim, lP8), + seq_len * unit * bytes, + UP_DIV(kv_seq_len, unit), + &post, + realcount + ); + } + } + else { + int loop_e = seq_len / eP; + int remain = seq_len % eP; + size_t shapeParameters[7] = {(size_t)eP * bytes, (size_t)mHeadDim, (size_t)kv_seq_len, (size_t)seq_len * unit * bytes, 0, 0, 0}; + for (int i = 0 ; i < loop_e; i++) { + QxK((float*)(pack_qk + (i * eP * unit) * bytes), (float*)(pack_q + (i * mHeadDim * eP) * bytes), (float*)key_addr, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr); + } + QxK_remain((float*)(pack_qk + (loop_e * eP * unit) * bytes), (float*)(pack_q + (loop_e * mHeadDim * eP) * bytes), (float*)key_addr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr); } - QxK_remain((float*)(pack_qk + (loop_e * eP * unit) * bytes), (float*)(pack_q + (loop_e * mHeadDim * eP) * bytes), (float*)key_addr, remain, shapeParameters, nullptr, nullptr, (float*)scale_addr, (float*)zero_point_addr); // qk: [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len, eP] if(bytes == 2) { - unpack_QK(unpack_qk, pack_qk, seq_len, kv_seq_len, unit); + unpack_QK(unpack_qk, pack_qk, seq_len, kv_seq_len); mask_QK(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits::lowest(), mask->host(), float_mask); softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len); pack_QK(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP); } else { - unpack_QK(unpack_qk, pack_qk, seq_len, kv_seq_len, unit); + unpack_QK(unpack_qk, pack_qk, seq_len, kv_seq_len); mask_QK(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits::lowest(), mask->host(), float_mask); softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len); pack_QK(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP); } // qk @ v - shapeParameters[1] = kv_seq_len; - shapeParameters[2] = mHeadDim; + size_t shapeParameters[7] = {(size_t)eP * bytes, (size_t)kv_seq_len, (size_t)mHeadDim, (size_t)seq_len * unit * bytes, 0, 0, 0}; shapeParameters[5] = quant_value ? 0 : (max_len - kv_seq_len) * hP * bytes; + int loop_e = seq_len / eP; + int remain = seq_len % eP; for (int i = 0 ; i < loop_e; i++) { core->MNNPackedMatMul((float*)(pack_qkv + (i * eP * unit) * bytes), (float*)(new_pack_qk + (i * kv_seq_len * eP) * bytes), (float*)value_addr, shapeParameters, nullptr, nullptr, nullptr, nullptr); } @@ -257,7 +359,7 @@ ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std: backend()->onReleaseBuffer(packQK.get(), Backend::STATIC); backend()->onReleaseBuffer(unpackQK.get(), Backend::STATIC); - backend()->onReleaseBuffer(softmaxQK.get(), Backend::STATIC); + backend()->onReleaseBuffer(softmMaxQ.get(), Backend::STATIC); backend()->onReleaseBuffer(newPackQK.get(), Backend::STATIC); if (quant_value){ backend()->onReleaseBuffer(dequantV.get(), Backend::STATIC); @@ -277,10 +379,13 @@ bool CPUAttention::onClone(Backend* bn, const Op* op, Execution** dst) { CPUAttention::CPUAttention(Backend *backend, bool kv_cache) : Execution(backend), mKVCache(kv_cache) { if (mKVCache) { + mPackQ.reset(Tensor::createDevice({1, 1, 1, 1})); + mPackQKV.reset(Tensor::createDevice({1, 1, 1, 1})); MNN::KVCacheManager::KVCacheConfig kvconfig; - int kvcacheQuantOptions = static_cast(backend)->getRuntime()->hint().kvcacheQuantOption; - kvconfig.mQuantKey = (kvcacheQuantOptions & 1); - kvconfig.mQuantValue = ((kvcacheQuantOptions >> 1) & 1); + int qkvQuantOptions = static_cast(backend)->getRuntime()->hint().qkvQuantOption; + kvconfig.mUseInt8Kernel = (qkvQuantOptions == 4); + kvconfig.mQuantKey = (qkvQuantOptions == 4) || (qkvQuantOptions & 1); + kvconfig.mQuantValue = (qkvQuantOptions == 4) || ((qkvQuantOptions >> 1) & 1); kvconfig.mKVCacheDir = static_cast(backend)->getRuntime()->hint().kvcacheDirPath; kvconfig.mKVCacheSizeLimit = static_cast(backend)->getRuntime()->hint().kvcacheSizeLimit; kvconfig.mExpandChunk = 64; @@ -305,4 +410,4 @@ REGISTER_CPU_OP_CREATOR_TRANSFORMER(CPUAttentionCreator, OpType_Attention); } // namespace MNN -#endif // MNN_SUPPORT_TRANSFORMER_FUSE \ No newline at end of file +#endif // MNN_SUPPORT_TRANSFORMER_FUSE diff --git a/source/backend/cpu/CPUAttention.hpp b/source/backend/cpu/CPUAttention.hpp index 4aba816f3..a05b68712 100644 --- a/source/backend/cpu/CPUAttention.hpp +++ b/source/backend/cpu/CPUAttention.hpp @@ -29,12 +29,17 @@ class CPUAttention : public Execution { bool mIsPrefill = true; bool mIsFirstPrefill = true; bool mKVCache = true; + bool mUseGemmInt8 = false; int bytes = 4; int mThreadNum = 1;; - int eP, lP, hP, unit; + int eP, lP, hP, unit; // float matmul packing + int eP8, lP8, hP8; // GemmInt8 packing int mNumHead, mKvNumHead, mHeadDim; - std::shared_ptr mPackQ, mPackQKV; + std::shared_ptr mPackQ, mPackQKV, mSumQ; std::shared_ptr mKVCacheManager = nullptr; + std::vector mMinQ, mMaxQ, mQueryScale, mQueryZeroPoint; + template void pack_query(Tensor* query, char* pack_q, char* sum_q, int seq_len, int h, float q_scale); + template void unpack_QK(float * unpack_qk_dst, char * pack_qk_src, int seq_len, int kv_seq_len); }; } // namespace MNN diff --git a/source/backend/cpu/CPUBackend.cpp b/source/backend/cpu/CPUBackend.cpp index 99156a447..dd3401dcf 100644 --- a/source/backend/cpu/CPUBackend.cpp +++ b/source/backend/cpu/CPUBackend.cpp @@ -37,6 +37,7 @@ #include "x86_x64/AVX2Backend.hpp" #endif +#define MNN_CPU_MAX_BUFFER_INDEX 2 #define MNN_CPU_CHECK_NAN 1 #define MNN_CPU_USE_DEFAULT_BACKEND 4 namespace MNN { @@ -208,7 +209,12 @@ void CPURuntime::onReset(int numberThread, const BackendConfig* config, bool ful } CPURuntime::CPURuntime(const Backend::Info& info) { - mStaticAllocator.reset(new EagerBufferAllocator(BufferAllocator::Allocator::createDefault())); + auto rawAlloc = BufferAllocator::Allocator::createDefault(); + mStaticAllocator.reset(new EagerBufferAllocator(rawAlloc)); + mDynamic.resize(MNN_CPU_MAX_BUFFER_INDEX); + for (auto& buf : mDynamic) { + buf.root = rawAlloc; + } mThreadNumber = info.numThread; mPower = BackendConfig::Power_Normal; mMemory = BackendConfig::Memory_Normal; @@ -231,17 +237,49 @@ CPURuntime:: ~ CPURuntime() { } float CPURuntime::onGetMemoryInMB() { auto staticMemoryInMB = mStaticAllocator->totalSize() / 1024.0f / 1024.0f; - return staticMemoryInMB; + float dynamicMemoryInMB = 0.0f; + for (auto& buf : mDynamic) { + dynamicMemoryInMB += buf.currentSize / 1024.0f / 1024.0f; + } + return staticMemoryInMB + dynamicMemoryInMB; } bool CPURuntime::onCheckInfo(Backend::Info& info) const { info.numThread = mThreadNumber; return true; } +SingleBufferWithAllocator* CPURuntime::buffer(int index) const { + if (mDynamicMmap.empty()) { + return mDynamic.data() + index; + } + return mDynamicMmap.data() + index; +} -Backend* CPURuntime::onCreate(const BackendConfig* config) const { +Backend* CPURuntime::onCreate(const BackendConfig* config, Backend* origin) const { + if (hint().midMemoryPath.size() > 0) { + if (mDynamicMmap.empty()) { + // Only support set featuremap dir once + mDynamicMmap.resize(2); + auto mmapMem = BufferAllocator::Allocator::createMmap(hint().midMemoryPath.c_str(), "dynamic"); + for (auto& buf : mDynamicMmap) { + buf.root = mmapMem; + } + } + } + if (hint().weightMemoryPath.size() > 0) { + if (nullptr == mStaticAllocatorCache.get()) { + // Only support set weightmap dir once + mStaticAllocatorCache = mStaticAllocator; + auto mmapMem = BufferAllocator::Allocator::createMmap(hint().weightMemoryPath.c_str(), "static"); + mStaticAllocator.reset(new EagerBufferAllocator(mmapMem, 32, 1024 * 1024 * 1024)); + } + } auto precision = mPrecision; auto memory = mMemory; size_t flags = mFlags; + if (nullptr != origin) { + auto cpuBn = static_cast(origin); + mSharedDmaInfo = cpuBn->mDmaInfo; + } _resetGroupCompute(); if (nullptr != config) { precision = config->precision; @@ -251,30 +289,36 @@ Backend* CPURuntime::onCreate(const BackendConfig* config) const { #ifdef LOG_VERBOSE MNN_PRINT("cpu backend was created by runtime:%p\n", this); #endif - + CPUBackend* res = nullptr; + do { #ifdef MNN_USE_ARMV82 - auto core = MNNGetCoreFunctions(); - if (core->supportFp16arith && precision == BackendConfig::Precision_Low) { - return new Arm82Backend(this, memory); - } + auto core = MNNGetCoreFunctions(); + if (core->supportFp16arith && precision == BackendConfig::Precision_Low) { + res = new Arm82Backend(this, memory); + break; + } #endif #ifdef MNN_SUPPORT_BF16 - if (precision == BackendConfig::Precision_Low_BF16 && BF16Functions::get()) { - auto res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU_EXTENSION, 0); - res->mCoreFunctions = BF16Functions::get(); - return res; - } + if (precision == BackendConfig::Precision_Low_BF16 && BF16Functions::get()) { + res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU_EXTENSION, 0); + res->mCoreFunctions = BF16Functions::get(); + break; + } #endif - if (flags == MNN_CPU_USE_DEFAULT_BACKEND) { - return new CPUBackend(this, precision, memory, MNN_FORWARD_CPU, 0); - } + if (flags == MNN_CPU_USE_DEFAULT_BACKEND) { + res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU, 0); + break; + } #ifdef MNN_USE_SSE - if (AVX2Backend::isValid()) { - return new AVX2Backend(this, memory, flags); - } + if (AVX2Backend::isValid()) { + res = new AVX2Backend(this, memory, flags); + break; + } #endif - - return new CPUBackend(this, precision, memory, MNN_FORWARD_CPU, flags); + res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU, flags); + } while (false); + mSharedDmaInfo = nullptr; + return res; } int CPURuntime::onGetRuntimeStatus(RuntimeStatus statusEnum) const { @@ -298,6 +342,11 @@ int CPURuntime::onGetRuntimeStatus(RuntimeStatus statusEnum) const { void CPURuntime::onGabageCollect(int level) { mStaticAllocator->release(false); + if (level >= 100) { + for (auto& buf : mDynamic) { + buf.release(); + } + } } @@ -339,25 +388,34 @@ bool CPUBackend::addCreator(OpType t, Creator* c) { map->insert(std::make_pair(t, c)); return true; } - +BufferAllocator* CPURuntime::createDynamicBufferAlloctor(int index) const { + if (hint().memoryAllocatorType == Runtime::Allocator_Defer) { + return new DeferBufferAllocator(buffer(index)); + } + if (nullptr != mStaticAllocatorCache.get()) { + return new EagerBufferAllocator(BufferAllocator::Allocator::createRecurse(mStaticAllocatorCache.get())); + } + return new EagerBufferAllocator(BufferAllocator::Allocator::createRecurse(mStaticAllocator.get())); +} CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode precision, BackendConfig::MemoryMode memory, MNNForwardType type, size_t flags) : Backend(type) { #ifdef LOG_VERBOSE MNN_PRINT("cpu backend create\n"); #endif mMemory = memory; mRuntime = const_cast(runtime); - std::shared_ptr defaultAlloc(BufferAllocator::Allocator::createRecurse(runtime->mStaticAllocator.get())); - if (mRuntime->hint().memoryAllocatorType == Runtime::Allocator_Defer) { - mDynamicAllocator.reset(new DeferBufferAllocator(defaultAlloc)); + auto dynamicAlloc = mRuntime->mSharedDmaInfo; + if (nullptr == dynamicAlloc.get()) { + mDmaInfo.reset(new CPURuntime::DynamicAllocator); + mDmaInfo->mDynamicAllocator.reset(mRuntime->createDynamicBufferAlloctor(0)); + mDmaInfo->mCurrentDynamicAllocator = mDmaInfo->mDynamicAllocator.get(); } else { - mDynamicAllocator.reset(new EagerBufferAllocator(defaultAlloc)); + mDmaInfo = dynamicAlloc; } - mCurrentDynamicAllocator = mDynamicAllocator.get(); mStaticAllocator = runtime->mStaticAllocator; mPrecisionMode = precision; mCoreFunctions = MNNGetCoreFunctions(); mInt8CoreFunctions = MNNGetInt8CoreFunctions(); - mCacheGroup.resize(2); + mCacheGroup.resize(MNN_CPU_MAX_BUFFER_INDEX); for (int i=0; imDynamicAllocator->apply(); + if (nullptr != mDmaInfo->mDynamicAllocatorBackup.get()) { + mDmaInfo->mDynamicAllocatorBackup->apply(); + } +} void CPUBackend::onExecuteBegin() const { + _resetDynamicMemory(); mRuntime->onConcurrencyBegin(); } @@ -377,23 +442,20 @@ void CPUBackend::onExecuteEnd() const { } void CPUBackend::onResizeBegin() { - mCurrentDynamicAllocator->reset(); + mDmaInfo->mCurrentDynamicAllocator->reset(); } bool CPUBackend::onSelectDynamicAllocator(int index, int maxIndex) { if (maxIndex > 2) { return false; } - if (maxIndex == 2 && mDynamicAllocatorBackup.get() == nullptr) { - if (mRuntime->hint().memoryAllocatorType == Runtime::Allocator_Defer) { - mDynamicAllocatorBackup.reset(new DeferBufferAllocator(BufferAllocator::Allocator::createRecurse(mStaticAllocator.get()))); - } else { - mDynamicAllocatorBackup.reset(new EagerBufferAllocator(BufferAllocator::Allocator::createRecurse(mStaticAllocator.get()))); - } + if (maxIndex == 2 && mDmaInfo->mDynamicAllocatorBackup.get() == nullptr) { + mDmaInfo->mDynamicAllocatorBackup.reset(mRuntime->createDynamicBufferAlloctor(1)); } if (1 == index) { - mCurrentDynamicAllocator = mDynamicAllocatorBackup.get(); + mDmaInfo->mCurrentDynamicAllocator = mDmaInfo->mDynamicAllocatorBackup.get(); } else { - mCurrentDynamicAllocator = mDynamicAllocator.get(); + mRuntime->buffer(0)->release(); + mDmaInfo->mCurrentDynamicAllocator = mDmaInfo->mDynamicAllocator.get(); } mCache = mCacheGroup[index].get(); return true; @@ -401,7 +463,11 @@ bool CPUBackend::onSelectDynamicAllocator(int index, int maxIndex) { ErrorCode CPUBackend::onResizeEnd() { getCache()->release(); - return mCurrentDynamicAllocator->compute(); + auto code = mDmaInfo->mCurrentDynamicAllocator->compute(); + if (NO_ERROR != code) { + return code; + } + return NO_ERROR; } Backend::MemObj* CPUBackend::allocBuffer(size_t size, Tensor* dest, StorageType storageType) { @@ -431,11 +497,11 @@ Backend::MemObj* CPUBackend::allocBuffer(size_t size, Tensor* dest, StorageType break; } case DYNAMIC: { - chunk = mCurrentDynamicAllocator->alloc(size, false); + chunk = mDmaInfo->mCurrentDynamicAllocator->alloc(size, false); break; } case DYNAMIC_SEPERATE: { - chunk = mCurrentDynamicAllocator->alloc(size, true); + chunk = mDmaInfo->mCurrentDynamicAllocator->alloc(size, true); break; } default: @@ -453,7 +519,7 @@ Backend::MemObj* CPUBackend::allocBuffer(size_t size, Tensor* dest, StorageType if (storageType == STATIC) { res = new CPUMemObj(mStaticAllocator.get(), chunk, size); } else { - res = new CPUMemObj(mCurrentDynamicAllocator, chunk, size); + res = new CPUMemObj(mDmaInfo->mCurrentDynamicAllocator, chunk, size); chunk.attach(dest); } if (chunk.ptr()) { @@ -591,8 +657,11 @@ const Runtime* CPUBackend::getRuntime() { } bool CPUBackend::onClearBuffer() { + if (nullptr != mRuntime->mStaticAllocatorCache.get()) { + mStaticAllocator = mRuntime->mStaticAllocatorCache; + } mCache->reset(); - mCurrentDynamicAllocator->release(true); + mDmaInfo->mCurrentDynamicAllocator->release(true); return true; } @@ -606,9 +675,9 @@ std::pair CPUBackend::multiThreadDivide(int size) const { return std::make_pair(sizeDivide, scheduleNumber); } void CPUBackend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor) const { + _resetDynamicMemory(); auto& srcBuffer = srcTensor->buffer(); auto& dstBuffer = dstTensor->buffer(); - if (srcBuffer.dimensions != dstBuffer.dimensions ) { if (srcBuffer.dim[srcBuffer.dimensions - 1].extent != 1 && dstBuffer.dim[dstBuffer.dimensions - 1].extent != 1) { MNN_ERROR("srcBuffer dimension not equal to dstBuffer, can't copy buffer\n"); diff --git a/source/backend/cpu/CPUBackend.hpp b/source/backend/cpu/CPUBackend.hpp index 1286df907..b4c9843d0 100644 --- a/source/backend/cpu/CPUBackend.hpp +++ b/source/backend/cpu/CPUBackend.hpp @@ -20,11 +20,16 @@ namespace MNN { class CPURuntime : public Runtime { public: + struct DynamicAllocator { + std::shared_ptr mDynamicAllocator; + std::shared_ptr mDynamicAllocatorBackup; + BufferAllocator* mCurrentDynamicAllocator = nullptr; + }; friend class CPUBackend; CPURuntime(const Backend::Info& info); virtual ~ CPURuntime(); int onGetRuntimeStatus(RuntimeStatus statusEnum) const override; - virtual Backend* onCreate(const BackendConfig* config) const override; + virtual Backend* onCreate(const BackendConfig* config, Backend* origin) const override; virtual void onReset(int numberThread, const BackendConfig* config, bool full) override; virtual void onGabageCollect(int level) override; virtual float onGetMemoryInMB() override; @@ -43,10 +48,13 @@ class CPURuntime : public Runtime { return mThreadOpen; } #endif + SingleBufferWithAllocator* buffer(int index) const; + BufferAllocator* createDynamicBufferAlloctor(int index) const; + private: void _bindCPUCore() const; void _resetThreadPool(); - std::shared_ptr mStaticAllocator; + mutable std::shared_ptr mStaticAllocator; int mThreadNumber; #ifdef MNN_USE_THREAD_POOL mutable int mTaskIndex = -1; @@ -64,6 +72,10 @@ class CPURuntime : public Runtime { static Backend*(*gExtraCreate)(const Runtime* runtime); size_t mFlags = 0; mutable int mCurrentTID = 0; + mutable std::vector mDynamic; + mutable std::vector mDynamicMmap; + mutable std::shared_ptr mSharedDmaInfo; + mutable std::shared_ptr mStaticAllocatorCache; }; struct CoreFunctions; struct CoreInt8Functions; @@ -122,6 +134,7 @@ class CPUBackend : public Backend { const CoreInt8Functions* int8Functions() const { return mInt8CoreFunctions; } + void _resetDynamicMemory() const; public: class Creator { public: @@ -141,7 +154,7 @@ class CPUBackend : public Backend { #endif BufferAllocator* getBufferAllocator(bool defer_allocator = true) const { - return mCurrentDynamicAllocator; + return mDmaInfo->mCurrentDynamicAllocator; } BackendConfig::MemoryMode memoryMode() const { @@ -164,22 +177,19 @@ class CPUBackend : public Backend { static DataType getDataType(const Tensor* tensor); friend class CPURuntime; - protected: MemObj* allocBuffer(size_t size, Tensor* dest, StorageType storageType); CoreFunctions* mCoreFunctions; CoreInt8Functions* mInt8CoreFunctions; private: + std::shared_ptr mDmaInfo; std::shared_ptr mStaticAllocator; - std::shared_ptr mDynamicAllocator; - std::shared_ptr mDynamicAllocatorBackup; CPURuntime* mRuntime; BackendConfig::PrecisionMode mPrecisionMode; BackendConfig::MemoryMode mMemory; static std::map* gCreator; CPUResizeCache* mCache; std::vector> mCacheGroup; - BufferAllocator* mCurrentDynamicAllocator = nullptr; }; /** execution cast wrapper. insert tensor cast dynamic. */ class CastWrapExecution : public Execution { diff --git a/source/backend/cpu/CPUCast.cpp b/source/backend/cpu/CPUCast.cpp index ad989f0f3..1bc72dbb1 100644 --- a/source/backend/cpu/CPUCast.cpp +++ b/source/backend/cpu/CPUCast.cpp @@ -21,13 +21,12 @@ ErrorCode CPUCastCreator::cast(const void* inputRaw, void* outputRaw, ConvertTyp int remain = number % pack; if (type == FlOAT_TO_INT8) { scale = (scale == 0.f ? 0.f : 1.f / scale); - std::vector scales(pack, scale); - bn->int8Functions()->MNNFloat2Int8((float*)(inputRaw), (int8_t*)(outputRaw), c4Size, scales.data(), min, max, zero); + bn->int8Functions()->MNNFloat2Int8((float*)(inputRaw), (int8_t*)(outputRaw), c4Size, &scale, min, max, &zero, 0); if (remain > 0) { std::vector tempSrc(pack); std::vector tempDst(pack); ::memcpy(tempSrc.data(), (float*)(inputRaw) + c4Size * pack, remain * sizeof(float)); - bn->int8Functions()->MNNFloat2Int8(tempSrc.data(), tempDst.data(), 1, scales.data(), min, max, zero); + bn->int8Functions()->MNNFloat2Int8(tempSrc.data(), tempDst.data(), 1, &scale, min, max, &zero, 0); ::memcpy(static_cast(outputRaw) + c4Size * pack, tempDst.data(), remain * sizeof(int8_t)); } return NO_ERROR; diff --git a/source/backend/cpu/CPUConvolution.cpp b/source/backend/cpu/CPUConvolution.cpp index 109b4cc6a..eb34aa9c2 100644 --- a/source/backend/cpu/CPUConvolution.cpp +++ b/source/backend/cpu/CPUConvolution.cpp @@ -117,7 +117,6 @@ void CPUConvolution::MutableResourceInt8::updateInputOutputScale(std::vectormOutputCount; const int kernelNum = static_cast(mResource->mInt8WeightKernelSum.size()); auto biasData = mResource->mOriginBias->host(); auto alphaData = mResource->mOriginScale->host(); @@ -189,7 +188,6 @@ std::shared_ptr CPUConvolution::makeResourceInt8(B const int8_t* weightSrc = nullptr; int weightSize = 0; std::shared_ptr quanCommon; - resource->mOutputCount = outputCount; if (!ConvolutionCommon::getConvInt8Parameters(op, quanCommon, backend, weightSrc, weightSize, scalePtr, biasPtr, betaPtr)) { return nullptr; } @@ -254,174 +252,6 @@ std::shared_ptr CPUConvolution::makeResourceInt8(B return resource; } -void CPUConvolution::makeResource(Backend* backend, std::shared_ptr resource, const MNN::Op *op, std::shared_ptr resourceInt8) { - /* Used to compute weight quant scale and bias and weightKernelSum of type float. */ - auto conv2d = op->main_as_Convolution2D(); - bool quanBuffer = (conv2d->quanParameter() != nullptr && conv2d->quanParameter()->buffer() != nullptr); - MNN_ASSERT(quanBuffer || resourceInt8); - resource->backend = backend; - auto core = static_cast(backend)->functions(); - // common parameters - int outputCount = conv2d->common()->outputCount(); - int LSize = conv2d->common()->inputCount() * conv2d->common()->kernelX() * conv2d->common()->kernelY(); - int ocUp4 = ROUND_UP(outputCount, core->pack); - int8_t* weightOrigin; - - // Save weight quant scale and bias: wf=scale*wi+bias - resource->mDequantize.mScaleBias.reset(Tensor::createDevice({2 * ocUp4 * core->bytes})); - auto success = resource->backend->onAcquireBuffer(resource->mDequantize.mScaleBias.get(), Backend::STATIC); - if (!success) { - MNN_ERROR("Alloc denquant scaleBias memory error\n"); - return; - } - auto alphaPtr = resource->mDequantize.mScaleBias->host(); - auto biasPtr = reinterpret_cast(reinterpret_cast(alphaPtr) + ocUp4 * core->bytes); - ::memset(alphaPtr, 0, 2 * ocUp4 * core->bytes); - - std::shared_ptr quantCommon; - // Load quant scale and bias - if (quanBuffer) { - quantCommon = ConvolutionCommon::load(op, backend, false, true); - weightOrigin = quantCommon->weight.get(); // weight before reorder - - int h = quantCommon->alpha.size(); - if (core->bytes == 2) { - if (quantCommon->asymmetric) { - std::unique_ptr tmp(new int16_t[h]); - core->MNNFp32ToLowp(quantCommon->alpha.get(), tmp.get(), h); - for (int i=0; i< h/2; ++i) { - reinterpret_cast(alphaPtr)[i] = tmp[2 * i + 1]; - reinterpret_cast(biasPtr)[i] = tmp[2 * i]; - } - } else { - core->MNNFp32ToLowp(quantCommon->alpha.get(), reinterpret_cast(alphaPtr), h); - } - } else { - if (quantCommon->asymmetric) { - h = h / 2; - for (int i=0; ialpha.get()[2 * i + 1]; - biasPtr[i] = quantCommon->alpha.get()[2 * i]; - } - } else { - for (int i=0; ialpha.get()[i]; - biasPtr[i] = 0.f; - } - } - } - } else { - weightOrigin = resourceInt8->mWeightInt8->host(); - auto wZero = resourceInt8->mWeightQuantZero->host(); // has packed to outputUp4 - auto wScale = resourceInt8->mOriginScale->host(); - int h = ocUp4; - if (core->bytes == 2) { - std::unique_ptr tmp(new int16_t[h]); - core->MNNFp32ToLowp(wScale, tmp.get(), h); - for (int i=0; i< h; ++i) { - reinterpret_cast(alphaPtr)[i] = tmp[i]; - reinterpret_cast(biasPtr)[i] = (-1.f) * wZero[i] * tmp[i]; - } - } else { - for (int i=0; i< h; ++i) { - alphaPtr[i] = wScale[i]; - biasPtr[i] = (-1.f) * wZero[i] * wScale[i]; - } - } - } - - // Compute float weightKernelSum - resource->mWeightKernelSum.reset(Tensor::createDevice({ocUp4 * 4})); - success = resource->backend->onAcquireBuffer(resource->mWeightKernelSum.get(), Backend::STATIC); - if (!success) { - MNN_ERROR("Alloc denquant mWeightKernelSum memory error\n"); - return; - } - auto weightKernelSum = resource->mWeightKernelSum->host(); - for (int i = 0; i < outputCount; ++i) { - int sum = 0; - for (int j = 0; j < LSize; ++j) { - sum = sum + static_cast(weightOrigin[j + i * LSize]); - } - if(core->bytes == 2) { - auto scale = reinterpret_cast(alphaPtr)[i]; - auto bias = reinterpret_cast(biasPtr)[i]; - weightKernelSum[i] = static_cast(sum) * scale + LSize * bias; - } else { - auto scale = alphaPtr[i]; - auto bias = biasPtr[i]; - weightKernelSum[i] = static_cast(sum) * scale + LSize * bias; - } - } -} - -void CPUConvolution::makeResourceNew(Backend* backend, const Convolution2D* conv2d, std::shared_ptr resourceInt8) { - /* Used to compute weight quant scale and bias and weightKernelSum of type float. */ - bool quanBuffer = (conv2d->quanParameter() != nullptr && conv2d->quanParameter()->buffer() != nullptr); - MNN_ASSERT(quanBuffer || resourceInt8); - auto core = static_cast(backend)->functions(); - // common parameters - int outputCount = conv2d->common()->outputCount(); - int LSize = conv2d->common()->inputCount() * conv2d->common()->kernelX() * conv2d->common()->kernelY(); - int ocUp4 = ROUND_UP(outputCount, core->pack); - int8_t* weightOrigin; - - // Save weight quant scale and bias: wf=scale*wi+bias - std::shared_ptr scaleBias(Tensor::createDevice({2 * ocUp4 * core->bytes})); - auto success = backend->onAcquireBuffer(scaleBias.get(), Backend::STATIC); - if (!success) { - MNN_ERROR("Alloc dequant scaleBias memory error\n"); - return; - } - auto alphaPtr = scaleBias->host(); - auto biasPtr = reinterpret_cast(reinterpret_cast(alphaPtr) + ocUp4 * core->bytes); - ::memset(alphaPtr, 0, 2 * ocUp4 * core->bytes); - - // Load quant scale and bias - weightOrigin = resourceInt8->mWeightInt8->host(); - auto wZero = resourceInt8->mWeightQuantZero->host(); // has packed to outputUp4 - auto wScale = resourceInt8->mOriginScale->host(); - int h = ocUp4; - if (core->bytes == 2) { - std::unique_ptr tmp(new int16_t[h]); - core->MNNFp32ToLowp(wScale, tmp.get(), h); - for (int i=0; i< h; ++i) { - reinterpret_cast(alphaPtr)[i] = tmp[i]; - reinterpret_cast(biasPtr)[i] = (-1.f) * wZero[i] * tmp[i]; - } - } else { - for (int i=0; i< h; ++i) { - alphaPtr[i] = wScale[i]; - biasPtr[i] = (-1.f) * wZero[i] * wScale[i]; - } - } - resourceInt8->mOriginScale = scaleBias; - - // Compute float weightKernelSum - resourceInt8->mWeightKernelSum.reset(Tensor::createDevice({ocUp4 * 4})); - success = backend->onAcquireBuffer(resourceInt8->mWeightKernelSum.get(), Backend::STATIC); - if (!success) { - MNN_ERROR("Alloc dequant mWeightKernelSum memory error\n"); - return; - } - auto weightKernelSum = resourceInt8->mWeightKernelSum->host(); - for (int i = 0; i < outputCount; ++i) { - int sum = 0; - for (int j = 0; j < LSize; ++j) { - sum = sum + static_cast(weightOrigin[j + i * LSize]); - } - if(core->bytes == 2) { - auto scale = reinterpret_cast(alphaPtr)[i]; - auto bias = reinterpret_cast(biasPtr)[i]; - weightKernelSum[i] = static_cast(sum) * scale + LSize * bias; - } else { - auto scale = alphaPtr[i]; - auto bias = biasPtr[i]; - weightKernelSum[i] = static_cast(sum) * scale + LSize * bias; - } - } -} - CPUConvolution::CPUConvolution(const Convolution2DCommon *convOp, Backend *b) : MNN::Execution(b), mCommon(convOp) { // Do nothing } diff --git a/source/backend/cpu/CPUConvolution.hpp b/source/backend/cpu/CPUConvolution.hpp index a34f68aad..8975f5963 100644 --- a/source/backend/cpu/CPUConvolution.hpp +++ b/source/backend/cpu/CPUConvolution.hpp @@ -69,12 +69,8 @@ class CPUConvolution : public Execution { bool mRelu; int mActBits; // quant bits - int mOutputCount; bool mUseConvQuan = true; bool mWeightAsymmetricQuant = true; -#ifdef MNN_USE_SSE - std::vector offsets; -#endif // Origin Attributes from net float mInputScale = 0.0f; float mOutputScale = 0.0f; @@ -82,6 +78,7 @@ class CPUConvolution : public Execution { int32_t mOutputZeroPoint; int8_t mClampMin; int8_t mClampMax; + bool mDynamicQuant = false; }; struct MutableResourceInt8 { MutableResourceInt8(std::shared_ptr res, Backend* backend); @@ -100,8 +97,6 @@ class CPUConvolution : public Execution { bool mValid; }; static std::shared_ptr makeResourceInt8(Backend *backend, const MNN::Op *op, int pack=4); - static void makeResource(Backend* backend, std::shared_ptr resource, const MNN::Op *op, std::shared_ptr resourceInt8 = nullptr); - static void makeResourceNew(Backend* backend, const Convolution2D* conv2d, std::shared_ptr resourceInt8); CPUConvolution(const Convolution2DCommon *convOp, Backend *b); virtual ~CPUConvolution() = default; virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; diff --git a/source/backend/cpu/CPUDeconvolution.cpp b/source/backend/cpu/CPUDeconvolution.cpp index 0364ad58e..6a75b3c61 100644 --- a/source/backend/cpu/CPUDeconvolution.cpp +++ b/source/backend/cpu/CPUDeconvolution.cpp @@ -346,7 +346,7 @@ ErrorCode CPUDeconvolutionOrigin::onResize(const std::vector& inputs, c } mPostFunctions.emplace_back(std::make_pair([ocC4, width, height, kh, kw, padY, padX, dilateY, dilateX, strideY, - strideX, threadNumber, src_width, src_height, plane, input, biasTensor, this, core, gcore, batch, outi8, scales, + strideX, threadNumber, src_width, src_height, plane, input, biasTensor, this, core, gcore, batch, outi8, scale, minValue, maxValue, zeroPoint, outputFp32Ptr](uint8_t* outputPtr, int tId) { auto colBufferPtr = mTempOutput->host(); auto biasPtr = biasTensor->host(); @@ -391,7 +391,9 @@ ErrorCode CPUDeconvolutionOrigin::onResize(const std::vector& inputs, c } core->MNNAxByClampBroadcastUnit((float*)dstZ, (float*)dstZ, (const float*)((uint8_t*)biasPtr + unitBytes * z), src_height * src_width * batch, 0, 0, 1, mPostParameters.data()); if (outi8) { - gcore->MNNFloat2Int8((float*)dstZ, (int8_t*)(outputPtr + z * float2Int8_step * core->pack), float2Int8_step, scales.data(), minValue, maxValue, zeroPoint); + float scaleOne = scale; + float zeroOne = zeroPoint; + gcore->MNNFloat2Int8((float*)dstZ, (int8_t*)(outputPtr + z * float2Int8_step * core->pack), float2Int8_step, &scaleOne, minValue, maxValue, &zeroOne, 0); } } }, threadNumber)); diff --git a/source/backend/cpu/CPUDynamicQuant.cpp b/source/backend/cpu/CPUDynamicQuant.cpp index 508d58627..ac4dd40c8 100644 --- a/source/backend/cpu/CPUDynamicQuant.cpp +++ b/source/backend/cpu/CPUDynamicQuant.cpp @@ -46,7 +46,7 @@ ErrorCode CPUDynamicQuant::onExecute(const std::vector &inputs, int pack = core->pack; std::vector qsVec(pack, quantScale); int sizeDiv = UP_DIV(size, pack); - int8core->MNNFloat2Int8(inputPtr, outputPtr, sizeDiv, qsVec.data(), -128, 127, (ssize_t)zeroPoint); + int8core->MNNFloat2Int8(inputPtr, outputPtr, sizeDiv, &quantScale, -128, 127, &zeroPoint, 0); float* scale = outputs[1]->host(); float* zeros = outputs[2]->host(); *scale = dequantScale; diff --git a/source/backend/cpu/CPUFloatToInt8.cpp b/source/backend/cpu/CPUFloatToInt8.cpp index 9a9329e5e..7770377c6 100644 --- a/source/backend/cpu/CPUFloatToInt8.cpp +++ b/source/backend/cpu/CPUFloatToInt8.cpp @@ -36,7 +36,7 @@ CPUFloatToInt8::CPUFloatToInt8(Backend* backend, const MNN::Op* param) : Executi memcpy(mScales->host(), scale->tensorScale()->data(), scaleLen * sizeof(float)); } - mZeroPoint = scale->zeroPoint(); + mZeroPoint = static_cast(scale->zeroPoint()); mClampMin = scale->clampMin(); mClampMax = scale->clampMax(); } @@ -78,7 +78,7 @@ ErrorCode CPUFloatToInt8::onExecute(const std::vector& inputs, const st const auto srcChannelPtr = inputDataPtr + tId * oc4Stride * pack; const auto scaleChannelPtr = scaleDataPtr + z * pack; auto dstChannlePtr = outputDataPtr + tId * oc4Stride * pack; - int8F->MNNFloat2Int8(srcChannelPtr, dstChannlePtr, oc4Stride, scaleChannelPtr, mClampMin, mClampMax, mZeroPoint); + int8F->MNNFloat2Int8(srcChannelPtr, dstChannlePtr, oc4Stride, scaleChannelPtr, mClampMin, mClampMax, &mZeroPoint, 1); } MNN_CONCURRENCY_END(); return NO_ERROR; diff --git a/source/backend/cpu/CPUFloatToInt8.hpp b/source/backend/cpu/CPUFloatToInt8.hpp index 7d26a90db..82ca68efe 100644 --- a/source/backend/cpu/CPUFloatToInt8.hpp +++ b/source/backend/cpu/CPUFloatToInt8.hpp @@ -22,7 +22,7 @@ class CPUFloatToInt8 : public Execution { private: std::shared_ptr mScales; - int8_t mZeroPoint; + float mZeroPoint; int8_t mClampMin; int8_t mClampMax; int mClipBits; diff --git a/source/backend/cpu/CPUImageProcess.cpp b/source/backend/cpu/CPUImageProcess.cpp index 078291c72..37d56b1b8 100644 --- a/source/backend/cpu/CPUImageProcess.cpp +++ b/source/backend/cpu/CPUImageProcess.cpp @@ -15,7 +15,6 @@ #include namespace MNN { -#define CACHE_SIZE 256 ErrorCode CPUImageProcess::onResize(const std::vector &inputs, const std::vector &outputs) { auto input = inputs[0]; diff --git a/source/backend/cpu/CPUProposal.cpp b/source/backend/cpu/CPUProposal.cpp index 84e67cdcb..6cc5ff4a4 100644 --- a/source/backend/cpu/CPUProposal.cpp +++ b/source/backend/cpu/CPUProposal.cpp @@ -16,12 +16,17 @@ #include namespace MNN { -CPUProposal::CPUProposal(Backend *backend, const Proposal *proposal) : Execution(backend), mProposal(proposal) { - auto ratioCount = mProposal->ratios()->float32s()->size(); - auto numScale = mProposal->scales()->float32s()->size(); +CPUProposal::CPUProposal(Backend *backend, const Proposal *proposal) : Execution(backend) { + auto ratioCount = proposal->ratios()->float32s()->size(); + auto numScale = proposal->scales()->float32s()->size(); mAnchors.reset(4 * ratioCount * numScale); + mCache.featStride = proposal->featStride(); + mCache.preNmsTopN = proposal->preNmsTopN(); + mCache.nmsThreshold = proposal->nmsThreshold(); + mCache.afterNmsTopN = proposal->afterNmsTopN(); + mCache.minSize = proposal->minSize(); - auto baseSize = mProposal->baseSize(); + auto baseSize = proposal->baseSize(); const auto cx = baseSize * 0.5f; const auto cy = baseSize * 0.5f; auto ratios = proposal->ratios()->float32s()->data(); @@ -117,11 +122,11 @@ ErrorCode CPUProposal::onExecute(const std::vector &inputs, const std: auto score = inputs[0]; auto boxes = inputs[1]; auto imInfo = inputs[2]; - auto featStride = mProposal->featStride(); - auto preNmsTopN = mProposal->preNmsTopN(); - auto nmsThreshold = mProposal->nmsThreshold(); - auto afterNmsTopN = mProposal->afterNmsTopN(); - auto minSize = mProposal->minSize(); + auto featStride = mCache.featStride; + auto preNmsTopN = mCache.preNmsTopN; + auto nmsThreshold = mCache.nmsThreshold; + auto afterNmsTopN = mCache.afterNmsTopN; + auto minSize = mCache.minSize; float* tmpScorePtr = (float*)mScoreBuffer.ptr(); // download diff --git a/source/backend/cpu/CPUProposal.hpp b/source/backend/cpu/CPUProposal.hpp index f002deb3c..8da27db4a 100644 --- a/source/backend/cpu/CPUProposal.hpp +++ b/source/backend/cpu/CPUProposal.hpp @@ -24,8 +24,15 @@ class CPUProposal : public Execution { virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; + struct ProposalCache { + int32_t featStride; + int32_t preNmsTopN; + int32_t minSize; + int32_t afterNmsTopN; + float nmsThreshold; + }; private: - const Proposal *mProposal; + ProposalCache mCache; AutoStorage mAnchors; MemChunk mScoreBuffer; }; diff --git a/source/backend/cpu/KVCacheManager.cpp b/source/backend/cpu/KVCacheManager.cpp index 7804d3dd5..5fd8c1d37 100644 --- a/source/backend/cpu/KVCacheManager.cpp +++ b/source/backend/cpu/KVCacheManager.cpp @@ -13,7 +13,7 @@ namespace MNN { -// @brief Translate an address to a hex number string +// Translate an address to a hex number string static inline std::string addrToHex(void *addr) { std::string result = ""; uint64_t n = (uint64_t)addr; @@ -106,11 +106,27 @@ void KVCacheManager::unmapKVCache(size_t keySize, size_t valueSize) */ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) { /*=================================== Key ===================================*/ - if (mConfig.mQuantKey) { + if (mConfig.mUseInt8Kernel) { + auto new_key = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), UP_DIV(mHeadDim, lP8), hP8 * lP8}); + mBackend->onAcquireBuffer(new_key, Backend::STATIC); + for (int h = 0; h < mKvNumHead; h++) { + memcpy( + new_key->host() + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, + mPastKey->host() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, + UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8 + ); + } + mPastKey.reset(new_key); + } + else if (mConfig.mQuantKey) { auto new_key = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP}); mBackend->onAcquireBuffer(new_key, Backend::STATIC); for (int h = 0; h < mKvNumHead; h++) { - memcpy(new_key->host() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP, mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP, UP_DIV(oldMaxLength, hP) * mHeadDim * hP); + memcpy( + new_key->host() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP, + mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP, + UP_DIV(oldMaxLength, hP) * mHeadDim * hP + ); } mPastKey.reset(new_key); } @@ -118,7 +134,11 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) { auto new_key = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP}); mBackend->onAcquireBuffer(new_key, Backend::STATIC); for (int h = 0; h < mKvNumHead; h++) { - memcpy(new_key->host() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes, mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes, UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes); + memcpy( + new_key->host() + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes, + mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes, + UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes + ); } mPastKey.reset(new_key); } @@ -128,7 +148,11 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) { mBackend->onAcquireBuffer(new_value, Backend::STATIC); for (int h = 0; h < mKvNumHead; h++) { for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { - memcpy(new_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP, mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP, oldMaxLength * hP); + memcpy( + new_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP, + mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP, + oldMaxLength * hP + ); } } mPastValue.reset(new_value); @@ -138,7 +162,11 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) { mBackend->onAcquireBuffer(new_value, Backend::STATIC); for (int h = 0; h < mKvNumHead; h++) { for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { - memcpy(new_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes, mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes, oldMaxLength * hP * mBytes); + memcpy( + new_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes, + mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes, + oldMaxLength * hP * mBytes + ); } } mPastValue.reset(new_value); @@ -151,16 +179,35 @@ void KVCacheManager::expandKVCacheInMem(int oldMaxLength) { */ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { /*=================================== Key ===================================*/ + if (mConfig.mUseInt8Kernel) { + for (int h = 0; h < mKvNumHead; h++) { + memcpy( + mMapKeyAddr + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, + mPastKey->host() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, + UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8 + ); + } + mBackend->onReleaseBuffer(mPastKey.get(), Backend::STATIC); + mPastKey.reset(); + } if (mConfig.mQuantKey) { for (int h = 0; h < mKvNumHead; h++) { - memcpy(mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP, mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP, UP_DIV(oldMaxLength, hP) * mHeadDim * hP); + memcpy( + mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP, + mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP, + UP_DIV(oldMaxLength, hP) * mHeadDim * hP + ); } mBackend->onReleaseBuffer(mPastKey.get(), Backend::STATIC); mPastKey.reset(); } else { for (int h = 0; h < mKvNumHead; h++) { - memcpy(mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes, mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes, UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes); + memcpy( + mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes, + mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes, + UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes + ); } mBackend->onReleaseBuffer(mPastKey.get(), Backend::STATIC); mPastKey.reset(); @@ -169,7 +216,11 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { if (mConfig.mQuantValue) { for (int h = 0; h < mKvNumHead; h++) { for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { - memcpy(mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP, mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP, oldMaxLength * hP); + memcpy( + mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP, + mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP, + oldMaxLength * hP + ); } } mBackend->onReleaseBuffer(mPastValue.get(), Backend::STATIC); @@ -178,7 +229,11 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { else { for (int h = 0; h < mKvNumHead; h++) { for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { - memcpy(mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes, mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes, oldMaxLength * hP * mBytes); + memcpy( + mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes, + mPastValue->host() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes, + oldMaxLength * hP * mBytes + ); } } mBackend->onReleaseBuffer(mPastValue.get(), Backend::STATIC); @@ -189,14 +244,12 @@ void KVCacheManager::moveKVCacheFromMemToDisk(int oldMaxLength) { /* ** @brief Expand the size of kvcache files in disk */ -void KVCacheManager::expandKVCacheInDisk(int oldMaxLength) { - size_t oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * (mConfig.mQuantKey ? 1 : mBytes); - size_t oldValueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * oldMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes); - size_t keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * (mConfig.mQuantKey ? 1 : mBytes); - size_t valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes); +void KVCacheManager::expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int oldValueSize, int keySize, int valueSize) { // Step 1: Copy the old kvcache from files to temporary buffers in memory std::shared_ptr old_key, old_value; - if (mConfig.mQuantKey) { + if (mConfig.mUseInt8Kernel) { + old_key.reset(Tensor::createDevice({mKvNumHead, UP_DIV(oldMaxLength, hP8), UP_DIV(mHeadDim, lP8), hP8 * lP8})); + } else if (mConfig.mQuantKey) { old_key.reset(Tensor::createDevice({mKvNumHead, UP_DIV(oldMaxLength, hP), mHeadDim, hP})); } else { old_key.reset(Tensor::createDevice({mKvNumHead, UP_DIV(oldMaxLength, hP), mHeadDim, hP})); @@ -216,25 +269,49 @@ void KVCacheManager::expandKVCacheInDisk(int oldMaxLength) { resetKVCacheFileSize(keySize, valueSize); mmapKVCache(keySize, valueSize); // Step 3: Move the kvcache from temporary buffers in memory to disk - if (mConfig.mQuantKey) { + if (mConfig.mUseInt8Kernel) { + for (int h = 0; h < mKvNumHead; h++) { + memcpy( + mMapKeyAddr + h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, + old_key->host() + h * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8, + UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8 + ); + } + } else if (mConfig.mQuantKey) { for (int h = 0; h < mKvNumHead; h++) { - memcpy(mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP, old_key->host() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP, UP_DIV(oldMaxLength, hP) * mHeadDim * hP); + memcpy( + mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP, + old_key->host() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP, + UP_DIV(oldMaxLength, hP) * mHeadDim * hP + ); } } else { for (int h = 0; h < mKvNumHead; h++) { - memcpy(mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes, old_key->host() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes, UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes); + memcpy( + mMapKeyAddr + h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes, + old_key->host() + h * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes, + UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes + ); } } if (mConfig.mQuantValue) { for (int h = 0; h < mKvNumHead; h++) { for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { - memcpy(mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP, old_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP, oldMaxLength * hP); + memcpy( + mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP, + old_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP, + oldMaxLength * hP + ); } } } else { for (int h = 0; h < mKvNumHead; h++) { for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { - memcpy(mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes, old_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes, oldMaxLength * hP * mBytes); + memcpy( + mMapValueAddr + (h * UP_DIV(mHeadDim, hP) + i) * mMaxLength * hP * mBytes, + old_value->host() + (h * UP_DIV(mHeadDim, hP) + i) * oldMaxLength * hP * mBytes, + oldMaxLength * hP * mBytes + ); } } } @@ -253,12 +330,22 @@ void KVCacheManager::onResize(int kv_num_head, int head_dim) { if (mThreadNum > mKvNumHead) { mThreadNum = mKvNumHead; } + if (mConfig.mUseInt8Kernel) { + static_cast(mBackend)->int8Functions()->MNNGetGemmUnit(&hP8, &lP8, &eP8); + } } void KVCacheManager::onAlloc(int kv_seq_len) { mMaxLength = kv_seq_len + mConfig.mExpandChunk; - size_t keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * (mConfig.mQuantKey ? 1 : mBytes); - size_t valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes); + size_t keySize = 0, valueSize = 0; + if (mConfig.mUseInt8Kernel) { + keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8; + } else if (mConfig.mQuantKey) { + keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP; + } else { + keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes; + } + valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes); /*============== Put the kvcache in disk ===========*/ if (mConfig.mKVCacheSizeLimit != -1 && keySize + valueSize > mConfig.mKVCacheSizeLimit) { createKVCacheFile(); @@ -268,7 +355,9 @@ void KVCacheManager::onAlloc(int kv_seq_len) { } /*============== Put the kvcache in memory ===========*/ else { - if (mConfig.mQuantKey) { + if (mConfig.mUseInt8Kernel) { + mPastKey.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), UP_DIV(mHeadDim, lP8), hP8 * lP8})); + } else if (mConfig.mQuantKey) { mPastKey.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP})); } else { mPastKey.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), mHeadDim, hP})); @@ -278,15 +367,22 @@ void KVCacheManager::onAlloc(int kv_seq_len) { } else { mPastValue.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mHeadDim, hP), mMaxLength, hP})); } - mBackend->onAcquireBuffer(mPastKey.get(), Backend::STATIC); - mBackend->onAcquireBuffer(mPastValue.get(), Backend::STATIC); - } - /* No matter where is the kvcache, the scales and zero points are always in memory, since their size is very small */ - if (mConfig.mQuantKey) { - mDequantKeyScale.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), 1, hP})); - mDequantKeyZeroPoint.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), 1, hP})); - mBackend->onAcquireBuffer(mDequantKeyScale.get(), Backend::STATIC); - mBackend->onAcquireBuffer(mDequantKeyZeroPoint.get(), Backend::STATIC); + mBackend->onAcquireBuffer(mPastKey.get(), Backend::STATIC); + mBackend->onAcquireBuffer(mPastValue.get(), Backend::STATIC); + } + // scale, zero point and sum of key for quantization + if (mConfig.mUseInt8Kernel) { + mKeyScale.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8})); + mKeyZeroPoint.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8})); + mKeySum.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8})); + mBackend->onAcquireBuffer(mKeyScale.get(), Backend::STATIC); + mBackend->onAcquireBuffer(mKeyZeroPoint.get(), Backend::STATIC); + mBackend->onAcquireBuffer(mKeySum.get(), Backend::STATIC); + } else if (mConfig.mQuantKey) { + mKeyScale.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), hP})); + mKeyZeroPoint.reset(Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), hP})); + mBackend->onAcquireBuffer(mKeyScale.get(), Backend::STATIC); + mBackend->onAcquireBuffer(mKeyZeroPoint.get(), Backend::STATIC); } } @@ -296,10 +392,19 @@ void KVCacheManager::onRealloc(int kv_seq_len) { } int oldMaxLength = mMaxLength; mMaxLength = kv_seq_len + mConfig.mExpandChunk; - size_t oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * (mConfig.mQuantKey ? 1 : mBytes); - size_t oldValueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * oldMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes); - size_t keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * (mConfig.mQuantKey ? 1 : mBytes); - size_t valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes); + size_t oldKeySize, oldValueSize, keySize, valueSize; + if (mConfig.mUseInt8Kernel) { + oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8; + keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8; + } else if (mConfig.mQuantKey) { + oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * mHeadDim * hP; + keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP; + } else { + oldKeySize = (size_t)mKvNumHead * UP_DIV(oldMaxLength, hP) * mHeadDim * hP * mBytes; + keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes; + } + oldValueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * oldMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes); + valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes); /*==== No limit for kvcache ====*/ if (mConfig.mKVCacheSizeLimit == -1) { expandKVCacheInMem(oldMaxLength); @@ -318,51 +423,100 @@ void KVCacheManager::onRealloc(int kv_seq_len) { } /*==== Last time the kvcache is disk, now it should be in disk too ====*/ else { - expandKVCacheInDisk(oldMaxLength); + expandKVCacheInDisk(oldMaxLength, oldKeySize, oldValueSize, keySize, valueSize); } /* No matter where is the kvcache, the scales and zero points are always in memory, since their size is very small */ - if (mConfig.mQuantKey) { + if (mConfig.mUseInt8Kernel) { + auto new_scale = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8}); + auto new_zeroPoint = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8}); + auto new_sum = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP8), hP8}); + mBackend->onAcquireBuffer(new_scale, Backend::STATIC); + mBackend->onAcquireBuffer(new_zeroPoint, Backend::STATIC); + mBackend->onAcquireBuffer(new_sum, Backend::STATIC); + for (int h = 0; h < mKvNumHead; h++) { + memcpy(new_scale->host() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyScale->host() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); + memcpy(new_zeroPoint->host() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeyZeroPoint->host() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); + memcpy(new_sum->host() + h * UP_DIV(mMaxLength, hP8) * hP8 * 4, mKeySum->host() + h * UP_DIV(oldMaxLength, hP8) * hP8 * 4, UP_DIV(oldMaxLength, hP8) * hP8 * 4); + } + mKeyScale.reset(new_scale); + mKeyZeroPoint.reset(new_zeroPoint); + mKeySum.reset(new_sum); + } else if (mConfig.mQuantKey) { auto new_scale = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), 1, hP}); auto new_zeroPoint = Tensor::createDevice({mKvNumHead, UP_DIV(mMaxLength, hP), 1, hP}); mBackend->onAcquireBuffer(new_scale, Backend::STATIC); mBackend->onAcquireBuffer(new_zeroPoint, Backend::STATIC); for (int h = 0; h < mKvNumHead; h++) { - memcpy(new_scale->host() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mDequantKeyScale->host() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes); - memcpy(new_zeroPoint->host() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mDequantKeyZeroPoint->host() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes); + memcpy(new_scale->host() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyScale->host() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes); + memcpy(new_zeroPoint->host() + h * UP_DIV(mMaxLength, hP) * hP * mBytes, mKeyZeroPoint->host() + h * UP_DIV(oldMaxLength, hP) * hP * mBytes, UP_DIV(oldMaxLength, hP) * hP * mBytes); } - mDequantKeyScale.reset(new_scale); - mDequantKeyZeroPoint.reset(new_zeroPoint); + mKeyScale.reset(new_scale); + mKeyZeroPoint.reset(new_zeroPoint); } } void KVCacheManager::onClear() { if (mKVCacheInDisk) { - size_t oldKeySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * (mConfig.mQuantKey ? 1 : mBytes); - size_t oldValueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes); - unmapKVCache(oldKeySize, oldValueSize); + size_t keySize = 0, valueSize = 0; + if (mConfig.mUseInt8Kernel) { + keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8; + } else if (mConfig.mQuantKey) { + keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP; + } else { + keySize = (size_t)mKvNumHead * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes; + } + valueSize = (size_t)mKvNumHead * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes); + unmapKVCache(keySize, valueSize); removeKVCacheFile(); mKVCacheInDisk = false; } - else { - mPastKey.reset(); - mPastValue.reset(); - } + mPastKey.reset(); + mPastValue.reset(); + mKeyScale.reset(); + mKeyZeroPoint.reset(); + mKeySum.reset(); mMaxLength = mPastLength = 0; } template -static void pack_key(const Tensor* key, char* pack_key, int mPastLength, int seq_len, int mKvNumHead, int mHeadDim, - int hP, int kv_h, bool quantKey, char* scale, char* zero_point, const MNN::CoreFunctions * core) { - if (quantKey) { - int8_t * key_dst = reinterpret_cast(pack_key); - T * scale_dst = reinterpret_cast(scale); - T * zeroPoint_dst = reinterpret_cast(zero_point); +void KVCacheManager::pack_key(const Tensor* key, int seq_len, int kv_h) { + if (mConfig.mUseInt8Kernel) { // [maxlen/hP8, headdim/lP8, hP8, lP8] + int8_t * key_dst = reinterpret_cast(addrOfKey(kv_h)); + float * scale_dst = reinterpret_cast(addrOfScale(kv_h)); + float * zeroPoint_dst = reinterpret_cast(addrOfZeroPoint(kv_h)); + float * sum_dst = reinterpret_cast(addrOfKeySum(kv_h)); + for (int s = 0; s < seq_len; s++) { + T * key_src = key->host() + s * mKvNumHead * mHeadDim + kv_h * mHeadDim; + float minKey = key_src[0]; + float maxKey = key_src[0]; + float sumKey = key_src[0]; + for (int d = 1; d < mHeadDim; d++) { + minKey = ALIMIN(minKey, key_src[d]); + maxKey = ALIMAX(maxKey, key_src[d]); + sumKey += key_src[d]; + } + int out_index = (mPastLength + s) / hP8; + int in_index = (mPastLength + s) % hP8; + scale_dst[out_index * hP8 + in_index] = (maxKey - minKey) / 255.0f; + zeroPoint_dst[out_index * hP8 + in_index] = -255.0f * minKey / (maxKey - minKey) - 128.0; + sum_dst[out_index * hP8 + in_index] = sumKey; + for (int d = 0; d < mHeadDim; d++) { + int i = d / lP8; + int j = d % lP8; + key_dst[out_index * UP_DIV(mHeadDim, lP8) * hP8 * lP8 + i * hP8 * lP8 + in_index * lP8 + j] = roundf((key_src[d] - minKey) / (maxKey - minKey) * 255.0f - 128.0f); + } + } + } + else if (mConfig.mQuantKey) { // [maxlen/hP, headdim, hP] + int8_t * key_dst = reinterpret_cast(addrOfKey(kv_h)); + T * scale_dst = reinterpret_cast(addrOfScale(kv_h)); + T * zeroPoint_dst = reinterpret_cast(addrOfZeroPoint(kv_h)); for (int i = 0; i < seq_len; i++) { T * key_src = key->host() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim; int out_index = (mPastLength + i) / hP; int in_index = (mPastLength + i) % hP; T minKey, maxKey; - core->MNNCountMaxMinValue((float*)key_src, (float*)&minKey, (float*)&maxKey, mHeadDim); + static_cast(mBackend)->functions()->MNNCountMaxMinValue((float*)key_src, (float*)&minKey, (float*)&maxKey, mHeadDim); scale_dst[out_index * hP + in_index] = (maxKey - minKey) / 255.0f; zeroPoint_dst[out_index * hP + in_index] = 128.0f * (maxKey - minKey) / 255.0f + minKey; for (int j = 0; j < mHeadDim; j++) { @@ -370,8 +524,8 @@ static void pack_key(const Tensor* key, char* pack_key, int mPastLength, int seq } } } - else { - T * key_dst = reinterpret_cast(pack_key); + else { // [maxlen/hP, headdim, hP] + T * key_dst = reinterpret_cast(addrOfKey(kv_h)); for (int i = 0; i < seq_len; i++) { T * key_src = key->host() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim; int out_index = (mPastLength + i) / hP; @@ -384,16 +538,16 @@ static void pack_key(const Tensor* key, char* pack_key, int mPastLength, int seq } template -static void pack_value(const Tensor* value, char* pack_value, int mMaxLength, int mPastLength, int seq_len, int mKvNumHead, int mHeadDim, int hP, int kv_h, bool quantValue, const MNN::CoreFunctions * core) { - if (quantValue) { - fp8_t * value_dst = reinterpret_cast(pack_value); +void KVCacheManager::pack_value(const Tensor* value, int seq_len, int kv_h) { // [headdim/hP, maxlen, hP] + if (mConfig.mQuantValue) { + fp8_t * value_dst = reinterpret_cast(addrOfValue(kv_h)); uint8_t * buf = (uint8_t *)MNNMemoryAllocAlign(mHeadDim, MNN_MEMORY_ALIGN_DEFAULT); for (int i = 0; i < seq_len; i++) { T * value_src = value->host() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim; if (sizeof(T) == 2) { - core->MNNFp16ToFp8(buf, (uint16_t*)value_src, mHeadDim); + static_cast(mBackend)->functions()->MNNFp16ToFp8(buf, (uint16_t*)value_src, mHeadDim); } else { - core->MNNFp32ToFp8(buf, (float*)value_src, mHeadDim); + static_cast(mBackend)->functions()->MNNFp32ToFp8(buf, (float*)value_src, mHeadDim); } for (int j = 0; j < mHeadDim; j++) { int out_index = j / hP; @@ -404,7 +558,7 @@ static void pack_value(const Tensor* value, char* pack_value, int mMaxLength, in MNNMemoryFreeAlign(buf); } else { - T * value_dst = reinterpret_cast(pack_value); + T * value_dst = reinterpret_cast(addrOfValue(kv_h)); for (int i = 0; i < seq_len; i++) { T * value_src = value->host() + i * mKvNumHead * mHeadDim + kv_h * mHeadDim; for (int j = 0; j < mHeadDim; j++) { @@ -423,11 +577,11 @@ void KVCacheManager::onPushBack(const Tensor * key, const Tensor * value) { std::function packKV = [=](int tid) { for (int kv_h = tid * tileCount; kv_h < (tid+1) * tileCount && kv_h < mKvNumHead; kv_h++) { if (mBytes == 2) { - pack_key(key, addrOfKey(kv_h), mPastLength, seq_len, mKvNumHead, mHeadDim, hP, kv_h, mConfig.mQuantKey, addrOfScale(kv_h), addrOfZeroPoint(kv_h), core); - pack_value(value, addrOfValue(kv_h), mMaxLength, mPastLength, seq_len, mKvNumHead, mHeadDim, hP, kv_h, mConfig.mQuantValue, core); + pack_key(key, seq_len, kv_h); + pack_value(value, seq_len, kv_h); } else { - pack_key(key, addrOfKey(kv_h), mPastLength, seq_len, mKvNumHead, mHeadDim, hP, kv_h, mConfig.mQuantKey, addrOfScale(kv_h), addrOfZeroPoint(kv_h), core); - pack_value(value, addrOfValue(kv_h), mMaxLength, mPastLength, seq_len, mKvNumHead, mHeadDim, hP, kv_h, mConfig.mQuantValue, core); + pack_key(key, seq_len, kv_h); + pack_value(value, seq_len, kv_h); } } }; diff --git a/source/backend/cpu/KVCacheManager.hpp b/source/backend/cpu/KVCacheManager.hpp index 582481990..c34e25c82 100644 --- a/source/backend/cpu/KVCacheManager.hpp +++ b/source/backend/cpu/KVCacheManager.hpp @@ -29,8 +29,9 @@ namespace MNN { class KVCacheManager : public NonCopyable{ public: struct KVCacheConfig { - bool mQuantKey = false; // Quantize keys to int8 or not - bool mQuantValue = false; // Quantize values to fp8 or not + bool mQuantKey = false; // Quantize keys to int8 or not + bool mQuantValue = false; // Quantize values to fp8 or not + bool mUseInt8Kernel = false; // Whether to use int8 gemm kernel in CPU attention std::string mKVCacheDir = "/tmp"; // Path of the kvcache files in disk size_t mKVCacheSizeLimit = -1; // The limit of the kvcache size int mExpandChunk = 64; // Number of expand chunks when the buffer is full @@ -38,10 +39,11 @@ class KVCacheManager : public NonCopyable{ private: Backend * mBackend; KVCacheConfig mConfig; - std::shared_ptr mPastKey; // numhead, [maxlen/eP, headdim, eP] - std::shared_ptr mPastValue; // numhead, [headdim/eP, maxlen, eP] - std::shared_ptr mDequantKeyScale; // numhead, [maxlen/eP, 1, eP] - std::shared_ptr mDequantKeyZeroPoint; // numhead, [maxlen/eP, 1, eP] + std::shared_ptr mPastKey; // {numhead, [maxlen/hP, headdim, hP]} or {numhead, [maxlen/hP8, headdim/lP8, hP8, lP8]} + std::shared_ptr mPastValue; // numhead, [headdim/hP, maxlen, hP] + std::shared_ptr mKeyScale; // {numhead, [maxlen/hP, hP]} or {numhead, [maxlen/hP8, hP8]} + std::shared_ptr mKeyZeroPoint; // {numhead, [maxlen/hP, hP]} or {numhead, [maxlen/hP8, hP8]} + std::shared_ptr mKeySum; // numhead, [maxlen/hP8, hP8] file_t mKeyCacheFD = INVALID_FILE; // The file descriptor of keys file_t mValueCacheFD = INVALID_FILE; // The file descriptor of values char * mMapKeyAddr = nullptr; // Memory-mapped address of keys @@ -49,8 +51,10 @@ class KVCacheManager : public NonCopyable{ bool mKVCacheInDisk = false; // Whether the kvcache is in disk or in memory now int mPastLength = 0; // Length of past kvcache int mMaxLength = 0; // Capacity of current kvcache buffer (how many kv items can be stored at most) - int eP, lP, hP, mBytes, mThreadNum; - int mKvNumHead = 0, mHeadDim = 0; + int eP, lP, hP; // Packing mode for float matmul + int eP8, lP8, hP8; // Packing mode for int8 gemm kernel + int mBytes = 4, mThreadNum = 1; + int mKvNumHead = 0, mHeadDim = 0; void createKVCacheFile(); void removeKVCacheFile(); void resetKVCacheFileSize(size_t keySize, size_t valueSize); @@ -58,7 +62,9 @@ class KVCacheManager : public NonCopyable{ void unmapKVCache(size_t keySize, size_t valueSize); void expandKVCacheInMem(int oldMaxLength); void moveKVCacheFromMemToDisk(int oldMaxLength); - void expandKVCacheInDisk(int oldMaxLength); + void expandKVCacheInDisk(int oldMaxLength, int oldKeySize, int oldValueSize, int keySize, int valueSize); + template void pack_key(const Tensor* key, int seq_len, int kv_h); + template void pack_value(const Tensor* value, int seq_len, int kv_h); public: KVCacheManager(Backend * backend, KVCacheConfig & kvConfig) { mBackend = backend; @@ -80,10 +86,13 @@ class KVCacheManager : public NonCopyable{ return mPastValue.get(); } const Tensor * scale() { - return mDequantKeyScale.get(); + return mKeyScale.get(); } const Tensor * zeroPoint() { - return mDequantKeyZeroPoint.get(); + return mKeyZeroPoint.get(); + } + const Tensor * keySum() { + return mKeySum.get(); } bool inDisk() { return mKVCacheInDisk; @@ -96,23 +105,46 @@ class KVCacheManager : public NonCopyable{ } char * addrOfKey(int kv_h) { char * baseAddr = mKVCacheInDisk ? mMapKeyAddr : mPastKey->host(); - return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * (mConfig.mQuantKey ? 1 : mBytes); + if (mConfig.mUseInt8Kernel) { + return baseAddr + kv_h * UP_DIV(mMaxLength, hP8) * UP_DIV(mHeadDim, lP8) * hP8 * lP8; + } else if (mConfig.mQuantKey) { + return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * mHeadDim * hP; + } else { + return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * mHeadDim * hP * mBytes; + } } char * addrOfValue(int kv_h) { char * baseAddr = mKVCacheInDisk ? mMapValueAddr : mPastValue->host(); - return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * mMaxLength * hP * (mConfig.mQuantValue ? 1 : mBytes); + if (mConfig.mQuantValue) { + return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * mMaxLength * hP; + } else { + return baseAddr + kv_h * UP_DIV(mHeadDim, hP) * mMaxLength * hP * mBytes; + } } char * addrOfScale(int kv_h) { - if (mConfig.mQuantKey == false) + if (mConfig.mUseInt8Kernel) { + return mKeyScale->host() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; + } else if (mConfig.mQuantKey) { + return mKeyScale->host() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes; + } else { return nullptr; - char * baseAddr = mDequantKeyScale->host(); - return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * 1 * hP * mBytes; + } } char * addrOfZeroPoint(int kv_h) { - if (mConfig.mQuantKey == false) + if (mConfig.mUseInt8Kernel) { + return mKeyZeroPoint->host() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; + } else if (mConfig.mQuantKey) { + return mKeyZeroPoint->host() + kv_h * UP_DIV(mMaxLength, hP) * hP * mBytes; + } else { + return nullptr; + } + } + char * addrOfKeySum(int kv_h) { + if (mConfig.mUseInt8Kernel) { + return mKeySum->host() + kv_h * UP_DIV(mMaxLength, hP8) * hP8 * 4; + }else { return nullptr; - char * baseAddr = mDequantKeyZeroPoint->host(); - return baseAddr + kv_h * UP_DIV(mMaxLength, hP) * 1 * hP * mBytes; + } } void onResize(int kv_num_head, int head_dim); void onAlloc(int kv_seq_len); diff --git a/source/backend/cpu/arm/CMakeLists.txt b/source/backend/cpu/arm/CMakeLists.txt index d23e5adb4..d8d06136c 100644 --- a/source/backend/cpu/arm/CMakeLists.txt +++ b/source/backend/cpu/arm/CMakeLists.txt @@ -15,6 +15,10 @@ if (MNN_LOW_MEMORY) FILE(GLOB MNN_AArch64_SRC ${MNN_AArch64_SRC} ${CMAKE_CURRENT_LIST_DIR}/arm64/low_memory/*.[sS]) endif() +if (MNN_CPU_WEIGHT_DEQUANT_GEMM) + FILE(GLOB MNN_AArch64_SRC ${MNN_AArch64_SRC} ${CMAKE_CURRENT_LIST_DIR}/arm64/normal_memory/*.[sS]) +endif() + if(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv7" OR ARCHS MATCHES "^armv7(;armv7s)?") message(STATUS "Enabling AArch32 Assemblies") add_library(MNNARM32 OBJECT ${MNN_AArch32_SRC} ${MNN_NEON_SRC}) diff --git a/source/backend/cpu/arm/arm32/MNNBGRAToBGRC8.S b/source/backend/cpu/arm/arm32/MNNBGRAToBGRC8.S new file mode 100644 index 000000000..74f47c637 --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNBGRAToBGRC8.S @@ -0,0 +1,33 @@ +// +// MNNBGRAToBGRC8.S +// MNN +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNBGRAToBGRC8 +// void MNNBGRAToBGRC8(const unsigned char* source, unsigned char* dest, size_t count); +// Auto Load: r0: source, r1: dest, r2: count + +push {lr} + +L1: +vld4.8 {d0, d1, d2, d3}, [r0]! +vst3.8 {d0, d1, d2}, [r1]! +subs r2, r2, #1 +bne L1 + +End: +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNBGRAToGRAYFast.S b/source/backend/cpu/arm/arm32/MNNBGRAToGRAYFast.S new file mode 100644 index 000000000..30d53059c --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNBGRAToGRAYFast.S @@ -0,0 +1,43 @@ +// +// MNNBGRAToGRAYFast.S +// MNN +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNBGRAToGRAYFast +// void MNNBGRAToGRAYFast(const unsigned char* source, unsigned char* dest, size_t count); +// Auto Load: r0: source, r1: dest, r2: count + +push {lr} + +vmov.i8 d4, #7 +vmov.i8 d5, #38 +vmov.i8 d6, #19 + +L1: +vld4.8 {d0, d1, d2, d3}, [r0]! +vmull.u16 q4, d0, d4 // b*7 +vmlal.u16 q4, d1, d5 // g*38 +vmlal.u16 q4, d2, d6 // r*19 + +vqshrn.u16 d8, q4, #6 +vst1.u8 {d8}, [r1]! + +subs r2, r2, #1 +bne L1 + +End: +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNBGRToBGR555Fast.S b/source/backend/cpu/arm/arm32/MNNBGRToBGR555Fast.S new file mode 100644 index 000000000..c2c48546c --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNBGRToBGR555Fast.S @@ -0,0 +1,46 @@ +// +// MNNBGRToBGR555Fast.S +// MNN +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNBGRToBGR555Fast +// void MNNBGRToBGR555Fast(const unsigned char* source, unsigned char* dest, size_t count); +// Auto Load: r0: source, r1: dest, r2: count + +push {lr} + +vmov.s8 q15, #8 +vneg.s8 q15, q15 + +L1: +vld3.8 {d0, d1, d2}, [r0]! +vand.u8 d2, d2, d30 // r & ~7 +vand.u8 d1, d1, d30 // g & ~7 +vshr.u8 d0, d0, #3 // b >> 3 +vshll.u8 q2, d2, #7 +vshll.u8 q3, d1, #2 +vmovl.u8 q8, d0 +vorr.u8 q2, q2, q3 +vorr.u8 q2, q2, q8 + +vst1.16 {q2}, [r1]! + +subs r2, r2, #1 +bne L1 + +End: +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNBGRToBGR565Fast.S b/source/backend/cpu/arm/arm32/MNNBGRToBGR565Fast.S new file mode 100644 index 000000000..db21624b7 --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNBGRToBGR565Fast.S @@ -0,0 +1,51 @@ +// +// MNNBGRToBGR565Fast.S +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNBGRToBGR565Fast +// void MNNBGRToBGR565Fast(const unsigned char* source, unsigned char* dest, size_t count); +// Auto Load: r0: source, r1: dest, r2: count + + +push {lr} +vmov.s8 q15, #8 +vneg.s8 q15, q15 +vmov.s8 q14, #4 +vneg.s8 q14, q14 + +L1: +vld3.8 {d0, d1, d2}, [r0]! // b, g, r + +vand.u8 d2, d2, d30 // r & ~7 +vand.u8 d1, d1, d28 // g & ~3 +vshr.u8 d0, d0, #3 // b >> 3 + +vshll.u8 q2, d2, #7 +vshl.u8 q2, q2, #1 +vshll.u8 q3, d1, #3 +vmovl.u8 q8, d0 + +vorr.u8 q2, q2, q3 +vorr.u8 q2, q2, q8 + +vst1.16 {q2}, [r1]! + +subs r2, r2, #1 +bne L1 + +End: +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNBGRToGRAYFast.S b/source/backend/cpu/arm/arm32/MNNBGRToGRAYFast.S new file mode 100644 index 000000000..0fb87d2a9 --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNBGRToGRAYFast.S @@ -0,0 +1,46 @@ +// +// MNNBGRToGRAYFast.S +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNBGRToGRAYFast +// void MNNBGRToGRAYFast(const unsigned char* source, unsigned char* dest, size_t count); +// Auto Load: r0: source, r1: dest, r2: count + +// b*7 +// g*38 +// r*19 + +push {lr} + +vmov.i8 d4, #7 +vmov.i8 d5, #38 +vmov.i8 d6, #19 + +L1: +vld3.8 {d0, d1, d2}, [r0]! // b,g,r +vmull.u8 q8, d0, d4 +vmlal.u8 q8, d1, d5 +vmlal.u8 q8, d2, d6 + +vqshrn.u16 d16, q8, #6 +vst1.8 {d16}, [r1]! + +subs r2, r2, #1 +bne L1 + +End: +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNC3ToC4Fast.S b/source/backend/cpu/arm/arm32/MNNC3ToC4Fast.S new file mode 100644 index 000000000..ff9e20724 --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNC3ToC4Fast.S @@ -0,0 +1,34 @@ +// +// MNNC3ToC4Fast.S +// MNN +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNC3ToC4Fast +// void MNNC3ToC4Fast(const unsigned char* source, unsigned char* dest, size_t count); +// Auto Load: r0: source, r1: dest, r2: count + +push {lr} + +vmov.i8 d3, #255 +L1: +vld3.8 {d0, d1, d2}, [r0]! +vst4.u8 {d0, d1, d2, d3}, [r1]! +subs r2, r2, #1 +bne L1 + +End: +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNC3ToXYZFast.S b/source/backend/cpu/arm/arm32/MNNC3ToXYZFast.S new file mode 100644 index 000000000..08e6a6c53 --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNC3ToXYZFast.S @@ -0,0 +1,95 @@ +// +// MNNC3ToXYZFast.S +// MNN +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNC3ToXYZFast +// void MNNC3ToXYZFast(const unsigned char* source, unsigned char* dest, size_t count, int32_t* c); +// Auto Load: r0: source, r1: dest, r2: count, r3: c + +push {lr} +vpush {q4-q7} + +// q4-q6, const +vld1.32 {d8[0]}, [r3]! // C0 +vld1.32 {d8[1]}, [r3]! // C1 +vld1.32 {d9[0]}, [r3]! // C2 +vld1.32 {d9[1]}, [r3]! // C3 +vld1.32 {d10[0]}, [r3]! // C4 +vld1.32 {d10[1]}, [r3]! // C5 +vld1.32 {d11[0]}, [r3]! // C6 +vld1.32 {d11[1]}, [r3]! // C7 +vld1.32 {d12[0]}, [r3]! // C8 + +vmov.u16 q15, #128 + +L1: +vld3.8 {d0, d1, d2}, [r0]! +vmovl.u8 q2, d0 // r: uint8_t -> uint16_t +vmovl.u8 q3, d1 +vmovl.u8 q13, d2 + +vmovl.u16 q7, d4 // r +vmovl.u16 q8, d5 // r +vmovl.u16 q9, d6 // g +vmovl.u16 q10, d7 // g +vmovl.u16 q11, d26 // b +vmovl.u16 q12, d27 // b + +// r*C0, g*C1, b*C2 +vmul.s32 q0, q7, d8[0] +vmul.s32 q1, q8, d8[0] +vmla.s32 q0, q9, d8[1] +vmla.s32 q1, q10, d8[1] +vmla.s32 q0, q11, d9[0] +vmla.s32 q1, q12, d9[0] + +// r*C3, g*C4, b*C5 +vmul.s32 q2, q7, d9[1] +vmul.s32 q3, q8, d9[1] +vmla.s32 q2, q9, d10[0] +vmla.s32 q3, q10, d10[0] +vmla.s32 q2, q11, d10[1] +vmla.s32 q3, q12, d10[1] + +// r*C6, g*C7, b*C8 +vmul.s32 q13, q7, d11[0] +vmul.s32 q14, q8, d11[0] +vmla.s32 q13, q9, d11[1] +vmla.s32 q14, q10, d11[1] +vmla.s32 q13, q11, d12[0] +vmla.s32 q14, q12, d12[0] + +vrshrn.u32 d0, q0, #12 +vrshrn.u32 d1, q1, #12 +vrshrn.u32 d2, q2, #12 +vrshrn.u32 d3, q3, #12 +vrshrn.u32 d4, q13, #12 +vrshrn.u32 d5, q14, #12 + +vqmovn.u16 d0, q0 +vqmovn.u16 d1, q1 +vqmovn.u16 d2, q2 + +vst3.8 {d0, d1, d2}, [r1]! + +subs r2, r2, #1 +bne L1 + +End: +vpop {q4-q7} +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNC3ToYUVFast.S b/source/backend/cpu/arm/arm32/MNNC3ToYUVFast.S new file mode 100644 index 000000000..fb37aea9b --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNC3ToYUVFast.S @@ -0,0 +1,98 @@ +// +// MNNC3ToYUVFast.S +// MNN +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNC3ToYUVFast +// void MNNC3ToYUVFast(const unsigned char* source, unsigned char* dest, size_t count, int32_t* c); +// Auto Load: r0: source, r1: dest, r2: count, r3: c + +push {lr} +vpush {q4-q7} + +// q4-q6, const +vld1.32 {d8[0]}, [r3]! // C0 +vld1.32 {d8[1]}, [r3]! // C1 +vld1.32 {d9[0]}, [r3]! // C2 +vld1.32 {d9[1]}, [r3]! // C3 +vld1.32 {d10[0]}, [r3]! // C4 +vld1.32 {d10[1]}, [r3]! // C5 +vld1.32 {d11[0]}, [r3]! // C6 +vld1.32 {d11[1]}, [r3]! // C7 +vld1.32 {d12[0]}, [r3]! // C8 + +vmov.u16 q15, #128 + +L1: +vld3.8 {d0, d1, d2}, [r0]! +vmovl.u8 q2, d0 // r: uint8_t -> uint16_t +vmovl.u8 q3, d1 +vmovl.u8 q13, d2 + +vmovl.u16 q7, d4 // r +vmovl.u16 q8, d5 // r +vmovl.u16 q9, d6 // g +vmovl.u16 q10, d7 // g +vmovl.u16 q11, d26 // b +vmovl.u16 q12, d27 // b + +// r*C0, g*C1, b*C2 +vmul.s32 q0, q7, d8[0] +vmul.s32 q1, q8, d8[0] +vmla.s32 q0, q9, d8[1] +vmla.s32 q1, q10, d8[1] +vmla.s32 q0, q11, d9[0] +vmla.s32 q1, q12, d9[0] + +// r*C3, g*C4, b*C5 +vmul.s32 q2, q7, d9[1] +vmul.s32 q3, q8, d9[1] +vmla.s32 q2, q9, d10[0] +vmla.s32 q3, q10, d10[0] +vmla.s32 q2, q11, d10[1] +vmla.s32 q3, q12, d10[1] + +// r*C6, g*C7, b*C8 +vmul.s32 q13, q7, d11[0] +vmul.s32 q14, q8, d11[0] +vmla.s32 q13, q9, d11[1] +vmla.s32 q14, q10, d11[1] +vmla.s32 q13, q11, d12[0] +vmla.s32 q14, q12, d12[0] + +vrshrn.u32 d0, q0, #14 +vrshrn.u32 d1, q1, #14 +vrshrn.u32 d2, q2, #14 +vrshrn.u32 d3, q3, #14 +vrshrn.u32 d4, q13, #14 +vrshrn.u32 d5, q14, #14 + +vadd.u16 q1, q1, q15 +vadd.u16 q2, q2, q15 + +vqmovn.u16 d0, q0 +vqmovn.u16 d1, q1 +vqmovn.u16 d2, q2 + +vst3.8 {d0, d1, d2}, [r1]! + +subs r2, r2, #1 +bne L1 + +End: +vpop {q4-q7} +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNFloat2Int8.S b/source/backend/cpu/arm/arm32/MNNFloat2Int8.S index b8702173c..07446d42e 100644 --- a/source/backend/cpu/arm/arm32/MNNFloat2Int8.S +++ b/source/backend/cpu/arm/arm32/MNNFloat2Int8.S @@ -22,26 +22,49 @@ vcvt.s32.f32 \x, q13 .endm asm_function MNNFloat2Int8 -//void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, float* scale, ssize_t aMin, ssize_t aMax, ssize_t zeroPoint); -//r0:src, r1:dst, r2:sizeQuad, r3:scale, r4:aMin, r5:aMax, r6:zeroPoint - +//void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, float* scale, ssize_t aMin, ssize_t aMax, float* zeroPoint, ssize_t quanParamVec); +// Auto load: r0:src, r1:dst, r2:sizeQuad, r3:scale +// Load from sp: aMin, aMax, lr: zeroPoint, r12: quanParamVec push {lr} vmov.f32 q10, #0.5 vmov.f32 q11, #-0.5 - -ldr r12, [sp, #4] -vld1.32 {q15}, [r3] +vmov.s32 q1, #1 +// scale +vld1.32 {d30[0]}, [r3] +vdup.32 q15, d30[0] // min +ldr r12, [sp, #4] vdup.s8 d28, r12 // max ldr r12, [sp, #8] vdup.s8 d29, r12 // zeropoint -ldr r12, [sp, #12] -vdup.s32 q9, r12 -vcvt.f32.s32 q9, q9 - +ldr lr, [sp, #12] +vld1.32 {d18[0]}, [lr] +vdup.32 q9, d18[0] + +// quanParamVec +ldr r12, [sp, #16] +cmp r12, #3 +bne LOAD_VEC_ZERO +vld1.f32 {q9}, [lr] // load vector zero +vld1.f32 {q15}, [r3] // load vector scale +b COMPUTE + +LOAD_VEC_ZERO: +cmp r12, #2 +bne LOAD_VEC_SCALE +vld1.f32 {q9}, [lr] // load vector zero +b COMPUTE + +LOAD_VEC_SCALE: +cmp r12, #1 +bne COMPUTE +vld1.f32 {q15}, [r3] // load vector scale + + +COMPUTE: cmp r2, #3 ble FL1 diff --git a/source/backend/cpu/arm/arm32/MNNGRAYToC3Fast.S b/source/backend/cpu/arm/arm32/MNNGRAYToC3Fast.S new file mode 100644 index 000000000..401b0b009 --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNGRAYToC3Fast.S @@ -0,0 +1,35 @@ +// +// MNNGRAYToC3Fast.S +// MNN +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNGRAYToC3Fast +// void MNNGRAYToC3Fast(const unsigned char* source, unsigned char* dest, size_t count); +// Auto Load: r0: source, r1: dest, r2: count + +push {lr} + +L1: +vld1.8 {d0}, [r0]! +vmov d1, d0 +vmov d2, d0 +vst3.u8 {d0, d1, d2}, [r1]! +subs r2, r2, #1 +bne L1 + +End: +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNGRAYToC4Fast.S b/source/backend/cpu/arm/arm32/MNNGRAYToC4Fast.S new file mode 100644 index 000000000..aaca8f5b1 --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNGRAYToC4Fast.S @@ -0,0 +1,36 @@ +// +// MNNGRAYToC4Fast.S +// MNN +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNGRAYToC4Fast +// void MNNGRAYToC4Fast(const unsigned char* source, unsigned char* dest, size_t count); +// Auto Load: r0: source, r1: dest, r2: count + +push {lr} + +vmov.i8 d3, #255 +L1: +vld1.8 {d0}, [r0]! +vmov d1, d0 +vmov d2, d0 +vst4.u8 {d0, d1, d2, d3}, [r1]! +subs r2, r2, #1 +bne L1 + +End: +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit.S b/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit.S index 72ff71423..8b62af530 100644 --- a/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit.S +++ b/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit.S @@ -51,7 +51,7 @@ ldr r8, [r6, #0] ldr lr, [r6, #4] vpush {q4-q7} - +sub sp, sp, #36 ldr r7, [r6, #16] // r7: useInt8 @@ -418,6 +418,7 @@ L1LoopCheck: bne L1LoopDz End: +add sp, sp, #36 vpop {q4-q7} pop {r4-r8, r10, pc} diff --git a/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S b/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S index 25c9e5359..8d9d0ef63 100644 --- a/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S +++ b/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S @@ -42,6 +42,7 @@ ldr r8, [r6, #0] ldr lr, [r6, #4] vpush {q4-q7} +sub sp, sp, #36 // Only int8 output use this kernel. @@ -301,6 +302,7 @@ L1LoopCheck: bne L1LoopDz End: +add sp, sp, #36 vpop {q4-q7} pop {r4-r8, r10, pc} diff --git a/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_w4_Unit.S b/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_w4_Unit.S index f7988025b..0e3966b9e 100644 --- a/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_w4_Unit.S +++ b/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_w4_Unit.S @@ -51,7 +51,7 @@ ldr r8, [r6, #0] ldr lr, [r6, #4] vpush {q4-q7} - +sub sp, sp, #36 // Branch1: input is int8_t, output is float32, DO NOT USE "scale". // Branch2: input is int8_t, output is float32. USE "scale", DO NOT USE "minValue" and "maxValue". // Branch3: input is int8_t, output is int8_t. USE "scale", "minValue" and "maxValue". @@ -398,6 +398,7 @@ L1LoopCheck: bne L1LoopDz End: +add sp, sp, #36 vpop {q4-q7} pop {r4-r8, r10, pc} diff --git a/source/backend/cpu/arm/arm32/MNNRGBAToBGRAFast.S b/source/backend/cpu/arm/arm32/MNNRGBAToBGRAFast.S new file mode 100644 index 000000000..5eb583031 --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNRGBAToBGRAFast.S @@ -0,0 +1,38 @@ +// +// MNNRGBAToBGRAFast.S +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNRGBAToBGRAFast +// void MNNRGBAToBGRAFast(const unsigned char* source, unsigned char* dest, size_t count); +// Auto Load: r0: source, r1: dest, r2: count + +push {lr} + +L1: +vld4.8 {d0, d1, d2, d3}, [r0]! // r,g,b,a + +// swap d0,d2 +vmov.32 d4, d2 +vmov.32 d2, d0 +vmov.32 d0, d4 +vst4.8 {d0, d1, d2, d3}, [r1]! + +subs r2, r2, #1 +bne L1 + +End: +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNRGBAToBGRFast.S b/source/backend/cpu/arm/arm32/MNNRGBAToBGRFast.S new file mode 100644 index 000000000..5a709f900 --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNRGBAToBGRFast.S @@ -0,0 +1,38 @@ +// +// MNNRGBAToBGRFast.S +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNRGBAToBGRFast +// void MNNRGBAToBGRFast(const unsigned char* source, unsigned char* dest, size_t count); +// Auto Load: r0: source, r1: dest, r2: count + +push {lr} + +L1: +vld4.8 {d0, d1, d2, d3}, [r0]! // r,g,b,a + +// swap d0,d2 +vmov.32 d4, d2 +vmov.32 d2, d0 +vmov.32 d0, d4 +vst3.8 {d0, d1, d2}, [r1]! + +subs r2, r2, #1 +bne L1 + +End: +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNRGBAToGRAYFast.S b/source/backend/cpu/arm/arm32/MNNRGBAToGRAYFast.S new file mode 100644 index 000000000..d54f02a59 --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNRGBAToGRAYFast.S @@ -0,0 +1,43 @@ +// +// MNNRGBAToGRAYFast.S +// MNN +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNRGBAToGRAYFast +// void MNNRGBAToGRAYFast(const unsigned char* source, unsigned char* dest, size_t count); +// Auto Load: r0: source, r1: dest, r2: count + +push {lr} + +vmov.i8 d4, #7 +vmov.i8 d5, #38 +vmov.i8 d6, #19 + +L1: +vld4.8 {d0, d1, d2, d3}, [r0]! +vmull.u8 q8, d2, d4 // b*7 +vmlal.u8 q8, d1, d5 // g*38 +vmlal.u8 q8, d0, d6 // r*19 + +vqshrn.u16 d16, q8, #6 +vst1.8 {d16}, [r1]! + +subs r2, r2, #1 +bne L1 + +End: +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNRGBToBGR555Fast.S b/source/backend/cpu/arm/arm32/MNNRGBToBGR555Fast.S new file mode 100644 index 000000000..ce328ea1d --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNRGBToBGR555Fast.S @@ -0,0 +1,46 @@ +// +// MNNRGBToBGR555Fast.S +// MNN +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNRGBToBGR555Fast +// void MNNRGBToBGR555Fast(const unsigned char* source, unsigned char* dest, size_t count); +// Auto Load: r0: source, r1: dest, r2: count + +push {lr} + +vmov.s8 q15, #8 +vneg.s8 q15, q15 + +L1: +vld3.8 {d0, d1, d2}, [r0]! +vand.u8 d0, d0, d30 // r & ~7 +vand.u8 d1, d1, d30 // g & ~7 +vshr.u8 d2, d2, #3 // b >> 3 +vshll.u8 q2, d0, #7 +vshll.u8 q3, d1, #2 +vmovl.u8 q8, d2 +vorr.u8 q2, q2, q3 +vorr.u8 q2, q2, q8 + +vst1.16 {q2}, [r1]! + +subs r2, r2, #1 +bne L1 + +End: +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNRGBToBGR565Fast.S b/source/backend/cpu/arm/arm32/MNNRGBToBGR565Fast.S new file mode 100644 index 000000000..2cc804876 --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNRGBToBGR565Fast.S @@ -0,0 +1,54 @@ +// +// MNNRGBToBGR565Fast.S +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNRGBToBGR565Fast +// void MNNRGBToBGR565Fast(const unsigned char* source, unsigned char* dest, size_t count); +// Auto Load: r0: source, r1: dest, r2: count + +// b*7 +// g*38 +// r*19 + +push {lr} +vmov.s8 q15, #8 +vneg.s8 q15, q15 +vmov.s8 q14, #4 +vneg.s8 q14, q14 + +L1: +vld3.8 {d0, d1, d2}, [r0]! // r,g,b + +vand.u8 d0, d0, d30 // r & ~7 +vand.u8 d1, d1, d28 // g & ~3 +vshr.u8 d2, d2, #3 // b >> 3 + +vshll.u8 q2, d0, #7 +vshl.u8 q2, q2, #1 +vshll.u8 q3, d1, #3 +vmovl.u8 q8, d2 + +vorr.u8 q2, q2, q3 +vorr.u8 q2, q2, q8 + +vst1.16 {q2}, [r1]! + +subs r2, r2, #1 +bne L1 + +End: +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNRGBToBGRC8.S b/source/backend/cpu/arm/arm32/MNNRGBToBGRC8.S new file mode 100644 index 000000000..f097b94bf --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNRGBToBGRC8.S @@ -0,0 +1,36 @@ +// +// MNNRGBToBGRC8.S +// MNN +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNRGBToBGRC8 +// void MNNRGBToBGRC8(const unsigned char* source, unsigned char* dest, size_t count); +// Auto Load: r0: source, r1: dest, r2: count + +push {lr} + +L1: +vld3.8 {d0, d1, d2}, [r0]! +vmov d3, d2 +vmov d4, d1 +vmov d5, d0 +vst3.8 {d3, d4, d5}, [r1]! +subs r2, r2, #1 +bne L1 + +End: +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNRGBToGRAYFast.S b/source/backend/cpu/arm/arm32/MNNRGBToGRAYFast.S new file mode 100644 index 000000000..258cb8892 --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNRGBToGRAYFast.S @@ -0,0 +1,43 @@ +// +// MNNRGBToGRAYFast.S +// MNN +// +// Created by MNN on 2024/08/28. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNRGBToGRAYFast +// void MNNRGBToGRAYFast(const unsigned char* source, unsigned char* dest, size_t count); +// Auto Load: r0: source, r1: dest, r2: count + +push {lr} + +vmov.i8 d4, #7 +vmov.i8 d5, #38 +vmov.i8 d6, #19 + +L1: +vld3.8 {d0, d1, d2}, [r0]! +vmull.u8 q8, d2, d4 // b*7 +vmlal.u8 q8, d1, d5 // g*38 +vmlal.u8 q8, d0, d6 // r*19 + +vqshrn.u16 d16, q8, #6 +vst1.8 {d16}, [r1]! + +subs r2, r2, #1 +bne L1 + +End: +pop {pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm64/MNNBGRAToBGR.S b/source/backend/cpu/arm/arm64/MNNBGRAToBGR.S new file mode 100644 index 000000000..14a684fdf --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNBGRAToBGR.S @@ -0,0 +1,129 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// void MNNBGRAToBGRC8(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNBGRAToBGRC8 +// x0: source, x1: dest, x2: count +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +L12: +cmp x2, #12 +blt L8 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +ld4 {v4.16b, v5.16b, v6.16b, v7.16b}, [x0], #64 +ld4 {v8.16b, v9.16b, v10.16b, v11.16b}, [x0], #64 +ld4 {v12.16b, v13.16b, v14.16b, v15.16b}, [x0], #64 +ld4 {v28.16b, v29.16b, v30.16b, v31.16b}, [x0], #64 +sub x2, x2, #12 +mov v16.16b, v0.16b +mov v17.16b, v1.16b +mov v18.16b, v2.16b +mov v19.16b, v4.16b +mov v20.16b, v5.16b +mov v21.16b, v6.16b + +mov v22.16b, v8.16b +mov v23.16b, v9.16b +mov v24.16b, v10.16b +mov v25.16b, v12.16b +mov v26.16b, v13.16b +mov v27.16b, v14.16b + +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 + +mov v4.16b, v28.16b +mov v5.16b, v29.16b +mov v6.16b, v30.16b +mov v8.16b, v0.16b +mov v9.16b, v1.16b +mov v10.16b, v2.16b + + +st3 {v16.16b, v17.16b, v18.16b}, [x1], #48 +st3 {v19.16b, v20.16b, v21.16b}, [x1], #48 +st3 {v22.16b, v23.16b, v24.16b}, [x1], #48 +st3 {v25.16b, v26.16b, v27.16b}, [x1], #48 +st3 {v4.16b, v5.16b, v6.16b}, [x1], #48 +st3 {v8.16b, v9.16b, v10.16b}, [x1], #48 + +b L12 + + +L8: +cmp x2, #8 +blt L4 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +ld4 {v4.16b, v5.16b, v6.16b, v7.16b}, [x0], #64 +ld4 {v8.16b, v9.16b, v10.16b, v11.16b}, [x0], #64 +ld4 {v12.16b, v13.16b, v14.16b, v15.16b}, [x0], #64 +sub x2, x2, #8 +mov v16.16b, v0.16b +mov v17.16b, v1.16b +mov v18.16b, v2.16b +mov v19.16b, v4.16b +mov v20.16b, v5.16b +mov v21.16b, v6.16b + +mov v22.16b, v8.16b +mov v23.16b, v9.16b +mov v24.16b, v10.16b +mov v25.16b, v12.16b +mov v26.16b, v13.16b +mov v27.16b, v14.16b + +st3 {v16.16b, v17.16b, v18.16b}, [x1], #48 +st3 {v19.16b, v20.16b, v21.16b}, [x1], #48 +st3 {v22.16b, v23.16b, v24.16b}, [x1], #48 +st3 {v25.16b, v26.16b, v27.16b}, [x1], #48 +b L8 + +L4: +cmp x2, #4 +blt L2 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +ld4 {v6.16b, v7.16b, v8.16b, v9.16b}, [x0], #64 +sub x2, x2, #4 +mov v10.16b, v0.16b +mov v11.16b, v1.16b +mov v12.16b, v2.16b +mov v13.16b, v6.16b +mov v14.16b, v7.16b +mov v15.16b, v8.16b + +st3 {v10.16b, v11.16b, v12.16b}, [x1], #48 +st3 {v13.16b, v14.16b, v15.16b}, [x1], #48 +b L4 + +L2: +cmp x2, #2 +blt L1 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +mov v4.16b, v0.16b +mov v5.16b, v1.16b +mov v6.16b, v2.16b +sub x2, x2, #2 +st3 {v4.16b, v5.16b, v6.16b}, [x1], #48 +b L2 + +L1: +cmp x2, #1 +blt End +ld4 {v0.8b, v1.8b, v2.8b, v3.8b}, [x0], #32 +mov v5.8b, v0.8b +mov v6.8b, v1.8b +mov v7.8b, v2.8b +st3 {v5.8b, v6.8b, v7.8b}, [x1], #24 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNBGRAToGRAY.S b/source/backend/cpu/arm/arm64/MNNBGRAToGRAY.S new file mode 100644 index 000000000..edf9f80fd --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNBGRAToGRAY.S @@ -0,0 +1,92 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +// void MNNBGRAToGRAYFast(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNBGRAToGRAYFast +// x0: source, x1: dest, x2: count +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +movi v29.16b, #7 +movi v30.16b, #38 +movi v31.16b, #19 + +L4: +cmp x2, #4 +blt L2 + +sub x2, x2, #4 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +ld4 {v14.16b, v15.16b, v16.16b, v17.16b}, [x0], #64 + +umull v4.8h, v0.8b, v29.8b // b*7 +umlal v4.8h, v1.8b, v30.8b // g*38 +umlal v4.8h, v2.8b, v31.8b // r*19 + +umull2 v7.8h, v0.16b, v29.16b // b*7 +umlal2 v7.8h, v1.16b, v30.16b // g*38 +umlal2 v7.8h, v2.16b, v31.16b // r*19 + +umull v18.8h, v14.8b, v29.8b // b*7 +umlal v18.8h, v15.8b, v30.8b // g*38 +umlal v18.8h, v16.8b, v31.8b // r*19 + +umull2 v21.8h, v14.16b, v29.16b // b*7 +umlal2 v21.8h, v15.16b, v30.16b // g*38 +umlal2 v21.8h, v16.16b, v31.16b // r*19 + +uqshrn v4.8b, v4.8h, #6 +uqshrn2 v4.16b, v7.8h, #6 +uqshrn v5.8b, v18.8h, #6 +uqshrn2 v5.16b, v21.8h, #6 + +st1 {v4.16b, v5.16b}, [x1], #32 +b L4 + +L2: +cmp x2, #2 +blt L1 + +sub x2, x2, #2 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 + +umull v4.8h, v0.8b, v29.8b // b*7 +umlal v4.8h, v1.8b, v30.8b // g*38 +umlal v4.8h, v2.8b, v31.8b // r*19 + +umull2 v7.8h, v0.16b, v29.16b // b*7 +umlal2 v7.8h, v1.16b, v30.16b // g*38 +umlal2 v7.8h, v2.16b, v31.16b // r*19 + +uqshrn v4.8b, v4.8h, #6 +uqshrn2 v4.16b, v7.8h, #6 + +st1 {v4.16b}, [x1], #16 +b L2 + +L1: +cmp x2, #1 +blt End +ld4 {v0.8b, v1.8b, v2.8b, v3.8b}, [x0], #32 + +umull v4.8h, v0.8b, v29.8b // b*7 +umlal v4.8h, v1.8b, v30.8b // g*38 +umlal v4.8h, v2.8b, v31.8b // r*19 + +uqshrn v10.8b, v4.8h, #6 + +st1 {v10.8b}, [x1], #8 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNBGRToBGR555.S b/source/backend/cpu/arm/arm64/MNNBGRToBGR555.S new file mode 100644 index 000000000..d4c8ddcd7 --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNBGRToBGR555.S @@ -0,0 +1,169 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +// void MNNBGRToBGR555Fast(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNBGRToBGR555Fast +// x0: source, x1: dest, x2: count, x3: c +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +movi v31.16b, #8 +neg v31.16b, v31.16b + +L6: +cmp x2, #6 +blt L4 + +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v11.16b, v12.16b, v13.16b}, [x0], #48 +ld3 {v24.16b, v25.16b, v26.16b}, [x0], #48 +and v2.16b, v2.16b, v31.16b // r & ~7 +and v1.16b, v1.16b, v31.16b // g & ~7 +ushr v0.16b, v0.16b, #3 // b >> 3 +and v13.16b, v13.16b, v31.16b // r & ~7 +and v12.16b, v12.16b, v31.16b // g & ~7 +ushr v11.16b, v11.16b, #3 // b >> 3 +and v26.16b, v26.16b, v31.16b // r & ~7 +and v25.16b, v25.16b, v31.16b // g & ~7 +ushr v24.16b, v24.16b, #3 // b >> 3 +sub x2, x2, #6 + +ushll v3.8h, v2.8b, #7 +ushll v4.8h, v1.8b, #2 +uxtl v5.8h, v0.8b +ushll2 v8.8h, v2.16b, #7 +ushll2 v9.8h, v1.16b, #2 +uxtl2 v10.8h, v0.16b + +ushll v14.8h, v13.8b, #7 +ushll v15.8h, v12.8b, #2 +uxtl v16.8h, v11.8b +ushll2 v17.8h, v13.16b, #7 +ushll2 v18.8h, v12.16b, #2 +uxtl2 v19.8h, v11.16b + +ushll v6.8h, v26.8b, #7 +ushll v7.8h, v25.8b, #2 +uxtl v27.8h, v24.8b +ushll2 v28.8h, v26.16b, #7 +ushll2 v29.8h, v25.16b, #2 +uxtl2 v30.8h, v24.16b + +orr v0.16b, v3.16b, v4.16b +orr v0.16b, v0.16b, v5.16b +orr v1.16b, v8.16b, v9.16b +orr v1.16b, v1.16b, v10.16b + +orr v2.16b, v14.16b, v15.16b +orr v2.16b, v2.16b, v16.16b +orr v3.16b, v17.16b, v18.16b +orr v3.16b, v3.16b, v19.16b + +orr v4.16b, v6.16b, v7.16b +orr v4.16b, v4.16b, v27.16b +orr v5.16b, v28.16b, v29.16b +orr v5.16b, v5.16b, v30.16b + +st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 +st1 {v4.8h, v5.8h}, [x1], #32 + +b L6 + +L4: +cmp x2, #4 +blt L2 + +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v11.16b, v12.16b, v13.16b}, [x0], #48 +and v2.16b, v2.16b, v31.16b // r & ~7 +and v1.16b, v1.16b, v31.16b // g & ~7 +ushr v0.16b, v0.16b, #3 // b >> 3 +and v13.16b, v13.16b, v31.16b // r & ~7 +and v12.16b, v12.16b, v31.16b // g & ~7 +ushr v11.16b, v11.16b, #3 // b >> 3 +sub x2, x2, #4 + +ushll v3.8h, v2.8b, #7 +ushll v4.8h, v1.8b, #2 +uxtl v5.8h, v0.8b +ushll2 v8.8h, v2.16b, #7 +ushll2 v9.8h, v1.16b, #2 +uxtl2 v10.8h, v0.16b + +ushll v14.8h, v13.8b, #7 +ushll v15.8h, v12.8b, #2 +uxtl v16.8h, v11.8b +ushll2 v17.8h, v13.16b, #7 +ushll2 v18.8h, v12.16b, #2 +uxtl2 v19.8h, v11.16b + + +orr v20.16b, v3.16b, v4.16b +orr v20.16b, v20.16b, v5.16b +orr v21.16b, v8.16b, v9.16b +orr v21.16b, v21.16b, v10.16b + +orr v22.16b, v14.16b, v15.16b +orr v22.16b, v22.16b, v16.16b +orr v23.16b, v17.16b, v18.16b +orr v23.16b, v23.16b, v19.16b + +st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x1], #64 + +b L4 + +L2: +cmp x2, #2 +blt L1 + +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +and v2.16b, v2.16b, v31.16b // r & ~7 +and v1.16b, v1.16b, v31.16b // g & ~7 +sub x2, x2, #2 +ushr v0.16b, v0.16b, #3 // b >> 3 + +ushll v3.8h, v2.8b, #7 +ushll v4.8h, v1.8b, #2 +uxtl v5.8h, v0.8b +ushll2 v8.8h, v2.16b, #7 +ushll2 v9.8h, v1.16b, #2 +uxtl2 v10.8h, v0.16b + +orr v6.16b, v3.16b, v4.16b +orr v6.16b, v6.16b, v5.16b +orr v7.16b, v8.16b, v9.16b +orr v7.16b, v7.16b, v10.16b + +st1 {v6.8h, v7.8h}, [x1], #32 + +b L2 + +L1: +cmp x2, #1 +blt End + +ld3 {v0.8b, v1.8b, v2.8b}, [x0], #24 +and v2.8b, v2.8b, v31.8b // r & ~7 +and v1.8b, v1.8b, v31.8b // g & ~7 +ushr v0.8b, v0.8b, #3 // b >> 3 +ushll v2.8h, v2.8b, #7 +ushll v1.8h, v1.8b, #2 +uxtl v0.8h, v0.8b +orr v3.16b, v0.16b, v1.16b +orr v3.16b, v3.16b, v2.16b + +st1 {v3.8h}, [x1], #16 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNBGRToBGR565.S b/source/backend/cpu/arm/arm64/MNNBGRToBGR565.S new file mode 100644 index 000000000..0210c0f0c --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNBGRToBGR565.S @@ -0,0 +1,187 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +// void MNNBGRToBGR565Fast(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNBGRToBGR565Fast +// x0: source, x1: dest, x2: count, x3: c +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +movi v31.16b, #8 +neg v31.16b, v31.16b + +L6: +cmp x2, #6 +blt L4 + +movi v30.16b, #4 +neg v30.16b, v30.16b + +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v11.16b, v12.16b, v13.16b}, [x0], #48 +ld3 {v24.16b, v25.16b, v26.16b}, [x0], #48 +and v2.16b, v2.16b, v31.16b // r & ~7 +and v1.16b, v1.16b, v30.16b // g & ~3 +ushr v0.16b, v0.16b, #3 // b >> 3 +and v13.16b, v13.16b, v31.16b // r & ~7 +and v12.16b, v12.16b, v30.16b // g & ~3 +ushr v11.16b, v11.16b, #3 // b >> 3 +and v26.16b, v26.16b, v31.16b // r & ~7 +and v25.16b, v25.16b, v30.16b // g & ~3 +ushr v24.16b, v24.16b, #3 // b >> 3 +sub x2, x2, #6 + +ushll v3.8h, v2.8b, #7 +shl v3.8h, v3.8h, #1 +ushll v4.8h, v1.8b, #3 +uxtl v5.8h, v0.8b +ushll2 v8.8h, v2.16b, #7 +shl v8.8h, v8.8h, #1 +ushll2 v9.8h, v1.16b, #3 +uxtl2 v10.8h, v0.16b + +ushll v14.8h, v13.8b, #7 +shl v14.8h, v14.8h, #1 +ushll v15.8h, v12.8b, #3 +uxtl v16.8h, v11.8b +ushll2 v17.8h, v13.16b, #7 +shl v17.8h, v17.8h, #1 +ushll2 v18.8h, v12.16b, #3 +uxtl2 v19.8h, v11.16b + +ushll v6.8h, v26.8b, #7 +shl v6.8h, v6.8h, #1 +ushll v7.8h, v25.8b, #3 +uxtl v27.8h, v24.8b +ushll2 v28.8h, v26.16b, #7 +shl v28.8h, v28.8h, #1 +ushll2 v29.8h, v25.16b, #3 +uxtl2 v30.8h, v24.16b + +orr v0.16b, v3.16b, v4.16b +orr v0.16b, v0.16b, v5.16b +orr v1.16b, v8.16b, v9.16b +orr v1.16b, v1.16b, v10.16b + +orr v2.16b, v14.16b, v15.16b +orr v2.16b, v2.16b, v16.16b +orr v3.16b, v17.16b, v18.16b +orr v3.16b, v3.16b, v19.16b + +orr v4.16b, v6.16b, v7.16b +orr v4.16b, v4.16b, v27.16b +orr v5.16b, v28.16b, v29.16b +orr v5.16b, v5.16b, v30.16b + +st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 +st1 {v4.8h, v5.8h}, [x1], #32 + +b L6 + +L4: +movi v30.16b, #4 +neg v30.16b, v30.16b +cmp x2, #4 +blt L2 + +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v11.16b, v12.16b, v13.16b}, [x0], #48 +and v2.16b, v2.16b, v31.16b // r & ~7 +and v1.16b, v1.16b, v30.16b // g & ~3 +ushr v0.16b, v0.16b, #3 // b >> 3 +and v13.16b, v13.16b, v31.16b // r & ~7 +and v12.16b, v12.16b, v30.16b // g & ~3 +ushr v11.16b, v11.16b, #3 // b >> 3 +sub x2, x2, #4 + +ushll v3.8h, v2.8b, #7 +shl v3.8h, v3.8h, #1 +ushll v4.8h, v1.8b, #3 +uxtl v5.8h, v0.8b +ushll2 v8.8h, v2.16b, #7 +shl v8.8h, v8.8h, #1 +ushll2 v9.8h, v1.16b, #3 +uxtl2 v10.8h, v0.16b + +ushll v14.8h, v13.8b, #7 +shl v14.8h, v14.8h, #1 +ushll v15.8h, v12.8b, #3 +uxtl v16.8h, v11.8b +ushll2 v17.8h, v13.16b, #7 +shl v17.8h, v17.8h, #1 +ushll2 v18.8h, v12.16b, #3 +uxtl2 v19.8h, v11.16b + + +orr v20.16b, v3.16b, v4.16b +orr v20.16b, v20.16b, v5.16b +orr v21.16b, v8.16b, v9.16b +orr v21.16b, v21.16b, v10.16b + +orr v22.16b, v14.16b, v15.16b +orr v22.16b, v22.16b, v16.16b +orr v23.16b, v17.16b, v18.16b +orr v23.16b, v23.16b, v19.16b + +st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x1], #64 + +b L4 + +L2: +cmp x2, #2 +blt L1 + +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +and v2.16b, v2.16b, v31.16b // r & ~7 +and v1.16b, v1.16b, v30.16b // g & ~3 +sub x2, x2, #2 +ushr v0.16b, v0.16b, #3 // b >> 3 + +ushll v3.8h, v2.8b, #7 +shl v3.8h, v3.8h, #1 +ushll v4.8h, v1.8b, #3 +uxtl v5.8h, v0.8b +ushll2 v8.8h, v2.16b, #7 +shl v8.8h, v8.8h, #1 +ushll2 v9.8h, v1.16b, #3 +uxtl2 v10.8h, v0.16b + +orr v6.16b, v3.16b, v4.16b +orr v6.16b, v6.16b, v5.16b +orr v7.16b, v8.16b, v9.16b +orr v7.16b, v7.16b, v10.16b + +st1 {v6.8h, v7.8h}, [x1], #32 + +b L2 + +L1: +cmp x2, #1 +blt End + +ld3 {v0.8b, v1.8b, v2.8b}, [x0], #24 +and v2.8b, v2.8b, v31.8b // r & ~7 +and v1.8b, v1.8b, v30.8b // g & ~3 +ushr v0.8b, v0.8b, #3 // b >> 3 +ushll v2.8h, v2.8b, #7 +shl v2.8h, v2.8h, #1 +ushll v1.8h, v1.8b, #3 +uxtl v0.8h, v0.8b +orr v3.16b, v0.16b, v1.16b +orr v3.16b, v3.16b, v2.16b + +st1 {v3.8h}, [x1], #16 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNBGRToGRAY.S b/source/backend/cpu/arm/arm64/MNNBGRToGRAY.S new file mode 100644 index 000000000..cd746a40a --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNBGRToGRAY.S @@ -0,0 +1,92 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +// void MNNBGRToGRAYFast(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNBGRToGRAYFast +// x0: source, x1: dest, x2: count +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +movi v29.16b, #7 +movi v30.16b, #38 +movi v31.16b, #19 + +L4: +cmp x2, #4 +blt L2 + +sub x2, x2, #4 +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v14.16b, v15.16b, v16.16b}, [x0], #48 + +umull v4.8h, v0.8b, v29.8b // b*7 +umlal v4.8h, v1.8b, v30.8b // g*38 +umlal v4.8h, v2.8b, v31.8b // r*19 + +umull2 v7.8h, v0.16b, v29.16b // b*7 +umlal2 v7.8h, v1.16b, v30.16b // g*38 +umlal2 v7.8h, v2.16b, v31.16b // r*19 + +umull v18.8h, v14.8b, v29.8b // b*7 +umlal v18.8h, v15.8b, v30.8b // g*38 +umlal v18.8h, v16.8b, v31.8b // r*19 + +umull2 v21.8h, v14.16b, v29.16b // b*7 +umlal2 v21.8h, v15.16b, v30.16b // g*38 +umlal2 v21.8h, v16.16b, v31.16b // r*19 + +uqshrn v4.8b, v4.8h, #6 +uqshrn2 v4.16b, v7.8h, #6 +uqshrn v5.8b, v18.8h, #6 +uqshrn2 v5.16b, v21.8h, #6 + +st1 {v4.16b, v5.16b}, [x1], #32 +b L4 + +L2: +cmp x2, #2 +blt L1 + +sub x2, x2, #2 +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 + +umull v4.8h, v0.8b, v29.8b // b*7 +umlal v4.8h, v1.8b, v30.8b // g*38 +umlal v4.8h, v2.8b, v31.8b // r*19 + +umull2 v7.8h, v0.16b, v29.16b // b*7 +umlal2 v7.8h, v1.16b, v30.16b // g*38 +umlal2 v7.8h, v2.16b, v31.16b // r*19 + +uqshrn v4.8b, v4.8h, #6 +uqshrn2 v4.16b, v7.8h, #6 + +st1 {v4.16b}, [x1], #16 +b L2 + +L1: +cmp x2, #1 +blt End +ld3 {v0.8b, v1.8b, v2.8b}, [x0], #24 + +umull v4.8h, v0.8b, v29.8b // b*7 +umlal v4.8h, v1.8b, v30.8b // g*38 +umlal v4.8h, v2.8b, v31.8b // r*19 + +uqshrn v10.8b, v4.8h, #6 + +st1 {v10.8b}, [x1], #8 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNC3ToC4Fast.S b/source/backend/cpu/arm/arm64/MNNC3ToC4Fast.S new file mode 100644 index 000000000..2c24bed03 --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNC3ToC4Fast.S @@ -0,0 +1,116 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// void MNNC3ToC4Fast(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNC3ToC4Fast +// x0: source, x1: dest, x2: count +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +movi v3.16b, #255 +movi v7.16b, #255 +movi v11.16b, #255 +movi v15.16b, #255 +movi v19.16b, #255 +movi v23.16b, #255 +movi v27.16b, #255 +movi v31.16b, #255 + +L16: +cmp x2, #16 +blt L12 +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v4.16b, v5.16b, v6.16b}, [x0], #48 +ld3 {v8.16b, v9.16b, v10.16b}, [x0], #48 +ld3 {v12.16b, v13.16b, v14.16b}, [x0], #48 +ld3 {v16.16b, v17.16b, v18.16b}, [x0], #48 +ld3 {v20.16b, v21.16b, v22.16b}, [x0], #48 +ld3 {v24.16b, v25.16b, v26.16b}, [x0], #48 +ld3 {v28.16b, v29.16b, v30.16b}, [x0], #48 +sub x2, x2, #16 + +st4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64 +st4 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64 +st4 {v8.16b, v9.16b, v10.16b, v11.16b}, [x1], #64 +st4 {v12.16b, v13.16b, v14.16b, v15.16b}, [x1], #64 +st4 {v16.16b, v17.16b, v18.16b, v19.16b}, [x1], #64 +st4 {v20.16b, v21.16b, v22.16b, v23.16b}, [x1], #64 +st4 {v24.16b, v25.16b, v26.16b, v27.16b}, [x1], #64 +st4 {v28.16b, v29.16b, v30.16b, v31.16b}, [x1], #64 +b L16 + +L12: +cmp x2, #12 +blt L8 +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v4.16b, v5.16b, v6.16b}, [x0], #48 +ld3 {v8.16b, v9.16b, v10.16b}, [x0], #48 +ld3 {v12.16b, v13.16b, v14.16b}, [x0], #48 +ld3 {v16.16b, v17.16b, v18.16b}, [x0], #48 +ld3 {v20.16b, v21.16b, v22.16b}, [x0], #48 +sub x2, x2, #12 + +st4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64 +st4 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64 +st4 {v8.16b, v9.16b, v10.16b, v11.16b}, [x1], #64 +st4 {v12.16b, v13.16b, v14.16b, v15.16b}, [x1], #64 +st4 {v16.16b, v17.16b, v18.16b, v19.16b}, [x1], #64 +st4 {v20.16b, v21.16b, v22.16b, v23.16b}, [x1], #64 + +b L12 + + +L8: +cmp x2, #8 +blt L4 +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v4.16b, v5.16b, v6.16b}, [x0], #48 +ld3 {v8.16b, v9.16b, v10.16b}, [x0], #48 +ld3 {v12.16b, v13.16b, v14.16b}, [x0], #48 +sub x2, x2, #8 + +st4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64 +st4 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64 +st4 {v8.16b, v9.16b, v10.16b, v11.16b}, [x1], #64 +st4 {v12.16b, v13.16b, v14.16b, v15.16b}, [x1], #64 +b L8 + +L4: +cmp x2, #4 +blt L2 +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v4.16b, v5.16b, v6.16b}, [x0], #48 +sub x2, x2, #4 + +st4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64 +st4 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64 +b L4 + +L2: +cmp x2, #2 +blt L1 +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +sub x2, x2, #2 + +st4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64 +b L2 + +L1: +cmp x2, #1 +blt End +ld3 {v0.8b, v1.8b, v2.8b}, [x0], #24 + +st4 {v0.8b, v1.8b, v2.8b, v3.8b}, [x1], #32 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNC3ToXYZFast.S b/source/backend/cpu/arm/arm64/MNNC3ToXYZFast.S new file mode 100644 index 000000000..dba224df8 --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNC3ToXYZFast.S @@ -0,0 +1,88 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +// void MNNC3ToXYZFast(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNC3ToXYZFast +// x0: source, x1: dest, x2: count, x3: c +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +ld1r {v23.4s}, [x3], #4 +ld1r {v24.4s}, [x3], #4 +ld1r {v25.4s}, [x3], #4 +ld1r {v26.4s}, [x3], #4 +ld1r {v27.4s}, [x3], #4 +ld1r {v28.4s}, [x3], #4 +ld1r {v29.4s}, [x3], #4 +ld1r {v30.4s}, [x3], #4 +ld1r {v31.4s}, [x3], #4 + +L1: +cmp x2, #1 +blt End + +ld3 {v0.8b, v1.8b, v2.8b}, [x0], #24 +ushll v0.8h, v0.8b, #0 // r: uint8_t -> uint16_t +ushll v1.8h, v1.8b, #0 +ushll v2.8h, v2.8b, #0 + +uxtl v3.4s, v0.4h // r +uxtl2 v4.4s, v0.8h // r +uxtl v5.4s, v1.4h // g +uxtl2 v6.4s, v1.8h // g +uxtl v7.4s, v2.4h // b +uxtl2 v8.4s, v2.8h // b + +// r*C0, g*C1, b*C2 +mul v9.4s, v3.4s, v23.4s +mul v10.4s, v4.4s, v23.4s +mla v9.4s, v5.4s, v24.4s +mla v10.4s, v6.4s, v24.4s +mla v9.4s, v7.4s, v25.4s +mla v10.4s, v8.4s, v25.4s + +// r*C3, g*C4, b*C5 +mul v15.4s, v3.4s, v26.4s +mul v16.4s, v4.4s, v26.4s +mla v15.4s, v5.4s, v27.4s +mla v16.4s, v6.4s, v27.4s +mla v15.4s, v7.4s, v28.4s +mla v16.4s, v8.4s, v28.4s + +// r*C6, g*C7, b*C8 +mul v21.4s, v3.4s, v29.4s +mul v22.4s, v4.4s, v29.4s +mla v21.4s, v5.4s, v30.4s +mla v22.4s, v6.4s, v30.4s +mla v21.4s, v7.4s, v31.4s +mla v22.4s, v8.4s, v31.4s + +uqrshrn v11.4h, v9.4s, #12 +uqrshrn2 v11.8h, v10.4s, #12 +uqrshrn v12.4h, v15.4s, #12 +uqrshrn2 v12.8h, v16.4s, #12 +uqrshrn v13.4h, v21.4s, #12 +uqrshrn2 v13.8h, v22.4s, #12 + +uqxtn v14.8b, v11.8h +uqxtn v15.8b, v12.8h +uqxtn v16.8b, v13.8h + + +st3 {v14.8b, v15.8b, v16.8b}, [x1], #24 +sub x2, x2, #1 +b L1 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNC3ToYUVFast.S b/source/backend/cpu/arm/arm64/MNNC3ToYUVFast.S new file mode 100644 index 000000000..8bd316552 --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNC3ToYUVFast.S @@ -0,0 +1,92 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +// void MNNC3ToYUVFast(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNC3ToYUVFast +// x0: source, x1: dest, x2: count, x3: c +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +ld1r {v23.4s}, [x3], #4 +ld1r {v24.4s}, [x3], #4 +ld1r {v25.4s}, [x3], #4 +ld1r {v26.4s}, [x3], #4 +ld1r {v27.4s}, [x3], #4 +ld1r {v28.4s}, [x3], #4 +ld1r {v29.4s}, [x3], #4 +ld1r {v30.4s}, [x3], #4 +ld1r {v31.4s}, [x3], #4 +movi v17.8h, #128 + +L1: +cmp x2, #1 +blt End + +ld3 {v0.8b, v1.8b, v2.8b}, [x0], #24 +ushll v0.8h, v0.8b, #0 // r: uint8_t -> uint16_t +ushll v1.8h, v1.8b, #0 +ushll v2.8h, v2.8b, #0 + +uxtl v3.4s, v0.4h // r +uxtl2 v4.4s, v0.8h // r +uxtl v5.4s, v1.4h // g +uxtl2 v6.4s, v1.8h // g +uxtl v7.4s, v2.4h // b +uxtl2 v8.4s, v2.8h // b + +// r*C0, g*C1, b*C2 +mul v9.4s, v3.4s, v23.4s +mul v10.4s, v4.4s, v23.4s +mla v9.4s, v5.4s, v24.4s +mla v10.4s, v6.4s, v24.4s +mla v9.4s, v7.4s, v25.4s +mla v10.4s, v8.4s, v25.4s + +// r*C3, g*C4, b*C5 +mul v15.4s, v3.4s, v26.4s +mul v16.4s, v4.4s, v26.4s +mla v15.4s, v5.4s, v27.4s +mla v16.4s, v6.4s, v27.4s +mla v15.4s, v7.4s, v28.4s +mla v16.4s, v8.4s, v28.4s + +// r*C6, g*C7, b*C8 +mul v21.4s, v3.4s, v29.4s +mul v22.4s, v4.4s, v29.4s +mla v21.4s, v5.4s, v30.4s +mla v22.4s, v6.4s, v30.4s +mla v21.4s, v7.4s, v31.4s +mla v22.4s, v8.4s, v31.4s + +uqrshrn v11.4h, v9.4s, #14 +uqrshrn2 v11.8h, v10.4s, #14 +uqrshrn v12.4h, v15.4s, #14 +uqrshrn2 v12.8h, v16.4s, #14 +uqrshrn v13.4h, v21.4s, #14 +uqrshrn2 v13.8h, v22.4s, #14 + +add v12.8h, v12.8h, v17.8h +add v13.8h, v13.8h, v17.8h + +uqxtn v14.8b, v11.8h +uqxtn v15.8b, v12.8h +uqxtn v16.8b, v13.8h + + +st3 {v14.8b, v15.8b, v16.8b}, [x1], #24 +sub x2, x2, #1 +b L1 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNFloat2Int8.S b/source/backend/cpu/arm/arm64/MNNFloat2Int8.S index 98816cfde..8b5a4e42c 100644 --- a/source/backend/cpu/arm/arm64/MNNFloat2Int8.S +++ b/source/backend/cpu/arm/arm64/MNNFloat2Int8.S @@ -14,21 +14,35 @@ .align 5 asm_function MNNFloat2Int8 -//void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, float* scale, size_t aMin, size_t aMax, size_t zeroPoint); -//x0:src, x1:dst, x2:sizeQuad, x3:scale, x4:aMin, x5:aMax, x6:zeroPoint +//void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, float* scale, size_t aMin, size_t aMax, float* zeroPoint, ssize_t quanParamVec); +//x0:src, x1:dst, x2:sizeQuad, x3:scale, x4:aMin, x5:aMax, x6:zeroPoint, x7: quanParamVec stp d14, d15, [sp, #-64]! stp d12, d13, [sp, #16] stp d10, d11, [sp, #32] stp d8, d9, [sp, #48] -ld1 {v31.4s}, [x3] +ld1r {v31.4s}, [x3] dup v30.16b, w4 dup v29.16b, w5 // copy zero point -dup v28.4s, w6 -scvtf v28.4s, v28.4s +ld1r {v28.4s}, [x6] + +cmp x7, #3 +bne LOAD_SCALE_VEC +ld1 {v31.4s}, [x3] // scale +ld1 {v28.4s}, [x6] // zero +b FL32 +LOAD_SCALE_VEC: +cmp x7, #1 +bne LOAD_ZERO_VEC +ld1 {v31.4s}, [x3] // scale +b FL32 +LOAD_ZERO_VEC: +cmp x7, #2 +bne FL32 +ld1 {v28.4s}, [x6] // zero FL32: cmp x2, #32 @@ -44,58 +58,53 @@ ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 // ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 // ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], #64 fmul v0.4s, v0.4s, v31.4s -fadd v0.4s, v0.4s, v28.4s fmul v1.4s, v1.4s, v31.4s -fadd v1.4s, v1.4s, v28.4s fmul v2.4s, v2.4s, v31.4s -fadd v2.4s, v2.4s, v28.4s fmul v3.4s, v3.4s, v31.4s -fadd v3.4s, v3.4s, v28.4s - fmul v4.4s, v4.4s, v31.4s -fadd v4.4s, v4.4s, v28.4s fmul v5.4s, v5.4s, v31.4s -fadd v5.4s, v5.4s, v28.4s fmul v6.4s, v6.4s, v31.4s -fadd v6.4s, v6.4s, v28.4s fmul v7.4s, v7.4s, v31.4s -fadd v7.4s, v7.4s, v28.4s - fmul v8.4s, v8.4s, v31.4s -fadd v8.4s, v8.4s, v28.4s fmul v9.4s, v9.4s, v31.4s -fadd v9.4s, v9.4s, v28.4s fmul v10.4s, v10.4s, v31.4s -fadd v10.4s, v10.4s, v28.4s fmul v11.4s, v11.4s, v31.4s -fadd v11.4s, v11.4s, v28.4s - fmul v12.4s, v12.4s, v31.4s -fadd v12.4s, v12.4s, v28.4s fmul v13.4s, v13.4s, v31.4s -fadd v13.4s, v13.4s, v28.4s fmul v14.4s, v14.4s, v31.4s -fadd v14.4s, v14.4s, v28.4s fmul v15.4s, v15.4s, v31.4s -fadd v15.4s, v15.4s, v28.4s - - fmul v16.4s, v16.4s, v31.4s -fadd v16.4s, v16.4s, v28.4s fmul v17.4s, v17.4s, v31.4s -fadd v17.4s, v17.4s, v28.4s fmul v18.4s, v18.4s, v31.4s -fadd v18.4s, v18.4s, v28.4s fmul v19.4s, v19.4s, v31.4s -fadd v19.4s, v19.4s, v28.4s - fmul v20.4s, v20.4s, v31.4s -fadd v20.4s, v20.4s, v28.4s fmul v21.4s, v21.4s, v31.4s -fadd v21.4s, v21.4s, v28.4s fmul v22.4s, v22.4s, v31.4s -fadd v22.4s, v22.4s, v28.4s fmul v23.4s, v23.4s, v31.4s + +fadd v0.4s, v0.4s, v28.4s +fadd v1.4s, v1.4s, v28.4s +fadd v2.4s, v2.4s, v28.4s +fadd v3.4s, v3.4s, v28.4s +fadd v4.4s, v4.4s, v28.4s +fadd v5.4s, v5.4s, v28.4s +fadd v6.4s, v6.4s, v28.4s +fadd v7.4s, v7.4s, v28.4s +fadd v8.4s, v8.4s, v28.4s +fadd v9.4s, v9.4s, v28.4s +fadd v10.4s, v10.4s, v28.4s +fadd v11.4s, v11.4s, v28.4s +fadd v12.4s, v12.4s, v28.4s +fadd v13.4s, v13.4s, v28.4s +fadd v14.4s, v14.4s, v28.4s +fadd v15.4s, v15.4s, v28.4s +fadd v16.4s, v16.4s, v28.4s +fadd v17.4s, v17.4s, v28.4s +fadd v18.4s, v18.4s, v28.4s +fadd v19.4s, v19.4s, v28.4s +fadd v20.4s, v20.4s, v28.4s +fadd v21.4s, v21.4s, v28.4s +fadd v22.4s, v22.4s, v28.4s fadd v23.4s, v23.4s, v28.4s fcvtas v0.4s, v0.4s @@ -171,21 +180,21 @@ sqxtn2 v4.16b, v5.8h sqxtn2 v6.16b, v7.8h fmul v8.4s, v8.4s, v31.4s -fadd v8.4s, v8.4s, v28.4s fmul v9.4s, v9.4s, v31.4s -fadd v9.4s, v9.4s, v28.4s fmul v10.4s, v10.4s, v31.4s -fadd v10.4s, v10.4s, v28.4s fmul v11.4s, v11.4s, v31.4s -fadd v11.4s, v11.4s, v28.4s - fmul v12.4s, v12.4s, v31.4s -fadd v12.4s, v12.4s, v28.4s fmul v13.4s, v13.4s, v31.4s -fadd v13.4s, v13.4s, v28.4s fmul v14.4s, v14.4s, v31.4s -fadd v14.4s, v14.4s, v28.4s fmul v15.4s, v15.4s, v31.4s + +fadd v8.4s, v8.4s, v28.4s +fadd v9.4s, v9.4s, v28.4s +fadd v10.4s, v10.4s, v28.4s +fadd v11.4s, v11.4s, v28.4s +fadd v12.4s, v12.4s, v28.4s +fadd v13.4s, v13.4s, v28.4s +fadd v14.4s, v14.4s, v28.4s fadd v15.4s, v15.4s, v28.4s fcvtas v8.4s, v8.4s @@ -207,8 +216,8 @@ sqxtn v19.4h, v14.4s sqxtn2 v19.8h, v15.4s smin v24.16b, v24.16b, v29.16b -smax v24.16b, v24.16b, v30.16b smin v25.16b, v26.16b, v29.16b +smax v24.16b, v24.16b, v30.16b smax v25.16b, v25.16b, v30.16b sqxtn v20.8b, v16.8h @@ -217,18 +226,18 @@ sqxtn v21.8b, v18.8h sqxtn2 v21.16b, v19.8h smin v26.16b, v0.16b, v29.16b -smax v26.16b, v26.16b, v30.16b smin v27.16b, v2.16b, v29.16b +smax v26.16b, v26.16b, v30.16b smax v27.16b, v27.16b, v30.16b smin v12.16b, v4.16b, v29.16b -smax v12.16b, v12.16b, v30.16b smin v13.16b, v6.16b, v29.16b +smax v12.16b, v12.16b, v30.16b smax v13.16b, v13.16b, v30.16b smin v14.16b, v20.16b, v29.16b -smax v14.16b, v14.16b, v30.16b smin v15.16b, v21.16b, v29.16b +smax v14.16b, v14.16b, v30.16b smax v15.16b, v15.16b, v30.16b st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x1], #64 @@ -248,39 +257,37 @@ ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 fmul v0.4s, v0.4s, v31.4s -fadd v0.4s, v0.4s, v28.4s fmul v1.4s, v1.4s, v31.4s -fadd v1.4s, v1.4s, v28.4s fmul v2.4s, v2.4s, v31.4s -fadd v2.4s, v2.4s, v28.4s fmul v3.4s, v3.4s, v31.4s -fadd v3.4s, v3.4s, v28.4s - fmul v4.4s, v4.4s, v31.4s -fadd v4.4s, v4.4s, v28.4s fmul v5.4s, v5.4s, v31.4s -fadd v5.4s, v5.4s, v28.4s fmul v6.4s, v6.4s, v31.4s -fadd v6.4s, v6.4s, v28.4s fmul v7.4s, v7.4s, v31.4s -fadd v7.4s, v7.4s, v28.4s - fmul v8.4s, v8.4s, v31.4s -fadd v8.4s, v8.4s, v28.4s fmul v9.4s, v9.4s, v31.4s -fadd v9.4s, v9.4s, v28.4s fmul v10.4s, v10.4s, v31.4s -fadd v10.4s, v10.4s, v28.4s fmul v11.4s, v11.4s, v31.4s -fadd v11.4s, v11.4s, v28.4s - fmul v12.4s, v12.4s, v31.4s -fadd v12.4s, v12.4s, v28.4s fmul v13.4s, v13.4s, v31.4s -fadd v13.4s, v13.4s, v28.4s fmul v14.4s, v14.4s, v31.4s -fadd v14.4s, v14.4s, v28.4s fmul v15.4s, v15.4s, v31.4s + +fadd v0.4s, v0.4s, v28.4s +fadd v1.4s, v1.4s, v28.4s +fadd v2.4s, v2.4s, v28.4s +fadd v3.4s, v3.4s, v28.4s +fadd v4.4s, v4.4s, v28.4s +fadd v5.4s, v5.4s, v28.4s +fadd v6.4s, v6.4s, v28.4s +fadd v7.4s, v7.4s, v28.4s +fadd v8.4s, v8.4s, v28.4s +fadd v9.4s, v9.4s, v28.4s +fadd v10.4s, v10.4s, v28.4s +fadd v11.4s, v11.4s, v28.4s +fadd v12.4s, v12.4s, v28.4s +fadd v13.4s, v13.4s, v28.4s +fadd v14.4s, v14.4s, v28.4s fadd v15.4s, v15.4s, v28.4s fcvtas v0.4s, v0.4s @@ -350,21 +357,21 @@ FLLoop8: ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 fmul v0.4s, v0.4s, v31.4s -fadd v0.4s, v0.4s, v28.4s fmul v1.4s, v1.4s, v31.4s -fadd v1.4s, v1.4s, v28.4s fmul v2.4s, v2.4s, v31.4s -fadd v2.4s, v2.4s, v28.4s fmul v3.4s, v3.4s, v31.4s -fadd v3.4s, v3.4s, v28.4s - fmul v4.4s, v4.4s, v31.4s -fadd v4.4s, v4.4s, v28.4s fmul v5.4s, v5.4s, v31.4s -fadd v5.4s, v5.4s, v28.4s fmul v6.4s, v6.4s, v31.4s -fadd v6.4s, v6.4s, v28.4s fmul v7.4s, v7.4s, v31.4s + +fadd v0.4s, v0.4s, v28.4s +fadd v1.4s, v1.4s, v28.4s +fadd v2.4s, v2.4s, v28.4s +fadd v3.4s, v3.4s, v28.4s +fadd v4.4s, v4.4s, v28.4s +fadd v5.4s, v5.4s, v28.4s +fadd v6.4s, v6.4s, v28.4s fadd v7.4s, v7.4s, v28.4s fcvtas v0.4s, v0.4s @@ -405,15 +412,14 @@ cmp x2, #3 ble FL1 FLLoop4: -ld1 {v0.4s, v1.4s}, [x0], #32 +ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 fmul v0.4s, v0.4s, v31.4s -fadd v0.4s, v0.4s, v28.4s -ld1 {v2.4s, v3.4s}, [x0], #32 fmul v1.4s, v1.4s, v31.4s -fadd v1.4s, v1.4s, v28.4s fmul v2.4s, v2.4s, v31.4s -fadd v2.4s, v2.4s, v28.4s fmul v3.4s, v3.4s, v31.4s +fadd v0.4s, v0.4s, v28.4s +fadd v1.4s, v1.4s, v28.4s +fadd v2.4s, v2.4s, v28.4s fadd v3.4s, v3.4s, v28.4s fcvtas v0.4s, v0.4s diff --git a/source/backend/cpu/arm/arm64/MNNGRAYToC3Fast.S b/source/backend/cpu/arm/arm64/MNNGRAYToC3Fast.S new file mode 100644 index 000000000..852e3c5aa --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNGRAYToC3Fast.S @@ -0,0 +1,124 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// void MNNGRAYToC3Fast(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNGRAYToC3Fast +// x0: source, x1: dest, x2: count +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +L12: +cmp x2, #12 +blt L8 +ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +ld1 {v21.16b, v22.16b}, [x0], #32 +sub x2, x2, #12 +mov v5.16b, v0.16b +mov v6.16b, v0.16b +mov v7.16b, v0.16b + +mov v9.16b, v1.16b +mov v10.16b, v1.16b +mov v11.16b, v1.16b + +mov v13.16b, v2.16b +mov v14.16b, v2.16b +mov v15.16b, v2.16b + +mov v17.16b, v3.16b +mov v18.16b, v3.16b +mov v19.16b, v3.16b + +mov v23.16b, v21.16b +mov v24.16b, v21.16b +mov v25.16b, v21.16b + +mov v27.16b, v22.16b +mov v28.16b, v22.16b +mov v29.16b, v22.16b + +st3 {v5.16b, v6.16b, v7.16b}, [x1], #48 +st3 {v9.16b, v10.16b, v11.16b}, [x1], #48 +st3 {v13.16b, v14.16b, v15.16b}, [x1], #48 +st3 {v17.16b, v18.16b, v19.16b}, [x1], #48 +st3 {v23.16b, v24.16b, v25.16b}, [x1], #48 +st3 {v27.16b, v28.16b, v29.16b}, [x1], #48 +b L12 + + +L8: +cmp x2, #8 +blt L4 +ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +sub x2, x2, #8 +mov v5.16b, v0.16b +mov v6.16b, v0.16b +mov v7.16b, v0.16b + +mov v9.16b, v1.16b +mov v10.16b, v1.16b +mov v11.16b, v1.16b + +mov v13.16b, v2.16b +mov v14.16b, v2.16b +mov v15.16b, v2.16b + +mov v17.16b, v3.16b +mov v18.16b, v3.16b +mov v19.16b, v3.16b + +st3 {v5.16b, v6.16b, v7.16b}, [x1], #48 +st3 {v9.16b, v10.16b, v11.16b}, [x1], #48 +st3 {v13.16b, v14.16b, v15.16b}, [x1], #48 +st3 {v17.16b, v18.16b, v19.16b}, [x1], #48 +b L8 + +L4: +cmp x2, #4 +blt L2 +ld1 {v0.16b, v1.16b}, [x0], #32 +sub x2, x2, #4 +mov v5.16b, v0.16b +mov v6.16b, v0.16b +mov v7.16b, v0.16b + +mov v9.16b, v1.16b +mov v10.16b, v1.16b +mov v11.16b, v1.16b + +st3 {v5.16b, v6.16b, v7.16b}, [x1], #48 +st3 {v9.16b, v10.16b, v11.16b}, [x1], #48 +b L4 + +L2: +cmp x2, #2 +blt L1 +ld1 {v0.16b}, [x0], #16 +mov v5.16b, v0.16b +mov v6.16b, v0.16b +mov v7.16b, v0.16b +sub x2, x2, #2 +st3 {v5.16b, v6.16b, v7.16b}, [x1], #48 +b L2 + +L1: +cmp x2, #1 +blt End +ld1 {v0.8b}, [x0], #8 +mov v5.8b, v0.8b +mov v6.8b, v0.8b +mov v7.8b, v0.8b +st3 {v5.8b, v6.8b, v7.8b}, [x1], #24 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNGRAYToC4Fast.S b/source/backend/cpu/arm/arm64/MNNGRAYToC4Fast.S new file mode 100644 index 000000000..e13691d2f --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNGRAYToC4Fast.S @@ -0,0 +1,139 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// void MNNGRAYToC4Fast(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNGRAYToC4Fast +// x0: source, x1: dest, x2: count +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] +movi v31.16b, #255 + +L12: +cmp x2, #12 +blt L8 +ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +ld1 {v21.16b, v22.16b}, [x0], #32 +sub x2, x2, #12 +mov v5.16b, v0.16b +mov v6.16b, v0.16b +mov v7.16b, v0.16b +mov v8.16b, v31.16b + +mov v9.16b, v1.16b +mov v10.16b, v1.16b +mov v11.16b, v1.16b +mov v12.16b, v31.16b + +mov v13.16b, v2.16b +mov v14.16b, v2.16b +mov v15.16b, v2.16b +mov v16.16b, v31.16b + +mov v17.16b, v3.16b +mov v18.16b, v3.16b +mov v19.16b, v3.16b +mov v20.16b, v31.16b + +mov v23.16b, v21.16b +mov v24.16b, v21.16b +mov v25.16b, v21.16b +mov v26.16b, v31.16b + +mov v27.16b, v22.16b +mov v28.16b, v22.16b +mov v29.16b, v22.16b +mov v30.16b, v31.16b + +st4 {v5.16b, v6.16b, v7.16b, v8.16b}, [x1], #64 +st4 {v9.16b, v10.16b, v11.16b, v12.16b}, [x1], #64 +st4 {v13.16b, v14.16b, v15.16b, v16.16b}, [x1], #64 +st4 {v17.16b, v18.16b, v19.16b, v20.16b}, [x1], #64 +st4 {v23.16b, v24.16b, v25.16b, v26.16b}, [x1], #64 +st4 {v27.16b, v28.16b, v29.16b, v30.16b}, [x1], #64 +b L12 + + +L8: +cmp x2, #8 +blt L4 +ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +sub x2, x2, #8 +mov v5.16b, v0.16b +mov v6.16b, v0.16b +mov v7.16b, v0.16b +mov v8.16b, v31.16b + +mov v9.16b, v1.16b +mov v10.16b, v1.16b +mov v11.16b, v1.16b +mov v12.16b, v31.16b + +mov v13.16b, v2.16b +mov v14.16b, v2.16b +mov v15.16b, v2.16b +mov v16.16b, v31.16b + +mov v17.16b, v3.16b +mov v18.16b, v3.16b +mov v19.16b, v3.16b +mov v20.16b, v31.16b + +st4 {v5.16b, v6.16b, v7.16b, v8.16b}, [x1], #64 +st4 {v9.16b, v10.16b, v11.16b, v12.16b}, [x1], #64 +st4 {v13.16b, v14.16b, v15.16b, v16.16b}, [x1], #64 +st4 {v17.16b, v18.16b, v19.16b, v20.16b}, [x1], #64 +b L8 + +L4: +cmp x2, #4 +blt L2 +ld1 {v0.16b, v1.16b}, [x0], #32 +sub x2, x2, #4 +mov v5.16b, v0.16b +mov v6.16b, v0.16b +mov v7.16b, v0.16b +mov v8.16b, v31.16b + +mov v9.16b, v1.16b +mov v10.16b, v1.16b +mov v11.16b, v1.16b +mov v12.16b, v31.16b + +st4 {v5.16b, v6.16b, v7.16b, v8.16b}, [x1], #64 +st4 {v9.16b, v10.16b, v11.16b, v12.16b}, [x1], #64 +b L4 + +L2: +cmp x2, #2 +blt L1 +ld1 {v0.16b}, [x0], #16 +mov v5.16b, v0.16b +mov v6.16b, v0.16b +mov v7.16b, v0.16b +mov v8.16b, v31.16b +sub x2, x2, #2 +st4 {v5.16b, v6.16b, v7.16b, v8.16b}, [x1], #64 +b L2 + +L1: +cmp x2, #1 +blt End +ld1 {v0.8b}, [x0], #8 +mov v5.8b, v0.8b +mov v6.8b, v0.8b +mov v7.8b, v0.8b +mov v8.8b, v31.8b +st4 {v5.8b, v6.8b, v7.8b, v8.8b}, [x1], #32 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV82_Unit.S b/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV82_Unit.S index d1fdd68bd..339bbd37e 100644 --- a/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV82_Unit.S +++ b/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV82_Unit.S @@ -127,7 +127,7 @@ stp x23, x24, [sp, #(16 * 8)] ldr x27, [x6, #64] // blockNum mul x27, x27, x3 // blockNum * src_depth_quad_perblock -lsl x15, x27, #4 // x15 = src_depth_quad * UNIT * SRC_UNIT +lsl x15, x27, #5 // x15 = src_depth_quad * UNIT * SRC_UNIT ldr w28, [x6, #24] // useInt8 ldr x25, [x6, #40] // xKernelSum @@ -135,9 +135,9 @@ ldr x26, [x6, #48] // weightQuantBias ldr x24, [x6, #80] // extraScale add x23, x6, #16 // int8 max ptr -mov x21, #4 // sizeof(int8_t) * UNIT +mov x21, #4 // sizeof(int8_t) * pack cbnz w28, Start -mov x21, #16 // sizeof(float) * UNIT +mov x21, #16 // sizeof(float) * pack ldr x23, [x6, #56] // fp32minmax Start: mov x22, #48 // src_steps @@ -148,7 +148,6 @@ TILE_12: cmp x5, #2 blt L4LoopDz_TILE_12 L8LoopDz_TILE_12: - //ld1 {v0.4s, v1.4s}, [x9], #32 // bias mov x11, x1 mov x13, x3 mov x20, x0 // tag dst address @@ -162,13 +161,13 @@ L8LoopDz_TILE_12: SET_BIAS v28, v29, v30, v31 L8LoopSz_TILE_12: - ld1 {v3.16b}, [x2], x15 // weight + ld1 {v3.16b, v4.16b}, [x2], #32 // weight ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] - ld1 {v4.16b}, [x2], #16 + .inst 0x4f81e06c // sdot v12.4s, v3.16b, v1.4b[0] .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] @@ -181,7 +180,7 @@ L8LoopDz_TILE_12: .inst 0x4fa0e095 // sdot v21.4s, v4.16b, v0.4b[1] .inst 0x4f80e896 // sdot v22.4s, v4.16b, v0.4b[2] .inst 0x4fa0e897 // sdot v23.4s, v4.16b, v0.4b[3] - sub x2, x2, x15 + .inst 0x4f81e098 // sdot v24.4s, v4.16b, v1.4b[0] .inst 0x4fa1e099 // sdot v25.4s, v4.16b, v1.4b[1] .inst 0x4f81e89a // sdot v26.4s, v4.16b, v1.4b[2] @@ -194,8 +193,7 @@ L8LoopDz_TILE_12: bne L8LoopSz_TILE_12 L8LoopSzEnd_TILE_12: - // add x2, x2, x15 - add x2, x27, x15, LSL #1 + add x2, x27, x15 sub x5, x5, #2 L8Tile12Quan: @@ -352,7 +350,7 @@ L8LoopDz_TILE_12: L8Tile12LoopCheck: cmp x5, #1 bgt L8LoopDz_TILE_12 - blt End + cbz x5, End L4LoopDz_TILE_12: SET_BIAS v8, v9, v10, v11 @@ -360,7 +358,7 @@ L4LoopDz_TILE_12: SET_BIAS v16, v17, v18, v19 L4LoopSz_TILE_12: - ld1 {v3.16b}, [x2], #16 // weight + ld1 {v3.16b}, [x2] // weight ld1 {v0.16b, v1.16b, v2.16b}, [x1], #48 // src .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] @@ -370,6 +368,7 @@ L4LoopDz_TILE_12: .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] + add x2, x2, #32 // weight offset=lp*hp=32 subs x3, x3, #1 .inst 0x4f82e070 // sdot v16.4s, v3.16b, v2.4b[0] .inst 0x4fa2e071 // sdot v17.4s, v3.16b, v2.4b[1] @@ -497,18 +496,18 @@ L8LoopDz_TILE_8: SET_BIAS v20, v21, v22, v23 L8LoopSz_TILE_8: - ld1 {v3.16b}, [x12], x15 // weight + ld1 {v3.16b, v4.16b}, [x12], #32 // weight ld1 {v0.16b, v1.16b}, [x11], x22 // src .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] - ld1 {v4.16b}, [x12], #16 + .inst 0x4f81e06c // sdot v12.4s, v3.16b, v1.4b[0] .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] - sub x12, x12, x15 + .inst 0x4f80e090 // sdot v16.4s, v4.16b, v0.4b[0] .inst 0x4fa0e091 // sdot v17.4s, v4.16b, v0.4b[1] .inst 0x4f80e892 // sdot v18.4s, v4.16b, v0.4b[2] @@ -521,8 +520,7 @@ L8LoopDz_TILE_8: bne L8LoopSz_TILE_8 L8LoopSzEnd_TILE_8: - //add x12, x12, x15 - add x12, x27, x15, LSL #1 + add x12, x27, x15 sub x14, x14, #2 L8Tile8Quan: @@ -652,12 +650,13 @@ L4LoopDz_TILE_8: SET_BIAS v12, v13, v14, v15 L4LoopSz_TILE_8: - ld1 {v3.16b}, [x12], #16 // weight + ld1 {v3.16b}, [x12] // weight ld1 {v0.16b, v1.16b}, [x11], x22 // src .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] + add x12, x12, #32 // weight offset=lp*hp subs x13, x13, #1 .inst 0x4f81e06c // sdot v12.4s, v3.16b, v1.4b[0] .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] @@ -772,15 +771,14 @@ L8LoopDz_TILE_4: SET_BIAS v12, v13, v14, v15 L8LoopSz_TILE_4: - ld1 {v3.16b}, [x12], x15 // weight + ld1 {v3.16b, v4.16b}, [x12], #32 // weight ld1 {v0.16b}, [x11], x22 // src - ld1 {v4.16b}, [x12], #16 // weight .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] + subs x13, x13, #1 - sub x12, x12, x15 .inst 0x4f80e08c // sdot v12.4s, v4.16b, v0.4b[0] .inst 0x4fa0e08d // sdot v13.4s, v4.16b, v0.4b[1] .inst 0x4f80e88e // sdot v14.4s, v4.16b, v0.4b[2] @@ -788,8 +786,7 @@ L8LoopDz_TILE_4: bne L8LoopSz_TILE_4 L8LoopSzEnd_TILE_4: - //add x12, x12, x15 - add x12, x27, x15, LSL #1 + add x12, x27, x15 sub x14, x14, #2 L8Tile4Quan: @@ -879,9 +876,10 @@ L4LoopDz_TILE_4: SET_BIAS v8, v9, v10, v11 L4LoopSz_TILE_4: - ld1 {v3.16b}, [x12], #16 // weight + ld1 {v3.16b}, [x12] // weight ld1 {v0.16b}, [x11], x22 // src subs x13, x13, #1 + add x12, x12, #32 // weight offset = lp*hp .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] @@ -974,17 +972,15 @@ L8LoopDz_TILE_1: movi v8.16b, #0 movi v9.16b, #0 L8LoopSz_TILE_1: - ld1 {v3.16b}, [x12], x15 // weight + ld1 {v3.16b, v4.16b}, [x12], #32 // weight ld1 {v0.s}[0], [x11], x22 // src - ld1 {v4.16b}, [x12], #16 // weight .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] subs x13, x13, #1 - sub x12, x12, x15 .inst 0x4f80e089 // sdot v9.4s, v4.16b, v0.4b[0] bne L8LoopSz_TILE_1 L8LoopSzEnd_TILE_1: - add x12, x27, x15, LSL #1 + add x12, x27, x15 sub x14, x14, #2 L8Tile1Quan: @@ -1067,9 +1063,10 @@ L4LoopDz_TILE_1: mov x13, x3 movi v8.16b, #0 L4LoopSz_TILE_1: - ld1 {v3.16b}, [x12], #16 // weight + ld1 {v3.16b}, [x12] // weight ld1 {v0.s}[0], [x11], x22 // src subs x13, x13, #1 + add x12, x12, #32 // weight offset = lp*hp .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] bne L4LoopSz_TILE_1 @@ -1132,11 +1129,11 @@ cbz x24, Tile1_End_Offset add x24, x24, #4 Tile1_End_Offset: - sub x7, x7, #1 + subs x7, x7, #1 add x0, x0, x21 add x1, x1, #4 add x25, x25, #4 - b TILE_1 + bne TILE_1 End: ldp x23, x24, [sp, #(16 * 8)] diff --git a/source/backend/cpu/arm/arm64/MNNPackC2.S b/source/backend/cpu/arm/arm64/MNNPackC2.S new file mode 100644 index 000000000..3a66bafd9 --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNPackC2.S @@ -0,0 +1,107 @@ +// +// MNNPackInt8C2.S +// MNN +// +// Created by MNN on 2019/02/02. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __aarch64__ +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNPackInt8C2 +//void MNNPackInt8C2(float* dst, const float* src, size_t area, size_t depth, int32_t* areaOffset) +//Auto load: +//x0:dst, x1:src, x2:area, x3:depth, x4: areaOffset, x5: areaOffset + +ldr w10, [x4, #4] // dstDepthOffset +ldr w9, [x4, #0] // srcDepthOffset +uxtw x10, w10 +uxtw x9, w9 + +//x12: srcDepthOffset:area*sizeof(float) +mov x12, #4 +mul x12, x9, x12 + +//r10 -> 2 * (dstArea * sizeof(float) - area * sizeof(float)) +mov x5, #8 +sub x10, x10, x2 +mul x10, x5, x10 + +//r9 -> (srcArea * sizeof(float) - area * sizeof(float)) +mov x6, #4 +sub x9, x9, x2 +mul x9, x6, x9 + +UpL2: +cmp x3, #1 +ble UpL1 + +UpL2Loop: +add x5, x1, x12 +mov x8, x2 +cmp x8, #3 +ble UpL2AreaRemain +UpL2AreaLoop: +ld1 {v0.4s}, [x1], #16 +ld1 {v1.4s}, [x5], #16 + +st2 {v0.4s, v1.4s}, [x0], #32 +sub x8, x8, #4 +cmp x8, #4 +bge UpL2AreaLoop + +cmp x8, #0 +beq UpL2AreaRemainEnd +UpL2AreaRemain: +ld1 {v0.s}[0], [x1], #4 +ld1 {v0.s}[1], [x5], #4 + +st1 {v0.d}[0], [x0], #8 + +subs x8, x8, #1 +bne UpL2AreaRemain + +UpL2AreaRemainEnd: +sub x3, x3, #2 +add x1, x5, x9 +cmp x3, #2 +add x0, x10, x0 +bge UpL2Loop + +UpL1: +cmp x3, #0 +beq UpEnd +mov x8, x2 +cmp x8, #3 +ble UpL1AreaRemain +UpL1AreaLoop: +ld1 {v0.4s}, [x1], #16 +movi v1.4s, #0 + +st2 {v0.4s, v1.4s}, [x0], #32 +sub x8, x8, #4 +cmp x8, #4 +bge UpL1AreaLoop + +cmp x8, #0 +beq UpL1AreaRemainEnd +UpL1AreaRemain: +movi v0.4s, #0 +ld1 {v0.s}[0], [x1], #4 + +st1 {v0.d}[0], [x0], #8 + +subs x8, x8, #1 +bne UpL1AreaRemain + +UpL1AreaRemainEnd: + +UpEnd: + +ret + +#endif diff --git a/source/backend/cpu/arm/arm64/MNNRGBAToBGRAFast.S b/source/backend/cpu/arm/arm64/MNNRGBAToBGRAFast.S new file mode 100644 index 000000000..884d48e8f --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNRGBAToBGRAFast.S @@ -0,0 +1,147 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// void MNNRGBAToBGRAFast(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNRGBAToBGRAFast +// x0: source, x1: dest, x2: count +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +L10: +cmp x2, #10 +blt L8 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +ld4 {v4.16b, v5.16b, v6.16b, v7.16b}, [x0], #64 +ld4 {v8.16b, v9.16b, v10.16b, v11.16b}, [x0], #64 +ld4 {v12.16b, v13.16b, v14.16b, v15.16b}, [x0], #64 +ld4 {v28.16b, v29.16b, v30.16b, v31.16b}, [x0], #64 +sub x2, x2, #10 + +mov v16.16b, v2.16b +mov v17.16b, v1.16b +mov v18.16b, v0.16b +mov v19.16b, v3.16b + +mov v20.16b, v6.16b +mov v21.16b, v5.16b +mov v22.16b, v4.16b +mov v23.16b, v7.16b + +mov v24.16b, v10.16b +mov v25.16b, v9.16b +mov v26.16b, v8.16b +mov v27.16b, v11.16b + +mov v0.16b, v14.16b +mov v1.16b, v13.16b +mov v2.16b, v12.16b +mov v3.16b, v15.16b + +mov v4.16b, v30.16b +mov v5.16b, v29.16b +mov v6.16b, v28.16b +mov v7.16b, v31.16b + +st4 {v16.16b, v17.16b, v18.16b, v19.16b}, [x1], #64 +st4 {v20.16b, v21.16b, v22.16b, v23.16b}, [x1], #64 +st4 {v24.16b, v25.16b, v26.16b, v27.16b}, [x1], #64 +st4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64 +st4 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64 + +b L10 + + +L8: +cmp x2, #8 +blt L4 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +ld4 {v4.16b, v5.16b, v6.16b, v7.16b}, [x0], #64 +ld4 {v8.16b, v9.16b, v10.16b, v11.16b}, [x0], #64 +ld4 {v12.16b, v13.16b, v14.16b, v15.16b}, [x0], #64 +sub x2, x2, #8 + +mov v16.16b, v2.16b +mov v17.16b, v1.16b +mov v18.16b, v0.16b +mov v19.16b, v3.16b + +mov v20.16b, v6.16b +mov v21.16b, v5.16b +mov v22.16b, v4.16b +mov v23.16b, v7.16b + +mov v24.16b, v10.16b +mov v25.16b, v9.16b +mov v26.16b, v8.16b +mov v27.16b, v11.16b + +mov v28.16b, v14.16b +mov v29.16b, v13.16b +mov v30.16b, v12.16b +mov v31.16b, v15.16b + +st4 {v16.16b, v17.16b, v18.16b, v19.16b}, [x1], #64 +st4 {v20.16b, v21.16b, v22.16b, v23.16b}, [x1], #64 +st4 {v24.16b, v25.16b, v26.16b, v27.16b}, [x1], #64 +st4 {v28.16b, v29.16b, v30.16b, v31.16b}, [x1], #64 +b L8 + +L4: +cmp x2, #4 +blt L2 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +ld4 {v4.16b, v5.16b, v6.16b, v7.16b}, [x0], #64 +sub x2, x2, #4 + +mov v16.16b, v2.16b +mov v17.16b, v1.16b +mov v18.16b, v0.16b +mov v19.16b, v3.16b + +mov v20.16b, v6.16b +mov v21.16b, v5.16b +mov v22.16b, v4.16b +mov v23.16b, v7.16b + +st4 {v16.16b, v17.16b, v18.16b, v19.16b}, [x1], #64 +st4 {v20.16b, v21.16b, v22.16b, v23.16b}, [x1], #64 +b L4 + +L2: +cmp x2, #2 +blt L1 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +sub x2, x2, #2 + +mov v16.16b, v2.16b +mov v17.16b, v1.16b +mov v18.16b, v0.16b +mov v19.16b, v3.16b + +st4 {v16.16b, v17.16b, v18.16b, v19.16b}, [x1], #64 +b L2 + +L1: +cmp x2, #1 +blt End +ld4 {v0.8b, v1.8b, v2.8b, v3.8b}, [x0], #32 + +mov v16.8b, v2.8b +mov v17.8b, v1.8b +mov v18.8b, v0.8b +mov v19.8b, v3.8b + +st4 {v16.8b, v17.8b, v18.8b, v19.8b}, [x1], #32 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNRGBAToBGRFast.S b/source/backend/cpu/arm/arm64/MNNRGBAToBGRFast.S new file mode 100644 index 000000000..d894c0c13 --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNRGBAToBGRFast.S @@ -0,0 +1,134 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// void MNNRGBAToBGRFast(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNRGBAToBGRFast +// x0: source, x1: dest, x2: count +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +L10: +cmp x2, #10 +blt L8 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +ld4 {v4.16b, v5.16b, v6.16b, v7.16b}, [x0], #64 +ld4 {v8.16b, v9.16b, v10.16b, v11.16b}, [x0], #64 +ld4 {v12.16b, v13.16b, v14.16b, v15.16b}, [x0], #64 +ld4 {v28.16b, v29.16b, v30.16b, v31.16b}, [x0], #64 +sub x2, x2, #10 + +mov v16.16b, v2.16b +mov v17.16b, v1.16b +mov v18.16b, v0.16b + +mov v20.16b, v6.16b +mov v21.16b, v5.16b +mov v22.16b, v4.16b + +mov v24.16b, v10.16b +mov v25.16b, v9.16b +mov v26.16b, v8.16b + +mov v0.16b, v14.16b +mov v1.16b, v13.16b +mov v2.16b, v12.16b + +mov v4.16b, v30.16b +mov v5.16b, v29.16b +mov v6.16b, v28.16b + +st3 {v16.16b, v17.16b, v18.16b}, [x1], #48 +st3 {v20.16b, v21.16b, v22.16b}, [x1], #48 +st3 {v24.16b, v25.16b, v26.16b}, [x1], #48 +st3 {v0.16b, v1.16b, v2.16b}, [x1], #48 +st3 {v4.16b, v5.16b, v6.16b}, [x1], #48 + +b L10 + + +L8: +cmp x2, #8 +blt L4 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +ld4 {v4.16b, v5.16b, v6.16b, v7.16b}, [x0], #64 +ld4 {v8.16b, v9.16b, v10.16b, v11.16b}, [x0], #64 +ld4 {v12.16b, v13.16b, v14.16b, v15.16b}, [x0], #64 +sub x2, x2, #8 + +mov v16.16b, v2.16b +mov v17.16b, v1.16b +mov v18.16b, v0.16b + +mov v20.16b, v6.16b +mov v21.16b, v5.16b +mov v22.16b, v4.16b + +mov v24.16b, v10.16b +mov v25.16b, v9.16b +mov v26.16b, v8.16b + +mov v28.16b, v14.16b +mov v29.16b, v13.16b +mov v30.16b, v12.16b + +st3 {v16.16b, v17.16b, v18.16b}, [x1], #48 +st3 {v20.16b, v21.16b, v22.16b}, [x1], #48 +st3 {v24.16b, v25.16b, v26.16b}, [x1], #48 +st3 {v28.16b, v29.16b, v30.16b}, [x1], #48 +b L8 + +L4: +cmp x2, #4 +blt L2 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +ld4 {v4.16b, v5.16b, v6.16b, v7.16b}, [x0], #64 +sub x2, x2, #4 + +mov v16.16b, v2.16b +mov v17.16b, v1.16b +mov v18.16b, v0.16b + +mov v20.16b, v6.16b +mov v21.16b, v5.16b +mov v22.16b, v4.16b + +st3 {v16.16b, v17.16b, v18.16b}, [x1], #48 +st3 {v20.16b, v21.16b, v22.16b}, [x1], #48 +b L4 + +L2: +cmp x2, #2 +blt L1 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +sub x2, x2, #2 + +mov v16.16b, v2.16b +mov v17.16b, v1.16b +mov v18.16b, v0.16b + +st3 {v16.16b, v17.16b, v18.16b}, [x1], #48 +b L2 + +L1: +cmp x2, #1 +blt End +ld4 {v0.8b, v1.8b, v2.8b, v3.8b}, [x0], #32 + +mov v16.8b, v2.8b +mov v17.8b, v1.8b +mov v18.8b, v0.8b + +st3 {v16.8b, v17.8b, v18.8b}, [x1], #24 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNRGBAToGRAYFast.S b/source/backend/cpu/arm/arm64/MNNRGBAToGRAYFast.S new file mode 100644 index 000000000..d83e3c8a1 --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNRGBAToGRAYFast.S @@ -0,0 +1,96 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +// void MNNRGBAToGRAYFast(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNRGBAToGRAYFast +// x0: source, x1: dest, x2: count +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +movi v29.16b, #19 +movi v30.16b, #38 +movi v31.16b, #7 + +// b*7 +// g*38 +// r*19 + +L4: +cmp x2, #4 +blt L2 + +sub x2, x2, #4 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 +ld4 {v14.16b, v15.16b, v16.16b, v17.16b}, [x0], #64 + +umull v4.8h, v0.8b, v29.8b +umlal v4.8h, v1.8b, v30.8b +umlal v4.8h, v2.8b, v31.8b + +umull2 v7.8h, v0.16b, v29.16b +umlal2 v7.8h, v1.16b, v30.16b +umlal2 v7.8h, v2.16b, v31.16b + +umull v18.8h, v14.8b, v29.8b +umlal v18.8h, v15.8b, v30.8b +umlal v18.8h, v16.8b, v31.8b + +umull2 v21.8h, v14.16b, v29.16b +umlal2 v21.8h, v15.16b, v30.16b +umlal2 v21.8h, v16.16b, v31.16b + +uqshrn v4.8b, v4.8h, #6 +uqshrn2 v4.16b, v7.8h, #6 +uqshrn v5.8b, v18.8h, #6 +uqshrn2 v5.16b, v21.8h, #6 + +st1 {v4.16b, v5.16b}, [x1], #32 +b L4 + +L2: +cmp x2, #2 +blt L1 + +sub x2, x2, #2 +ld4 {v0.16b, v1.16b, v2.16b, v3.16b}, [x0], #64 + +umull v4.8h, v0.8b, v29.8b +umlal v4.8h, v1.8b, v30.8b +umlal v4.8h, v2.8b, v31.8b + +umull2 v7.8h, v0.16b, v29.16b +umlal2 v7.8h, v1.16b, v30.16b +umlal2 v7.8h, v2.16b, v31.16b + +uqshrn v4.8b, v4.8h, #6 +uqshrn2 v4.16b, v7.8h, #6 + +st1 {v4.16b}, [x1], #16 +b L2 + +L1: +cmp x2, #1 +blt End +ld4 {v0.8b, v1.8b, v2.8b, v3.8b}, [x0], #32 + +umull v4.8h, v0.8b, v29.8b +umlal v4.8h, v1.8b, v30.8b +umlal v4.8h, v2.8b, v31.8b + +uqshrn v10.8b, v4.8h, #6 + +st1 {v10.8b}, [x1], #8 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNRGBToBGR.S b/source/backend/cpu/arm/arm64/MNNRGBToBGR.S new file mode 100644 index 000000000..f12cf78e0 --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNRGBToBGR.S @@ -0,0 +1,126 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// void MNNRGBToBGR(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNRGBToBGRC8 +// x0: source, x1: dest, x2: count +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +L12: +cmp x2, #12 +blt L8 +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v6.16b, v7.16b, v8.16b}, [x0], #48 +ld3 {v12.16b, v13.16b, v14.16b}, [x0], #48 +ld3 {v15.16b, v16.16b, v17.16b}, [x0], #48 +ld3 {v24.16b, v25.16b, v26.16b}, [x0], #48 +ld3 {v27.16b, v28.16b, v29.16b}, [x0], #48 +sub x2, x2, #12 +mov v3.16b, v2.16b +mov v4.16b, v1.16b +mov v5.16b, v0.16b +mov v9.16b, v8.16b +mov v10.16b, v7.16b +mov v11.16b, v6.16b + +mov v18.16b, v14.16b +mov v19.16b, v13.16b +mov v20.16b, v12.16b +mov v21.16b, v17.16b +mov v22.16b, v16.16b +mov v23.16b, v15.16b + +mov v0.16b, v26.16b +mov v1.16b, v25.16b +mov v2.16b, v24.16b +mov v6.16b, v29.16b +mov v7.16b, v28.16b +mov v8.16b, v27.16b +st3 {v3.16b, v4.16b, v5.16b}, [x1], #48 +st3 {v9.16b, v10.16b, v11.16b}, [x1], #48 +st3 {v18.16b, v19.16b, v20.16b}, [x1], #48 +st3 {v21.16b, v22.16b, v23.16b}, [x1], #48 +st3 {v0.16b, v1.16b, v2.16b}, [x1], #48 +st3 {v6.16b, v7.16b, v8.16b}, [x1], #48 + +b L12 + + +L8: +cmp x2, #8 +blt L4 +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v6.16b, v7.16b, v8.16b}, [x0], #48 +ld3 {v12.16b, v13.16b, v14.16b}, [x0], #48 +ld3 {v15.16b, v16.16b, v17.16b}, [x0], #48 +sub x2, x2, #8 +mov v3.16b, v2.16b +mov v4.16b, v1.16b +mov v5.16b, v0.16b +mov v9.16b, v8.16b +mov v10.16b, v7.16b +mov v11.16b, v6.16b + +mov v18.16b, v14.16b +mov v19.16b, v13.16b +mov v20.16b, v12.16b +mov v21.16b, v17.16b +mov v22.16b, v16.16b +mov v23.16b, v15.16b + +st3 {v3.16b, v4.16b, v5.16b}, [x1], #48 +st3 {v9.16b, v10.16b, v11.16b}, [x1], #48 +st3 {v18.16b, v19.16b, v20.16b}, [x1], #48 +st3 {v21.16b, v22.16b, v23.16b}, [x1], #48 +b L8 + +L4: +cmp x2, #4 +blt L2 +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v6.16b, v7.16b, v8.16b}, [x0], #48 +sub x2, x2, #4 +mov v3.16b, v2.16b +mov v4.16b, v1.16b +mov v5.16b, v0.16b +mov v9.16b, v8.16b +mov v10.16b, v7.16b +mov v11.16b, v6.16b + +st3 {v3.16b, v4.16b, v5.16b}, [x1], #48 +st3 {v9.16b, v10.16b, v11.16b}, [x1], #48 +b L4 + +L2: +cmp x2, #2 +blt L1 +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +mov v3.16b, v2.16b +mov v4.16b, v1.16b +mov v5.16b, v0.16b +sub x2, x2, #2 +st3 {v3.16b, v4.16b, v5.16b}, [x1], #48 +b L2 + +L1: +cmp x2, #1 +blt End +ld3 {v0.8b, v1.8b, v2.8b}, [x0], #24 +mov v3.8b, v2.8b +mov v4.8b, v1.8b +mov v5.8b, v0.8b +st3 {v3.8b, v4.8b, v5.8b}, [x1], #24 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNRGBToBGR555.S b/source/backend/cpu/arm/arm64/MNNRGBToBGR555.S new file mode 100644 index 000000000..d34a588c9 --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNRGBToBGR555.S @@ -0,0 +1,169 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +// void MNNRGBToBGR555Fast(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNRGBToBGR555Fast +// x0: source, x1: dest, x2: count, x3: c +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +movi v31.16b, #8 +neg v31.16b, v31.16b + +L6: +cmp x2, #6 +blt L4 + +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v11.16b, v12.16b, v13.16b}, [x0], #48 +ld3 {v24.16b, v25.16b, v26.16b}, [x0], #48 +and v0.16b, v0.16b, v31.16b // r & ~7 +and v1.16b, v1.16b, v31.16b // g & ~7 +ushr v2.16b, v2.16b, #3 // b >> 3 +and v11.16b, v11.16b, v31.16b // r & ~7 +and v12.16b, v12.16b, v31.16b // g & ~7 +ushr v13.16b, v13.16b, #3 // b >> 3 +and v24.16b, v24.16b, v31.16b // r & ~7 +and v25.16b, v25.16b, v31.16b // g & ~7 +ushr v26.16b, v26.16b, #3 // b >> 3 +sub x2, x2, #6 + +ushll v3.8h, v0.8b, #7 +ushll v4.8h, v1.8b, #2 +uxtl v5.8h, v2.8b +ushll2 v8.8h, v0.16b, #7 +ushll2 v9.8h, v1.16b, #2 +uxtl2 v10.8h, v2.16b + +ushll v14.8h, v11.8b, #7 +ushll v15.8h, v12.8b, #2 +uxtl v16.8h, v13.8b +ushll2 v17.8h, v11.16b, #7 +ushll2 v18.8h, v12.16b, #2 +uxtl2 v19.8h, v13.16b + +ushll v6.8h, v24.8b, #7 +ushll v7.8h, v25.8b, #2 +uxtl v27.8h, v26.8b +ushll2 v28.8h, v24.16b, #7 +ushll2 v29.8h, v25.16b, #2 +uxtl2 v30.8h, v26.16b + +orr v0.16b, v3.16b, v4.16b +orr v0.16b, v0.16b, v5.16b +orr v1.16b, v8.16b, v9.16b +orr v1.16b, v1.16b, v10.16b + +orr v2.16b, v14.16b, v15.16b +orr v2.16b, v2.16b, v16.16b +orr v3.16b, v17.16b, v18.16b +orr v3.16b, v3.16b, v19.16b + +orr v4.16b, v6.16b, v7.16b +orr v4.16b, v4.16b, v27.16b +orr v5.16b, v28.16b, v29.16b +orr v5.16b, v5.16b, v30.16b + +st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 +st1 {v4.8h, v5.8h}, [x1], #32 + +b L6 + +L4: +cmp x2, #4 +blt L2 + +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v11.16b, v12.16b, v13.16b}, [x0], #48 +and v0.16b, v0.16b, v31.16b // r & ~7 +and v1.16b, v1.16b, v31.16b // g & ~7 +ushr v2.16b, v2.16b, #3 // b >> 3 +and v11.16b, v11.16b, v31.16b // r & ~7 +and v12.16b, v12.16b, v31.16b // g & ~7 +ushr v13.16b, v13.16b, #3 // b >> 3 +sub x2, x2, #4 + +ushll v3.8h, v0.8b, #7 +ushll v4.8h, v1.8b, #2 +uxtl v5.8h, v2.8b +ushll2 v8.8h, v0.16b, #7 +ushll2 v9.8h, v1.16b, #2 +uxtl2 v10.8h, v2.16b + +ushll v14.8h, v11.8b, #7 +ushll v15.8h, v12.8b, #2 +uxtl v16.8h, v13.8b +ushll2 v17.8h, v11.16b, #7 +ushll2 v18.8h, v12.16b, #2 +uxtl2 v19.8h, v13.16b + + +orr v20.16b, v3.16b, v4.16b +orr v20.16b, v20.16b, v5.16b +orr v21.16b, v8.16b, v9.16b +orr v21.16b, v21.16b, v10.16b + +orr v22.16b, v14.16b, v15.16b +orr v22.16b, v22.16b, v16.16b +orr v23.16b, v17.16b, v18.16b +orr v23.16b, v23.16b, v19.16b + +st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x1], #64 + +b L4 + +L2: +cmp x2, #2 +blt L1 + +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +and v0.16b, v0.16b, v31.16b // r & ~7 +and v1.16b, v1.16b, v31.16b // g & ~7 +sub x2, x2, #2 +ushr v2.16b, v2.16b, #3 // b >> 3 + +ushll v3.8h, v0.8b, #7 +ushll v4.8h, v1.8b, #2 +uxtl v5.8h, v2.8b +ushll2 v8.8h, v0.16b, #7 +ushll2 v9.8h, v1.16b, #2 +uxtl2 v10.8h, v2.16b + +orr v6.16b, v3.16b, v4.16b +orr v6.16b, v6.16b, v5.16b +orr v7.16b, v8.16b, v9.16b +orr v7.16b, v7.16b, v10.16b + +st1 {v6.8h, v7.8h}, [x1], #32 + +b L2 + +L1: +cmp x2, #1 +blt End + +ld3 {v0.8b, v1.8b, v2.8b}, [x0], #24 +and v0.8b, v0.8b, v31.8b // r & ~7 +and v1.8b, v1.8b, v31.8b // g & ~7 +ushr v2.8b, v2.8b, #3 // b >> 3 +ushll v0.8h, v0.8b, #7 +ushll v1.8h, v1.8b, #2 +uxtl v2.8h, v2.8b +orr v0.16b, v0.16b, v1.16b +orr v0.16b, v0.16b, v2.16b + +st1 {v0.8h}, [x1], #16 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNRGBToBGR565.S b/source/backend/cpu/arm/arm64/MNNRGBToBGR565.S new file mode 100644 index 000000000..359ba392b --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNRGBToBGR565.S @@ -0,0 +1,187 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +// void MNNRGBToBGR565Fast(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNRGBToBGR565Fast +// x0: source, x1: dest, x2: count, x3: c +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +movi v31.16b, #8 +neg v31.16b, v31.16b + +L6: +cmp x2, #6 +blt L4 + +movi v30.16b, #4 +neg v30.16b, v30.16b + +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v11.16b, v12.16b, v13.16b}, [x0], #48 +ld3 {v24.16b, v25.16b, v26.16b}, [x0], #48 +and v0.16b, v0.16b, v31.16b // r & ~7 +and v1.16b, v1.16b, v30.16b // g & ~3 +ushr v2.16b, v2.16b, #3 // b >> 3 +and v11.16b, v11.16b, v31.16b // r & ~7 +and v12.16b, v12.16b, v30.16b // g & ~3 +ushr v13.16b, v13.16b, #3 // b >> 3 +and v24.16b, v24.16b, v31.16b // r & ~7 +and v25.16b, v25.16b, v30.16b // g & ~3 +ushr v26.16b, v26.16b, #3 // b >> 3 +sub x2, x2, #6 + +ushll v3.8h, v0.8b, #7 +shl v3.8h, v3.8h, #1 +ushll v4.8h, v1.8b, #3 +uxtl v5.8h, v2.8b +ushll2 v8.8h, v0.16b, #7 +shl v8.8h, v8.8h, #1 +ushll2 v9.8h, v1.16b, #3 +uxtl2 v10.8h, v2.16b + +ushll v14.8h, v11.8b, #7 +shl v14.8h, v14.8h, #1 +ushll v15.8h, v12.8b, #3 +uxtl v16.8h, v13.8b +ushll2 v17.8h, v11.16b, #7 +shl v17.8h, v17.8h, #1 +ushll2 v18.8h, v12.16b, #3 +uxtl2 v19.8h, v13.16b + +ushll v6.8h, v24.8b, #7 +shl v6.8h, v6.8h, #1 +ushll v7.8h, v25.8b, #3 +uxtl v27.8h, v26.8b +ushll2 v28.8h, v24.16b, #7 +shl v28.8h, v28.8h, #1 +ushll2 v29.8h, v25.16b, #3 +uxtl2 v30.8h, v26.16b + +orr v0.16b, v3.16b, v4.16b +orr v0.16b, v0.16b, v5.16b +orr v1.16b, v8.16b, v9.16b +orr v1.16b, v1.16b, v10.16b + +orr v2.16b, v14.16b, v15.16b +orr v2.16b, v2.16b, v16.16b +orr v3.16b, v17.16b, v18.16b +orr v3.16b, v3.16b, v19.16b + +orr v4.16b, v6.16b, v7.16b +orr v4.16b, v4.16b, v27.16b +orr v5.16b, v28.16b, v29.16b +orr v5.16b, v5.16b, v30.16b + +st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 +st1 {v4.8h, v5.8h}, [x1], #32 + +b L6 + +L4: +movi v30.16b, #4 +neg v30.16b, v30.16b +cmp x2, #4 +blt L2 + +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v11.16b, v12.16b, v13.16b}, [x0], #48 +and v0.16b, v0.16b, v31.16b // r & ~7 +and v1.16b, v1.16b, v30.16b // g & ~3 +ushr v2.16b, v2.16b, #3 // b >> 3 +and v11.16b, v11.16b, v31.16b // r & ~7 +and v12.16b, v12.16b, v30.16b // g & ~3 +ushr v13.16b, v13.16b, #3 // b >> 3 +sub x2, x2, #4 + +ushll v3.8h, v0.8b, #7 +shl v3.8h, v3.8h, #1 +ushll v4.8h, v1.8b, #3 +uxtl v5.8h, v2.8b +ushll2 v8.8h, v0.16b, #7 +shl v8.8h, v8.8h, #1 +ushll2 v9.8h, v1.16b, #3 +uxtl2 v10.8h, v2.16b + +ushll v14.8h, v11.8b, #7 +shl v14.8h, v14.8h, #1 +ushll v15.8h, v12.8b, #3 +uxtl v16.8h, v13.8b +ushll2 v17.8h, v11.16b, #7 +shl v17.8h, v17.8h, #1 +ushll2 v18.8h, v12.16b, #3 +uxtl2 v19.8h, v13.16b + + +orr v20.16b, v3.16b, v4.16b +orr v20.16b, v20.16b, v5.16b +orr v21.16b, v8.16b, v9.16b +orr v21.16b, v21.16b, v10.16b + +orr v22.16b, v14.16b, v15.16b +orr v22.16b, v22.16b, v16.16b +orr v23.16b, v17.16b, v18.16b +orr v23.16b, v23.16b, v19.16b + +st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x1], #64 + +b L4 + +L2: +cmp x2, #2 +blt L1 + +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +and v0.16b, v0.16b, v31.16b // r & ~7 +and v1.16b, v1.16b, v30.16b // g & ~7 +sub x2, x2, #2 +ushr v2.16b, v2.16b, #3 // b >> 3 + +ushll v3.8h, v0.8b, #7 +shl v3.8h, v3.8h, #1 +ushll v4.8h, v1.8b, #3 +uxtl v5.8h, v2.8b +ushll2 v8.8h, v0.16b, #7 +shl v8.8h, v8.8h, #1 +ushll2 v9.8h, v1.16b, #3 +uxtl2 v10.8h, v2.16b + +orr v6.16b, v3.16b, v4.16b +orr v6.16b, v6.16b, v5.16b +orr v7.16b, v8.16b, v9.16b +orr v7.16b, v7.16b, v10.16b + +st1 {v6.8h, v7.8h}, [x1], #32 + +b L2 + +L1: +cmp x2, #1 +blt End + +ld3 {v0.8b, v1.8b, v2.8b}, [x0], #24 +and v0.8b, v0.8b, v31.8b // r & ~7 +and v1.8b, v1.8b, v30.8b // g & ~7 +ushr v2.8b, v2.8b, #3 // b >> 3 +ushll v0.8h, v0.8b, #7 +shl v0.8h, v0.8h, #1 +ushll v1.8h, v1.8b, #3 +uxtl v2.8h, v2.8b +orr v0.16b, v0.16b, v1.16b +orr v0.16b, v0.16b, v2.16b + +st1 {v0.8h}, [x1], #16 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNRGBToGRAYFast.S b/source/backend/cpu/arm/arm64/MNNRGBToGRAYFast.S new file mode 100644 index 000000000..09ffb3ac7 --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNRGBToGRAYFast.S @@ -0,0 +1,92 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +// void MNNRGBToGRAYFast(const unsigned char* source, unsigned char* dest, size_t count); +asm_function MNNRGBToGRAYFast +// x0: source, x1: dest, x2: count +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +movi v29.16b, #19 +movi v30.16b, #38 +movi v31.16b, #7 + +L4: +cmp x2, #4 +blt L2 + +sub x2, x2, #4 +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 +ld3 {v14.16b, v15.16b, v16.16b}, [x0], #48 + +umull v4.8h, v0.8b, v29.8b // b*7 +umlal v4.8h, v1.8b, v30.8b // g*38 +umlal v4.8h, v2.8b, v31.8b // r*19 + +umull2 v7.8h, v0.16b, v29.16b // b*7 +umlal2 v7.8h, v1.16b, v30.16b // g*38 +umlal2 v7.8h, v2.16b, v31.16b // r*19 + +umull v18.8h, v14.8b, v29.8b // b*7 +umlal v18.8h, v15.8b, v30.8b // g*38 +umlal v18.8h, v16.8b, v31.8b // r*19 + +umull2 v21.8h, v14.16b, v29.16b // b*7 +umlal2 v21.8h, v15.16b, v30.16b // g*38 +umlal2 v21.8h, v16.16b, v31.16b // r*19 + +uqshrn v4.8b, v4.8h, #6 +uqshrn2 v4.16b, v7.8h, #6 +uqshrn v5.8b, v18.8h, #6 +uqshrn2 v5.16b, v21.8h, #6 + +st1 {v4.16b, v5.16b}, [x1], #32 +b L4 + +L2: +cmp x2, #2 +blt L1 + +sub x2, x2, #2 +ld3 {v0.16b, v1.16b, v2.16b}, [x0], #48 + +umull v4.8h, v0.8b, v29.8b // b*7 +umlal v4.8h, v1.8b, v30.8b // g*38 +umlal v4.8h, v2.8b, v31.8b // r*19 + +umull2 v7.8h, v0.16b, v29.16b // b*7 +umlal2 v7.8h, v1.16b, v30.16b // g*38 +umlal2 v7.8h, v2.16b, v31.16b // r*19 + +uqshrn v4.8b, v4.8h, #6 +uqshrn2 v4.16b, v7.8h, #6 + +st1 {v4.16b}, [x1], #16 +b L2 + +L1: +cmp x2, #1 +blt End +ld3 {v0.8b, v1.8b, v2.8b}, [x0], #24 + +umull v4.8h, v0.8b, v29.8b // b*7 +umlal v4.8h, v1.8b, v30.8b // g*38 +umlal v4.8h, v2.8b, v31.8b // r*19 + +uqshrn v10.8b, v4.8h, #6 + +st1 {v10.8b}, [x1], #8 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif diff --git a/source/backend/cpu/arm/arm64/MNNSamplerC3BilinearOpt.S b/source/backend/cpu/arm/arm64/MNNSamplerC3BilinearOpt.S new file mode 100644 index 000000000..b809984c3 --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNSamplerC3BilinearOpt.S @@ -0,0 +1,171 @@ +// +// MNNSamplerC3BilinearOpt.S +// MNN +// +// Created by MNN on 2018/11/20. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +//void MNNSamplerC3BilinearOpt(const unsigned char* source, unsigned char* dest, float* points, size_t count, size_t iw, size_t ih, size_t yStride); +asm_function MNNSamplerC3BilinearOpt + +//Auto: x0:source, x1:dest, x2:points, x3:count +//x4: xMax, x5: yMax, x6:yStride + +movi v19.4s, #0 + +ld1 {v0.2s, v1.2s}, [x2] +//L4: +//cmp x3, #4 +//blt L1 +//dup v16.4s, w4 +//dup v17.4s, w5 +//movi v3.2s, #4 +//scvtf v3.2s, v3.2s +//fmul v3.2s, v3.2s, v1.2s +//dup v25.4s, v3.s[0] +//dup v26.4s, v3.s[1] +// +//fadd v2.2s, v0.2s, v1.2s +//mov v4.s[0], v0.s[0] +//fadd v3.2s, v2.2s, v1.2s +//mov v5.s[0], v0.s[1] +//mov v4.s[1], v2.s[0] +//mov v5.s[1], v2.s[1] +//mov v4.s[2], v3.s[0] +//fadd v2.2s, v3.2s, v1.2s +//mov v5.s[2], v3.s[1] +//mov v4.s[3], v2.s[0] +//mov v5.s[3], v2.s[1] +// +//dup v23.4s, w6 +//movi v24.4s, #4 +//dup v22.2d, x0 +// +//L4Loop: +//fcvtns v6.4s, v4.4s +//fcvtns v7.4s, v5.4s +// +//smin v6.4s, v6.4s, v16.4s +//smin v7.4s, v7.4s, v17.4s +//smax v6.4s, v6.4s, v19.4s +//smax v7.4s, v7.4s, v19.4s +// +//mul v7.4s, v7.4s, v23.4s +//mla v7.4s, v6.4s, v24.4s +//uxtl v6.2d, v7.2s +//uxtl2 v7.2d, v7.4s +//add v6.2d, v6.2d, v22.2d +//add v7.2d, v7.2d, v22.2d +// +//mov x12, v6.d[0] +//mov x13, v6.d[1] +//ld1 {v3.s}[0], [x12] +//mov x12, v7.d[0] +//ld1 {v3.s}[1], [x13] +//fadd v5.4s, v26.4s, v5.4s +//mov x13, v7.d[1] +//ld1 {v3.s}[2], [x12] +//fadd v4.4s, v25.4s, v4.4s +//ld1 {v3.s}[3], [x13] +// +//st1 {v3.4s}, [x1], #16 +// +// +//sub x3, x3, #4 +//cmp x3, #4 +//bge L4Loop +// +//mov v0.s[0], v4.s[0] +//mov v0.s[1], v5.s[0] + + +L1: +cmp x3, #0 +beq End +mov v16.s[0], w4 +mov v16.s[1], w5 // v16:[xMax, yMax] +mov w12, #3 +mov v7.s[0], w12 // bpp=4 +mov v7.s[1], w6 // yStride +dup v20.2d, x0 + +L1Loop: + +fcvtzs v2.2s, v0.2s // [x0, y0] +frintm v4.2s, v0.2s +smax v2.2s, v2.2s, v19.2s // max(0, y) +fcvtps v3.2s, v0.2s // [x1, y1] +fabd v4.2s, v0.2s, v4.2s // (xF, yF) +smax v3.2s, v3.2s, v19.2s +smin v2.2s, v2.2s, v16.2s +smin v3.2s, v3.2s, v16.2s +mul v2.2s, v2.2s, v7.2s // [bpp * x0, y0 * yStride] +mul v3.2s, v3.2s, v7.2s // [bpp * x1, y1 * yStride] +mov v2.s[2], v3.s[0] // v2: [bpp*x0, y0*yStride, bpp*x1, y0*yStride] +mov v3.s[2], v2.s[0] // v3: [bpp*x1, y1*yStride, bpp*x0, y1*yStride] +mov v2.s[3], v2.s[1] +mov v3.s[3], v3.s[1] + +uaddlp v2.2d, v2.4s // [c00, c01] +uaddlp v3.2d, v3.4s // [c11, c10] + +add v2.2d, v20.2d, v2.2d +add v3.2d, v20.2d, v3.2d +mov x4, v2.d[0] +mov x5, v2.d[1] +ld1 {v5.h}[0], [x4], #2 +ld1 {v5.b}[2], [x4] +ld1 {v5.h}[2], [x5], #2 +ld1 {v5.b}[6], [x5] +mov x4, v3.d[0] +uxtl v5.8h, v5.8b +mov x5, v3.d[1] +ld1 {v6.h}[0], [x4], #2 +ld1 {v6.b}[2], [x4] +ld1 {v6.h}[2], [x5], #2 +ld1 {v6.b}[6], [x5] +uxtl v6.8h, v6.8b +//Now v2, v3 is of no use + +//v2: LT, v3: RT, v5: LB, v6:BT +uxtl v2.4s, v5.4h // c00 +uxtl2 v3.4s, v5.8h // c01 + +ucvtf v2.4s, v2.4s +uxtl v5.4s, v6.4h // c11 +ucvtf v3.4s, v3.4s +uxtl2 v6.4s, v6.8h // c10 +ucvtf v5.4s, v5.4s +ucvtf v6.4s, v6.4s + +fsub v3.4s, v3.4s, v2.4s +fsub v5.4s, v5.4s, v6.4s +fmla v2.4s, v3.4s, v4.s[0] // (c01-c00)*xF+c00 +fmla v6.4s, v5.4s, v4.s[0] // (c11-c10)*xF+c10 + +fsub v6.4s, v6.4s, v2.4s +fmla v2.4s, v6.4s, v4.s[1] + +fcvtzs v2.4s, v2.4s +uqxtn v2.4h, v2.4s +uqxtn v2.8b, v2.8h + +fadd v0.2s, v0.2s, v1.2s +subs x3, x3, #1 +st1 {v2.h}[0], [x1], #2 +st1 {v2.b}[0], [x1], #1 + + +bne L1Loop + +End: + +ret +#endif diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_16x4_w4_Unit.S b/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_16x4_w4_Unit.S index fa8258b66..90ad5673b 100644 --- a/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_16x4_w4_Unit.S +++ b/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_16x4_w4_Unit.S @@ -129,24 +129,17 @@ beq L2Dz cmp x8, #1 beq L1Dz -//cmp w13, #1 -//bne L4LoopDz -//sub x4, x4, #8 // post->scale != nullptr && post->useInt8 == 1. L4LoopDz: mov x8, x1 mov x22, x2 - ld1 {v0.16b, v1.16b}, [x2], #32 // weight + ld1 {v10.16b, v11.16b}, [x2], #32 // weight ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64 // src // int4->int8 movi v8.16b, #15 - ushr v10.16b, v0.16b, #4 - and v11.16b, v0.16b, v8.16b - ushr v12.16b, v1.16b, #4 - and v13.16b, v1.16b, v8.16b - zip1 v0.16b, v10.16b, v11.16b - zip2 v1.16b, v10.16b, v11.16b - zip1 v2.16b, v12.16b, v13.16b - zip2 v3.16b, v12.16b, v13.16b + ushr v0.16b, v10.16b, #4 + and v2.16b, v10.16b, v8.16b + ushr v1.16b, v11.16b, #4 + and v3.16b, v11.16b, v8.16b smull v8.8h, v0.8b, v4.8b smull v9.8h, v1.8b, v4.8b @@ -207,17 +200,13 @@ L4LoopDz: L4LoopSz: ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64 - ld1 {v0.16b, v1.16b}, [x2], #32 + ld1 {v10.16b, v11.16b}, [x2], #32 // int4->int8 movi v8.16b, #15 - ushr v10.16b, v0.16b, #4 - and v11.16b, v0.16b, v8.16b - ushr v12.16b, v1.16b, #4 - and v13.16b, v1.16b, v8.16b - zip1 v0.16b, v10.16b, v11.16b - zip2 v1.16b, v10.16b, v11.16b - zip1 v2.16b, v12.16b, v13.16b - zip2 v3.16b, v12.16b, v13.16b + ushr v0.16b, v10.16b, #4 + and v2.16b, v10.16b, v8.16b + ushr v1.16b, v11.16b, #4 + and v3.16b, v11.16b, v8.16b smull v8.8h, v0.8b, v4.8b smull v9.8h, v1.8b, v4.8b @@ -355,19 +344,15 @@ sub x4, x4, #8 L3LoopDz: mov x8, x1 mov x22, x2 - ld1 {v0.16b, v1.16b}, [x2], #32 + ld1 {v10.16b, v11.16b}, [x2], #32 ld1 {v4.16b, v5.16b, v6.16b}, [x1], #48 add x1, x1, #16 // int4->int8 movi v8.16b, #15 - ushr v10.16b, v0.16b, #4 - and v11.16b, v0.16b, v8.16b - ushr v12.16b, v1.16b, #4 - and v13.16b, v1.16b, v8.16b - zip1 v0.16b, v10.16b, v11.16b - zip2 v1.16b, v10.16b, v11.16b - zip1 v2.16b, v12.16b, v13.16b - zip2 v3.16b, v12.16b, v13.16b + ushr v0.16b, v10.16b, #4 + and v2.16b, v10.16b, v8.16b + ushr v1.16b, v11.16b, #4 + and v3.16b, v11.16b, v8.16b smull v8.8h, v0.8b, v4.8b smull v9.8h, v1.8b, v4.8b @@ -418,17 +403,13 @@ L3LoopDz: L3LoopSz: ld1 {v4.16b, v5.16b, v6.16b}, [x1], #48 - ld1 {v0.16b, v1.16b}, [x2], #32 + ld1 {v10.16b, v11.16b}, [x2], #32 // int4->int8 movi v8.16b, #15 - ushr v10.16b, v0.16b, #4 - and v11.16b, v0.16b, v8.16b - ushr v12.16b, v1.16b, #4 - and v13.16b, v1.16b, v8.16b - zip1 v0.16b, v10.16b, v11.16b - zip2 v1.16b, v10.16b, v11.16b - zip1 v2.16b, v12.16b, v13.16b - zip2 v3.16b, v12.16b, v13.16b + ushr v0.16b, v10.16b, #4 + and v2.16b, v10.16b, v8.16b + ushr v1.16b, v11.16b, #4 + and v3.16b, v11.16b, v8.16b smull v8.8h, v0.8b, v4.8b smull v9.8h, v1.8b, v4.8b @@ -548,20 +529,15 @@ L2Dz: L2LoopDz: mov x8, x1 mov x22, x2 - ld1 {v0.16b, v1.16b}, [x2], #32 + ld1 {v10.16b, v11.16b}, [x2], #32 ld1 {v4.16b, v5.16b}, [x1], #32 // int4->int8 movi v8.16b, #15 - ushr v10.16b, v0.16b, #4 - and v11.16b, v0.16b, v8.16b - ushr v12.16b, v1.16b, #4 - and v13.16b, v1.16b, v8.16b - zip1 v0.16b, v10.16b, v11.16b - zip2 v1.16b, v10.16b, v11.16b - zip1 v2.16b, v12.16b, v13.16b - zip2 v3.16b, v12.16b, v13.16b - - + ushr v0.16b, v10.16b, #4 + and v2.16b, v10.16b, v8.16b + ushr v1.16b, v11.16b, #4 + and v3.16b, v11.16b, v8.16b + smull v8.8h, v0.8b, v4.8b smull v9.8h, v1.8b, v4.8b smull v10.8h, v2.8b, v4.8b @@ -595,17 +571,13 @@ L2LoopDz: L2LoopSz: ld1 {v4.16b, v5.16b}, [x1], #32 - ld1 {v0.16b, v1.16b}, [x2], #32 + ld1 {v10.16b, v11.16b}, [x2], #32 // int4->int8 movi v8.16b, #15 - ushr v10.16b, v0.16b, #4 - and v11.16b, v0.16b, v8.16b - ushr v12.16b, v1.16b, #4 - and v13.16b, v1.16b, v8.16b - zip1 v0.16b, v10.16b, v11.16b - zip2 v1.16b, v10.16b, v11.16b - zip1 v2.16b, v12.16b, v13.16b - zip2 v3.16b, v12.16b, v13.16b + ushr v0.16b, v10.16b, #4 + and v2.16b, v10.16b, v8.16b + ushr v1.16b, v11.16b, #4 + and v3.16b, v11.16b, v8.16b smull v8.8h, v0.8b, v4.8b smull v9.8h, v1.8b, v4.8b @@ -699,17 +671,14 @@ L1Dz: L1LoopDz: mov x8, x1 mov x22, x2 - ld1 {v0.16b, v1.16b}, [x2], #32 + ld1 {v10.16b, v11.16b}, [x2], #32 // int4->int8 movi v8.16b, #15 - ushr v10.16b, v0.16b, #4 - and v11.16b, v0.16b, v8.16b - ushr v12.16b, v1.16b, #4 - and v13.16b, v1.16b, v8.16b - zip1 v0.16b, v10.16b, v11.16b - zip2 v1.16b, v10.16b, v11.16b - zip1 v2.16b, v12.16b, v13.16b - zip2 v3.16b, v12.16b, v13.16b + ushr v0.16b, v10.16b, #4 + and v2.16b, v10.16b, v8.16b + ushr v1.16b, v11.16b, #4 + and v3.16b, v11.16b, v8.16b + dup v16.4s, wzr dup v17.4s, wzr ld1 {v4.16b}, [x1], #16 @@ -739,19 +708,14 @@ L1LoopDz: sadalp v22.4s, v14.8h sadalp v23.4s, v15.8h - ld1 {v0.16b, v1.16b}, [x2], #32 + ld1 {v10.16b, v11.16b}, [x2], #32 add x1, x1, #48 // int4->int8 movi v8.16b, #15 - ushr v10.16b, v0.16b, #4 - and v11.16b, v0.16b, v8.16b - ushr v12.16b, v1.16b, #4 - and v13.16b, v1.16b, v8.16b - zip1 v0.16b, v10.16b, v11.16b - zip2 v1.16b, v10.16b, v11.16b - zip1 v2.16b, v12.16b, v13.16b - zip2 v3.16b, v12.16b, v13.16b - + ushr v0.16b, v10.16b, #4 + and v2.16b, v10.16b, v8.16b + ushr v1.16b, v11.16b, #4 + and v3.16b, v11.16b, v8.16b smull v8.8h, v0.8b, v4.8b smull v9.8h, v1.8b, v4.8b smull v10.8h, v2.8b, v4.8b diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit.S b/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit.S index fa9bc1f43..49b9567cc 100644 --- a/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit.S +++ b/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit.S @@ -126,13 +126,13 @@ stp x23, x24, [sp, #(16 * 8)] ldr x27, [x6, #64] // blockNum mul x27, x27, x3 // blockNum * src_depth_quad_perblock -lsl x15, x27, #3 // x15 = src_depth_quad * UNIT * SRC_UNIT * sizeof(int4_t) +lsl x15, x27, #4 // x15 = src_depth_quad * UNIT * SRC_UNIT * sizeof(int4_t) ldr x25, [x6, #40] // xKernelSum ldr x26, [x6, #48] // weightQuantBias ldr x24, [x6, #80] // extraScale -mov x21, #16 // sizeof(float) * UNIT +mov x21, #16 // sizeof(float) * pack ldr x23, [x6, #56] // fp32minmax Start: mov x22, #48 // src_steps @@ -158,13 +158,11 @@ L8LoopDz_TILE_12: SET_BIAS v28, v29, v30, v31 L8LoopSz_TILE_12: - ld1 {v3.d}[0], [x2], x15 // weight - ld1 {v4.d}[0], [x2], #8 + ld1 {v5.16b}, [x2], #16 // weight ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src // int4->int8 - ushr v5.16b, v3.16b, #4 - and v6.16b, v3.16b, v7.16b - zip1 v3.16b, v5.16b, v6.16b + ushr v3.16b, v5.16b, #4 + and v4.16b, v5.16b, v7.16b .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] @@ -175,21 +173,16 @@ L8LoopDz_TILE_12: .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] - // int4->int8 - ushr v5.16b, v4.16b, #4 - and v6.16b, v4.16b, v7.16b - zip1 v4.16b, v5.16b, v6.16b .inst 0x4f82e070 // sdot v16.4s, v3.16b, v2.4b[0] .inst 0x4fa2e071 // sdot v17.4s, v3.16b, v2.4b[1] .inst 0x4f82e872 // sdot v18.4s, v3.16b, v2.4b[2] .inst 0x4fa2e873 // sdot v19.4s, v3.16b, v2.4b[3] - .inst 0x4f80e094 // sdot v20.4s, v4.16b, v0.4b[0] .inst 0x4fa0e095 // sdot v21.4s, v4.16b, v0.4b[1] .inst 0x4f80e896 // sdot v22.4s, v4.16b, v0.4b[2] .inst 0x4fa0e897 // sdot v23.4s, v4.16b, v0.4b[3] - sub x2, x2, x15 + .inst 0x4f81e098 // sdot v24.4s, v4.16b, v1.4b[0] .inst 0x4fa1e099 // sdot v25.4s, v4.16b, v1.4b[1] .inst 0x4f81e89a // sdot v26.4s, v4.16b, v1.4b[2] @@ -202,8 +195,7 @@ L8LoopDz_TILE_12: bne L8LoopSz_TILE_12 L8LoopSzEnd_TILE_12: - // add x2, x2, x15 - add x2, x27, x15, LSL #1 + add x2, x27, x15 sub x5, x5, #2 L8Tile12Quan: @@ -313,7 +305,7 @@ L8LoopDz_TILE_12: L8Tile12LoopCheck: cmp x5, #1 bgt L8LoopDz_TILE_12 - blt End + cbz x5, End L4LoopDz_TILE_12: SET_BIAS v8, v9, v10, v11 @@ -322,12 +314,10 @@ L4LoopDz_TILE_12: movi v7.16b, #15 L4LoopSz_TILE_12: - ld1 {v3.d}[0], [x2], #8 // weight + ld1 {v5.16b}, [x2], #16 // weight ld1 {v0.16b, v1.16b, v2.16b}, [x1], #48 // src // int4->int8 - ushr v5.16b, v3.16b, #4 - and v6.16b, v3.16b, v7.16b - zip1 v3.16b, v5.16b, v6.16b + ushr v3.16b, v5.16b, #4 .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] @@ -437,27 +427,22 @@ L8LoopDz_TILE_8: SET_BIAS v20, v21, v22, v23 L8LoopSz_TILE_8: - ld1 {v3.d}[0], [x12], x15 // weight - ld1 {v4.d}[0], [x12], #8 + ld1 {v5.16b}, [x12], #16 // weight ld1 {v0.16b, v1.16b}, [x11], x22 // src // int4->int8 - ushr v5.16b, v3.16b, #4 - and v6.16b, v3.16b, v7.16b - zip1 v3.16b, v5.16b, v6.16b + ushr v3.16b, v5.16b, #4 + and v4.16b, v5.16b, v7.16b .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] - // int4->int8 - ushr v5.16b, v4.16b, #4 - and v6.16b, v4.16b, v7.16b - zip1 v4.16b, v5.16b, v6.16b + .inst 0x4f81e06c // sdot v12.4s, v3.16b, v1.4b[0] .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] - sub x12, x12, x15 + .inst 0x4f80e090 // sdot v16.4s, v4.16b, v0.4b[0] .inst 0x4fa0e091 // sdot v17.4s, v4.16b, v0.4b[1] .inst 0x4f80e892 // sdot v18.4s, v4.16b, v0.4b[2] @@ -471,7 +456,7 @@ L8LoopDz_TILE_8: L8LoopSzEnd_TILE_8: //add x12, x12, x15 - add x12, x27, x15, LSL #1 + add x12, x27, x15 sub x14, x14, #2 L8Tile8Quan: @@ -567,12 +552,10 @@ L4LoopDz_TILE_8: SET_BIAS v12, v13, v14, v15 L4LoopSz_TILE_8: - ld1 {v3.d}[0], [x12], #8 // weight + ld1 {v5.16b}, [x12], #16 // weight ld1 {v0.16b, v1.16b}, [x11], x22 // src // int4->int8 - ushr v5.16b, v3.16b, #4 - and v6.16b, v3.16b, v7.16b - zip1 v3.16b, v5.16b, v6.16b + ushr v3.16b, v5.16b, #4 .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] @@ -652,7 +635,7 @@ Tile8_End_Offset: TILE_4: cmp x7, #4 - blt TILE_1 + blt TILE_1_Init mov x10, x0 mov x12, x2 mov x14, x5 @@ -672,24 +655,18 @@ L8LoopDz_TILE_4: SET_BIAS v12, v13, v14, v15 L8LoopSz_TILE_4: - ld1 {v3.d}[0], [x12], x15 // weight + ld1 {v5.16b}, [x12], #16 // weight ld1 {v0.16b}, [x11], x22 // src - ld1 {v4.d}[0], [x12], #8 // weight // int4->int8 - ushr v5.16b, v3.16b, #4 - and v6.16b, v3.16b, v7.16b - zip1 v3.16b, v5.16b, v6.16b + ushr v3.16b, v5.16b, #4 + and v4.16b, v5.16b, v7.16b .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] - // int4->int8 - ushr v5.16b, v4.16b, #4 - and v6.16b, v4.16b, v7.16b - zip1 v4.16b, v5.16b, v6.16b + subs x13, x13, #1 - sub x12, x12, x15 .inst 0x4f80e08c // sdot v12.4s, v4.16b, v0.4b[0] .inst 0x4fa0e08d // sdot v13.4s, v4.16b, v0.4b[1] .inst 0x4f80e88e // sdot v14.4s, v4.16b, v0.4b[2] @@ -698,7 +675,7 @@ L8LoopDz_TILE_4: L8LoopSzEnd_TILE_4: //add x12, x12, x15 - add x12, x27, x15, LSL #1 + add x12, x27, x15 sub x14, x14, #2 L8Tile4Quan: @@ -764,12 +741,10 @@ L4LoopDz_TILE_4: SET_BIAS v8, v9, v10, v11 L4LoopSz_TILE_4: - ld1 {v3.d}[0], [x12], #8 // weight + ld1 {v5.16b}, [x12], #16 // weight ld1 {v0.16b}, [x11], x22 // src // int4->int8 - ushr v5.16b, v3.16b, #4 - and v6.16b, v3.16b, v7.16b - zip1 v3.16b, v5.16b, v6.16b + ushr v3.16b, v5.16b, #4 subs x13, x13, #1 .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] @@ -826,9 +801,14 @@ Tile4_End_Offset: add x1, x1, #16 add x25, x25, #16 -TILE_1: +TILE_1_Init: cbz x7, End movi v7.16b, #15 + cbz x23, TILE_1 + ld1r {v26.4s}, [x23], #4 // f32 min + ld1r {v27.4s}, [x23] // f32 max + sub x23, x23, #4 +TILE_1: mov x10, x0 mov x12, x2 mov x14, x5 @@ -845,28 +825,64 @@ L8LoopDz_TILE_1: movi v8.16b, #0 movi v9.16b, #0 - L8LoopSz_TILE_1: - ld1 {v3.d}[0], [x12], x15 // weight + cmp x13, #4 + blt L8LoopSz_TILE_1_lu1 + //lsl x22, x22, #2 + + L8LoopSz_TILE_1_lu4: + ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x12], #64 // weight: hu=0,1,2,3,pack=0~7 ld1 {v0.s}[0], [x11], x22 // src - ld1 {v4.d}[0], [x12], #8 // weight + ld1 {v0.s}[1], [x11], x22 + ld1 {v0.s}[2], [x11], x22 + ld1 {v0.s}[3], [x11], x22 + + sub x13, x13, #4 // int4->int8 - ushr v5.16b, v3.16b, #4 - and v6.16b, v3.16b, v7.16b - zip1 v3.16b, v5.16b, v6.16b + ushr v12.16b, v3.16b, #4 + and v22.16b, v3.16b, v7.16b - .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + ushr v15.16b, v4.16b, #4 + and v23.16b, v4.16b, v7.16b + + ushr v18.16b, v5.16b, #4 + and v24.16b, v5.16b, v7.16b + + ushr v21.16b, v6.16b, #4 + and v25.16b, v6.16b, v7.16b + + cmp x13, #4 + .inst 0x4f80e188 // sdot v8.4s, v12.16b, v0.4b[0] + .inst 0x4f80e2c9 // sdot v9.4s, v22.16b, v0.4b[0] + .inst 0x4fa0e1e8 // sdot v8.4s, v15.16b, v0.4b[1] + .inst 0x4fa0e2e9 // sdot v9.4s, v23.16b, v0.4b[1] + .inst 0x4f80ea48 // sdot v8.4s, v18.16b, v0.4b[2] + .inst 0x4f80eb09 // sdot v9.4s, v24.16b, v0.4b[2] + .inst 0x4fa0eaa8 // sdot v8.4s, v21.16b, v0.4b[3] + .inst 0x4fa0eb29 // sdot v9.4s, v25.16b, v0.4b[3] + bge L8LoopSz_TILE_1_lu4 + + cbz x13, L8LoopSzEnd_TILE_1 + + L8LoopSz_TILE_1_lu1: + ld1 {v5.16b}, [x12], #16 // weight + ld1 {v0.s}[0], [x11], x22 // src + //ld1 {v4.d}[0], [x12], #8 // weight subs x13, x13, #1 // int4->int8 - ushr v5.16b, v4.16b, #4 - and v6.16b, v4.16b, v7.16b - zip1 v4.16b, v5.16b, v6.16b - sub x12, x12, x15 + ushr v3.16b, v5.16b, #4 + and v12.16b, v5.16b, v7.16b + + //ushr v10.16b, v4.16b, #4 + //and v11.16b, v4.16b, v7.16b + //zip1 v12.16b, v10.16b, v11.16b - .inst 0x4f80e089 // sdot v9.4s, v4.16b, v0.4b[0] - bne L8LoopSz_TILE_1 + //sub x12, x12, x15 + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + .inst 0x4f80e189 // sdot v9.4s, v12.16b, v0.4b[0] + bne L8LoopSz_TILE_1_lu1 L8LoopSzEnd_TILE_1: - add x12, x27, x15, LSL #1 + add x12, x27, x15 sub x14, x14, #2 L8Tile1Quan: @@ -903,9 +919,6 @@ L8LoopDz_TILE_1: TILE1_POST: cbz x23, TILE1_STORE - ld1r {v26.4s}, [x23], #4 // f32 min - ld1r {v27.4s}, [x23] // f32 max - sub x23, x23, #4 fmin v8.4s, v8.4s, v27.4s fmin v9.4s, v9.4s, v27.4s fmax v8.4s, v8.4s, v26.4s @@ -926,12 +939,10 @@ L4LoopDz_TILE_1: mov x13, x3 movi v8.16b, #0 L4LoopSz_TILE_1: - ld1 {v3.d}[0], [x12], #8 // weight + ld1 {v5.16b}, [x12], #16 // weight ld1 {v0.s}[0], [x11], x22 // src // int4->int8 - ushr v5.16b, v3.16b, #4 - and v6.16b, v3.16b, v7.16b - zip1 v3.16b, v5.16b, v6.16b + ushr v3.16b, v5.16b, #4 subs x13, x13, #1 .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] @@ -965,9 +976,6 @@ L4LoopDz_TILE_1: TILE1_L4_POST: cbz x23, TILE1_L4_STORE - ld1r {v26.4s}, [x23], #4 // f32 min - ld1r {v27.4s}, [x23] // f32 max - sub x23, x23, #4 fmax v8.4s, v8.4s, v26.4s fmin v8.4s, v8.4s, v27.4s TILE1_L4_STORE: @@ -978,11 +986,11 @@ cbz x24, Tile1_End_Offset add x24, x24, #4 Tile1_End_Offset: - sub x7, x7, #1 + subs x7, x7, #1 add x0, x0, x21 add x1, x1, #4 add x25, x25, #4 - b TILE_1 + bne TILE_1 End: ldp x23, x24, [sp, #(16 * 8)] diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_w4_Unit.S b/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_w4_Unit.S index b4cc330c2..891196103 100644 --- a/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_w4_Unit.S +++ b/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_w4_Unit.S @@ -138,6 +138,7 @@ LoopDz8_TILE_10: mov x11, x1 // src mov x12, x2 // weight mov x13, x3 // src_depth_quad + movi v2.16b, #15 SET_0_5 v12, v16, v20, v24, v28 // oc:0,1,0,1 SET_0_5 v13, v17, v21, v25, v29 // oc:2,3,2,3 @@ -146,7 +147,6 @@ LoopDz8_TILE_10: LoopSz_TILE_10: ld1 {v0.16b, v1.16b}, [x12], #32 // weight - movi v2.16b, #15 ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9 ld1 {v7.16b}, [x11], #16 // int4->int8 @@ -1066,7 +1066,64 @@ LoopDz_TILE_1: movi v17.4s, #0 movi v18.4s, #0 movi v19.4s, #0 -LoopSz_TILE_1: +cmp x13, #4 +blt LoopSz_TILE_1_lu1 + +LoopSz1_TILE_1_lu4: + ld1 {v5.16b, v6.16b, v7.16b, v8.16b}, [x12], #64 // weight + ld1 {v9.16b, v10.16b, v11.16b, v12.16b}, [x12], #64 + ld1 {v0.8b}, [x11], x22 // src + ld1 {v1.8b}, [x11], x22 + ld1 {v2.8b}, [x11], x22 + ld1 {v3.8b}, [x11], x22 + + // int4->int8 + ushr v4.16b, v5.16b, #4 + ushr v14.16b, v6.16b, #4 + and v13.16b, v5.16b, v28.16b + and v15.16b, v6.16b, v28.16b + + ushr v20.16b, v7.16b, #4 + ushr v21.16b, v8.16b, #4 + and v22.16b, v7.16b, v28.16b + and v23.16b, v8.16b, v28.16b + + ushr v24.16b, v9.16b, #4 + ushr v25.16b, v10.16b, #4 + and v26.16b, v9.16b, v28.16b + and v27.16b, v10.16b, v28.16b + + ushr v5.16b, v11.16b, #4 + ushr v6.16b, v12.16b, #4 + and v7.16b, v11.16b, v28.16b + and v8.16b, v12.16b, v28.16b + + sub x13, x13, #4 + + .inst 0x4e84a410 // smmla v16.4s, v0.16b, v4.16b + .inst 0x4e8ea411 // smmla v17.4s, v0.16b, v14.16b + .inst 0x4e8da412 // smmla v18.4s, v0.16b, v13.16b + .inst 0x4e8fa413 // smmla v19.4s, v0.16b, v15.16b + + .inst 0x4e94a430 // smmla v16.4s, v1.16b, v20.16b + .inst 0x4e95a431 // smmla v17.4s, v1.16b, v21.16b + .inst 0x4e96a432 // smmla v18.4s, v1.16b, v22.16b + .inst 0x4e97a433 // smmla v19.4s, v1.16b, v23.16b + cmp x13, #4 + .inst 0x4e98a450 // smmla v16.4s, v2.16b, v24.16b + .inst 0x4e99a451 // smmla v17.4s, v2.16b, v25.16b + .inst 0x4e9aa452 // smmla v18.4s, v2.16b, v26.16b + .inst 0x4e9ba453 // smmla v19.4s, v2.16b, v27.16b + + .inst 0x4e85a470 // smmla v16.4s, v3.16b, v5.16b + .inst 0x4e86a471 // smmla v17.4s, v3.16b, v6.16b + .inst 0x4e87a472 // smmla v18.4s, v3.16b, v7.16b + .inst 0x4e88a473 // smmla v19.4s, v3.16b, v8.16b + + bge LoopSz1_TILE_1_lu4 + cbz x13, LoopSzEnd_TILE_1 + +LoopSz_TILE_1_lu1: ld1 {v2.8b}, [x11], x22 // src // int4->int8 ld1 {v0.16b, v1.16b}, [x12], #32 // weight @@ -1080,7 +1137,7 @@ LoopSz_TILE_1: .inst 0x4e89a451 // smmla v17.4s, v2.16b, v9.16b .inst 0x4e8aa452 // smmla v18.4s, v2.16b, v10.16b .inst 0x4e8ba453 // smmla v19.4s, v2.16b, v11.16b - bne LoopSz_TILE_1 + bne LoopSz_TILE_1_lu1 LoopSzEnd_TILE_1: add x25, x25, x15 sub x24, x24, #2 diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMulRemain_int4.S b/source/backend/cpu/arm/arm64/normal_memory/MNNPackedMatMulRemain_int4.S similarity index 100% rename from source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMulRemain_int4.S rename to source/backend/cpu/arm/arm64/normal_memory/MNNPackedMatMulRemain_int4.S diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMulRemain_int8.S b/source/backend/cpu/arm/arm64/normal_memory/MNNPackedMatMulRemain_int8.S similarity index 100% rename from source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMulRemain_int8.S rename to source/backend/cpu/arm/arm64/normal_memory/MNNPackedMatMulRemain_int8.S diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMul_int4.S b/source/backend/cpu/arm/arm64/normal_memory/MNNPackedMatMul_int4.S similarity index 100% rename from source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMul_int4.S rename to source/backend/cpu/arm/arm64/normal_memory/MNNPackedMatMul_int4.S diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMul_int8.S b/source/backend/cpu/arm/arm64/normal_memory/MNNPackedMatMul_int8.S similarity index 100% rename from source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMul_int8.S rename to source/backend/cpu/arm/arm64/normal_memory/MNNPackedMatMul_int8.S diff --git a/source/backend/cpu/compute/CommonOptFunction.cpp b/source/backend/cpu/compute/CommonOptFunction.cpp index d806e0cb9..df1b70970 100644 --- a/source/backend/cpu/compute/CommonOptFunction.cpp +++ b/source/backend/cpu/compute/CommonOptFunction.cpp @@ -35,8 +35,8 @@ void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count) { } #endif -#ifdef MNN_LOW_MEMORY #ifndef __aarch64__ +#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM static void _MNNPackedMatMulRemain_int4(float* C, const float* A, const float* fB, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, int aStride, const float* k, const float* b) { auto B = reinterpret_cast(fB); auto h = parameter[2]; @@ -191,6 +191,9 @@ void MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t auto aStride = parameter[0] / sizeof(float); _MNNPackedMatMulRemain_int8(C, A, B, eSize, parameter, postParameters, bias, aStride, k, b); } +#endif // MNN_CPU_WEIGHT_DEQUANT_GEMM + +#ifdef MNN_LOW_MEMORY void MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) { // source: (ic/4, N, 4) auto srcStep = pack * realSize; @@ -261,8 +264,8 @@ void MNNDynamicUpdateConvBiasScale(float* newbias, float* newscale, float* oldbi } } -#endif // not __aarch64__ #endif // LOW_MEMORY +#endif // not __aarch64__ static void MNNSumByAxisLForMatmul_A(float* dest, int8_t* source, const float* scale, ssize_t realDstCount, SumByAxisParams sumParams) { @@ -3422,12 +3425,14 @@ void MNNCoreFunctionInit() { gCoreFunction->supportSDot = gCPUInfo.dot; gCoreFunction->supportI8mm = gCPUInfo.i8mm; gCoreFunction->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A; -#ifdef MNN_LOW_MEMORY +#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM // Weight Dequant Gemm Kernels gCoreFunction->MNNPackedMatMul_int4 = MNNPackedMatMul_int4; gCoreFunction->MNNPackedMatMulRemain_int4 = MNNPackedMatMulRemain_int4; gCoreFunction->MNNPackedMatMul_int8 = MNNPackedMatMul_int8; gCoreFunction->MNNPackedMatMulRemain_int8 = MNNPackedMatMulRemain_int8; +#endif +#ifdef MNN_LOW_MEMORY // Dynamic Quant Helper Functions gCoreFunction->MNNAbsMax = MNNAbsMaxFP32; gCoreFunction->MNNDynamicQuant = MNNDynamicQuantFP32; @@ -3470,10 +3475,11 @@ void MNNUnpackC2(double* dst, const double* src, size_t area, size_t depth, int* void MNNUnpackC2Float(float* dst, const float* src, size_t area, size_t depth, int* areaOffset, int pack) { MNNUnpackC2Common(dst, src, area, depth, areaOffset, pack); } - +#ifndef __aarch64__ void MNNPackInt8C2(float* dst, const float* src, size_t area, size_t depth, int* areaOffset) { MNNPackC2Common(dst, src, area, depth, areaOffset); } +#endif void MNNUnpackInt8C2(float* dst, const float* src, size_t area, size_t depth, int* areaOffset) { MNNUnpackC2Common(dst, src, area, depth, areaOffset); diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index 25fb13a8f..46b1f0739 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -21,8 +21,10 @@ namespace MNN { ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Op* op): CPUConvolution(op->main_as_Convolution2D()->common(), backend) {} ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Op* op, std::shared_ptr res): CPUConvolution(op->main_as_Convolution2D()->common(), backend), mResourceInt8(res) { - mMutableResource.reset(new MutableResourceInt8(res, backend)); - mValid = mMutableResource->mValid; + if (!res->mDynamicQuant) { + mMutableResource.reset(new MutableResourceInt8(res, backend)); + mValid = mMutableResource->mValid; + } } ConvInt8TiledExecutor::~ConvInt8TiledExecutor() { @@ -34,7 +36,9 @@ bool ConvInt8TiledExecutor::onClone(Backend* bn, const Op* op, Execution** dst) } ErrorCode ConvInt8TiledExecutor::onResize(const std::vector& inputs, const std::vector& outputs) { - mMutableResource->updateInputOutputScale(TensorUtils::getQuantInfo(inputs[0]), TensorUtils::getQuantInfo(outputs[0])); + if (nullptr != mMutableResource) { + mMutableResource->updateInputOutputScale(TensorUtils::getQuantInfo(inputs[0]), TensorUtils::getQuantInfo(outputs[0])); + } CPUConvolution::onResize(inputs, outputs); ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParamter, mCommon, inputs[0], outputs[0], mPadX, mPadY, static_cast(backend())->functions(), static_cast(backend())->int8Functions()); return NO_ERROR; @@ -234,18 +238,17 @@ static void GetResourceInt8(std::shared_ptr resour } } -DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Op* op, std::shared_ptr quanCommon) : mDynamicQuantExe(true), ConvInt8TiledExecutor(backend, op) { +DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Op* op, std::shared_ptr quanCommon) : ConvInt8TiledExecutor(backend, op) { auto convOp = op->main_as_Convolution2D(); auto core = static_cast(backend)->int8Functions(); auto gcore = static_cast(backend)->functions(); mResourceInt8.reset(new CPUConvolution::ResourceInt8); + mResourceInt8->mDynamicQuant = true; GetResourceInt8(mResourceInt8, quanCommon, convOp, backend); - mMutableResource.reset(new MutableResourceInt8(mResourceInt8, backend)); // dynamic quant int UNIT, SRC_UNIT, DST_XUNIT; core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); int pack = gcore->pack; - bool needPermuteInt4weight = ((UNIT == 8 && SRC_UNIT == 8 && DST_XUNIT ==10) || (UNIT == 64 && SRC_UNIT == 4 && DST_XUNIT ==4)); auto weightLength = quanCommon->weight.size(); int kernelCount = mCommon->kernelX() * mCommon->kernelY(); int oc = convOp->common()->outputCount(); @@ -264,9 +267,9 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O std::vector shape; if (SRC_UNIT > pack) { MNN_ASSERT(SRC_UNIT % pack == 0); - shape = {UP_DIV(oc, UNIT), UP_DIV(UP_DIV(ic, pack) * kernelCount, SRC_UNIT / pack), UNIT, SRC_UNIT}; + shape = {UP_DIV(oc, UNIT), UP_DIV(UP_DIV(ic, pack) * kernelCount, SRC_UNIT / pack), UNIT * SRC_UNIT / 2}; } else { - shape = {UP_DIV(oc, UNIT), UP_DIV(ic, SRC_UNIT) * kernelCount, UNIT, SRC_UNIT}; + shape = {UP_DIV(oc, UNIT), UP_DIV(ic, SRC_UNIT) * kernelCount, UNIT * SRC_UNIT / 2}; } mResourceInt8->mWeightInt8.reset(Tensor::createDevice(shape)); @@ -280,32 +283,30 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O ::memset(dstPtr, 0, mResourceInt8->mWeightInt8->size()); // Pack two int4-weight to one int8-weight. - if (false == needPermuteInt4weight) { - for (int i = 0; i < hU; i++) { - for (int j = 0; j < lU; j++) { - for (int k = 0; k < hP; k++) { - for (int id = 0; id < lP / 2; ++id) { - dstPtr[(i * lU * lP * hP + j * hP * lP + k * lP) / 2 + id] = srcPtr[((i * hP + k) * lP * lU + (j * lP)) / 2 + id]; - } - } + int cnt = lP * hP / 4; + int L = lU * lP; + for (int i = 0; i < hU; ++i) { + for (int j = 0; j < lU; ++j) { + for (int k = 0; k < cnt; ++k) { + int dstIndx0 = (i * lU * lP * hP + j * lP * hP) / 2 + (2 * k); + + int hpId0 = (2 * k + 1) / lP; + int lpId0 = (2 * k) % lP; + int hpId1 = (2 * (k + cnt) + 1) / lP; + int lpId1 = (2 * (k + cnt)) % lP; + int srcIndx0 = ((i * hP + hpId0) * L + (j * lP + lpId0)) / 2; + int srcIndx1 = ((i * hP + hpId1) * L + (j * lP + lpId1)) / 2; + int s0 = (srcPtr[srcIndx0] >> 4); + int s1 = (srcPtr[srcIndx0] & 15); + int s2 = (srcPtr[srcIndx1] >> 4); + int s3 = (srcPtr[srcIndx1] & 15); + int d0 = s0 * 16 + s2; + int d1 = s1 * 16 + s3; + + dstPtr[dstIndx0] = d0; + dstPtr[dstIndx0 + 1] = d1; } } - } else { - for (int i = 0; i < hU; i++) { - for (int j = 0; j < lU; j++) { - auto dst_ptr = dstPtr + (i * lU * lP * hP + j * hP * lP) / 2; - for (int k = 0; k < 16; k++) { - int col = k % 4; - int row = k / 4; - uint8_t s0 = srcPtr[((i * hP + row + 0) * lP * lU + j * lP) / 2 + col]; - uint8_t s1 = srcPtr[((i * hP + row + 4) * lP * lU + j * lP) / 2 + col]; - uint8_t d0 = (s0 & 0xf0) | (s1 >> 4); - uint8_t d1 = (s0 << 4) | (s1 & 0x0f); - dst_ptr[k * 2 + 0] = d0; - dst_ptr[k * 2 + 1] = d1; - } - } - } } } else { // std::shared_ptr srcWeight; @@ -331,27 +332,19 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O std::shared_ptr weightLow(Tensor::create({halflen})); auto dstint4Ptr = weightLow->host(); auto srcint4Ptr = mResourceInt8->mWeightInt8->host(); - if (false == needPermuteInt4weight) { - for (int i=0; i < halflen; ++i) { - int s0 = srcint4Ptr[2 * i + 0]; - int s1 = srcint4Ptr[2 * i + 1]; + int permuteUnit = UNIT * SRC_UNIT; + int halfPermuteStride = static_cast(permuteUnit / 2); + for (int i = 0; i < leng / permuteUnit; ++i) { + auto src0 = srcint4Ptr + i * permuteUnit; + auto dst0 = dstint4Ptr + i * halfPermuteStride; + for (int j = 0; j < halfPermuteStride; ++j) { + int s0 = src0[j]; + int s1 = src0[j + halfPermuteStride]; int d = (s0 + 8) * 16 + (s1 + 8); - dstint4Ptr[i] = d; - } - } else { - int permuteUnit = UNIT * SRC_UNIT; - int halfPermuteStride = static_cast(permuteUnit / 2); - for (int i = 0; i < leng / permuteUnit; ++i) { - auto src0 = srcint4Ptr + i * permuteUnit; - auto dst0 = dstint4Ptr + i * halfPermuteStride; - for (int j = 0; j < halfPermuteStride; ++j) { - int s0 = src0[j]; - int s1 = src0[j + halfPermuteStride]; - int d = (s0 + 8) * 16 + (s1 + 8); - dst0[j] = d; - } + dst0[j] = d; } } + // Update int4 weight to mWeightInt8. mResourceInt8->mWeightInt8 = weightLow; } else { @@ -372,8 +365,68 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O gcore->MNNFp32ToLowp(mResourceInt8->mReluThreshold.data(), reinterpret_cast(mResourceInt8->mReluThreshold.data()), 2); } } +static void _computeAlphaScale(Backend* backend, const Convolution2D* conv2d, std::shared_ptr resourceInt8) { + /* Used to compute weight quant scale and bias and weightKernelSum of type float. */ + bool quanBuffer = (conv2d->quanParameter() != nullptr && conv2d->quanParameter()->buffer() != nullptr); + MNN_ASSERT(quanBuffer || resourceInt8); + auto core = static_cast(backend)->functions(); + // common parameters + int outputCount = conv2d->common()->outputCount(); + int LSize = conv2d->common()->inputCount() * conv2d->common()->kernelX() * conv2d->common()->kernelY(); + int ocUp4 = ROUND_UP(outputCount, core->pack); + int8_t* weightOrigin; -DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Op* op, std::shared_ptr res) : mDynamicQuantExe(false), ConvInt8TiledExecutor(backend, op, res) { + // Save weight quant scale and bias: wf=scale*wi+bias + std::shared_ptr scaleBias(Tensor::createDevice({2 * ocUp4 * core->bytes})); + auto success = backend->onAcquireBuffer(scaleBias.get(), Backend::STATIC); + if (!success) { + MNN_ERROR("Alloc dequant scaleBias memory error\n"); + return; + } + auto alphaPtr = scaleBias->host(); + auto biasPtr = reinterpret_cast(reinterpret_cast(alphaPtr) + ocUp4 * core->bytes); + ::memset(alphaPtr, 0, 2 * ocUp4 * core->bytes); + + // Load quant scale and bias + weightOrigin = resourceInt8->mWeightInt8->host(); + auto wZero = resourceInt8->mWeightQuantZero->host(); // has packed to outputUp4 + auto wScale = resourceInt8->mOriginScale->host(); + int h = ocUp4; + if (core->bytes == 2) { + std::unique_ptr tmp(new int16_t[h]); + core->MNNFp32ToLowp(wScale, tmp.get(), h); + for (int i=0; i< h; ++i) { + reinterpret_cast(alphaPtr)[i] = tmp[i]; + reinterpret_cast(biasPtr)[i] = (-1.f) * wZero[i] * tmp[i]; + } + } else { + for (int i=0; i< h; ++i) { + alphaPtr[i] = wScale[i]; + biasPtr[i] = (-1.f) * wZero[i] * wScale[i]; + } + } + resourceInt8->mOriginScale = scaleBias; + + // Compute float weightKernelSum + resourceInt8->mWeightKernelSum.reset(Tensor::createDevice({ocUp4 * 4})); + success = backend->onAcquireBuffer(resourceInt8->mWeightKernelSum.get(), Backend::STATIC); + if (!success) { + MNN_ERROR("Alloc dequant mWeightKernelSum memory error\n"); + return; + } + auto weightKernelSum = resourceInt8->mWeightKernelSum->host(); + for (int i = 0; i < outputCount; ++i) { + int sum = 0; + for (int j = 0; j < LSize; ++j) { + sum = sum + static_cast(weightOrigin[j + i * LSize]); + } + auto scale = alphaPtr[i]; + auto bias = biasPtr[i]; + weightKernelSum[i] = static_cast(sum) * scale + LSize * bias; + } +} + +DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Op* op, std::shared_ptr res) : ConvInt8TiledExecutor(backend, op, res) { std::shared_ptr weightOrigin = mResourceInt8->mWeightInt8; auto convOp = op->main_as_Convolution2D(); mValid = _reorderWeightInside(backend, convOp->common(), weightOrigin, mResourceInt8->mWeightInt8); @@ -393,11 +446,11 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O mGemmKernel = core->Int8GemmKernelFast; } #endif - CPUConvolution::makeResourceNew(backend, convOp, mResourceInt8); + _computeAlphaScale(backend, convOp, mResourceInt8); } DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Op* op, const DenseConvInt8TiledExecutor& exe) - : ConvInt8TiledExecutor(backend, op, exe.mResourceInt8), mGemmKernel(exe.mGemmKernel), mDynamicQuantExe(exe.mDynamicQuantExe) { + : ConvInt8TiledExecutor(backend, op, exe.mResourceInt8), mGemmKernel(exe.mGemmKernel) { } DenseConvInt8TiledExecutor::~DenseConvInt8TiledExecutor() { @@ -427,14 +480,14 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input && outputs[0]->width() == inputs[0]->width() && outputs[0]->height() == inputs[0]->height() && mCommon->strideX() == 1 && mCommon->strideY() == 1 && mCommon->padX() == 0 && mCommon->padY() == 0 && outputs[0]->height() == 1 && outputs[0]->width() == 1; - mUseBatchQuan &= mDynamicQuantExe; + mUseBatchQuan &= mResourceInt8->mDynamicQuant; mUseBatchQuan &= (inputs[0]->batch() > 1); auto core = static_cast(backend())->int8Functions(); auto gcore =static_cast(backend())->functions(); int UNIT, SRC_UNIT, DST_XUNIT; core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); - if (mDynamicQuantExe == false) { + if (mResourceInt8->mDynamicQuant == false) { mMutableResource->updateInputOutputScale(TensorUtils::getQuantInfo(inputs[0]), TensorUtils::getQuantInfo(outputs[0])); CPUConvolution::onResize(inputs, outputs); ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParamter, mCommon, inputs[0], outputs[0], mPadX, mPadY, gcore, core); @@ -537,7 +590,7 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input if (!success || mBlitInfo.invalid()) { return OUT_OF_MEMORY; } - if (false == mDynamicQuantExe) { + if (false == mResourceInt8->mDynamicQuant) { bufferAlloc->free(mBlitInfo); backend()->onReleaseBuffer(mInputDeqScales.get(), Backend::DYNAMIC); backend()->onReleaseBuffer(mTempIm2ColBuffer.get(), Backend::DYNAMIC); @@ -591,9 +644,6 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu int UNIT__, SRC_UNIT, DST_XUNIT; core->MNNGetGemmUnit(&UNIT__, &SRC_UNIT, &DST_XUNIT); auto blitProc = core->MNNPackC4Int8ForMatMul_A; - if ( mDynamicQuantExe && gcore->bytes == 2 && core->MNNPackC4Int8ForMatMul_A_ARM86FP16) { - blitProc = core->MNNPackC4Int8ForMatMul_A_ARM86FP16; - } const int plane = output->batch() * mIm2ColParamter.oh * mIm2ColParamter.ow; const int batch = input->batch(); const int PackUnit = gcore->pack; @@ -618,12 +668,16 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu auto weightDequantBias = mResourceInt8->mOriginScale->host() + alphaSize * 4; auto outputDataPtr = output->host(); - auto biasPtr = mMutableResource->mBiasFloat->host(); - auto scalePtr = mMutableResource->mScaleFloat->host(); - - auto inputZeroPoint = mMutableResource->mInputZeroPoint; + uint8_t* biasPtr = nullptr; + uint8_t* scalePtr = nullptr; + int32_t inputZeroPoint = 0; auto inputScalePtr = mInputDeqScales->host(); - (reinterpret_cast(inputScalePtr))[0] = mMutableResource->mInputScale; + if (nullptr != mMutableResource.get()) { + biasPtr = mMutableResource->mBiasFloat->host(); + scalePtr = mMutableResource->mScaleFloat->host(); + inputZeroPoint = mMutableResource->mInputZeroPoint; + (reinterpret_cast(inputScalePtr))[0] = mMutableResource->mInputScale; + } auto SingleDynamicQuant = [&] () { const auto floatptr = input->host(); @@ -631,7 +685,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu auto inputsize = static_cast(backend())->getTensorSize(inputs[0]); float quantscale = 0.f; float dequantscale = 0.f; - int zeropoint = 0; + float zeropoint = 0; /* Count max and min value to compute input scale and zeropoint */ auto maxMinValPtr = mTempMaxMinValueBuffer->host(); @@ -675,14 +729,14 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu float range = maxVal - minVal; quantscale = 255.0f / range; dequantscale = range / 255.0f; - zeropoint = static_cast(roundf(-minVal * 255.f / range) - 128.0f); + zeropoint = roundf(-minVal * 255.f / range) - 128.0f; std::vectorqsVec(PackUnit, quantscale); auto sizeDiv = UP_DIV(inputsize, PackUnit); int inputPlane = input->batch() * mIm2ColParamter.iw * mIm2ColParamter.ih; if (gcore->bytes == 2 && gcore->pack == 8 && inputPlane > 1) { // C8->C4 - mQuantAndReorderFunc(floatptr, int8ptr, inputPlane, qsVec.data(), -128, 127, (ssize_t)zeropoint, UP_DIV(input->channel(), PackUnit), 4 * inputPlane); + mQuantAndReorderFunc(floatptr, int8ptr, inputPlane, &quantscale, -128, 127, &zeropoint, UP_DIV(input->channel(), PackUnit), 4 * inputPlane); } else { - mQuantFunc(floatptr, int8ptr, sizeDiv, qsVec.data(), -128, 127, (ssize_t)zeropoint); + mQuantFunc(floatptr, int8ptr, sizeDiv, &quantscale, -128, 127, &zeropoint, 0); } /* bias float */ @@ -691,7 +745,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu #else int offset = 0; #endif - auto biasfp32 = mMutableResource->mResource->mOriginBias->host(); + auto biasfp32 = mResourceInt8->mOriginBias->host(); auto weightDequantScale = mResourceInt8->mOriginScale->host(); float zerofp32 = (zeropoint + offset) * dequantscale; @@ -750,14 +804,14 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu inputZeroPoint = 0; inputScalePtr = (uint8_t*)dequantPtr; inputDataPtr = mQuantInput->host(); - biasPtr = mMutableResource->mResource->mOriginBias->host(); + biasPtr = mResourceInt8->mOriginBias->host(); scalePtr = mResourceInt8->mOriginScale->host(); }; ssize_t oneScale = 1; if (mUseBatchQuan) { BatchDynamicQuant(); oneScale = 0; - } else if (mDynamicQuantExe) { + } else if (mResourceInt8->mDynamicQuant) { SingleDynamicQuant(); } else { // offline quant. diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp index c5fc5f4d3..bebeaa5c4 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp @@ -61,8 +61,8 @@ class DenseConvInt8TiledExecutor : public ConvInt8TiledExecutor { DenseConvInt8TiledExecutor(Backend* backend, const Op* op, const DenseConvInt8TiledExecutor& exe); decltype(CoreInt8Functions::Int8GemmKernel) mGemmKernel; - std::function mQuantFunc; - std::function mQuantAndReorderFunc = nullptr; + std::function mQuantFunc; + std::function mQuantAndReorderFunc = nullptr; std::function mSumByAxisLFunc; std::shared_ptr mQuantInput; std::shared_ptr mDynamicBias; @@ -76,7 +76,6 @@ class DenseConvInt8TiledExecutor : public ConvInt8TiledExecutor { int mThreadNums; int mBlockNum; int mOcPerThread; - bool mDynamicQuantExe; bool mSplitByOc; bool mUseBatchQuan; }; diff --git a/source/backend/cpu/compute/ConvInt8Winograd.cpp b/source/backend/cpu/compute/ConvInt8Winograd.cpp index 433b88812..a460c1db8 100644 --- a/source/backend/cpu/compute/ConvInt8Winograd.cpp +++ b/source/backend/cpu/compute/ConvInt8Winograd.cpp @@ -189,6 +189,17 @@ ErrorCode ConvInt8Winograd::onResize(const std::vector &inputs, const core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); UNIT = gcore->pack; int pack = gcore->pack; + + mFusedBias.reset(Tensor::createDevice({ROUND_UP(outputs[0]->channel(), pack)})); + mValid &= backend()->onAcquireBuffer(mFusedBias.get(), Backend::STATIC); + if (!mValid) { + return OUT_OF_MEMORY; + } + auto fusedBiasPtr = mFusedBias->host(); + ::memset(fusedBiasPtr, 0, mFusedBias->size()); + for (int i = 0; i < outputs[0]->channel(); ++i) { + fusedBiasPtr[i] = mResource->mOriginBias->host()[i] / mResource->mOutputScale + static_cast(mResource->mOutputZeroPoint); + } auto input = mInputFloat.get(), output = outputs[0]; int batch = input->batch(), ic = input->channel(), oc = output->channel(); @@ -235,9 +246,10 @@ static void mergeAddBiasScaleQuantize(const std::vector& inputs, Tensor for (int i = 1; i < inputs.size(); ++i) { core->MNNMatrixAdd(mergeFloat, mergeFloat, inputs[i]->host(), plane * countC4, 0, 0, 0, 1); } - std::vector fakeScale(countC4 * pack, 1); - core->MNNScaleAndAddBias(mergeFloat, mergeFloat, quanParam->biasFloat, fakeScale.data(), plane, countC4); - coreInt8->MNNFloat2Int8(mergeFloat, output->host(), plane * countC4, quanParam->scale, quanParam->minValue, quanParam->maxValue, zeroPoint); + auto zeroPointPtr = quanParam->biasFloat; + for (int i = 0; i < countC4; ++i) { + coreInt8->MNNFloat2Int8(mergeFloat + i * plane * pack, output->host() + i * plane * pack, plane, quanParam->scale, quanParam->minValue, quanParam->maxValue, zeroPointPtr + i * pack, 2); + } } // AVX: 8 -> 16, arm32/64: 4 -> 16, AVX512: 16 -> 16, arm82: 4 -> 4 @@ -246,6 +258,10 @@ static void _reorderCommon(float* dst, const float* src, size_t area, size_t dep MNNPackC4((float*)dst, (const float*)src, area, depth, areaOffset); return; } + if (uFrom == 1 && uTo == 2) { + MNNPackInt8C2((float*)dst, (const float*)src, area, depth, areaOffset); + return; + } size_t srcOffset = areaOffset[0], dstOffset = areaOffset[1]; int z = 0; if (uFrom == 2 && uTo == 4) { @@ -318,10 +334,11 @@ ErrorCode ConvInt8Winograd::onExecute(const std::vector &inputs, const tmp_outputs.push_back(unit.output.get()); } QuanPostTreatParameters quanParam; - scale.assign(pack, 1.0 / outputQuant[0]); - quanParam.scale = scale.data(); + float outputdequantScale = 1.0 / mResource->mOutputScale; + quanParam.scale = &outputdequantScale; // For winograd Int8, will not treat origin bias to int32, use float directly - quanParam.biasFloat = mResource->mOriginBias->host(); + // quanParam.biasFloat = mResource->mOriginBias->host(); + quanParam.biasFloat = mFusedBias->host(); quanParam.maxValue = outputQuant[3]; if (mResource->mRelu) { quanParam.minValue = outputQuant[1]; @@ -501,6 +518,13 @@ ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector auto tFunction = [&](int tId) { auto _srcOrigin = mTempInputBuffer->host() + tId * mTempInputBuffer->stride(0); auto _dstOrigin = mTempOutputBuffer->host() + tId * mTempOutputBuffer->stride(0); + QuanPostTreatParameters quanParam; + quanParam.useInt8 = 0; + quanParam.srcKernelSum = xkernelSum.data(); + quanParam.weightQuanBias = wKernelSum.data(); + quanParam.fp32minmax = reluThred.data(); + quanParam.extraScale = nullptr; + for (int tIndex = (int)tId; tIndex < tileCount; tIndex += threadNumber) { int xIndex = (int)tIndex * DST_XUNIT; int xReamin = totalCount - xIndex; @@ -518,8 +542,8 @@ ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector auto _srcInt8Ptr = _srcOrigin + i * mTempInputBuffer->stride(1); auto scaleVec = mWinoResource->transInputScales->host() + i * pack; - int zeroPoint = mWinoResource->transInputZeroPoints[i]; - coreInt8->MNNFloat2Int8(buffer2 + i * DST_XUNIT * ic_4 * pack, (pack == SRC_UNIT ? _srcInt8Ptr: (int8_t*)buffer0), ic_4 * DST_XUNIT, scaleVec, -127, 127, zeroPoint); + float zeroPoint = static_cast(mWinoResource->transInputZeroPoints[i]); + coreInt8->MNNFloat2Int8(buffer2 + i * DST_XUNIT * ic_4 * pack, (pack == SRC_UNIT ? _srcInt8Ptr: (int8_t*)buffer0), ic_4 * DST_XUNIT, scaleVec, -127, 127, &zeroPoint, 0); if (pack != SRC_UNIT) { int areaOffset[] = {DST_XUNIT, DST_XUNIT}, byte = sizeof(float); _reorderCommon((float*)_srcInt8Ptr, buffer0, DST_XUNIT, UP_DIV(ic, byte), areaOffset, pack / byte, SRC_UNIT / byte); @@ -527,14 +551,12 @@ ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector auto _dstFloatPtr = _dstOrigin + i * dc_4 * xC * pack; auto _weightInt8Ptr = weight + i * mWinoResource->weight->stride(0); - QuanPostTreatParameters quanParam; + quanParam.biasFloat = (mWinoResource->offsets->host() + i * mWinoResource->offsets->stride(0)); - quanParam.useInt8 = 0; - quanParam.srcKernelSum = xkernelSum.data(); - quanParam.weightQuanBias = wKernelSum.data(); - quanParam.fp32minmax = reluThred.data(); quanParam.scale = mWinoResource->scales->host() + i * dc_4 * pack; quanParam.extraScale = nullptr; + quanParam.bias = nullptr; + quanParam.blockNum = 1; gemmFunc((int8_t*)_dstFloatPtr, _srcInt8Ptr, _weightInt8Ptr, mTempInputBuffer->length(2), xC * pack * sizeof(float), dc_4, &quanParam, xC); } #ifndef MNN_WINO_TRANFORM_TEST_CLOSE diff --git a/source/backend/cpu/compute/ConvInt8Winograd.hpp b/source/backend/cpu/compute/ConvInt8Winograd.hpp index b876059fa..c3f2d58d5 100644 --- a/source/backend/cpu/compute/ConvInt8Winograd.hpp +++ b/source/backend/cpu/compute/ConvInt8Winograd.hpp @@ -36,6 +36,7 @@ class ConvInt8Winograd : public CPUConvolution { std::vector mUnits; std::shared_ptr mResource; std::shared_ptr mInputFloat; + std::shared_ptr mFusedBias; struct WinoResource { std::shared_ptr weight; diff --git a/source/backend/cpu/compute/ConvolutionFloatFactory.cpp b/source/backend/cpu/compute/ConvolutionFloatFactory.cpp index 738d85826..d09a3f6fd 100644 --- a/source/backend/cpu/compute/ConvolutionFloatFactory.cpp +++ b/source/backend/cpu/compute/ConvolutionFloatFactory.cpp @@ -82,10 +82,18 @@ Execution* ConvolutionFloatFactory::create(const std::vector& inputs, c return new ConvolutionTiledExecutorMultiInput(conv2d->common(), backend); } #ifdef MNN_LOW_MEMORY - bool lowMemory = static_cast(backend)->memoryMode() != BackendConfig::Memory_High && static_cast(backend)->functions()->MNNPackedMatMul_int8 != nullptr; + bool lowMemory = static_cast(backend)->memoryMode() == BackendConfig::Memory_Low; + if (static_cast(backend)->functions()->bytes == 2 && static_cast(backend)->int8Functions()->MNNGemmInt8AddBiasScale_Unit_FP16 == nullptr) { + // Fall back to fp32 + return nullptr; + } #else bool lowMemory = false; #endif + +#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM + lowMemory = lowMemory || (static_cast(backend)->memoryMode() != BackendConfig::Memory_High); +#endif const float* originWeight = nullptr; const float* originBias = nullptr; int originWeightSize = 0; diff --git a/source/backend/cpu/compute/GemmInt8Executor.cpp b/source/backend/cpu/compute/GemmInt8Executor.cpp index a73afdba8..e314f9fcf 100644 --- a/source/backend/cpu/compute/GemmInt8Executor.cpp +++ b/source/backend/cpu/compute/GemmInt8Executor.cpp @@ -13,11 +13,42 @@ #include "core/TensorUtils.hpp" namespace MNN { +static void _makeResource(Backend* backend, std::shared_ptr resource, const MNN::Op *op, std::shared_ptr resourceInt8) { + /* Used to compute weight quant scale and bias and weightKernelSum of type float. */ + auto conv2d = op->main_as_Convolution2D(); + bool quanBuffer = (conv2d->quanParameter() != nullptr && conv2d->quanParameter()->buffer() != nullptr); + MNN_ASSERT(quanBuffer || resourceInt8); + resource->backend = backend; + auto core = static_cast(backend)->functions(); + // common parameters + int outputCount = conv2d->common()->outputCount(); + int LSize = conv2d->common()->inputCount() * conv2d->common()->kernelX() * conv2d->common()->kernelY(); + int ocUp4 = ROUND_UP(outputCount, core->pack); + int8_t* weightOrigin; + + // Save weight quant scale and bias: wf=scale*wi+bias + resource->mDequantize.mScaleBias.reset(Tensor::createDevice({2 * ocUp4 * core->bytes})); + auto success = resource->backend->onAcquireBuffer(resource->mDequantize.mScaleBias.get(), Backend::STATIC); + if (!success) { + MNN_ERROR("Alloc denquant scaleBias memory error\n"); + return; + } + auto alphaPtr = resource->mDequantize.mScaleBias->host(); + auto biasPtr = reinterpret_cast(reinterpret_cast(alphaPtr) + ocUp4 * core->bytes); + ::memset(alphaPtr, 0, 2 * ocUp4 * core->bytes); + auto wZero = resourceInt8->mWeightQuantZero->host(); // has packed to outputUp4 + auto wScale = resourceInt8->mOriginScale->host(); + int h = ocUp4; + for (int i=0; i< h; ++i) { + alphaPtr[i] = wScale[i]; + biasPtr[i] = (-1.f) * wZero[i] * wScale[i]; + } +} GemmInt8Executor::GemmInt8Executor(Backend* bn, std::shared_ptr resource, const Op *op, decltype(CoreInt8Functions::Int8GemmKernel) gemmKernel, std::vector bias) : CPUConvolution(op->main_as_Convolution2D()->common(), bn), mResourceInt8(resource), mMutableResource(resource, bn), mGemmKernel(gemmKernel), mQuantBias(bias){ mResource.reset(new Resource); - CPUConvolution::makeResource(bn, mResource, op, mResourceInt8); + _makeResource(bn, mResource, op, mResourceInt8); } GemmInt8Executor::~GemmInt8Executor() { diff --git a/source/backend/cpu/compute/IdstConvolutionInt8.cpp b/source/backend/cpu/compute/IdstConvolutionInt8.cpp index 05a9df338..20ce94af3 100644 --- a/source/backend/cpu/compute/IdstConvolutionInt8.cpp +++ b/source/backend/cpu/compute/IdstConvolutionInt8.cpp @@ -175,7 +175,7 @@ ErrorCode IdstConvolutionInt8::onExecute(const std::vector& inputs, con mQuanScale, mQuanScale }; - int8_t zeroPoint = 0; + float zeroPoint = 0; std::vector fakeScale(ocC4 * PackUnit, 1.0f); QuanPostTreatParameters quanParam; @@ -199,7 +199,7 @@ ErrorCode IdstConvolutionInt8::onExecute(const std::vector& inputs, con auto srcOrigin = input->host() + input->stride(0) * batchIndex; auto dstOrigin = output->host() + output->stride(0) * batchIndex; - MNNFloat2Int8(srcOrigin, srcCopy, inputTotalSize / 4, quantScale, mAMin, mAMax, zeroPoint); + MNNFloat2Int8(srcOrigin, srcCopy, inputTotalSize / 4, &mQuanScale, mAMin, mAMax, &zeroPoint, 0); int tileCount = UP_DIV(count, DST_XUNIT); threadNumber = std::max(((CPUBackend*)backend())->threadNumber(), 1); diff --git a/source/backend/cpu/compute/ImageProcessFunction.cpp b/source/backend/cpu/compute/ImageProcessFunction.cpp index d84d2f5e6..340b2386d 100644 --- a/source/backend/cpu/compute/ImageProcessFunction.cpp +++ b/source/backend/cpu/compute/ImageProcessFunction.cpp @@ -24,6 +24,23 @@ void MNNSamplerC4NearestOpt(const unsigned char* source, unsigned char* dest, fl void MNNSamplerC1NearestOpt(const unsigned char* source, unsigned char* dest, float* points, size_t count, size_t iw, size_t ih, size_t yStride); void MNNBlitC1ToFloatRGBA(const unsigned char* source, float* dest, const float* mean, const float* normal, size_t count); void MNNBlitC3ToFloatRGBA(const unsigned char* source, float* dest, const float* mean, const float* normal, size_t count); +void MNNRGBToBGRC8(const unsigned char* source, unsigned char* dest, size_t count); +void MNNBGRAToBGRC8(const unsigned char* source, unsigned char* dest, size_t count); +void MNNGRAYToC4Fast(const unsigned char* source, unsigned char* dest, size_t count); +void MNNGRAYToC3Fast(const unsigned char* source, unsigned char* dest, size_t count); +void MNNC3ToC4Fast(const unsigned char* source, unsigned char* dest, size_t count); +void MNNBGRAToGRAYFast(const unsigned char* source, unsigned char* dest, size_t count); +void MNNRGBToGRAYFast(const unsigned char* source, unsigned char* dest, size_t count); +void MNNRGBAToGRAYFast(const unsigned char* source, unsigned char* dest, size_t count); +void MNNBGRToGRAYFast(const unsigned char* source, unsigned char* dest, size_t count); +void MNNC3ToYUVFast(const unsigned char* source, unsigned char* dest, size_t count, int32_t* c); +void MNNC3ToXYZFast(const unsigned char* source, unsigned char* dest, size_t count, int32_t* c); +void MNNRGBToBGR555Fast(const unsigned char* source, unsigned char* dest, size_t count); +void MNNBGRToBGR555Fast(const unsigned char* source, unsigned char* dest, size_t count); +void MNNBGRToBGR565Fast(const unsigned char* source, unsigned char* dest, size_t count); +void MNNRGBToBGR565Fast(const unsigned char* source, unsigned char* dest, size_t count); +void MNNRGBAToBGRAFast(const unsigned char* source, unsigned char* dest, size_t count); +void MNNRGBAToBGRFast(const unsigned char* source, unsigned char* dest, size_t count); } void MNNGRAYToC4(const unsigned char* source, unsigned char* dest, size_t count) { @@ -31,16 +48,7 @@ void MNNGRAYToC4(const unsigned char* source, unsigned char* dest, size_t count) #ifdef MNN_USE_NEON int countD8 = (int)count / 8; if (countD8 > 0) { - for (int i = 0; i < countD8; ++i) { - auto gray = vld1_u8(source + 8 * i); - - uint8x8x4_t rgba; - rgba.val[0] = gray; - rgba.val[1] = gray; - rgba.val[2] = gray; - rgba.val[3] = vdup_n_u8(255); - vst4_u8(dest + 32 * i, rgba); - } + MNNGRAYToC4Fast(source, dest, countD8); sta = countD8 * 8; } #endif @@ -57,15 +65,7 @@ void MNNGRAYToC3(const unsigned char* source, unsigned char* dest, size_t count) #ifdef MNN_USE_NEON int countD8 = (int)count / 8; if (countD8 > 0) { - for (int i = 0; i < countD8; ++i) { - auto gray = vld1_u8(source + 8 * i); - - uint8x8x3_t rgba; - rgba.val[0] = gray; - rgba.val[1] = gray; - rgba.val[2] = gray; - vst3_u8(dest + 24 * i, rgba); - } + MNNGRAYToC3Fast(source, dest, countD8); sta = countD8 * 8; } #endif @@ -81,16 +81,7 @@ void MNNC3ToC4(const unsigned char* source, unsigned char* dest, size_t count) { #ifdef MNN_USE_NEON int countD8 = (int)count / 8; if (countD8 > 0) { - for (int i = 0; i < countD8; ++i) { - uint8x8x3_t c3 = vld3_u8(source + 24 * i); - - uint8x8x4_t c4; - c4.val[0] = c3.val[0]; - c4.val[1] = c3.val[1]; - c4.val[2] = c3.val[2]; - c4.val[3] = vdup_n_u8(255); - vst4_u8(dest + 32 * i, c4); - } + MNNC3ToC4Fast(source, dest, countD8); sta = countD8 * 8; } #endif @@ -105,15 +96,9 @@ void MNNC3ToC4(const unsigned char* source, unsigned char* dest, size_t count) { void MNNRGBAToBGRA(const unsigned char* source, unsigned char* dest, size_t count) { int sta = 0; #ifdef MNN_USE_NEON - int countD8 = (int)count / 8; + auto countD8 = count / 8; if (countD8 > 0) { - for (int i = 0; i < countD8; ++i) { - uint8x8x4_t rgba = vld4_u8(source + 32 * i); - auto t = rgba.val[0]; - rgba.val[0] = rgba.val[2]; - rgba.val[2] = t; - vst4_u8(dest + 32 * i, rgba); - } + MNNRGBAToBGRAFast(source, dest, countD8); sta = countD8 * 8; } #endif @@ -128,17 +113,9 @@ void MNNRGBAToBGRA(const unsigned char* source, unsigned char* dest, size_t coun void MNNRGBAToBGR(const unsigned char* source, unsigned char* dest, size_t count) { int sta = 0; #ifdef MNN_USE_NEON - int countD8 = (int)count / 8; + auto countD8 = count / 8; if (countD8 > 0) { - for (int i = 0; i < countD8; ++i) { - uint8x8x4_t rgba = vld4_u8(source + 32 * i); - - uint8x8x3_t bgr; - bgr.val[0] = rgba.val[2]; - bgr.val[1] = rgba.val[1]; - bgr.val[2] = rgba.val[0]; - vst3_u8(dest + 24 * i, bgr); - } + MNNRGBAToBGRFast(source, dest, countD8); sta = countD8 * 8; } #endif @@ -152,18 +129,11 @@ void MNNRGBAToBGR(const unsigned char* source, unsigned char* dest, size_t count void MNNRGBToBGR(const unsigned char* source, unsigned char* dest, size_t count) { int sta = 0; #ifdef MNN_USE_NEON - int countD8 = (int)count / 8; - if (countD8 > 0) { - for (int i = 0; i < countD8; ++i) { - uint8x8x3_t rgba = vld3_u8(source + 24 * i); - uint8x8x3_t bgr; - bgr.val[0] = rgba.val[2]; - bgr.val[1] = rgba.val[1]; - bgr.val[2] = rgba.val[0]; - vst3_u8(dest + 24 * i, bgr); - } + int countD8 = (int)count / 8; + if (countD8 > 0) { + MNNRGBToBGRC8(source, dest, countD8); sta = countD8 * 8; - } + } #endif for (int i = sta; i < count; ++i) { dest[3 * i + 0] = source[3 * i + 2]; @@ -177,15 +147,7 @@ void MNNBGRAToBGR(const unsigned char* source, unsigned char* dest, size_t count #ifdef MNN_USE_NEON int countD8 = (int)count / 8; if (countD8 > 0) { - for (int i = 0; i < countD8; ++i) { - uint8x8x4_t bgra = vld4_u8(source + 32 * i); - - uint8x8x3_t bgr; - bgr.val[0] = bgra.val[0]; - bgr.val[1] = bgra.val[1]; - bgr.val[2] = bgra.val[2]; - vst3_u8(dest + 24 * i, bgr); - } + MNNBGRAToBGRC8(source, dest, countD8); sta = countD8 * 8; } #endif @@ -198,23 +160,13 @@ void MNNBGRAToBGR(const unsigned char* source, unsigned char* dest, size_t count void MNNBGRAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { int sta = 0; - /* -#ifdef MNN_USE_NEON - int countD8 = (int)count / 8; - if (countD8 > 0) { - auto rC = vdup_n_u8(19); - auto gC = vdup_n_u8(38); - auto bC = vdup_n_u8(7); - for (int i = 0; i < countD8; ++i) { - auto rgb = vld4_u8(source + 32 * i); - auto res = vmull_u8(rC, rgb.val[2]) + vmull_u8(gC, rgb.val[1]) + vmull_u8(bC, rgb.val[0]); - auto resU8 = vshrn_n_u16(res, 6); - vst1_u8(dest + 8 * i, resU8); - } + #if defined MNN_USE_NEON + int countD8 = (int)count / 8; + if (countD8 > 0) { + MNNBGRAToGRAYFast(source, dest, countD8); sta = countD8 * 8; - } -#endif - */ + } + #endif for (int i = sta; i < count; ++i) { int r = source[4 * i + 2]; int g = source[4 * i + 1]; @@ -228,23 +180,14 @@ void MNNBGRAToGRAY(const unsigned char* source, unsigned char* dest, size_t coun void MNNRGBAToGRAY(const unsigned char* source, unsigned char* dest, size_t count) { int sta = 0; - /* -#ifdef MNN_USE_NEON + +#if defined MNN_USE_NEON int countD8 = (int)count / 8; if (countD8 > 0) { - auto rC = vdup_n_u8(19); - auto gC = vdup_n_u8(38); - auto bC = vdup_n_u8(7); - for (int i = 0; i < countD8; ++i) { - auto rgb = vld4_u8(source + 32 * i); - auto res = vmull_u8(rC, rgb.val[0]) + vmull_u8(gC, rgb.val[1]) + vmull_u8(bC, rgb.val[2]); - auto resU8 = vshrn_n_u16(res, 6); - vst1_u8(dest + 8 * i, resU8); - } + MNNRGBAToGRAYFast(source, dest, countD8); sta = countD8 * 8; } #endif - */ for (int i = sta; i < count; ++i) { int r = source[4 * i + 0]; @@ -291,28 +234,15 @@ void MNNC3ToYUV(const unsigned char* source, unsigned char* dest, size_t count, C3 = coeffs[r1], C4 = coeffs[g1], C5 = coeffs[b1], C6 = coeffs[r2], C7 = coeffs[g2], C8 = coeffs[b2]; int sta = 0; - /* -#ifdef MNN_USE_NEON + +#if defined MNN_USE_NEON int countD8 = (int)count / 8; if (countD8 > 0) { - auto rC0 = vdup_n_u8(C0), rC1 = vdup_n_u8(C1), rC2 = vdup_n_u8(C2), - rC3 = vdup_n_u8(C3), rC4 = vdup_n_u8(C4), rC5 = vdup_n_u8(C5), - rC6 = vdup_n_u8(C6), rC7 = vdup_n_u8(C7), rC8 = vdup_n_u8(C8); - auto delta = vdup_n_u8(128); - for (int i = 0; i < countD8; ++i) { - auto rgb = vld4_u8(source + 24 * i); - uint8x8x3_t yuv; - yuv.val[0] = CV_MUL_SHIFT(rC0, rC1, rC2, 14); - yuv.val[1] = CV_MUL_SHIFT(rC3, rC4, rC5, 14); - yuv.val[2] = CV_MUL_SHIFT(rC6, rC7, rC8, 14); - yuv.val[1] = vadd_u8(yuv.val[1], delta); - yuv.val[2] = vadd_u8(yuv.val[2], delta); - vst3_u8(dest + 24 * i, yuv); - } + int32_t c[] = {C0, C1, C2, C3, C4, C5, C6, C7, C8}; + MNNC3ToYUVFast(source, dest, countD8, c); sta = countD8 * 8; } #endif - */ for (int i = sta; i < count; ++i) { int r = source[3 * i + 0]; int g = source[3 * i + 1]; @@ -342,25 +272,16 @@ void MNNC3ToXYZ(const unsigned char* source, unsigned char* dest, size_t count, C3 = coeffs[r1], C4 = coeffs[4], C5 = coeffs[b1], C6 = coeffs[r2], C7 = coeffs[7], C8 = coeffs[b2]; int sta = 0; - /* -#ifdef MNN_USE_NEON + +#if defined MNN_USE_NEON int countD8 = (int)count / 8; if (countD8 > 0) { - auto rC0 = vdup_n_u8(C0), rC1 = vdup_n_u8(C1), rC2 = vdup_n_u8(C2), - rC3 = vdup_n_u8(C3), rC4 = vdup_n_u8(C4), rC5 = vdup_n_u8(C5), - rC6 = vdup_n_u8(C6), rC7 = vdup_n_u8(C7), rC8 = vdup_n_u8(C8); - for (int i = 0; i < countD8; ++i) { - auto rgb = vld4_u8(source + 24 * i); - uint8x8x3_t xyz; - xyz.val[0] = CV_MUL_SHIFT(rC0, rC1, rC2, 12); - xyz.val[1] = CV_MUL_SHIFT(rC3, rC4, rC5, 12); - xyz.val[2] = CV_MUL_SHIFT(rC6, rC7, rC8, 12); - vst3_u8(dest + 24 * i, xyz); - } + int32_t c[] = {C0, C1, C2, C3, C4, C5, C6, C7, C8}; + MNNC3ToXYZFast(source, dest, countD8, c); sta = countD8 * 8; } #endif - */ + for (int i = sta; i < count; ++i) { int r = source[3 * i + 0]; int g = source[3 * i + 1]; @@ -403,6 +324,18 @@ void MNNC3ToHSV(const unsigned char* source, unsigned char* dest, size_t count, void MNNC3ToBGR555(const unsigned char* source, unsigned char* dest, size_t count, bool bgr) { int i = 0; + int countD8 = (int)count / 8; +#if defined MNN_USE_NEON + if (countD8 > 0) { + if (bgr) { + MNNBGRToBGR555Fast(source, dest, countD8); + } else { + MNNRGBToBGR555Fast(source, dest, countD8); + } + + i = countD8 * 8; + } +#endif for (; i < count; ++i) { int r = source[3 * i + 0]; int g = source[3 * i + 1]; @@ -414,6 +347,17 @@ void MNNC3ToBGR555(const unsigned char* source, unsigned char* dest, size_t coun void MNNC3ToBGR565(const unsigned char* source, unsigned char* dest, size_t count, bool bgr) { int i = 0; +#if defined MNN_USE_NEON + auto countD8 = count / 8; + if (countD8 > 0) { + if (bgr) { + MNNBGRToBGR565Fast(source, dest, countD8); + } else { + MNNRGBToBGR565Fast(source, dest, countD8); + } + i = countD8 * 8; + } +#endif for (; i < count; ++i) { int r = source[3 * i + 0]; int g = source[3 * i + 1]; @@ -428,15 +372,7 @@ void MNNRGBToGRAY(const unsigned char* source, unsigned char* dest, size_t count #ifdef MNN_USE_NEON int countD8 = (int)count / 8; if (countD8 > 0) { - auto rC = vdup_n_u8(19); - auto gC = vdup_n_u8(38); - auto bC = vdup_n_u8(7); - for (int i = 0; i < countD8; ++i) { - auto rgb = vld3_u8(source + 24 * i); - auto res = vmull_u8(rC, rgb.val[0]) + vmull_u8(gC, rgb.val[1]) + vmull_u8(bC, rgb.val[2]); - auto resU8 = vshrn_n_u16(res, 6); - vst1_u8(dest + 8 * i, resU8); - } + MNNRGBToGRAYFast(source, dest, countD8); sta = countD8 * 8; } #endif @@ -457,15 +393,7 @@ void MNNBRGToGRAY(const unsigned char* source, unsigned char* dest, size_t count #ifdef MNN_USE_NEON int countD8 = (int)count / 8; if (countD8 > 0) { - auto rC = vdup_n_u8(19); - auto gC = vdup_n_u8(38); - auto bC = vdup_n_u8(7); - for (int i = 0; i < countD8; ++i) { - auto rgb = vld3_u8(source + 24 * i); - auto res = vmull_u8(rC, rgb.val[2]) + vmull_u8(gC, rgb.val[1]) + vmull_u8(bC, rgb.val[0]); - auto resU8 = vshrn_n_u16(res, 6); - vst1_u8(dest + 8 * i, resU8); - } + MNNBGRToGRAYFast(source, dest, countD8); sta = countD8 * 8; } #endif @@ -839,7 +767,7 @@ static void _sampleBilinearCommon(const unsigned char* source, unsigned char* de float v = (1.0f - xF) * (1.0f - yF) * c00 + xF * (1.0f - yF) * c01 + yF * (1.0 - xF) * c10 + xF * yF * (c11); v = std::min(std::max(v, 0.0f), 255.0f); - dest[bpp * i + b] = (unsigned char)v; + dest[bpp * i + b] = (unsigned char)roundf(v); } curPoints.fY += dy; curPoints.fX += dx; diff --git a/source/backend/cpu/compute/Int8FunctionsOpt.cpp b/source/backend/cpu/compute/Int8FunctionsOpt.cpp index 50fad7e6a..497ef3bf9 100644 --- a/source/backend/cpu/compute/Int8FunctionsOpt.cpp +++ b/source/backend/cpu/compute/Int8FunctionsOpt.cpp @@ -37,6 +37,8 @@ void MNNGemmInt8AddBiasScale_ARMV86_Unit(int8_t* dst, const int8_t* src, const i const QuanPostTreatParameters* post, size_t realDstCount); void MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder=nullptr); +void MNNSumByAxisLForMatmul_A_ARM86(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams); +void MNNSumByAxisLForMatmul_A_ARM82(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams); #if defined(MNN_LOW_MEMORY) // int4 weight gemmInt8 kernel void MNNGemmInt8AddBiasScale_ARMV82_w4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, @@ -48,7 +50,7 @@ void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src, const // Tools to dynamic-quant fp16-input data. #ifdef MNN_USE_ARMV82 void DynamicQuanInput_ARM82(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, - ssize_t maxValue, ssize_t zeroPoint); + ssize_t maxValue, const float* zeroPoint, ssize_t quanParamVec); // int8 weight gemmInt8 kernel to return fp16-output data. void MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount); @@ -59,7 +61,7 @@ void MNNGemmInt8AddBiasScale_ARMV86_Unit_FP16(int8_t* dst, const int8_t* src, co void MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount); void DynamicQuanInputAndReorder_ARM82(const float* src, int8_t* dst, size_t planeSize, const float* scale, ssize_t aMin, - ssize_t aMax, ssize_t zeroPoint, size_t ocQuad, size_t offset); + ssize_t aMax, const float* zeroPoint, size_t ocQuad, size_t offset); #endif #endif #endif // __aarch64__ @@ -1514,8 +1516,8 @@ static void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src, int w8[64]; // 64=GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT for (int k = 0; k < 32; ++k) { - w8[2 * k] = (weight_sz[k]>>4); - w8[2 * k + 1] = (weight_sz[k] & c); + w8[k] = (weight_sz[k]>>4); + w8[k + 32] = (weight_sz[k] & c); } for (int j = 0; j < GEMM_INT8_UNIT; ++j) { @@ -1642,10 +1644,28 @@ static void MNNLineDepthWiseInt8AddBiasScaleUnit3x3(int8_t* dst, const int8_t* s #ifndef MNN_USE_NEON void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, - ssize_t maxValue, ssize_t zeroPoint) { + ssize_t maxValue, const float* zeroPoint, ssize_t quanParamVec) { + // quanParamVec: + // 00: scale is vector + // 10: zero is vector + // 11: both are vector + float scale4[4] = {scalep[0], scalep[0], scalep[0], scalep[0] }; + float zero4[4] = {zeroPoint[0], zeroPoint[0], zeroPoint[0], zeroPoint[0]}; + if (quanParamVec % 2 == 1) { + scale4[0] = scalep[0]; + scale4[1] = scalep[1]; + scale4[2] = scalep[2]; + scale4[3] = scalep[3]; + } + if (quanParamVec >> 1 == 1) { + zero4[0] = zeroPoint[0]; + zero4[1] = zeroPoint[1]; + zero4[2] = zeroPoint[2]; + zero4[3] = zeroPoint[3]; + } for (int i = 0; i < sizeQuad; ++i) { for (int j=0; j<4; ++j) { - int v = (int)roundf(src[4*i+j] * scalep[j]) + zeroPoint; + int v = (int)roundf(src[4*i+j] * scale4[j]) + zero4[j]; if (v > maxValue) { v = maxValue; } @@ -2103,7 +2123,7 @@ static void MNNGetGemmUnit(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) { } static void MNNGetGemmUnitSdot(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) { - *UNIT = 4; + *UNIT = 8; *SRC_UNIT = 4; *DST_XUNIT = 12; } @@ -2226,6 +2246,7 @@ void MNNCoreInt8FunctionInit() { gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A_L4<12, 4>; // ConvDepthwise gCoreFunc->ConvDepthwise3x3LineInt8_ARM82 = MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3; + core->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A_ARM82; #if defined(MNN_LOW_MEMORY) #ifdef MNN_USE_ARMV82 gCoreFunc->DynamicQuanInput_ARM82 = DynamicQuanInput_ARM82; @@ -2241,6 +2262,7 @@ void MNNCoreInt8FunctionInit() { gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_ARMV86_Unit; gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_ARMV86_Unit; gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnitI8mm; + core->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A_ARM86; #if defined(MNN_LOW_MEMORY) gCoreFunc->Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_ARMV86_w4_Unit; #ifdef MNN_USE_ARMV82 @@ -2250,7 +2272,6 @@ void MNNCoreInt8FunctionInit() { #endif // Im2Col gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A<10, 8, 8>; - gCoreFunc->MNNPackC4Int8ForMatMul_A_ARM86FP16 = _ArmBasicMNNPackC4ForMatMul_A<10, 8, 8>; } #endif MNNInt8FunctionInit(); diff --git a/source/backend/cpu/compute/Int8FunctionsOpt.h b/source/backend/cpu/compute/Int8FunctionsOpt.h index da974619c..6860c0643 100644 --- a/source/backend/cpu/compute/Int8FunctionsOpt.h +++ b/source/backend/cpu/compute/Int8FunctionsOpt.h @@ -48,7 +48,7 @@ struct QuanPostTreatParameters { float* weightQuanBias; float* fp32minmax; ssize_t blockNum = 1; - const int32_t* bias; + const int32_t* bias = nullptr; const float* extraScale = nullptr; const float* extraBias = nullptr; }; @@ -61,7 +61,7 @@ struct QuanPrePostParameters{ ssize_t maxValue; }; void MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, - ssize_t maxValue, ssize_t zeroPoint); + ssize_t maxValue, const float* zeroPoint, ssize_t quanParamVec); void MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t size, ssize_t zeroPoint); void MNNInt8FunctionInit(); void MNNPackedSparseQuantMatMulEpx1(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam, const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap); @@ -84,11 +84,10 @@ struct CoreInt8Functions { void(*Int8GemmKernelFast)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount); void(*MNNGetGemmUnit)(int* UNIT, int* SRC_UNIT, int* DST_XUNIT); void(*MNNPackC4Int8ForMatMul_A)(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el); - void(*MNNPackC4Int8ForMatMul_A_ARM86FP16)(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el) = nullptr; void(*MNNGemmInt8AddBiasScale_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, - const QuanPostTreatParameters* post, size_t realDstCount); + const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; void(*MNNGemmInt8AddBiasScale_w4_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, - const QuanPostTreatParameters* post, size_t realDstCount); + const QuanPostTreatParameters* post, size_t realDstCount) = nullptr; void(*Int8GemmKernel_W4)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDstCount); // sparse @@ -102,9 +101,9 @@ struct CoreInt8Functions { size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder); void(*ConvDepthwise3x3LineInt8_ARM82)(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder) = nullptr; - void(*DynamicQuanInput_ARM82)(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, ssize_t maxValue, ssize_t zeroPoint) = nullptr; - void (*DynamicQuanInputAndReorder_ARM82)(const float* src, int8_t* dst, size_t planeSize, const float* scale, ssize_t aMin, ssize_t aMax, ssize_t zeroPoint, size_t ocQuad, size_t offset) = nullptr; - void(*MNNFloat2Int8)(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, ssize_t maxValue, ssize_t zeroPoint); + void(*DynamicQuanInput_ARM82)(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, ssize_t maxValue, const float* zeroPoint, ssize_t quanParamVec) = nullptr; + void (*DynamicQuanInputAndReorder_ARM82)(const float* src, int8_t* dst, size_t planeSize, const float* scale, ssize_t aMin, ssize_t aMax, const float* zeroPoint, size_t ocQuad, size_t offset) = nullptr; + void(*MNNFloat2Int8)(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, ssize_t maxValue, const float* zeroPoint, ssize_t quanParamVec); void(*MNNInt8ScaleToFloat)(float* dst, const int8_t* src, const float* scale, size_t size, ssize_t zeroPoint); void(*MNNScaleAndAddBias)(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber, size_t biasNumber); diff --git a/source/backend/cpu/x86_x64/AVX2Backend.cpp b/source/backend/cpu/x86_x64/AVX2Backend.cpp index 167a5f984..ed263e366 100644 --- a/source/backend/cpu/x86_x64/AVX2Backend.cpp +++ b/source/backend/cpu/x86_x64/AVX2Backend.cpp @@ -366,6 +366,7 @@ void AVX2Backend::onCopyBuffer(const Tensor* srcTensor, const Tensor* dstTensor) CPUBackend::onCopyBuffer(srcTensor, dstTensor); return; } + _resetDynamicMemory(); if (getDataType(srcTensor) != getDataType(dstTensor)) { auto dimType = Tensor::CAFFE; switch (TensorUtils::getDescribe(srcTensor)->dimensionFormat) { diff --git a/source/backend/cpu/x86_x64/AVX2Functions.cpp b/source/backend/cpu/x86_x64/AVX2Functions.cpp index e48d00981..3bafc7573 100644 --- a/source/backend/cpu/x86_x64/AVX2Functions.cpp +++ b/source/backend/cpu/x86_x64/AVX2Functions.cpp @@ -39,11 +39,14 @@ bool AVX2Functions::init(int cpuFlags) { coreFunction->MNNPackedMatMul = _AVX_MNNPackedMatMul; coreFunction->MNNPackedMatMulRemain = _AVX_MNNPackedMatMulRemain; -#ifdef MNN_LOW_MEMORY +#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM coreFunction->MNNPackedMatMul_int4 = _AVX_MNNPackedMatMul_int4; coreFunction->MNNPackedMatMulRemain_int4 = _AVX_MNNPackedMatMulRemain_int4; coreFunction->MNNPackedMatMul_int8 = _AVX_MNNPackedMatMul_int8; coreFunction->MNNPackedMatMulRemain_int8 = _AVX_MNNPackedMatMulRemain_int8; +#endif + +#ifdef MNN_LOW_MEMORY coreFunction->MNNAbsMax = _AVX_MNNAbsMaxFP32; #endif coreFunction->MNNPackC4ForMatMul_A = _AVX_MNNPackC4ForMatMul_A; diff --git a/source/backend/cpu/x86_x64/CMakeLists.txt b/source/backend/cpu/x86_x64/CMakeLists.txt index 631f12069..d9b462266 100644 --- a/source/backend/cpu/x86_x64/CMakeLists.txt +++ b/source/backend/cpu/x86_x64/CMakeLists.txt @@ -95,6 +95,12 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86_64)|(X86_64)|(x64)|(X64)|(amd64)|(AMD64) target_compile_options(MNNAVX PRIVATE -DMNN_LOW_MEMORY) target_compile_options(MNNAVXFMA PRIVATE -DMNN_LOW_MEMORY) endif() + if (MNN_CPU_WEIGHT_DEQUANT_GEMM) + target_compile_options(MNNX8664 PRIVATE -DMNN_CPU_WEIGHT_DEQUANT_GEMM) + target_compile_options(MNNSSE PRIVATE -DMNN_CPU_WEIGHT_DEQUANT_GEMM) + target_compile_options(MNNAVX PRIVATE -DMNN_CPU_WEIGHT_DEQUANT_GEMM) + target_compile_options(MNNAVXFMA PRIVATE -DMNN_CPU_WEIGHT_DEQUANT_GEMM) + endif() list(APPEND MNN_OBJECTS_TO_LINK $ $ $ $) if (MSVC AND WIN_USE_ASM) target_compile_options(MNNAVX PRIVATE -DMNN_X86_USE_ASM) diff --git a/source/backend/cpu/x86_x64/FunctionDispatcher.cpp b/source/backend/cpu/x86_x64/FunctionDispatcher.cpp index ca87c0464..54effc2cb 100644 --- a/source/backend/cpu/x86_x64/FunctionDispatcher.cpp +++ b/source/backend/cpu/x86_x64/FunctionDispatcher.cpp @@ -50,11 +50,14 @@ void MNNFunctionInit() { coreFunction->MNNGetMatMulPackMode = _SSEMNNGetMatMulPackMode; coreFunction->MNNPackedMatMul = _SSE_MNNPackedMatMul; coreFunction->MNNPackedMatMulRemain = _SSE_MNNPackedMatMulRemain; -#ifdef MNN_LOW_MEMORY +#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM coreFunction->MNNPackedMatMul_int4 = _SSE_MNNPackedMatMul_int4; coreFunction->MNNPackedMatMulRemain_int4 = _SSE_MNNPackedMatMulRemain_int4; coreFunction->MNNPackedMatMul_int8 = _SSE_MNNPackedMatMul_int8; coreFunction->MNNPackedMatMulRemain_int8 = _SSE_MNNPackedMatMulRemain_int8; +#endif + +#ifdef MNN_LOW_MEMORY coreFunction->MNNAbsMax = _SSE_MNNAbsMaxFP32; #endif coreFunction->MNNPackC4ForMatMul_A = _SSE_MNNPackC4ForMatMul_A; diff --git a/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp b/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp index 214010c6f..c21411b48 100644 --- a/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp +++ b/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp @@ -37,7 +37,7 @@ void _AVX_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t const float* postParameters, const float* bias, const float* k, const float* b); void _AVX_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); -#ifdef MNN_LOW_MEMORY +#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM void _AVX_MNNPackedMatMul_int4(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void _AVX_MNNPackedMatMulRemain_int4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, @@ -46,13 +46,16 @@ void _AVX_MNNPackedMatMul_int8(float* C, const float* A, const float* B, const s const float* postParameters, const float* bias, const float* k, const float* b); void _AVX_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +#endif + +#ifdef MNN_LOW_MEMORY void _AVX_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack); #endif void _AVX_MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); void _AVX_MNNExpC8(float* dest, const float* source, float* offset, const float* parameters, size_t countC8); void _AVX_MNNSoftmax(float* dest, const float* source, size_t size); -void _AVX_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, ssize_t zeroPoint); +void _AVX_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, const float* zeroPoint, ssize_t quanParamVec); void _AVX_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t sizeQuad, ssize_t zeroPoint); void _AVX_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dstO, const int8_t* srcO, const int8_t* weightO, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder); void _AVX_MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId); diff --git a/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp b/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp index d19863b14..516f247cc 100644 --- a/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp +++ b/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp @@ -31,7 +31,7 @@ void _AVX_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t AVX2GemmPostTreat(C, eSize, parameter, postParameters, bias); } -#ifdef MNN_LOW_MEMORY +#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM void _AVX_MNNPackedMatMul_int4(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) { _AVX_MNNPackedMatMul_Main_int4(C, A, B, parameter, k, b); @@ -60,17 +60,9 @@ void _AVX_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, s AVX2GemmPostTreat(C, eSize, parameter, postParameters, bias); } } -static __m128i _load_int4_to_int8(const uint8_t* src) { - uint8_t c = 0xf; - uint8_t temp[16]; - for (int i = 0; i < 8; ++i) { - temp[2 * i] = (src[i] >> 4); - temp[2 * i +1] = (src[i] & c); - } - auto int8_tx16 = _mm_loadu_si128((const __m128i*)temp); - return int8_tx16; -} +#endif +#ifdef MNN_LOW_MEMORY void _AVX_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) { // source: (ic/8, N, 8) auto srcStep = pack * realSize; diff --git a/source/backend/cpu/x86_x64/avx/GemmFunction.hpp b/source/backend/cpu/x86_x64/avx/GemmFunction.hpp index bf299722c..8a3accc18 100644 --- a/source/backend/cpu/x86_x64/avx/GemmFunction.hpp +++ b/source/backend/cpu/x86_x64/avx/GemmFunction.hpp @@ -816,7 +816,7 @@ static void _AVX_MNNPackednMatMulRemainCommon(TYPE* C, const TYPE* A, const TYPE } } -#ifdef MNN_LOW_MEMORY +#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM //----------------------- MatMul(float, int4) Functions ---------------------------// #define LOAD_WEIGHT_ALPHA_BIAS_int4x4 \ diff --git a/source/backend/cpu/x86_x64/avx/GemmInt8.cpp b/source/backend/cpu/x86_x64/avx/GemmInt8.cpp index 1a6b60746..450714416 100644 --- a/source/backend/cpu/x86_x64/avx/GemmInt8.cpp +++ b/source/backend/cpu/x86_x64/avx/GemmInt8.cpp @@ -53,10 +53,8 @@ D##u##v = _mm256_add_epi32(D##u##v, _mm256_madd_epi16(W##u, S##v)); #define LOAD_INT4_TO_INT8 \ auto w_int4 = _mm_loadu_si128((__m128i const*)weight_sz);\ -auto w_int4_high = _mm_and_si128(mask, _mm_srli_epi16(w_int4, 4));\ -auto w_int4_low = _mm_and_si128(mask, w_int4);\ -auto w_0 = _mm_unpacklo_epi8(w_int4_high, w_int4_low);\ -auto w_1 = _mm_unpackhi_epi8(w_int4_high, w_int4_low); +auto w_0 = _mm_and_si128(mask, _mm_srli_epi16(w_int4, 4));\ +auto w_1 = _mm_and_si128(mask, w_int4); void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) { MNN_ASSERT(post->useInt8==0); @@ -1316,15 +1314,22 @@ void _AVX_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dstO, const int8_t* srcO, } } -void _AVX_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, ssize_t zeroPoint) { +void _AVX_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, const float* zeroPoint, ssize_t quanParamVec) { auto zero = _mm256_set1_epi32(0); auto minValue = _mm256_set1_ps(minV); auto maxValue = _mm256_set1_ps(maxV); - auto zeroPointValue = _mm256_set1_ps(zeroPoint); + auto zeroPointValue = _mm256_set1_ps(zeroPoint[0]); auto offset = _mm256_set1_epi32(128); auto plus = _mm256_set1_ps(0.5f); auto minus = _mm256_set1_ps(-0.5f); - auto scaleValue = _mm256_loadu_ps(scalep); + auto scaleValue = _mm256_set1_ps(scalep[0]); + + if (quanParamVec & 1) { + scaleValue = _mm256_loadu_ps(scalep); + } + if (quanParamVec >> 1) { + zeroPointValue = _mm256_loadu_ps(zeroPoint); + } for (int i = 0; i < sizeQuad; ++i) { auto f0 = _mm256_loadu_ps(src + 8 * i); diff --git a/source/backend/cpu/x86_x64/avx512/GemmInt8.cpp b/source/backend/cpu/x86_x64/avx512/GemmInt8.cpp index 6eb8a5379..fd80b6dc8 100644 --- a/source/backend/cpu/x86_x64/avx512/GemmInt8.cpp +++ b/source/backend/cpu/x86_x64/avx512/GemmInt8.cpp @@ -201,16 +201,23 @@ void _AVX512_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dstO, const int8_t* sr src += src_w_step; } } -void _AVX512_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, ssize_t zeroPoint) { +void _AVX512_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, const float* zeroPoint, ssize_t quanParamVec) { auto zero = _mm256_set1_epi32(0); auto minValue = _mm256_set1_ps(minV); auto maxValue = _mm256_set1_ps(maxV); - auto zeroPointValue = _mm256_set1_ps(zeroPoint); + auto zeroPointValue = _mm256_set1_ps(zeroPoint[0]); auto offset = _mm256_set1_epi32(128); auto plus = _mm256_set1_ps(0.5f); auto minus = _mm256_set1_ps(-0.5f); - auto scaleValue0 = _mm256_loadu_ps(scalep); - auto scaleValue1 = _mm256_loadu_ps(scalep + 8); + auto scaleValue0 = _mm256_set1_ps(scalep[0]); + auto scaleValue1 = scaleValue0; + if (quanParamVec & 1) { + scaleValue0 = _mm256_loadu_ps(scalep); + scaleValue1 = _mm256_loadu_ps(scalep + 8); + } + if (quanParamVec >> 1) { + zeroPointValue = _mm256_loadu_ps(zeroPoint); + } for (int i = 0; i < sizeQuad; ++i) { auto f0 = _mm256_loadu_ps(src + PACK_UNIT * i); diff --git a/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp b/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp index 4f1525087..7e8fff748 100644 --- a/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp +++ b/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp @@ -50,7 +50,7 @@ void _SSE_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t const float* postParameters, const float* bias, const float* k, const float* b); void _SSE_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); -#ifdef MNN_LOW_MEMORY +#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM void _SSE_MNNPackedMatMul_int4(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void _SSE_MNNPackedMatMulRemain_int4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, @@ -59,6 +59,8 @@ void _SSE_MNNPackedMatMul_int8(float* C, const float* A, const float* B, const s const float* postParameters, const float* bias, const float* k, const float* b); void _SSE_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +#endif +#ifdef MNN_LOW_MEMORY void _SSE_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack); void _SSE_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst); @@ -71,7 +73,7 @@ void _SSE_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst); void _SSE_MNNExpC8(float* dest, const float* source, float* offset, const float* parameters, size_t countC8); void _SSE_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose); -void _SSE_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, ssize_t maxValue, ssize_t zeroPoint); +void _SSE_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, ssize_t maxValue, const float* zeroPoint, ssize_t quanParamVec); void _SSE_MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size_t size, ssize_t zeroPoint); void _SSE_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder=nullptr); diff --git a/source/backend/cpu/x86_x64/sse/GemmInt8.cpp b/source/backend/cpu/x86_x64/sse/GemmInt8.cpp index 77702c2d4..f1fb9b338 100644 --- a/source/backend/cpu/x86_x64/sse/GemmInt8.cpp +++ b/source/backend/cpu/x86_x64/sse/GemmInt8.cpp @@ -300,14 +300,14 @@ auto d##i##j = _mm_add_epi32(_mm_madd_epi16(S##i##j##0, W##i##j##0), _mm_madd_ep } } } +#define LOAD_INT4_TO_INT8 \ + auto w0_int4 = _mm_loadu_si128(reinterpret_cast(weight_sz));\ + auto w1_int4 = _mm_loadu_si128(reinterpret_cast(weight_sz + 16));\ + auto w0 = _mm_and_si128(mask, _mm_srli_epi16(w0_int4, 4));\ + auto w1 = _mm_and_si128(mask, _mm_srli_epi16(w1_int4, 4));\ + auto w2 = _mm_and_si128(mask, w0_int4);\ + auto w3 = _mm_and_si128(mask, w1_int4); -static inline void _load_int4_to_int8(const uint8_t* src, int8_t* dst) { - uint8_t c = 0xf; - for (int i = 0; i < 32; ++i) { - dst[2 * i] = (src[i] >> 4); - dst[2 * i +1] = (src[i] & c); - } -} void _SSE_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) { MNN_ASSERT(post->useInt8 == 0); @@ -335,6 +335,7 @@ void _SSE_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const __m128 kernelSum1 = _mm_setzero_ps(); __m128 kernelSum2 = _mm_setzero_ps(); __m128 kernelSum3 = _mm_setzero_ps(); + const auto mask = _mm_set1_epi8(0xf); if (GEMM_INT8_DST_XUNIT == realDst) { kernelSum0 = _mm_load_ps1(post->srcKernelSum); kernelSum1 = _mm_load_ps1(post->srcKernelSum + 1); @@ -402,13 +403,7 @@ void _SSE_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const const auto weight_sz = weight_dz + weight_step_Y * sz; const auto src_z = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT; - int8_t tmp_w[64]; - _load_int4_to_int8((uint8_t*)weight_sz, tmp_w); - - auto w0 = _mm_loadu_si128((__m128i*)(tmp_w + GEMM_INT8_SRC_UNIT * 0)); - auto w1 = _mm_loadu_si128((__m128i*)(tmp_w + GEMM_INT8_SRC_UNIT * 1)); - auto w2 = _mm_loadu_si128((__m128i*)(tmp_w + GEMM_INT8_SRC_UNIT * 2)); - auto w3 = _mm_loadu_si128((__m128i*)(tmp_w + GEMM_INT8_SRC_UNIT * 3)); + LOAD_INT4_TO_INT8; auto s0 = _mm_loadu_si128((__m128i*)(src_z + GEMM_INT8_SRC_UNIT * 0)); auto s1 = _mm_loadu_si128((__m128i*)(src_z + GEMM_INT8_SRC_UNIT * 1)); @@ -480,12 +475,6 @@ auto d##i##j = _mm_add_epi32(_mm_madd_epi16(S##i##j##0, W##i##j##0), _mm_madd_ep E1 = _mm_hadd_epi32(E2, E3); d3 = _mm_hadd_epi32(E0, E1); auto scaleValue = _mm_loadu_ps(scale_dz); - // auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz)); - // d0 = _mm_add_epi32(d0, biasValue); - // d1 = _mm_add_epi32(d1, biasValue); - // d2 = _mm_add_epi32(d2, biasValue); - // d3 = _mm_add_epi32(d3, biasValue); - //auto biasValue = _mm_loadu_ps((float*)(bias_dz)); auto weightBiasValue = _mm_loadu_ps((float*)weightBias_dz); __m128 f0 = _mm_cvtepi32_ps(d0); __m128 f1 = _mm_cvtepi32_ps(d1); @@ -584,14 +573,20 @@ void _SSE_MNNReluInt8(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroP } // require SSE 4.1 -void _SSE_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, ssize_t zeroPoint) { +void _SSE_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, const float* zeroPoint, ssize_t quanParamVec) { __m128i zero = _mm_set1_epi32(0); __m128 minValue = _mm_set1_ps(minV); __m128 maxValue = _mm_set1_ps(maxV); - __m128 zeroPointValue = _mm_set1_ps(zeroPoint); + __m128 zeroPointValue = _mm_set1_ps(zeroPoint[0]); __m128 plus = _mm_set1_ps(0.5f); __m128 minus = _mm_set1_ps(-0.5f); - __m128 scaleValue = _mm_loadu_ps(scalep); + __m128 scaleValue = _mm_set1_ps(scalep[0]); + if (quanParamVec & 1) { + scaleValue = _mm_loadu_ps(scalep); + } + if (quanParamVec >> 1) { + zeroPointValue = _mm_loadu_ps(zeroPoint); + } auto offset = _mm_set1_epi32(128); for (int i = 0; i < sizeQuad; ++i) { diff --git a/source/backend/cpu/x86_x64/sse/GemmSSE.cpp b/source/backend/cpu/x86_x64/sse/GemmSSE.cpp index 8e5a32896..336019603 100644 --- a/source/backend/cpu/x86_x64/sse/GemmSSE.cpp +++ b/source/backend/cpu/x86_x64/sse/GemmSSE.cpp @@ -27,7 +27,7 @@ void _SSE_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t _SSE_GemmPostTreat(C, eSize, parameter, postParameters, bias); } -#ifdef MNN_LOW_MEMORY +#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM //----------------------- MatMul(float, int4) Functions ---------------------------// void _SSE_MNNPackedMatMul_int4(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) { @@ -66,7 +66,9 @@ void _SSE_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, s _SSE_GemmPostTreat(C, eSize, parameter, postParameters, bias); } } +#endif +#ifdef MNN_LOW_MEMORY // Dynamic quant void _SSE_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) { // source: (ic/4, N, 4) diff --git a/source/backend/cuda/core/CUDABackend.cpp b/source/backend/cuda/core/CUDABackend.cpp index 1cefb8a2b..e72155724 100644 --- a/source/backend/cuda/core/CUDABackend.cpp +++ b/source/backend/cuda/core/CUDABackend.cpp @@ -79,7 +79,7 @@ bool CUDARuntimeWrapper::onSetCache(const void* buffer, size_t size) {//set Cach return mCUDARuntime->setCache(std::make_pair(buffer, size)); } -Backend* CUDARuntimeWrapper::onCreate(const BackendConfig* config) const { +Backend* CUDARuntimeWrapper::onCreate(const BackendConfig* config, Backend* origin) const { #ifdef LOG_VERBOSE MNN_PRINT("cudaruntime:%p, create CUDABackend\n", this); #endif diff --git a/source/backend/cuda/core/CUDABackend.hpp b/source/backend/cuda/core/CUDABackend.hpp index 3c3fb2402..03737a10b 100644 --- a/source/backend/cuda/core/CUDABackend.hpp +++ b/source/backend/cuda/core/CUDABackend.hpp @@ -31,7 +31,7 @@ class MNN_PUBLIC CUDARuntimeWrapper : public Runtime { public: CUDARuntimeWrapper(BackendConfig::PrecisionMode precision, BackendConfig::PowerMode power, BackendConfig::MemoryMode memory, int deviceId = 0); virtual ~CUDARuntimeWrapper(); - virtual Backend *onCreate(const BackendConfig* config) const override; + virtual Backend *onCreate(const BackendConfig* config, Backend* origin) const override; virtual void onGabageCollect(int level) override; bool isCreateError() const { return mIsCreateError; diff --git a/source/backend/hiai/backend/NPUBackend.cpp b/source/backend/hiai/backend/NPUBackend.cpp index 1b4f45fca..33159aa67 100644 --- a/source/backend/hiai/backend/NPUBackend.cpp +++ b/source/backend/hiai/backend/NPUBackend.cpp @@ -552,7 +552,7 @@ namespace MNN { NPURuntime::~NPURuntime() {} - Backend* NPURuntime::onCreate(const BackendConfig* config) const { + Backend* NPURuntime::onCreate(const BackendConfig* config, Backend* origin) const { return new NPUBackend(this); } diff --git a/source/backend/hiai/backend/NPUBackend.hpp b/source/backend/hiai/backend/NPUBackend.hpp index 4ee14a513..cfada3d13 100644 --- a/source/backend/hiai/backend/NPUBackend.hpp +++ b/source/backend/hiai/backend/NPUBackend.hpp @@ -251,7 +251,7 @@ namespace MNN { NPURuntime(const Backend::Info& info); virtual ~NPURuntime(); virtual CompilerType onGetCompilerType() const override; - virtual Backend* onCreate(const BackendConfig* conf) const override; + virtual Backend* onCreate(const BackendConfig* conf, Backend* origin) const override; virtual void onGabageCollect(int level) override; // If buffer is not nullptr, try copy cache, else delete cache virtual bool onSetCache(const void* buffer, size_t size) override { diff --git a/source/backend/metal/MetalAttention.mm b/source/backend/metal/MetalAttention.mm index 2c6eed591..e1d1ef28f 100644 --- a/source/backend/metal/MetalAttention.mm +++ b/source/backend/metal/MetalAttention.mm @@ -239,21 +239,11 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { auto context = (__bridge MNNMetalContext *)mtbn->context(); mParamQKV = [context newDeviceBuffer:sizeof(Param) access:CPUWriteOnly]; mParamSoftmax = [context newDeviceBuffer:4 * sizeof(int) access:CPUWriteOnly]; - + mTempQK.reset(Tensor::createDevice({0, 0})); + mTempSoftMax.reset(Tensor::createDevice({0, 0})); } void AttentionBufExecution::reallocKVCache() { - if (mCache->mPastLength < mCache->mMaxLength || nullptr == mTempQK || (!mIsDecode)) { - if (mIsDecode) { - mTempQK.reset(Tensor::createDevice({mNumHead, mCache->mMaxLength})); - mTempSoftMax.reset(Tensor::createDevice({mNumHead, mCache->mMaxLength})); - } else { - mTempQK.reset(Tensor::createDevice({mNumHead, mCache->mPastLength, mCache->mPastLength})); - mTempSoftMax.reset(Tensor::createDevice({mNumHead, mCache->mPastLength, mCache->mPastLength})); - } - backend()->onAcquireBuffer(mTempQK.get(), Backend::STATIC); - backend()->onAcquireBuffer(mTempSoftMax.get(), Backend::STATIC); - } if (!mKVCache || mCache->mPastLength < mCache->mMaxLength) { return; } @@ -378,6 +368,31 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { int group_size = mNumHead / mKvNumHead; reallocKVCache(); + bool needMalloc = mTempQK->length(0) != mNumHead; + if (mIsDecode) { + if (mTempQK->length(1) != mCache->mMaxLength) { + needMalloc = true; + } + mTempQK->setLength(0, mNumHead); + mTempQK->setLength(1, mCache->mMaxLength); + mTempSoftMax->setLength(0, mNumHead); + mTempSoftMax->setLength(1, mCache->mMaxLength); + } else { + if (mTempQK->length(1) != mCache->mPastLength * mCache->mPastLength) { + needMalloc = true; + } + mTempQK->setLength(0, mNumHead); + mTempQK->setLength(1, mCache->mPastLength * mCache->mPastLength); + mTempSoftMax->setLength(0, mNumHead); + mTempSoftMax->setLength(1, mCache->mPastLength * mCache->mPastLength); + } + if (needMalloc) { + auto res = backend()->onAcquireBuffer(mTempQK.get(), Backend::STATIC) && backend()->onAcquireBuffer(mTempSoftMax.get(), Backend::STATIC); + if (!res) { + MNN_ERROR("MNN::Metal: OUT_OF_MEMORY when execute attention metal\n"); + return; + } + } // Update Parameters { @@ -456,7 +471,6 @@ virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override { mCache->mPastLength += 1; mCache->mKv_seq_len = mCache->mPastLength + 1; } - return; } diff --git a/source/backend/metal/MetalBackend.hpp b/source/backend/metal/MetalBackend.hpp index e01913a38..22eee335f 100644 --- a/source/backend/metal/MetalBackend.hpp +++ b/source/backend/metal/MetalBackend.hpp @@ -56,7 +56,7 @@ class MetalRuntime : public Runtime { std::map>, std::tuple, std::vector, uint32_t>>& getTunedThreadGroup() { return mTunedThreadGroup; }; - virtual Backend *onCreate(const BackendConfig* config) const override; + virtual Backend *onCreate(const BackendConfig* config, Backend* origin) const override; virtual void onGabageCollect(int level) override; virtual CompilerType onGetCompilerType() const override { return Compiler_Loop; @@ -71,10 +71,16 @@ class MetalRuntime : public Runtime { const MNN::Op* op) override; virtual bool onMeasure(const std::vector& inputs, const std::vector& outputs, const MNN::Op* op, Runtime::OpInfo& dstInfo) const override; + SingleBufferWithAllocator* buffer(int index) const { + return &mDynamic[index]; + } + BufferAllocator* createDynamicAllocator(int index, bool secondResize) const; private: MetalRuntime(void* context); void* mContext = nullptr; - std::shared_ptr mStatic; + mutable std::shared_ptr mStatic; + mutable std::shared_ptr mStaticCache; + mutable std::vector mDynamic; MetalTuneLevel mTuneLevel = Wide; std::map>, std::tuple, std::vector, uint32_t>> mTunedThreadGroup; @@ -226,8 +232,6 @@ class MetalBackend : public Backend { id _commandQueue; const MetalRuntime* mRuntime; - id mShapeH2D; - id mShapeD2H; mutable NSUInteger mEncoderCount = 0; mutable bool mOpEncoderSet = false;//whether has set encoder mutable bool mSupportDeferEncode = true; @@ -240,6 +244,7 @@ class MetalBackend : public Backend { std::shared_ptr mStaticBufferPool; private: + void _resetDynamicMemory() const; CopyPipeline _makeCopyInfo(const Tensor *src, const Tensor *dst, id shape, int castType) const; mutable id mHostBuffer = nullptr; diff --git a/source/backend/metal/MetalBackend.mm b/source/backend/metal/MetalBackend.mm index 6f73629bb..268db6fde 100644 --- a/source/backend/metal/MetalBackend.mm +++ b/source/backend/metal/MetalBackend.mm @@ -10,6 +10,7 @@ #define MNN_METAL #import #define METAL_CONST_BUFFER_LIMIT 128 +#define METAL_SEPERATE_MAX_COUNT 2 #if MNN_METAL_ENABLED #import "backend/metal/MNNMetalContext.h" #import "core/Macro.h" @@ -35,12 +36,16 @@ static void _MetalApplyTensor(uint8_t* host, size_t offset, Tensor* t) { auto des = TensorUtils::getDescribe(t); des->extra.offset = offset; } -static BufferAllocator* _createBufferAllocator(const Runtime* runtime, BufferAllocator* origin, bool secondResize) { - if (runtime->hint().memoryAllocatorType == Runtime::Allocator_Defer && secondResize) { - return new DeferBufferAllocator(BufferAllocator::Allocator::createRecurse(origin), 1024, _MetalApplyTensor); +BufferAllocator* MetalRuntime::createDynamicAllocator(int index, bool secondResize) const { + if (hint().memoryAllocatorType == Runtime::Allocator_Defer && secondResize) { + return new DeferBufferAllocator(buffer(index), 1024, _MetalApplyTensor); } - return new EagerBufferAllocator(BufferAllocator::Allocator::createRecurse(origin), 1024); + if (mStaticCache.get() != nullptr) { + return new EagerBufferAllocator(BufferAllocator::Allocator::createRecurse(mStaticCache.get()), 1024); + } + return new EagerBufferAllocator(BufferAllocator::Allocator::createRecurse(mStatic.get()), 1024); } + struct TunedInfo { std::vector> mInfos; }; @@ -70,11 +75,9 @@ static void _MetalApplyTensor(uint8_t* host, size_t offset, Tensor* t) { { mRuntime = runtime; auto ctx = (__bridge MNNMetalContext *)runtime->context(); - mBufferPool.reset(_createBufferAllocator(runtime, staticMem.get(), false)); + mBufferPool.reset(runtime->createDynamicAllocator(0, false)); mCurrentAllocator = mBufferPool.get(); mStaticBufferPool = staticMem; - mShapeH2D = getConstBuffer(4 * sizeof(int)); - mShapeD2H = getConstBuffer(4 * sizeof(int)); mUseFloatAsFp16 = usefp16AsFp32; mIsIphone = ctx.isIphone; if (runtime->getCommandQueue() == nil) { @@ -207,6 +210,9 @@ MemChunk chunk() override { bool MetalBackend::onClearBuffer() { mCurrentAllocator->release(true); + if (nullptr != mRuntime->mStaticCache.get()) { + mStaticBufferPool = mRuntime->mStaticCache; + } return true; } @@ -238,8 +244,15 @@ MemChunk chunk() override { mComputeEncoder = nil; } } +void MetalBackend::_resetDynamicMemory() const { + mCurrentAllocator->apply(); + if (nullptr != mBufferPoolShapeImmutable.get()) { + mBufferPoolShapeImmutable->apply(); + } +} void MetalBackend::onExecuteBegin() const { + _resetDynamicMemory(); mEncoderCount = 0; } void MetalBackend::onExecuteEnd() const { @@ -263,8 +276,8 @@ MemChunk chunk() override { return false; } if (maxIndex == 2 && mBufferPoolShapeImmutable.get() == nullptr) { - mBufferPoolShapeImmutable.reset(_createBufferAllocator(mRuntime, mStaticBufferPool.get(), true)); - mBufferPool.reset(_createBufferAllocator(mRuntime, mStaticBufferPool.get(), true)); + mBufferPoolShapeImmutable.reset(mRuntime->createDynamicAllocator(1, true)); + mBufferPool.reset(mRuntime->createDynamicAllocator(0, true)); } if (1 == index) { mCurrentAllocator = mBufferPoolShapeImmutable.get(); @@ -315,9 +328,7 @@ MemChunk chunk() override { } id MetalBackend::getHostBuffer(size_t size) const { - if (size < METAL_CONST_BUFFER_LIMIT) { - size = METAL_CONST_BUFFER_LIMIT; - } + size = UP_DIV(size, METAL_CONST_BUFFER_LIMIT) * METAL_CONST_BUFFER_LIMIT; // reuse if (nullptr != mHostBuffer && mHostBuffer.length >= size) { return mHostBuffer; @@ -703,7 +714,7 @@ static void _execute(id encoder, const MetalBackend::C if(!mFrameEncodeCache) { commit_net(); } - + _resetDynamicMemory(); onCopyBuffer(src, dst, nil, nil); } @@ -983,6 +994,10 @@ static void _execute(id encoder, const MetalBackend::C auto ctx = (__bridge MNNMetalContext *)mContext; std::shared_ptr allocator(new MetalRuntimeAllocator([ctx device])); mStatic.reset(new EagerBufferAllocator(allocator)); + mDynamic.resize(METAL_SEPERATE_MAX_COUNT); + for (auto& buf : mDynamic) { + buf.root = allocator; + } mTunedInfo = new TunedInfo; } @@ -1067,7 +1082,11 @@ static void _execute(id encoder, const MetalBackend::C float MetalRuntime::onGetMemoryInMB() { auto staticMemoryInMB = mStatic->totalSize() / 1024.0f / 1024.0f; - return staticMemoryInMB; + float dynamicMemoryInMB = 0.0f; + for (auto& buf : mDynamic) { + dynamicMemoryInMB += buf.currentSize / 1024.0f / 1024.0f; + } + return staticMemoryInMB + dynamicMemoryInMB; } void MetalRuntime::onMaskOpReady(const std::vector& inputs, const std::vector& outputs, @@ -1153,7 +1172,36 @@ static bool _checkTensorInfo(const MetalCache::TensorInfoT* dst, const Tensor* s return true; } -Backend* MetalRuntime::onCreate(const BackendConfig* config) const { +class MetalWrapAllocator : public BufferAllocator::Allocator { +private: + std::shared_ptr mOrigin; + id mDevice; +public: + MetalWrapAllocator(std::shared_ptr origin, id device) : mOrigin(origin), mDevice(device) {} + virtual ~ MetalWrapAllocator() { + // Do nothing + } + virtual MemChunk onAlloc(size_t size, size_t align) override { + auto mem = mOrigin->onAlloc(size, align); + MNN_ASSERT(mem.second == 0); + id buffer = [mDevice newBufferWithBytesNoCopy:mem.first length:size options:MTLResourceStorageModeShared deallocator:nil]; + auto wrap = new MetalRuntimeAllocator::MetalBufferAlloc(buffer); + return MemChunk((void *)wrap, 0); + } + virtual void onRelease(MemChunk chunk) override { + auto mem = (MetalRuntimeAllocator::MetalBufferAlloc *)chunk.first; + mOrigin->onRelease(MemChunk(mem->getBuffer().contents)); + delete mem; + } +}; +Backend* MetalRuntime::onCreate(const BackendConfig* config, Backend* origin) const { + if (hint().weightMemoryPath.size() > 0 && mStaticCache.get() == nullptr) { + auto ctx = (__bridge MNNMetalContext *)mContext; + auto mmap = BufferAllocator::Allocator::createMmap(hint().weightMemoryPath.c_str(), "metal.weight"); + std::shared_ptr mmapMem(new MetalWrapAllocator(mmap, [ctx device])); + mStaticCache = mStatic; + mStatic.reset(new EagerBufferAllocator(mmapMem, 32, 1024 * 1024 * 1024)); + } BackendConfig::PrecisionMode precision = mDefaultConfig.precision; if (nullptr != config) { precision = config->precision; @@ -1164,6 +1212,11 @@ static bool _checkTensorInfo(const MetalCache::TensorInfoT* dst, const Tensor* s void MetalRuntime::onGabageCollect(int level) { mStatic->release(false); + if (level >= 100) { + for (auto& buf : mDynamic) { + buf.release(); + } + } } std::pair MetalRuntime::onGetCache() {//make Cache diff --git a/source/backend/metal/MetalBinary.mm b/source/backend/metal/MetalBinary.mm index b9482ce84..6854a5578 100755 --- a/source/backend/metal/MetalBinary.mm +++ b/source/backend/metal/MetalBinary.mm @@ -40,9 +40,9 @@ void MetalBinary::onEncode(const std::vector &inputs, const std::vector &outputs, id encoder) { auto input0 = inputs[0], input1 = inputs[1], output = outputs[0]; [encoder setComputePipelineState:mPipeline]; - [encoder setBuffer:(id)((MetalRuntimeAllocator::MetalBufferAlloc *)input0->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input0)->extra.offset atIndex:0]; - [encoder setBuffer:(id)((MetalRuntimeAllocator::MetalBufferAlloc *)input1->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input1)->extra.offset atIndex:1]; - [encoder setBuffer:(id)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:2]; + MetalBackend::setTensor(input0, encoder, 0); + MetalBackend::setTensor(input1, encoder, 1); + MetalBackend::setTensor(output, encoder, 2); [encoder setBuffer:mConstBuffer offset:0 atIndex:3]; [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second]; } diff --git a/source/backend/metal/MetalConvolutionDepthwise.mm b/source/backend/metal/MetalConvolutionDepthwise.mm index 85b17c88f..cb3036225 100755 --- a/source/backend/metal/MetalConvolutionDepthwise.mm +++ b/source/backend/metal/MetalConvolutionDepthwise.mm @@ -73,7 +73,13 @@ mConstBuffer, (id)(((MetalRuntimeAllocator::MetalBufferAlloc *)mWeight->deviceId()))->getBuffer(), ((MetalRuntimeAllocator::MetalBufferAlloc *)mBias->deviceId())->getBuffer(), nil]; const Tensor* weight = mWeight.get(); const Tensor* bias = mBias.get(); - int buffer_offset[] = {TensorUtils::getDescribe(input)->extra.offset, TensorUtils::getDescribe(output)->extra.offset, TensorUtils::getDescribe(weight)->extra.offset, TensorUtils::getDescribe(bias)->extra.offset, 0}; + int buffer_offset[] = { + TensorUtils::getDescribe(input)->extra.offset, + TensorUtils::getDescribe(output)->extra.offset, + 0, + TensorUtils::getDescribe(weight)->extra.offset, + TensorUtils::getDescribe(bias)->extra.offset + }; std::string name = "conv_depthwise"; MetalRuntime *rt = (MetalRuntime *)backend->runtime(); diff --git a/source/backend/metal/MetalUnary.mm b/source/backend/metal/MetalUnary.mm index bc66f77a2..72b91f874 100755 --- a/source/backend/metal/MetalUnary.mm +++ b/source/backend/metal/MetalUnary.mm @@ -122,8 +122,8 @@ kernel void main0(const device T *in [[buffer(0)]], \ void MetalUnary::onEncode(const std::vector &inputs, const std::vector &outputs, id encoder) { auto input = inputs[0], output = outputs[0]; [encoder setComputePipelineState:mPipeline]; - [encoder setBuffer:(id)((MetalRuntimeAllocator::MetalBufferAlloc *)input->deviceId())->getBuffer() offset:TensorUtils::getDescribe(input)->extra.offset atIndex:0]; - [encoder setBuffer:(id)((MetalRuntimeAllocator::MetalBufferAlloc *)output->deviceId())->getBuffer() offset:TensorUtils::getDescribe(output)->extra.offset atIndex:1]; + MetalBackend::setTensor(input, encoder, 0); + MetalBackend::setTensor(output, encoder, 1); [encoder setBuffer:mConstBuffer offset:0 atIndex:2]; [encoder dispatchThreadgroups:mThreads.first threadsPerThreadgroup:mThreads.second]; } diff --git a/source/backend/nnapi/backend/NNAPIBackend.cpp b/source/backend/nnapi/backend/NNAPIBackend.cpp index 81ac425e0..a60b0b90f 100644 --- a/source/backend/nnapi/backend/NNAPIBackend.cpp +++ b/source/backend/nnapi/backend/NNAPIBackend.cpp @@ -549,7 +549,7 @@ namespace MNN { NNAPIRuntime::~NNAPIRuntime() {} - Backend* NNAPIRuntime::onCreate(const BackendConfig* config) const { + Backend* NNAPIRuntime::onCreate(const BackendConfig* config, Backend* origin) const { return new NNAPIBackend(this); } diff --git a/source/backend/nnapi/backend/NNAPIBackend.hpp b/source/backend/nnapi/backend/NNAPIBackend.hpp index 17e947973..ac6b462b3 100644 --- a/source/backend/nnapi/backend/NNAPIBackend.hpp +++ b/source/backend/nnapi/backend/NNAPIBackend.hpp @@ -50,7 +50,7 @@ namespace MNN { NNAPIRuntime(const Backend::Info& info); virtual ~NNAPIRuntime(); virtual CompilerType onGetCompilerType() const override; - virtual Backend* onCreate(const BackendConfig* conf) const override; + virtual Backend* onCreate(const BackendConfig* conf, Backend* origin) const override; virtual void onGabageCollect(int level) override; virtual std::pair onGetCache() override { return std::make_pair(mCacheBuffer, mCacheSize); diff --git a/source/backend/opencl/core/BufferConvertor.cpp b/source/backend/opencl/core/BufferConvertor.cpp index 1d649a0b8..1f6abd82b 100644 --- a/source/backend/opencl/core/BufferConvertor.cpp +++ b/source/backend/opencl/core/BufferConvertor.cpp @@ -170,82 +170,6 @@ bool converNCHWOrNHWCBufferToNC4HW4OrNC16HW16Buffer(const Tensor *input, Tensor return true; } -bool convertNC4HW4BufferToNC4HW4Buffer(const Tensor *input, Tensor *output, OpenCLRuntime *runtime, TransType formatTrans, bool needWait, bool svmFlag, bool srcswap, bool dstswap) { - std::vector outputShape = tensorShapeFormat(input); - uint32_t outputGlobalWorkSize[2] = {static_cast(UP_DIV(outputShape[3], 4) * outputShape[2]), - static_cast(outputShape[0] * outputShape[1])}; - std::set buildOptions; - std::string kernelName = "nc4hw4_buffer_to_nc4hw4_buffer"; - switch (formatTrans) { - case InpTrans: - AddBuildOptionOfDataType(input, output, buildOptions, runtime->isSupportedFP16(), true, false); - break; - case OutTrans: - AddBuildOptionOfDataType(input, output, buildOptions, runtime->isSupportedFP16(), false, true); - break; - default: - AddBuildOptionOfDataType(input, output, buildOptions, runtime->isSupportedFP16(), true, true); - break; - } - auto convertBufferKernelW = runtime->buildKernelWithCache("buffer_convert_buf", kernelName, buildOptions); - auto convertBufferKernel = convertBufferKernelW->get(); - uint32_t idx = 0; - int outputImageShape[2] = {input->height(), input->width()}; - int channelC4 = UP_DIV(input->channel(), 4); - int batch = input->batch(); - int srcStride[2] = { - channelC4, - 1 - }; - int dstStride[2] = { - channelC4, - 1 - }; - if (srcswap) { - srcStride[0] = 1; - srcStride[1] = batch; - } - if (dstswap) { - dstStride[0] = 1; - dstStride[1] = batch; - } - cl_int ret = CL_SUCCESS; - ret |= convertBufferKernel.setArg(idx++, outputGlobalWorkSize[0]); - ret |= convertBufferKernel.setArg(idx++, outputGlobalWorkSize[1]); -#ifdef MNN_OPENCL_SVM_ENABLE - if(svmFlag == true) - { - ret |= clSetKernelArgSVMPointer(convertBufferKernel.get(), idx++, (const void *)input->buffer().device); - } - else -#endif - { - ret |= convertBufferKernel.setArg(idx++, openCLBuffer(input)); - } - ret |= convertBufferKernel.setArg(idx++, sizeof(outputImageShape), outputImageShape); - ret |= convertBufferKernel.setArg(idx++, sizeof(srcStride), srcStride); - ret |= convertBufferKernel.setArg(idx++, sizeof(dstStride), dstStride); - ret |= convertBufferKernel.setArg(idx++, openCLBuffer(output)); - MNN_CHECK_CL_SUCCESS(ret, "setArg convertNC4HW4BufferToNC4HW4Buffer"); - - const uint32_t maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(convertBufferKernelW)); - const std::vector lws = {16, std::max((uint32_t)1, maxWorkGroupSize / 16)}; - cl::Event event; - cl_int res; - std::vector roundUpGroupWorkSize(lws.size()); - for (size_t i = 0; i < lws.size(); ++i) { - roundUpGroupWorkSize[i] = ROUND_UP(outputGlobalWorkSize[i], lws[i]); - } - res = runtime->commandQueue().enqueueNDRangeKernel(convertBufferKernel, cl::NullRange, - cl::NDRange(roundUpGroupWorkSize[0], roundUpGroupWorkSize[1]), - cl::NDRange(lws[0], lws[1]), nullptr, &event); - MNN_CHECK_CL_SUCCESS(res, "nc4hw4_buffer_to_nc4hw4_buffer"); - if (true == needWait) { - event.wait(); - } - return true; -} - #ifdef MNN_SUPPORT_INTEL_SUBGROUP bool convertNC4HW4BufferBetweenNC16HW16Buffer(const Tensor *input, Tensor *output, const std::string Name, OpenCLRuntime *runtime, TransType formatTrans, bool needWait, bool svmFlag, @@ -511,6 +435,145 @@ bool BufferConvertor::convertToNC4HW4Buffer(const Tensor *buffer, const OpenCLBu #endif return true; } + +bool convertBufferToBuffer(Tensor *input, Tensor *output, OpenCLRuntime *runtime, bool toDevice, bool toHost, bool needWait, bool svmFlag) { + std::vector outputShape = tensorShapeFormat(input); + int shape[4] = {outputShape[0], outputShape[3], outputShape[1], outputShape[2]};//N C H W + auto srcDimensionFormat = TensorUtils::getDescribe(input)->dimensionFormat; + auto dstDimensionFormat = TensorUtils::getDescribe(output)->dimensionFormat; + if (MNN_DATA_FORMAT_NC4HW4 == dstDimensionFormat && srcDimensionFormat != dstDimensionFormat && (outputShape[3] % 4) != 0){ + int region[] = {outputShape[0], ROUND_UP(outputShape[3], 4), outputShape[1], outputShape[2]};//nchw + + auto kernelW = runtime->buildKernelWithCache("raster_buf", "buffer_set_zero", {}, output, output); + auto kernel = kernelW->get(); + uint32_t lws[2] = {8, 8}; + uint32_t gws[2] = {(uint32_t)UP_DIV((region[2] * region[3]), 8)*8, (uint32_t)UP_DIV((region[0] * region[1]), 8)*8}; + + int global_dim0 = region[2] * region[3]; + int global_dim1 = region[0] * region[1]; + + uint32_t idx = 0; + cl_int res = CL_SUCCESS; + res |= kernel.setArg(idx++, global_dim0); + res |= kernel.setArg(idx++, global_dim1); + res |= kernel.setArg(idx++, openCLBuffer(output)); + MNN_CHECK_CL_SUCCESS(res, "setArg buffer_set_zero"); + + res = runtime->commandQueue().enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange(gws[0], gws[1]), + cl::NDRange(lws[0], lws[1]), nullptr, nullptr); + MNN_CHECK_CL_SUCCESS(res, "buffer_set_zero"); + } + if (srcDimensionFormat == dstDimensionFormat && MNN_DATA_FORMAT_NC4HW4 != dstDimensionFormat){ + int size = outputShape[0] * outputShape[1] * outputShape[2] * outputShape[3]; + uint32_t gws[2] = {static_cast(UP_DIV(size, 4)), static_cast(1)}; + std::set buildOptions; + if(size % 4 != 0){ + buildOptions.emplace("-DPACK_LEAVE"); + } + AddBuildOptionOfDataType(input, output, buildOptions, runtime->isSupportedFP16(), toDevice, toHost); + auto convertBufferKernelW = runtime->buildKernelWithCache("buffer_convert_buf", "buffer_copy_to_buffer", buildOptions); + auto convertBufferKernel = convertBufferKernelW->get(); + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; + ret |= convertBufferKernel.setArg(idx++, gws[0]); + ret |= convertBufferKernel.setArg(idx++, gws[1]); +#ifdef MNN_OPENCL_SVM_ENABLE + if(svmFlag == true && toDevice) { + ret |= clSetKernelArgSVMPointer(convertBufferKernel.get(), idx++, (const void *)input->deviceId()); + } + else +#endif + { + ret |= convertBufferKernel.setArg(idx++, openCLBuffer(input)); + } +#ifdef MNN_OPENCL_SVM_ENABLE + if(svmFlag == true && toHost) { + ret |= clSetKernelArgSVMPointer(convertBufferKernel.get(), idx++, (const void *)output->deviceId()); + } + else +#endif + { + ret |= convertBufferKernel.setArg(idx++, openCLBuffer(output)); + } + ret |= convertBufferKernel.setArg(idx++, size); + MNN_CHECK_CL_SUCCESS(ret, "setArg buffer_convert_to_buffer"); + + const uint32_t maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(convertBufferKernelW)); + const std::vector lws = {16, std::max((uint32_t)1, maxWorkGroupSize / 16)}; + cl::Event event; + cl_int res; + std::vector roundUpGroupWorkSize(lws.size()); + for (size_t i = 0; i < lws.size(); ++i) { + roundUpGroupWorkSize[i] = ROUND_UP(gws[i], lws[i]); + } + + res = runtime->commandQueue().enqueueNDRangeKernel(convertBufferKernel, cl::NullRange, + cl::NDRange(roundUpGroupWorkSize[0], roundUpGroupWorkSize[1]), + cl::NDRange(lws[0], lws[1]), nullptr, &event); + MNN_CHECK_CL_SUCCESS(res, "buffer_convert_to_buffer"); + + if (true == needWait) { + event.wait(); + } + } else{ + uint32_t gws[3] = {static_cast(shape[2] * shape[3]), + static_cast(shape[1]), + static_cast(shape[0])}; + std::set buildOptions; + buildOptions.emplace("-DINPUT_FORMAT=" + std::to_string(srcDimensionFormat)); + buildOptions.emplace("-DOUTPUT_FORMAT=" + std::to_string(dstDimensionFormat)); + AddBuildOptionOfDataType(input, output, buildOptions, runtime->isSupportedFP16(), toDevice, toHost); + auto convertBufferKernelW = runtime->buildKernelWithCache("buffer_convert_buf", "buffer_convert_to_buffer", buildOptions); + auto convertBufferKernel = convertBufferKernelW->get(); + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; + ret |= convertBufferKernel.setArg(idx++, gws[0]); + ret |= convertBufferKernel.setArg(idx++, gws[1]); + ret |= convertBufferKernel.setArg(idx++, gws[2]); +#ifdef MNN_OPENCL_SVM_ENABLE + if(svmFlag == true && toDevice) { + ret |= clSetKernelArgSVMPointer(convertBufferKernel.get(), idx++, (const void *)input->deviceId()); + } + else +#endif + { + ret |= convertBufferKernel.setArg(idx++, openCLBuffer(input)); + } + + ret |= convertBufferKernel.setArg(idx++, sizeof(shape), shape); +#ifdef MNN_OPENCL_SVM_ENABLE + if(svmFlag == true && toHost) { + ret |= clSetKernelArgSVMPointer(convertBufferKernel.get(), idx++, (const void *)output->deviceId()); + } + else +#endif + { + ret |= convertBufferKernel.setArg(idx++, openCLBuffer(output)); + } + MNN_CHECK_CL_SUCCESS(ret, "setArg buffer_convert_to_buffer"); + + const uint32_t maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(convertBufferKernelW)); + const std::vector lws = {16, std::max((uint32_t)1, maxWorkGroupSize / 16), 1}; + cl::Event event; + cl_int res; + std::vector roundUpGroupWorkSize(lws.size()); + for (size_t i = 0; i < lws.size(); ++i) { + roundUpGroupWorkSize[i] = ROUND_UP(gws[i], lws[i]); + } + + res = runtime->commandQueue().enqueueNDRangeKernel(convertBufferKernel, cl::NullRange, + cl::NDRange(roundUpGroupWorkSize[0], roundUpGroupWorkSize[1], roundUpGroupWorkSize[2]), + cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event); + MNN_CHECK_CL_SUCCESS(res, "buffer_convert_to_buffer"); + + if (true == needWait) { + event.wait(); + } + } + return true; +} + } // namespace OpenCL } // namespace MNN #endif /* MNN_OPENCL_BUFFER_CLOSED */ diff --git a/source/backend/opencl/core/BufferConvertor.hpp b/source/backend/opencl/core/BufferConvertor.hpp index 71514acbf..b1843226e 100644 --- a/source/backend/opencl/core/BufferConvertor.hpp +++ b/source/backend/opencl/core/BufferConvertor.hpp @@ -26,14 +26,13 @@ bool convertNC4HW4OrNC16HW16BufferToNCHWOrNHWCBuffer(const Tensor *input, Tensor enum TransType {InpTrans = 0, OutTrans = 1, NoTrans = 2}; -bool convertNC4HW4BufferToNC4HW4Buffer(const Tensor *input, Tensor *output, - OpenCLRuntime *runtime, TransType formatTrans = NoTrans, bool needWait = false, bool svmFlag = false, bool srcswap = false, bool dstswap = false); - #ifdef MNN_SUPPORT_INTEL_SUBGROUP bool convertNC4HW4BufferBetweenNC16HW16Buffer(const Tensor *input, Tensor *output, const std::string Name, OpenCLRuntime *runtime, TransType formatTrans = NoTrans, bool needWait = false, bool svmFlag = false, bool srcswap = false, bool dstswap = false); #endif + +bool convertBufferToBuffer(Tensor *input, Tensor *output, OpenCLRuntime *runtime, bool toDevice, bool toHost, bool needWait = false, bool svmFlag = false); class BufferConvertor { public: diff --git a/source/backend/opencl/core/OpenCLBackend.cpp b/source/backend/opencl/core/OpenCLBackend.cpp index f05c0b2e5..67e0a1a81 100644 --- a/source/backend/opencl/core/OpenCLBackend.cpp +++ b/source/backend/opencl/core/OpenCLBackend.cpp @@ -191,7 +191,7 @@ std::pair CLRuntime::onGetCache() { return mOpenCLRuntime->makeCache(mTunedInfo); } -Backend* CLRuntime::onCreate(const BackendConfig* config) const { +Backend* CLRuntime::onCreate(const BackendConfig* config, Backend* origin) const { // FIXME: Use config info return new OpenCLBackend(mImagePool, mBufferPool, this); } @@ -413,6 +413,9 @@ Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageTyp } bool OpenCLBackend::onSelectDynamicAllocator(int index, int maxIndex) { + if (mUseRecordQueue && false == mDevideOpRecord){ + return false; + } if (maxIndex > 2) { return false; } @@ -702,25 +705,7 @@ void CLRuntime::convertFromDevice(const Tensor* srcTensor, const Tensor* dstTens } } else #endif - { - switch (data_format) { - case MNN_DATA_FORMAT_NHWC: - OpenCL::convertNC4HW4OrNC16HW16BufferToNCHWOrNHWCBuffer(srcTensor, const_cast(dstTensor), - "nc4hw4_buffer_to_nhwc_buffer", mOpenCLRuntime.get(), true, false, svmFlag); - break; - case MNN_DATA_FORMAT_NCHW: - OpenCL::convertNC4HW4OrNC16HW16BufferToNCHWOrNHWCBuffer(srcTensor, const_cast(dstTensor), - "nc4hw4_buffer_to_nchw_buffer", mOpenCLRuntime.get(), true, false, svmFlag); - break; - case MNN_DATA_FORMAT_NC4HW4: - OpenCL::convertNC4HW4BufferToNC4HW4Buffer(srcTensor, const_cast(dstTensor), - mOpenCLRuntime.get(), OutTrans, false, svmFlag, false, true); - break; - default: - MNN_PRINT("output data format not support!\n"); - break; - } - } + OpenCL::convertBufferToBuffer(const_cast(srcTensor), const_cast(dstTensor), mOpenCLRuntime.get(), false, true, true, svmFlag); } else #endif /* MNN_OPENCL_BUFFER_CLOSED */ @@ -755,18 +740,41 @@ void CLRuntime::convertFromDevice(const Tensor* srcTensor, const Tensor* dstTens void OpenCLBackend::copyFromDevice(const Tensor* srcTensor, const Tensor* dstTensor) const{ auto needSize = dstTensor->size(); - + auto shape = tensorShapeFormat(srcTensor); + auto srcDimensionFormat = TensorUtils::getDescribe(srcTensor)->dimensionFormat; + auto dstDimensionFormat = TensorUtils::getDescribe(dstTensor)->dimensionFormat; + auto memType = dstTensor->buffer().flags; + bool directCopy = BUFFER == mOpenCLRuntime->getGpuMemType() + && (srcDimensionFormat == dstDimensionFormat || srcTensor->dimensions() <= 1) + && MNN::MNN_DATA_FORMAT_NC4HW4 != dstDimensionFormat && MNN_DATA_FORMAT_NC4HW4 != srcDimensionFormat + && (getDataType(srcTensor) == getDataType(dstTensor)) + && memType != MNN_FORWARD_OPENCL + && memType != MNN_FORWARD_OPENGL; + if (mOpenCLRuntime->isSupportedFP16()) { // Fp16 + if (dstTensor->getType().code == halide_type_float) { + directCopy = false; + } + } + if(mOpenCLRuntime->isSupportedIntelSubgroup()){ + int cPack = TensorUtils::getTensorChannelPack(srcTensor); + if (cPack == 16){ + directCopy = false; + } + } void* hostPtr = dstTensor->host(); + if(directCopy){ + mOpenCLRuntime->commandQueue().enqueueReadBuffer(openCLBuffer(srcTensor), CL_TRUE, 0, needSize, hostPtr); + return; + } _allocHostBuffer(needSize, dstTensor); MNN::Tensor interTensor(dstTensor, dstTensor->getDimensionType(), false); interTensor.buffer().device = (uint64_t)mHostBuffer.second.get(); - - MNN_DATA_FORMAT data_format = TensorUtils::getDescribe(dstTensor)->dimensionFormat; + TensorUtils::getDescribe(&interTensor)->dimensionFormat = dstDimensionFormat; //Convert format - mCLRuntime->convertFromDevice(srcTensor, (const Tensor*)&interTensor, data_format, false); + mCLRuntime->convertFromDevice(srcTensor, (const Tensor*)&interTensor, dstDimensionFormat, false); mOpenCLRuntime->printEventTime(); cl_int res; @@ -808,18 +816,7 @@ void CLRuntime::convertToDevice(const Tensor* srcTensor, const Tensor* dstTensor } }else #endif - { - if (MNN_DATA_FORMAT_NHWC == data_format) { - OpenCL::converNCHWOrNHWCBufferToNC4HW4OrNC16HW16Buffer(srcTensor, const_cast(dstTensor), "nhwc_buffer_to_nc4hw4_buffer",mOpenCLRuntime.get(), true, false, svmFlag); - } else if (MNN_DATA_FORMAT_NCHW == data_format) { - OpenCL::converNCHWOrNHWCBufferToNC4HW4OrNC16HW16Buffer(srcTensor, const_cast(dstTensor), "nchw_buffer_to_nc4hw4_buffer",mOpenCLRuntime.get(), true, false, svmFlag); - } else if (MNN_DATA_FORMAT_NC4HW4 == data_format) { - OpenCL::convertNC4HW4BufferToNC4HW4Buffer(srcTensor, const_cast(dstTensor), mOpenCLRuntime.get(), InpTrans, false, svmFlag, true, false); - } else { - MNN_PRINT("input data format not support\n"); - MNN_ASSERT(false); - } - } + OpenCL::convertBufferToBuffer(const_cast(srcTensor), const_cast(dstTensor), mOpenCLRuntime.get(), true, false, false, svmFlag); } else #endif /* MNN_OPENCL_BUFFER_CLOSED */ @@ -853,28 +850,47 @@ void CLRuntime::convertToDevice(const Tensor* srcTensor, const Tensor* dstTensor void OpenCLBackend::copyToDevice(const Tensor* srcTensor, const Tensor* dstTensor) const{ auto needSize = srcTensor->size(); auto shape = tensorShapeFormat(srcTensor); + auto srcDimensionFormat = TensorUtils::getDescribe(srcTensor)->dimensionFormat; + auto dstDimensionFormat = TensorUtils::getDescribe(dstTensor)->dimensionFormat; + auto memType = srcTensor->buffer().flags; + void* hostPtr = srcTensor->host(); // 1*1*1*1 don't need convert - if(BUFFER == mOpenCLRuntime->getGpuMemType() && 1 == shape[0] * shape[1] * shape[2] * shape[3]){ - void *tmpPtr; - void *hostPtr = srcTensor->host(); - if(srcTensor->getType().code == halide_type_float && mOpenCLRuntime->isSupportedFP16()) { - needSize /= 2; - void *tmpPtr = malloc(needSize); - ((half_float::half*)tmpPtr)[0] = (half_float::half)(((float*)hostPtr)[0]); - mOpenCLRuntime->commandQueue().enqueueWriteBuffer(openCLBuffer(dstTensor), CL_TRUE, 0, needSize, tmpPtr); - free(tmpPtr); - } else { - mOpenCLRuntime->commandQueue().enqueueWriteBuffer(openCLBuffer(dstTensor), CL_TRUE, 0, needSize, hostPtr); + if(srcTensor->getType().code == halide_type_float && mOpenCLRuntime->isSupportedFP16() && 1 == shape[0] * shape[1] * shape[2] * shape[3]){ + needSize /= 2; + void *tmpPtr = malloc(needSize); + ((half_float::half*)tmpPtr)[0] = (half_float::half)(((float*)hostPtr)[0]); + mOpenCLRuntime->commandQueue().enqueueWriteBuffer(openCLBuffer(dstTensor), CL_TRUE, 0, needSize, tmpPtr); + free(tmpPtr); + return; + } + + bool directCopy = BUFFER == mOpenCLRuntime->getGpuMemType() + && (srcDimensionFormat == dstDimensionFormat || srcTensor->dimensions() <= 1) + && MNN_DATA_FORMAT_NC4HW4 != dstDimensionFormat && MNN_DATA_FORMAT_NC4HW4 != srcDimensionFormat + && (getDataType(srcTensor) == getDataType(dstTensor)) + && memType != MNN_FORWARD_OPENCL + && memType != MNN_FORWARD_OPENGL; + if (mOpenCLRuntime->isSupportedFP16()) { // Fp16 + if (dstTensor->getType().code == halide_type_float) { + directCopy = false; + } + } + if(mOpenCLRuntime->isSupportedIntelSubgroup()){ + int cPack = TensorUtils::getTensorChannelPack(dstTensor); + if (cPack == 16){ + directCopy = false; } + } + if(directCopy){ + mOpenCLRuntime->commandQueue().enqueueWriteBuffer(openCLBuffer(dstTensor), CL_TRUE, 0, needSize, hostPtr); return; } - void* hostPtr = srcTensor->host(); - _allocHostBuffer(needSize, srcTensor); MNN::Tensor interTensor(srcTensor, srcTensor->getDimensionType(), false); interTensor.buffer().device = (uint64_t)mHostBuffer.second.get(); + TensorUtils::getDescribe(&interTensor)->dimensionFormat = srcDimensionFormat; #ifdef ENABLE_OPENCL_TIME_PROFILER mOpenCLRuntime->commandQueue().finish(); @@ -891,8 +907,7 @@ void OpenCLBackend::copyToDevice(const Tensor* srcTensor, const Tensor* dstTenso #endif //Covert format - MNN_DATA_FORMAT data_format = TensorUtils::getDescribe(srcTensor)->dimensionFormat; - mCLRuntime->convertToDevice((const Tensor*)&interTensor, dstTensor, data_format, false); + mCLRuntime->convertToDevice((const Tensor*)&interTensor, dstTensor, srcDimensionFormat, false); return; } @@ -904,6 +919,7 @@ void OpenCLBackend::copyBetweenDevice(const Tensor* srcTensor, const Tensor* dst mCLRuntime->copyBetweenDevice(srcTensor, dstTensor); } else { const Tensor* copyTensor = MNN_FORWARD_CPU != srcMemtype ? srcTensor : dstTensor; + MNN_DATA_FORMAT data_format = TensorUtils::getDescribe(copyTensor)->dimensionFormat; int memType = MNN_FORWARD_CPU != srcMemtype ? srcMemtype : dstMemtype; if(MNN_FORWARD_OPENCL != memType && MNN_FORWARD_OPENGL != memType){ MNN_PRINT("Unsupport ForwardType %d for OpenCL backend!\n", memType); @@ -916,6 +932,7 @@ void OpenCLBackend::copyBetweenDevice(const Tensor* srcTensor, const Tensor* dst _allocHostBuffer(0, copyTensor); MNN::Tensor interTensor(copyTensor, copyTensor->getDimensionType(), false); + TensorUtils::getDescribe(&interTensor)->dimensionFormat = data_format; if(MNN_FORWARD_OPENCL == memType ){ interTensor.buffer().device = (uint64_t)mDeviceBuffer; }else if(MNN_FORWARD_OPENGL == memType){ @@ -924,7 +941,6 @@ void OpenCLBackend::copyBetweenDevice(const Tensor* srcTensor, const Tensor* dst interTensor.buffer().device = (uint64_t)mHostBuffer.second.get(); } //Covert format - MNN_DATA_FORMAT data_format = TensorUtils::getDescribe(copyTensor)->dimensionFormat; if(MNN_FORWARD_CPU != srcMemtype){ mCLRuntime->convertToDevice((const Tensor*)&interTensor, dstTensor, data_format, false, srcMemtype); }else{ @@ -937,7 +953,7 @@ void CLRuntime::copyBetweenDevice(const Tensor* srcTensor, const Tensor* dstTens #ifndef MNN_OPENCL_BUFFER_CLOSED if(mOpenCLRuntime->getGpuMemType() == BUFFER) { - OpenCL::convertNC4HW4BufferToNC4HW4Buffer(srcTensor, const_cast(dstTensor), mOpenCLRuntime.get(), NoTrans); + OpenCL::convertBufferToBuffer(const_cast(srcTensor), const_cast(dstTensor), mOpenCLRuntime.get(), true, true); } else #endif /* MNN_OPENCL_BUFFER_CLOSED */ @@ -1166,7 +1182,7 @@ class CLRuntimeCreator : public RuntimeCreator { } }; -DataType OpenCLBackend::getDataType(const Tensor* tensor) { +DataType OpenCLBackend::getDataType(const Tensor* tensor) const{ auto des = TensorUtils::getDescribe(tensor); if (nullptr == des->quantAttr.get()) { return DataType_DT_FLOAT; diff --git a/source/backend/opencl/core/OpenCLBackend.hpp b/source/backend/opencl/core/OpenCLBackend.hpp index 6e50d25ba..3f0abcefb 100644 --- a/source/backend/opencl/core/OpenCLBackend.hpp +++ b/source/backend/opencl/core/OpenCLBackend.hpp @@ -48,7 +48,7 @@ class CLRuntime : public Runtime { CLRuntime(const Backend::Info& info, int platformSize, int platformId, int deviceId = 0, void *contextPtr = nullptr, void *glshared = nullptr); virtual ~CLRuntime(); - virtual Backend* onCreate(const BackendConfig* config) const override; + virtual Backend* onCreate(const BackendConfig* config, Backend* origin) const override; virtual void onReset(int numberThread, const BackendConfig* config, bool full) override; virtual void onGabageCollect(int level) override; virtual float onGetMemoryInMB() override; @@ -122,7 +122,7 @@ class OpenCLBackend : public Backend { } float getBytes(const Tensor* tensor); - DataType getDataType(const Tensor* tensor); + DataType getDataType(const Tensor* tensor) const; cl_channel_type fpType(); int fpBytes(); diff --git a/source/backend/opencl/core/OpenCLGemmTune.cpp b/source/backend/opencl/core/OpenCLGemmTune.cpp index 388fba6f0..8f1b63c50 100644 --- a/source/backend/opencl/core/OpenCLGemmTune.cpp +++ b/source/backend/opencl/core/OpenCLGemmTune.cpp @@ -127,24 +127,86 @@ static bool isCandidateValid(uint32_t kwg, uint32_t kwi, uint32_t mwg, uint32_t return true; } + +static bool GemmlocalWSTune(const std::map, std::pair, uint32_t>>>> &tuneMap, const std::vector &gemmSize, std::vector& res, OpenCLRuntime *runtime){ + auto iter = tuneMap.find("Xgemm_tune"); + if(iter == tuneMap.end()){ + return false; + } + auto gwsAndLws = iter->second; + uint32_t minPoint = UINT_MAX; + int index = -1; + for(int i = 0; i < gwsAndLws.size(); ++i){ + // Layout+Precision, Batch, Bias+GroupSize must equall + if(gemmSize[3] != gwsAndLws[i].first[3] || gemmSize[4] != gwsAndLws[i].first[4] || gemmSize[5] != gwsAndLws[i].first[5]){ + continue; + } + auto combinations = gwsAndLws[i].second.first; + uint32_t kwg = combinations[0]; + uint32_t kwi = combinations[1]; + uint32_t mdima = combinations[2]; + uint32_t mdimc = combinations[3]; + uint32_t mwg = combinations[4]; + uint32_t ndimb = combinations[5]; + uint32_t ndimc = combinations[6]; + uint32_t nwg = combinations[7]; + uint32_t sa = combinations[8]; + uint32_t sb = combinations[9]; + uint32_t strm = combinations[10]; + uint32_t strn = combinations[11]; + uint32_t vwm = combinations[12]; + uint32_t vwn = combinations[13]; + + if(!isCandidateValid(kwg, kwi, mwg, mdimc, vwm, nwg, ndimc, vwn, mdima, ndimb, sa, sb, runtime, gemmSize)) { + continue; + } + uint32_t point = 0; + for(int j = 0; j < 3; ++j){ + point += std::abs(static_cast(gemmSize[j]) - static_cast(gwsAndLws[i].first[j])); + } + + if(point < minPoint){ + index = i; + minPoint = point; + } + } + if(index != -1){ + res = gwsAndLws[index].second.first; + } else{ + return false; + } + return true; +} std::vector getGemmParams(const std::vector &gemmSize, const std::vector tensorMemory, OpenCLRuntime *runtime) { - MNN_ASSERT(gemmSize.size() == 6); // M, N, K, Layout, Batch, Bias + MNN_ASSERT(gemmSize.size() == 6); // M, N, K, Layout+Precision, Batch, Bias+GroupSize MNN_ASSERT(gemmSize[0] % 16 == 0); MNN_ASSERT(gemmSize[1] % 16 == 0); MNN_ASSERT(gemmSize[2] % 4 == 0); - MNN_ASSERT((gemmSize[5] == 0 && tensorMemory.size() == 3) || (gemmSize[5] >= 1 && tensorMemory.size() == 4)); + int layoutType = gemmSize[3] % 10; + int mixPrecision = gemmSize[3] / 10; + int biasType = gemmSize[5] % 10; + int groupSize = gemmSize[5] / 10 + 1; + MNN_ASSERT((biasType == 0 && tensorMemory.size() == 3) || (biasType >= 1 && tensorMemory.size() == 4)); auto& tunedGemmParams = runtime->tunedGemmParamsMap(); + auto& tuneLws = runtime->getTuneLwsMap(); std::vector info(gemmSize); - uint32_t isFp16 = runtime->isSupportedFP16(); - info.emplace_back(isFp16); + uint32_t precisionType = runtime->getPrecisionLevel(); + if(precisionType == 2 && mixPrecision > 0) { + precisionType = 0; + } + info.emplace_back(precisionType); if (tunedGemmParams.find(info) != tunedGemmParams.end()) { return tunedGemmParams[info]; } + std::vector tuneLwsRes; + if(GemmlocalWSTune(tuneLws, gemmSize, tuneLwsRes, runtime)){ + return tuneLwsRes; + } if(runtime->getCLTuneLevel() == None) { auto getMaxDivisor = [](uint32_t num) -> uint32_t { @@ -201,6 +263,8 @@ std::vector getGemmParams(const std::vector &gemmSize, const totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 1, 2, 8});//2 totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 128, 0, 0, 0, 0, 8, 8}); totalCombinations.push_back({16, 2, 8 , 8 , 16 , 8 , 8 , 128, 0, 0, 0, 0, 2, 8}); + totalCombinations.push_back({16, 2, 4, 4, 32, 8, 8, 32, 0, 0, 0, 0, 8, 2}); + totalCombinations.push_back({16, 2, 4, 4, 16, 8, 8, 32, 0, 0, 0, 0, 4, 2}); if(runtime->getCLTuneLevel() < Fast) { totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 0, 8, 8});//4 @@ -226,14 +290,17 @@ std::vector getGemmParams(const std::vector &gemmSize, const totalCombinations.push_back({16, 2, 8, 8, 32, 8, 8, 32, 0, 0, 1, 0, 2, 4}); totalCombinations.push_back({16, 2, 8, 8, 16, 8, 8, 32, 0, 0, 1, 1, 2, 4}); + totalCombinations.push_back({16, 2, 4, 4, 16, 8, 8, 64, 0, 0, 0, 0, 2, 8}); + totalCombinations.push_back({16, 2, 4, 4, 64, 8, 8, 32, 0, 0, 1, 0, 4, 4}); + totalCombinations.push_back({16, 2, 4, 4, 32, 8, 8, 64, 0, 0, 0, 1, 2, 4}); } } else { // get all combinations std::vector> candidates = { {16, 32}, // KWG {2}, // KWI - {8, 16}, // MDIMA - {8, 16}, // MDIMC + {4, 8, 16}, // MDIMA + {4, 8, 16}, // MDIMC {16, 32, 64, 128}, // MWG {8, 16}, // NDIMB {8, 16}, // NDIMC @@ -284,7 +351,7 @@ std::vector getGemmParams(const std::vector &gemmSize, const buildOptions.emplace("-DVWM=" + std::to_string(vwm)); buildOptions.emplace("-DVWN=" + std::to_string(vwn)); - if(gemmSize[3] >= 4) { + if(layoutType >= 4) { buildOptions.emplace(" -DOUTPUTMN"); } if(runtime->getGpuType() == GpuType::ADRENO) { @@ -292,24 +359,29 @@ std::vector getGemmParams(const std::vector &gemmSize, const buildOptions.emplace(" -DRELAX_WORKGROUP_SIZE=1"); } - if(gemmSize[5] >= 1) { - buildOptions.emplace(" -DBIAS_TYPE=" + std::to_string((int)gemmSize[5])); + if(biasType >= 1) { + buildOptions.emplace(" -DBIAS_TYPE=" + std::to_string((int)biasType)); + } + if(mixPrecision > 0) { + buildOptions.emplace("-DPRECISION_COMPUTE=float -DCONVERT_PRECISION_COMPUTE=convert_float"); + buildOptions.emplace("-DPRECISION_COMPUTE2=float2 -DCONVERT_PRECISION_COMPUTE2=convert_float2"); + buildOptions.emplace("-DPRECISION_COMPUTE4=float4 -DCONVERT_PRECISION_COMPUTE4=convert_float4"); + buildOptions.emplace("-DPRECISION_COMPUTE8=float8 -DCONVERT_PRECISION_COMPUTE8=convert_float8"); + buildOptions.emplace("-DPRECISION_COMPUTE16=float16 -DCONVERT_PRECISION_COMPUTE16=convert_float16"); } int localM = mdimc; int localN = ndimc; - std::shared_ptr kernel = runtime->buildKernel("matmul_params_buf", "Xgemm", buildOptions); - if(kernel == nullptr) { - continue; - } + std::shared_ptr kernel; if(gemmSize[4] > 1) { kernel = runtime->buildKernel("matmul_params_buf", "XgemmBatched", buildOptions); - if(kernel == nullptr) { - continue; - } + } else { + kernel = runtime->buildKernel("matmul_params_buf", "Xgemm", buildOptions); + } + if(kernel == nullptr) { + continue; } - if(localM * localN > runtime->getMaxWorkGroupSize(kernel)) { continue; } @@ -326,52 +398,56 @@ std::vector getGemmParams(const std::vector &gemmSize, const // A: [n, l, e] // B: [n, l, h] - cl::Event event; - int idx = 0; + int cost_time; + int idx = 0; cl_int ret = CL_SUCCESS; ret |= kernel->get().setArg(idx++, static_cast(gemmSize[0])); ret |= kernel->get().setArg(idx++, static_cast(gemmSize[1])); ret |= kernel->get().setArg(idx++, static_cast(gemmSize[2])); ret |= kernel->get().setArg(idx++, alpha); ret |= kernel->get().setArg(idx++, beta); + + int stride[4] = {(int)gemmSize[0], (int)gemmSize[1], (int)gemmSize[1], (int)gemmSize[1]}; + if(layoutType < 4) { + stride[2] = gemmSize[0]; // output: [N, M] + } if(gemmSize[4] > 1) { int batch_offset_a = gemmSize[0] * gemmSize[2]; int batch_offset_b = gemmSize[1] * gemmSize[2]; int batch_offset_c = gemmSize[0] * gemmSize[1]; + int batch_offset[4] = {batch_offset_a, batch_offset_b, batch_offset_c, 0}; + int group[4] = {1, (int)groupSize, 1, (int)gemmSize[4]}; ret |= kernel->get().setArg(idx++, tensorMemory[0]); - ret |= kernel->get().setArg(idx++, batch_offset_a); ret |= kernel->get().setArg(idx++, tensorMemory[1]); - ret |= kernel->get().setArg(idx++, batch_offset_b); - if(gemmSize[5] == 1) { + if(biasType > 0) { ret |= kernel->get().setArg(idx++, tensorMemory[3]); - ret |= kernel->get().setArg(idx++, gemmSize[1]); - } else if(gemmSize[5] > 1) { - MNN_ERROR("BatchGemm with bias type > 1 (elementwise) not supported! please check\n"); } ret |= kernel->get().setArg(idx++, tensorMemory[2]); - ret |= kernel->get().setArg(idx++, batch_offset_c); - + ret |= kernel->get().setArg(idx++, sizeof(batch_offset), batch_offset); + ret |= kernel->get().setArg(idx++, sizeof(stride), stride); + ret |= kernel->get().setArg(idx++, sizeof(group), group); + MNN_CHECK_CL_SUCCESS(ret, "setArg getGemmParams XgemmBatchhed Kernel"); + cl::Event event; auto res = CL_SUCCESS; res = runtime->commandQueue().enqueueNDRangeKernel(kernel->get(), cl::NullRange, {globalWorkSize[0], globalWorkSize[1], globalWorkSize[2]}, {localWorkSize[0], localWorkSize[1], localWorkSize[2]}, nullptr, &event); if (res != CL_SUCCESS) { MNN_PRINT("XgemmBatched params tune error: %d\n", res); continue; } + + cost_time = (int)runtime->getCostTime(&event); } else { int offset_a = 0; int offset_b = 0; int offset_c = 0; int offset[4] = {0, 0, 0, 0}; - int stride[4] = {(int)gemmSize[0], (int)gemmSize[1], (int)gemmSize[1], (int)gemmSize[1]}; - if(gemmSize[3] < 4) { - stride[2] = gemmSize[0]; // output: [N, M] - } + ret |= kernel->get().setArg(idx++, tensorMemory[0]); ret |= kernel->get().setArg(idx++, tensorMemory[1]); - if(gemmSize[5] >= 1) { + if(biasType >= 1) { ret |= kernel->get().setArg(idx++, tensorMemory[3]); } ret |= kernel->get().setArg(idx++, tensorMemory[2]); @@ -380,17 +456,17 @@ std::vector getGemmParams(const std::vector &gemmSize, const MNN_CHECK_CL_SUCCESS(ret, "setArg getGemmParams Xgemm Kernel"); + cl::Event event; auto res = CL_SUCCESS; res = runtime->commandQueue().enqueueNDRangeKernel(kernel->get(), cl::NullRange, {globalWorkSize[0], globalWorkSize[1]}, {localWorkSize[0], localWorkSize[1]}, nullptr, &event); if (res != CL_SUCCESS) { MNN_PRINT("Xgemm params tune error: %d\n", res); continue; } + cost_time = (int)runtime->getCostTime(&event); } - - int cost_time = (int)runtime->getCostTime(&event); - if(cost_time < min_cost) { + if(cost_time > 0 && cost_time < min_cost) { min_cost = cost_time; params_prefer[0] = kwg; params_prefer[1] = kwi; diff --git a/source/backend/opencl/core/runtime/OpenCLRuntime.cpp b/source/backend/opencl/core/runtime/OpenCLRuntime.cpp index 51ba62619..d42961c1e 100644 --- a/source/backend/opencl/core/runtime/OpenCLRuntime.cpp +++ b/source/backend/opencl/core/runtime/OpenCLRuntime.cpp @@ -499,7 +499,9 @@ uint32_t OpenCLRuntime::MaxThreadsPerDevice() const { uint32_t OpenCLRuntime::MaxWorkGroupSize() const { return mMaxWorkGroupSize; } - +uint32_t OpenCLRuntime::getPrecisionLevel() const { + return mPrecisionLevel; +} uint32_t OpenCLRuntime::maxFreq() const { return mMaxFreq; } @@ -548,11 +550,11 @@ std::shared_ptr OpenCLRuntime::buildKernelWithCache(const std::strin const std::set &buildOptions, const Tensor *input, const Tensor *output, bool useCache) { std::string buildOptionsStr; if (mPrecisionLevel == 2) {// Fp16 Memory and fp16 compute - buildOptionsStr = "-DFLOAT=half -DFLOAT2=half2 -DFLOAT3=half3 -DFLOAT4=half4 -DFLOAT8=half8 -DFLOAT16=half16 -DCOMPUTE_FLOAT=half -DCOMPUTE_FLOAT2=half2 -DCOMPUTE_FLOAT3=half3 -DCOMPUTE_FLOAT4=half4 -DCOMPUTE_FLOAT8=half8 -DCOMPUTE_FLOAT16=half16 -DCONVERT_COMPUTE_FLOAT2=convert_half2 -DCONVERT_COMPUTE_FLOAT4=convert_half4 -DCONVERT_COMPUTE_FLOAT8=convert_half8 -DCONVERT_COMPUTE_FLOAT16=convert_half16 -DRI_F=read_imageh -DWI_F=write_imageh -DCONVERT_FLOAT2=convert_half2 -DCONVERT_FLOAT4=convert_half4 -DCONVERT_FLOAT8=convert_half8 -DCONVERT_FLOAT16=convert_half16 -DMNN_SUPPORT_FP16"; + buildOptionsStr = "-DFLOAT=half -DFLOAT2=half2 -DFLOAT3=half3 -DFLOAT4=half4 -DFLOAT8=half8 -DFLOAT16=half16 -DCOMPUTE_FLOAT=half -DCOMPUTE_FLOAT2=half2 -DCOMPUTE_FLOAT3=half3 -DCOMPUTE_FLOAT4=half4 -DCOMPUTE_FLOAT8=half8 -DCOMPUTE_FLOAT16=half16 -DCONVERT_COMPUTE_FLOAT=convert_half -DCONVERT_COMPUTE_FLOAT2=convert_half2 -DCONVERT_COMPUTE_FLOAT4=convert_half4 -DCONVERT_COMPUTE_FLOAT8=convert_half8 -DCONVERT_COMPUTE_FLOAT16=convert_half16 -DRI_F=read_imageh -DWI_F=write_imageh -DCONVERT_FLOAT=convert_half -DCONVERT_FLOAT2=convert_half2 -DCONVERT_FLOAT3=convert_half3 -DCONVERT_FLOAT4=convert_half4 -DCONVERT_FLOAT8=convert_half8 -DCONVERT_FLOAT16=convert_half16 -DMNN_SUPPORT_FP16"; } else if (mPrecisionLevel == 0) {// Fp16 Memory and fp32 compute - buildOptionsStr = "-DFLOAT=half -DFLOAT2=half2 -DFLOAT3=half3 -DFLOAT4=half4 -DFLOAT8=half8 -DFLOAT16=half16 -DCOMPUTE_FLOAT=float -DCOMPUTE_FLOAT2=float2 -DCOMPUTE_FLOAT3=float3 -DCOMPUTE_FLOAT4=float4 -DCOMPUTE_FLOAT8=float8 -DCOMPUTE_FLOAT16=float16 -DCONVERT_COMPUTE_FLOAT2=convert_float2 -DCONVERT_COMPUTE_FLOAT4=convert_float4 -DCONVERT_COMPUTE_FLOAT8=convert_float8 -DCONVERT_COMPUTE_FLOAT16=convert_float16 -DCONVERT_FLOAT2=convert_half2 -DCONVERT_FLOAT4=convert_half4 -DCONVERT_FLOAT8=convert_half8 -DCONVERT_FLOAT16=convert_half16 -DRI_F=read_imageh -DWI_F=write_imageh -DMNN_SUPPORT_FP16"; + buildOptionsStr = "-DFLOAT=half -DFLOAT2=half2 -DFLOAT3=half3 -DFLOAT4=half4 -DFLOAT8=half8 -DFLOAT16=half16 -DCOMPUTE_FLOAT=float -DCOMPUTE_FLOAT2=float2 -DCOMPUTE_FLOAT3=float3 -DCOMPUTE_FLOAT4=float4 -DCOMPUTE_FLOAT8=float8 -DCOMPUTE_FLOAT16=float16 -DCONVERT_COMPUTE_FLOAT=convert_float -DCONVERT_COMPUTE_FLOAT2=convert_float2 -DCONVERT_COMPUTE_FLOAT4=convert_float4 -DCONVERT_COMPUTE_FLOAT8=convert_float8 -DCONVERT_COMPUTE_FLOAT16=convert_float16 -DCONVERT_FLOAT=convert_half -DCONVERT_FLOAT2=convert_half2 -DCONVERT_FLOAT3=convert_half3 -DCONVERT_FLOAT4=convert_half4 -DCONVERT_FLOAT8=convert_half8 -DCONVERT_FLOAT16=convert_half16 -DRI_F=read_imageh -DWI_F=write_imageh -DMNN_SUPPORT_FP16"; } else {// Fp32 Memory and fp32 compute - buildOptionsStr = "-DFLOAT=float -DFLOAT2=float2 -DFLOAT3=float3 -DFLOAT4=float4 -DFLOAT8=float8 -DFLOAT16=float16 -DCOMPUTE_FLOAT=float -DCOMPUTE_FLOAT2=float2 -DCOMPUTE_FLOAT3=float3 -DCOMPUTE_FLOAT4=float4 -DCOMPUTE_FLOAT8=float8 -DCOMPUTE_FLOAT16=float16 -DCONVERT_COMPUTE_FLOAT2=convert_float2 -DCONVERT_COMPUTE_FLOAT4=convert_float4 -DCONVERT_COMPUTE_FLOAT8=convert_float8 -DCONVERT_COMPUTE_FLOAT16=convert_float16 -DRI_F=read_imagef -DFLOAT16=float16 -DWI_F=write_imagef -DCONVERT_FLOAT2=convert_float2 -DCONVERT_FLOAT4=convert_float4 -DCONVERT_FLOAT8=convert_float8 -DCONVERT_FLOAT16=convert_float16"; + buildOptionsStr = "-DFLOAT=float -DFLOAT2=float2 -DFLOAT3=float3 -DFLOAT4=float4 -DFLOAT8=float8 -DFLOAT16=float16 -DCOMPUTE_FLOAT=float -DCOMPUTE_FLOAT2=float2 -DCOMPUTE_FLOAT3=float3 -DCOMPUTE_FLOAT4=float4 -DCOMPUTE_FLOAT8=float8 -DCOMPUTE_FLOAT16=float16 -DCONVERT_COMPUTE_FLOAT=convert_float -DCONVERT_COMPUTE_FLOAT2=convert_float2 -DCONVERT_COMPUTE_FLOAT4=convert_float4 -DCONVERT_COMPUTE_FLOAT8=convert_float8 -DCONVERT_COMPUTE_FLOAT16=convert_float16 -DRI_F=read_imagef -DFLOAT16=float16 -DWI_F=write_imagef -DCONVERT_FLOAT=convert_float -DCONVERT_FLOAT2=convert_float2 -DCONVERT_FLOAT3=convert_float3 -DCONVERT_FLOAT4=convert_float4 -DCONVERT_FLOAT8=convert_float8 -DCONVERT_FLOAT16=convert_float16"; } if(nullptr != input){ @@ -975,6 +977,7 @@ bool OpenCLRuntime::setCache(std::pair cache) { params[v] = tun->paramInfo()->data()[v]; } mTunedGemmParams.insert(std::make_pair(info, params)); + mTuneLws["Xgemm_tune"].push_back(std::make_pair(info, std::make_pair(params, 0))); } } @@ -1026,6 +1029,8 @@ void OpenCLRuntime::printEventTime(){ conv_time += kernel_time; } else if (mEvents[i].first.length() >= 11 && mEvents[i].first.substr(0, 11) == "Convolution") { conv_time += kernel_time; + } else if (mEvents[i].first.length() >= 8 && mEvents[i].first.substr(0, 8) == "Strassen") { + conv_time += kernel_time; } if((mEvents[i].first.length() >= 10 && mEvents[i].first.substr(0, 10) == "While-gemm")) { loop_bg_time += kernel_time; @@ -1043,6 +1048,10 @@ void OpenCLRuntime::printEventTime(){ wino_gemm_time += kernel_time; conv_time += kernel_time; } + if((mEvents[i].first.length() >= 6 && mEvents[i].first.substr(0, 6) == "Raster")) { + raster_num++; + raster_time += kernel_time; + } kernels[i] = std::make_pair(mEvents[i].first, kernel_time); } @@ -1063,7 +1072,7 @@ void OpenCLRuntime::printEventTime(){ MNN_PRINT("kernel time = %d us %s\n", kernels[i].second, kernels[i].first.c_str()); } mEvents.clear(); - MNN_PRINT("total kernel time = %d us, conv time = %d us (gemm2:%d us, gemm1:%d us, 1x1:%d us, ori:%d us, wino: %d us, other: %d us), while gemm time = %d us (core gemm time: %d us, softmax:%d us), ori softmax: %d us\n", mKernelTime, conv_time, conv_gemm2_buf_time, conv_gemm1_buf_time, conv_1x1_buf_time, conv_ori_buf_time, wino_gemm_time, conv_time-conv_gemm2_buf_time-conv_gemm1_buf_time-conv_1x1_buf_time-conv_ori_buf_time-wino_gemm_time, loop_bg_time, loop_bg_gemm_time, loop_softmax_time, ori_softmax_time); + MNN_PRINT("total kernel time = %d us, conv time = %d us (gemm2:%d us, gemm1:%d us, 1x1:%d us, ori:%d us, wino: %d us, other: %d us), while gemm time = %d us (core gemm time: %d us, softmax:%d us), ori softmax: %d us, raster[%d] time: %d us\n", mKernelTime, conv_time, conv_gemm2_buf_time, conv_gemm1_buf_time, conv_1x1_buf_time, conv_ori_buf_time, wino_gemm_time, conv_time-conv_gemm2_buf_time-conv_gemm1_buf_time-conv_1x1_buf_time-conv_ori_buf_time-wino_gemm_time, loop_bg_time, loop_bg_gemm_time, loop_softmax_time, ori_softmax_time, raster_num, raster_time); #endif } } // namespace MNN diff --git a/source/backend/opencl/core/runtime/OpenCLRuntime.hpp b/source/backend/opencl/core/runtime/OpenCLRuntime.hpp index 2586a6559..b5dfa5918 100644 --- a/source/backend/opencl/core/runtime/OpenCLRuntime.hpp +++ b/source/backend/opencl/core/runtime/OpenCLRuntime.hpp @@ -109,6 +109,7 @@ class OpenCLRuntime { float getCLVersion() { return mCLVersion; } + uint32_t getPrecisionLevel() const; bool isSupportGL(){ return mIsSupportGL; } diff --git a/source/backend/opencl/execution/buffer/ArgMaxBufExecution.cpp b/source/backend/opencl/execution/buffer/ArgMaxBufExecution.cpp index d1faba8eb..d2676ddcb 100644 --- a/source/backend/opencl/execution/buffer/ArgMaxBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/ArgMaxBufExecution.cpp @@ -19,7 +19,7 @@ ArgMaxBufExecution::ArgMaxBufExecution(const std::string &compute, const MNN::Op mOpenCLBackend = static_cast(backend); std::set buildOptions = mBuildOptions; buildOptions.emplace("-DARGMAX_LOCAL_SIZE=512"); - auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("argmax_buf", "argmax_channel_buf", buildOptions); + auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("argmax_buf", "argmax_buf", buildOptions); mMaxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel)); } @@ -32,12 +32,30 @@ int ArgMaxBufExecution::getLocalSize(int size, int maxGroupSize){ } ErrorCode ArgMaxBufExecution::onEncode(const std::vector& inputs, const std::vector& outputs) { - mUnits.resize(1); - auto &unit = mUnits[0]; + mUnits.clear(); auto runtime = mOpenCLBackend->getOpenCLRuntime(); auto MaxLocalSize = std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize); auto input = inputs[0]; auto output = outputs[0]; + + const auto layout = TensorUtils::getDescribe(input)->dimensionFormat; + mNeedUnpackC4 = layout == MNN_DATA_FORMAT_NC4HW4; + if (mNeedUnpackC4) { + int inputTotalSize = 1, outputTotalSize = 1; + for (int i = 1; i < input->dimensions(); ++i) { + inputTotalSize *= input->length(i); + } + for (int i = 1; i < output->dimensions(); ++i) { + outputTotalSize *= output->length(i); + } + mTempInputTensor.reset(Tensor::createDevice({inputTotalSize})); + mTempOutputTensor.reset(Tensor::createDevice({outputTotalSize})); + mOpenCLBackend->onAcquireBuffer(mTempInputTensor.get(), Backend::DYNAMIC); + mOpenCLBackend->onAcquireBuffer(mTempOutputTensor.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mTempInputTensor.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mTempOutputTensor.get(), Backend::DYNAMIC); + + } if(mAxis < 0){ mAxis = input->dimensions() + mAxis; } @@ -51,74 +69,111 @@ ErrorCode ArgMaxBufExecution::onEncode(const std::vector& inputs, const } int dim = input->length(mAxis); - std::vector inputShape = tensorShapeFormat(input); - std::vector outputShape = tensorShapeFormat(output); + // NC4HW4 -> NCHW + if(mNeedUnpackC4){ + Unit unit; + std::vector outputShape = tensorShapeFormat(input); + int shape[4] = {outputShape[0], outputShape[3], outputShape[1], outputShape[2]};//N C H W + std::set buildOptions; + buildOptions.emplace("-DINPUT_FORMAT=MNN_DATA_FORMAT_NC4HW4"); + buildOptions.emplace("-DOUTPUT_FORMAT=MNN_DATA_FORMAT_NCHW"); + unit.kernel = runtime->buildKernel("buffer_convert_buf", "buffer_convert_to_buffer", buildOptions, input, output); + mGlobalWorkSize = {static_cast(shape[2] * shape[3]), static_cast(shape[1]), static_cast(shape[0])}; + cl_int ret = CL_SUCCESS; + uint32_t idx = 0; + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[2]); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); + ret |= unit.kernel->get().setArg(idx++, sizeof(shape), shape); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mTempInputTensor.get())); + MNN_CHECK_CL_SUCCESS(ret, "setArg buffer_convert_to_buffer"); - int batch = inputShape.at(0); - int inputHeight = inputShape.at(1); - int inputWidth = inputShape.at(2); - int inputChannels = inputShape.at(3); - int inputChannelBlocks = (inputChannels + 3) / 4; - int outputBatch = outputShape.at(0); - int outputHeight = outputShape.at(1); - int outputWidth = outputShape.at(2); - int outputChannels = outputShape.at(3); - int outputChannelBlocks = (outputChannels + 3) / 4; - - int localSize = getLocalSize(dim, MaxLocalSize); - if(localSize < 4){ - localSize = 1; + const uint32_t maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); + mLocalSize = {16, std::max((uint32_t)1, maxWorkGroupSize / 16), 1}; + + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalSize); + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalSize[0], mLocalSize[1], mLocalSize[2]}; + mUnits.emplace_back(unit); } - std::set buildOptions = mBuildOptions; - buildOptions.emplace("-DARGMAX_LOCAL_SIZE=" + std::to_string(localSize)); - std::string kernelName; - if(batch * inputHeight * inputChannels == outside && 1 == inside && dim == inputWidth){ - kernelName = "argmax_width_buf"; - unit.kernel = runtime->buildKernel("argmax_buf", kernelName, buildOptions); - mGlobalWorkSize = {static_cast(localSize), static_cast(outputHeight), static_cast(outputBatch * outputChannelBlocks)}; - }else if(batch * inputChannels == outside && inputWidth == inside && dim == inputHeight){ - kernelName = "argmax_height_buf"; - unit.kernel = runtime->buildKernel("argmax_buf", kernelName, buildOptions); - mGlobalWorkSize = {static_cast(localSize), static_cast(outputWidth), static_cast(outputBatch * outputChannelBlocks)}; - }else if(batch == outside && inputWidth * inputHeight == inside && dim == inputChannels){ - if(output->buffer().dimensions == 1){ - buildOptions.emplace("-DARGMAX_CHANNEL_DIM1"); + + // Argmax + { + Unit unit; + int localSize = getLocalSize(dim, MaxLocalSize); + if(localSize < 4){ + localSize = 1; } - kernelName = "argmax_channel_buf"; - unit.kernel = runtime->buildKernel("argmax_buf", kernelName, buildOptions); - mGlobalWorkSize = {static_cast(localSize), static_cast(outputWidth * outputHeight), static_cast(outputBatch * outputChannels)}; - }else if(1 == outside && inputWidth * inputHeight * inputChannels == inside && dim == batch){ - kernelName = "argmax_batch_buf"; - unit.kernel = runtime->buildKernel("argmax_buf", kernelName, buildOptions); - mGlobalWorkSize = {static_cast(localSize), static_cast(outputWidth * outputHeight), static_cast(outputChannelBlocks)}; + std::set buildOptions = mBuildOptions; + buildOptions.emplace("-DARGMAX_LOCAL_SIZE=" + std::to_string(localSize)); + std::string kernelName; + if(inside % 4 == 0){ + kernelName = "argmax_v4_buf"; + unit.kernel = runtime->buildKernel("argmax_buf", kernelName, buildOptions); + mGlobalWorkSize = {static_cast(localSize), static_cast(UP_DIV(inside, 4)), static_cast(outside)}; + }else { + kernelName = "argmax_buf"; + unit.kernel = runtime->buildKernel("argmax_buf", kernelName, buildOptions); + mGlobalWorkSize = {static_cast(localSize), static_cast(inside), static_cast(outside)}; + } + mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); + mLocalSize = {(uint32_t)(localSize), 1, 1}; + + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[2]); + if(mNeedUnpackC4){ + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mTempInputTensor.get())); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mTempOutputTensor.get())); + }else{ + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); + } + ret |= unit.kernel->get().setArg(idx++, inside); + ret |= unit.kernel->get().setArg(idx++, outside); + ret |= unit.kernel->get().setArg(idx++, dim); + MNN_CHECK_CL_SUCCESS(ret, "setArg ArgMaxBufExecution"); + + if(localSize == 1){ + mLocalSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName, unit.kernel).first; + } + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalSize); + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalSize[0], mLocalSize[1], mLocalSize[2]}; + mUnits.emplace_back(unit); } - mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); - mLocalSize = {(uint32_t)(localSize), 1, 1}; - - uint32_t idx = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); - ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[2]); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); - ret |= unit.kernel->get().setArg(idx++, inputWidth); - ret |= unit.kernel->get().setArg(idx++, inputHeight); - ret |= unit.kernel->get().setArg(idx++, inputChannels); - ret |= unit.kernel->get().setArg(idx++, batch); - ret |= unit.kernel->get().setArg(idx++, inputChannelBlocks); - ret |= unit.kernel->get().setArg(idx++, outputWidth); - ret |= unit.kernel->get().setArg(idx++, outputHeight); - ret |= unit.kernel->get().setArg(idx++, outputChannels); - ret |= unit.kernel->get().setArg(idx++, outputChannelBlocks); - MNN_CHECK_CL_SUCCESS(ret, "setArg ArgMaxBufExecution"); + + // NCHW -> NC4HW4 + if(mNeedUnpackC4){ + Unit unit; + std::vector outputShape = tensorShapeFormat(output); + int shape[4] = {outputShape[0], outputShape[3], outputShape[1], outputShape[2]};//N C H W + std::set buildOptions; + buildOptions.emplace("-DINPUT_FORMAT=MNN_DATA_FORMAT_NCHW"); + buildOptions.emplace("-DOUTPUT_FORMAT=MNN_DATA_FORMAT_NC4HW4"); + unit.kernel = runtime->buildKernel("buffer_convert_buf", "buffer_convert_to_buffer", buildOptions, input, output); + mGlobalWorkSize = {static_cast(shape[2] * shape[3]), static_cast(shape[1]), static_cast(shape[0])}; + cl_int ret = CL_SUCCESS; + uint32_t idx = 0; + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[2]); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mTempOutputTensor.get())); + ret |= unit.kernel->get().setArg(idx++, sizeof(shape), shape); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); + MNN_CHECK_CL_SUCCESS(ret, "setArg buffer_convert_to_buffer"); - if(localSize == 1){ - mLocalSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName, unit.kernel).first; + const uint32_t maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); + mLocalSize = {16, std::max((uint32_t)1, maxWorkGroupSize / 16), 1}; + + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalSize); + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalSize[0], mLocalSize[1], mLocalSize[2]}; + mUnits.emplace_back(unit); } - mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalSize); - unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalSize[0], mLocalSize[1], mLocalSize[2]}; return NO_ERROR; } diff --git a/source/backend/opencl/execution/buffer/ArgMaxBufExecution.hpp b/source/backend/opencl/execution/buffer/ArgMaxBufExecution.hpp index 760f909ce..9ce5cd79a 100644 --- a/source/backend/opencl/execution/buffer/ArgMaxBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/ArgMaxBufExecution.hpp @@ -29,6 +29,9 @@ class ArgMaxBufExecution : public CommonExecution { std::set mBuildOptions; int mAxis; OpenCLBackend *mOpenCLBackend; + std::shared_ptr mTempInputTensor; + std::shared_ptr mTempOutputTensor; + bool mNeedUnpackC4; }; } // namespace OpenCL diff --git a/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp b/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp index 2ca359ecf..2302714cc 100644 --- a/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp @@ -13,17 +13,17 @@ namespace MNN { namespace OpenCL { -AttentionBufImpl::AttentionBufImpl(const MNN::Op *op, Backend *backend, bool kv_cahce) - : mKVCache(kv_cahce){ +KVCacheCLManager::KVCacheCLManager(Backend *backend, bool kv_cahce) : mKVCache(kv_cahce){ mOpenCLBackend = static_cast(backend); - auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("softmax_buf", "softmax_channel", {"-DSOFTMAX_LOCAL_SIZE=512"}); - mMaxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel)); } -void AttentionBufImpl::allocKVCache() { +void KVCacheCLManager::allocKVCache() { if (!mKVCache || mPastLength < mMaxLength) { return; } + if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()){ + mByte = 2; + } mMaxLength = mPastLength + mExpandChunk; size_t buffer_size = UP_DIV(mMaxLength, 4) * mKvNumHead * mHeadDim * 4 * mByte; // past_key: [1, numhead, headdim, maxlen] @@ -32,9 +32,9 @@ void AttentionBufImpl::allocKVCache() { mPastValue.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size)); } -void AttentionBufImpl::reallocKVCache() { +bool KVCacheCLManager::reallocKVCache() { if (!mKVCache || mPastLength < mMaxLength) { - return; + return false; } size_t old_size = mKvNumHead * UP_DIV(mMaxLength, 4) * mHeadDim * 4 * mByte; @@ -70,40 +70,47 @@ void AttentionBufImpl::reallocKVCache() { mPastKey.reset(new_key); mPastValue.reset(new_value); - mTempQK.reset(Tensor::createDevice({UP_DIV(mMaxLength, 4) * mNumHead * 4})); - mTempSoftMax.reset(Tensor::createDevice({UP_DIV(mMaxLength, 4) * mNumHead * 4})); + return true; +} + +int AttentionBufExecution::getLocalSize(int size, int maxGroupSize){ + int local_size = 1; + while(local_size * 2 <= maxGroupSize && local_size * 2 <= size){ + local_size *= 2; + } + return local_size; +} + +void AttentionBufExecution::reallocKVCache() { + int maxLength = mKVCacheCLManager->maxLength(); + int numHead = mKVCacheCLManager->numHead(); + mTempQK.reset(Tensor::createDevice({UP_DIV(maxLength, 4) * numHead * 4})); + mTempSoftMax.reset(Tensor::createDevice({UP_DIV(maxLength, 4) * numHead * 4})); mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::STATIC); mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::STATIC); // reset memory for args if(mOpenCLBackend->isUseRecordQueue()){ mQkUpdateInfo.update_kernel_args[1].arg_value = &openCLBuffer(mTempQK.get())(); - mQkUpdateInfo.update_kernel_args[2].arg_value = &(*(mPastKey.get()))(); + mQkUpdateInfo.update_kernel_args[2].arg_value = &(*(mKVCacheCLManager->key()))(); mSoftMaxUpdateInfo.update_kernel_args[0].arg_value = &openCLBuffer(mTempQK.get())(); mSoftMaxUpdateInfo.update_kernel_args[1].arg_value = &openCLBuffer(mTempSoftMax.get())(); mQkvUpdateInfo.update_kernel_args[0].arg_value = &openCLBuffer(mTempSoftMax.get())(); - mQkvUpdateInfo.update_kernel_args[1].arg_value = &(*(mPastValue.get()))(); + mQkvUpdateInfo.update_kernel_args[1].arg_value = &(*(mKVCacheCLManager->value()))(); }else{ cl_int ret = CL_SUCCESS; ret |= mKernel_qk->get().setArg(5, openCLBuffer(mTempQK.get())); - ret |= mKernel_qk->get().setArg(6, *mPastKey.get()); + ret |= mKernel_qk->get().setArg(6, *mKVCacheCLManager->key()); ret |= mKernel_softmax->get().setArg(3, openCLBuffer(mTempQK.get())); ret |= mKernel_softmax->get().setArg(4, openCLBuffer(mTempSoftMax.get())); ret |= mKernel_qkv->get().setArg(3, openCLBuffer(mTempSoftMax.get())); - ret |= mKernel_qkv->get().setArg(6, *mPastValue.get()); + ret |= mKernel_qkv->get().setArg(6, *mKVCacheCLManager->value()); MNN_CHECK_CL_SUCCESS(ret, "reset memory arg for AttentionBufExecution"); } + mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::STATIC); + mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::STATIC); } -int AttentionBufImpl::getLocalSize(int size, int maxGroupSize){ - int local_size = 1; - while(local_size * 2 <= maxGroupSize && local_size * 2 <= size){ - local_size *= 2; - } - return local_size; -} - -ErrorCode AttentionBufImpl::onResize(Backend *backend, const std::vector &inputs, const std::vector &outputs) { - mOpenCLBackend = static_cast(backend); +ErrorCode AttentionBufExecution::onResize(const std::vector &inputs, const std::vector &outputs) { mOpenCLBackend->startRecord(mRecording); //clear update arg vector, if prefill and decode use the same one mOpRecordUpdateInfo.clear(); @@ -124,195 +131,605 @@ ErrorCode AttentionBufImpl::onResize(Backend *backend, const std::vectorgetOpenCLRuntime(); auto shape = query->shape(); + int batch = shape[0]; int seq_len = shape[1]; - mNumHead = shape[2]; - mKvNumHead = key->shape()[2]; - int group_size = mNumHead / mKvNumHead; - mHeadDim = shape[3]; - mScale = 1.0 / sqrt(mHeadDim); + int numHead = shape[2]; + int kvNumHead = key->shape()[2]; + int headDim = shape[3]; + int group_size = numHead / kvNumHead; + float scale = 1.0 / sqrt(headDim); mIsDecode = seq_len == 1; - mIsFirstDecode = true; - if (mPastLength == 0 || seq_len > 1) { - mPastLength = seq_len; - } - mKv_seq_len = mPastLength; - if(mIsDecode){ - mKv_seq_len = mPastLength + 1; - } - if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()){ - mByte = 2; + mIsAddMask = (mask->getType() == halide_type_of()); + mLongPrefill = false; + if(false == mIsDecode){ + mKVCacheCLManager->setArgs(seq_len, numHead, kvNumHead, headDim); + mKVCacheCLManager->allocKVCache(); + + if(seq_len > 512) { + mLongPrefill = true; + mAlignQ = 128; + mAlignKV = 128; + mAlignHDK = 4; + mAlignHDN = 128; + + mTempQ.reset(Tensor::createDevice({ROUND_UP(seq_len, mAlignQ) * ROUND_UP(headDim, mAlignHDK) * batch * numHead})); + mTempK.reset(Tensor::createDevice({ROUND_UP(seq_len, mAlignKV) * ROUND_UP(headDim, mAlignHDK) * batch * numHead})); + mTempV.reset(Tensor::createDevice({ROUND_UP(seq_len, mAlignKV) * ROUND_UP(headDim, mAlignHDN) * batch * numHead})); + if(mIsAddMask) { + mTempMask.reset(Tensor::createDevice({ROUND_UP(seq_len, mAlignQ) * ROUND_UP(seq_len, mAlignKV) * batch})); + } else { + mTempMask.reset(Tensor::createDevice({ROUND_UP(seq_len, mAlignQ) * ROUND_UP(seq_len, mAlignKV) * batch})); + } + mTempQK.reset(Tensor::createDevice({ROUND_UP(seq_len, mAlignQ) * ROUND_UP(seq_len, mAlignKV) * batch * numHead})); + mTempSoftMax.reset(Tensor::createDevice({ROUND_UP(seq_len, mAlignQ) * ROUND_UP(seq_len, mAlignKV) * batch * numHead})); + mTempQKV.reset(Tensor::createDevice({ROUND_UP(seq_len, mAlignQ) * ROUND_UP(headDim, mAlignHDN) * batch * numHead})); + + } else { + mTempQK.reset(Tensor::createDevice({UP_DIV(seq_len, 4) * seq_len * numHead * 4})); + mTempSoftMax.reset(Tensor::createDevice({UP_DIV(seq_len, 4) * seq_len * numHead * 4})); + } + mKv_seq_len = mKVCacheCLManager->kvLength(); + } else { + mKv_seq_len = mKVCacheCLManager->kvLength() + 1; + int maxLength = mKVCacheCLManager->maxLength(); + mTempQK.reset(Tensor::createDevice({UP_DIV(maxLength, 4) * numHead * 4})); + mTempSoftMax.reset(Tensor::createDevice({UP_DIV(maxLength, 4) * numHead * 4})); } - allocKVCache(); - if (mIsDecode) { - mTempQK.reset(Tensor::createDevice({UP_DIV(mMaxLength, 4) * mNumHead * 4})); - mTempSoftMax.reset(Tensor::createDevice({UP_DIV(mMaxLength, 4) * mNumHead * 4})); + + if(mLongPrefill) { + mOpenCLBackend->onAcquireBuffer(mTempQ.get(), Backend::DYNAMIC); + mOpenCLBackend->onAcquireBuffer(mTempK.get(), Backend::DYNAMIC); + mOpenCLBackend->onAcquireBuffer(mTempV.get(), Backend::DYNAMIC); + mOpenCLBackend->onAcquireBuffer(mTempMask.get(), Backend::DYNAMIC); + mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC); + + mOpenCLBackend->onReleaseBuffer(mTempQ.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mTempK.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mTempMask.get(), Backend::DYNAMIC); + + mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC); + + mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC); + + mOpenCLBackend->onAcquireBuffer(mTempQKV.get(), Backend::DYNAMIC); + + mOpenCLBackend->onReleaseBuffer(mTempV.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mTempQKV.get(), Backend::DYNAMIC); + } else { - mTempQK.reset(Tensor::createDevice({UP_DIV(mPastLength, 4) * mPastLength * mNumHead * 4})); - mTempSoftMax.reset(Tensor::createDevice({UP_DIV(mPastLength, 4) * mPastLength * mNumHead * 4})); + mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC); + mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC); } - mOpenCLBackend->onAcquireBuffer(mTempQK.get(), Backend::DYNAMIC); - mOpenCLBackend->onAcquireBuffer(mTempSoftMax.get(), Backend::DYNAMIC); - // query * key -> div -> select - { - std::set buildOption; - if(!mIsDecode){ - buildOption.emplace("-DOPENCL_PREFILL_ATTENTION"); + + if(mLongPrefill) { + // query: [batch, seqLenQ, headNum, headDim] -> mTempQ: [batch*headNum, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenQ, mAlignQ)] + // key: [batch, seqLenKV/4, headNum/group, headDim, seqLenKV_4] -> mTempK: [batch*headNum/group, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenKV, mAlignKV)] + // value: [batch, seqLenKV/4, headNum/group, headDim, seqLenKV_4] -> mTempV: [batch*headNum/group, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(headDim, mAlignHDK] + // key & value -> pastKey & pastValue (copy) + { + std::set buildOption; + if((headDim % 4) != 0){ + buildOption.emplace("-DHEADDIM_LEAVE"); + } + if((seq_len % 4) != 0){ + buildOption.emplace("-DSEQLEN_LEAVE"); + } + + int seq_len_pack_q = ROUND_UP(seq_len, mAlignQ); + int seq_len_pack_kv = ROUND_UP(mKv_seq_len, mAlignKV); + + int head_dim_pack_qk = ROUND_UP(headDim, mAlignHDK); + int head_dim_pack_v = ROUND_UP(headDim, mAlignHDN); + + int tile[4] = {mAlignQ, mAlignKV, mAlignHDK, mAlignHDN}; + int shape[4] = {seq_len, mKv_seq_len, numHead, headDim}; + int param[4] = {group_size, batch, 0, 0}; + mKernel_rearrange = runtime->buildKernel("attention_buf", "rearrange_qkv", buildOption, inputs[0], outputs[0]); + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_rearrange)); + + mGlobalWorkSizeRearrg = {static_cast(ALIMAX(UP_DIV(seq_len_pack_q, 4), UP_DIV(seq_len_pack_kv, 4))), \ + static_cast(ALIMAX(UP_DIV(head_dim_pack_qk, 4), UP_DIV(head_dim_pack_v, 4))), \ + static_cast(batch*numHead)}; + + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= mKernel_rearrange->get().setArg(index++, mGlobalWorkSizeRearrg[0]); + ret |= mKernel_rearrange->get().setArg(index++, mGlobalWorkSizeRearrg[1]); + ret |= mKernel_rearrange->get().setArg(index++, mGlobalWorkSizeRearrg[2]); + ret |= mKernel_rearrange->get().setArg(index++, openCLBuffer(query)); + ret |= mKernel_rearrange->get().setArg(index++, openCLBuffer(key)); + ret |= mKernel_rearrange->get().setArg(index++, openCLBuffer(value)); + ret |= mKernel_rearrange->get().setArg(index++, openCLBuffer(mTempQ.get())); + ret |= mKernel_rearrange->get().setArg(index++, openCLBuffer(mTempK.get())); + ret |= mKernel_rearrange->get().setArg(index++, openCLBuffer(mTempV.get())); + ret |= mKernel_rearrange->get().setArg(index++, *mKVCacheCLManager->key()); + ret |= mKernel_rearrange->get().setArg(index++, *mKVCacheCLManager->value()); + ret |= mKernel_rearrange->get().setArg(index++, tile); + ret |= mKernel_rearrange->get().setArg(index++, shape); + ret |= mKernel_rearrange->get().setArg(index++, param); + + MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_qkv"); + mLocalWorkSizeRearrg = localWS3DDefault(mGlobalWorkSizeRearrg, maxWorkGroupSize, runtime, "rearrange_qkv", mKernel_rearrange).first; + mGlobalWorkSizeRearrg[0] = ROUND_UP(mGlobalWorkSizeRearrg[0], std::max((uint32_t)1, mLocalWorkSizeRearrg[0])); + mGlobalWorkSizeRearrg[1] = ROUND_UP(mGlobalWorkSizeRearrg[1], std::max((uint32_t)1, mLocalWorkSizeRearrg[1])); + mGlobalWorkSizeRearrg[2] = ROUND_UP(mGlobalWorkSizeRearrg[2], std::max((uint32_t)1, mLocalWorkSizeRearrg[2])); + mOpenCLBackend->recordKernel3d(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg); } - if((mHeadDim % 4) != 0){ - buildOption.emplace("-DHEADDIM_LEAVE"); + + // mask rearaange + { + std::set buildOption; + + int seq_len_pack_q = ROUND_UP(seq_len, mAlignQ); + int seq_len_pack_kv = ROUND_UP(mKv_seq_len, mAlignKV); + int shape[4] = {seq_len, mKv_seq_len, mAlignQ, mAlignKV}; + + mKernel_mask = runtime->buildKernel("attention_buf", "rearrange_mask", buildOption, inputs[0], outputs[0]); + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_mask)); + + mGlobalWorkSizeMask = {static_cast(UP_DIV(seq_len_pack_q, 4)), \ + static_cast(UP_DIV(seq_len_pack_kv, 4)), \ + static_cast(batch)}; + + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= mKernel_mask->get().setArg(index++, mGlobalWorkSizeMask[0]); + ret |= mKernel_mask->get().setArg(index++, mGlobalWorkSizeMask[1]); + ret |= mKernel_mask->get().setArg(index++, mGlobalWorkSizeMask[2]); + ret |= mKernel_mask->get().setArg(index++, openCLBuffer(mask)); + ret |= mKernel_mask->get().setArg(index++, openCLBuffer(mTempMask.get())); + ret |= mKernel_mask->get().setArg(index++, shape); + + MNN_CHECK_CL_SUCCESS(ret, "setArg rearrange_mask"); + mLocalWorkSizeMask = localWS3DDefault(mGlobalWorkSizeMask, maxWorkGroupSize, runtime, "rearrange_mask", mKernel_mask).first; + mGlobalWorkSizeMask[0] = ROUND_UP(mGlobalWorkSizeMask[0], std::max((uint32_t)1, mLocalWorkSizeMask[0])); + mGlobalWorkSizeMask[1] = ROUND_UP(mGlobalWorkSizeMask[1], std::max((uint32_t)1, mLocalWorkSizeMask[1])); + mGlobalWorkSizeMask[2] = ROUND_UP(mGlobalWorkSizeMask[2], std::max((uint32_t)1, mLocalWorkSizeMask[2])); + mOpenCLBackend->recordKernel3d(mKernel_mask, mGlobalWorkSizeMask, mLocalWorkSizeMask); } - if(mask->getType() == halide_type_of()){ - buildOption.emplace("-DADD_MASK"); + + { + // Q : [batch*headNum, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenQ, mAlignQ)] -> [B, K, M] + // K : [batch*headNum/group, ROUND_UP(headDim, mAlignHDK), ROUND_UP(seqLenKV, mAlignKV)] -> [B, K, N] + // QV: [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ), ROUND_UP(seqLenKV, mAlignKV)] -> [B, M, N] + int loop = batch * numHead; + int e_pack = ROUND_UP(seq_len, mAlignQ); + int h_pack = ROUND_UP(mKv_seq_len, mAlignKV); + int l_pack = ROUND_UP(headDim, mAlignHDK); + + std::set buildOptions; + + int biasType = 5;// int value mask + if(mIsAddMask) { + biasType = 2; + } + uint32_t layout = 14; // 10 means mix-precision, 4 means layput + auto param = getGemmParams({(uint32_t)e_pack, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)loop, (uint32_t)(biasType + 10*(group_size-1))}, {openCLBuffer(mTempQ.get()), openCLBuffer(mTempK.get()), openCLBuffer(mTempQK.get()), openCLBuffer(mTempMask.get())}, mOpenCLBackend->getOpenCLRuntime()); + + int KWG=param[0], KWI=param[1], MDIMA=param[2], MDIMC=param[3], MWG=param[4], NDIMB=param[5], NDIMC=param[6], NWG=param[7], SA=param[8], SB=param[9], STRM=param[10], STRN=param[11], VWM=param[12], VWN=param[13]; + buildOptions.emplace("-DKWG=" + std::to_string(KWG)); + buildOptions.emplace("-DKWI=" + std::to_string(KWI)); + buildOptions.emplace("-DMDIMA=" + std::to_string(MDIMA)); + buildOptions.emplace("-DMDIMC=" + std::to_string(MDIMC)); + buildOptions.emplace("-DMWG=" + std::to_string(MWG)); + buildOptions.emplace("-DNDIMB=" + std::to_string(NDIMB)); + buildOptions.emplace("-DNDIMC=" + std::to_string(NDIMC)); + buildOptions.emplace("-DNWG=" + std::to_string(NWG)); + buildOptions.emplace("-DSA=" + std::to_string(SA)); + buildOptions.emplace("-DSB=" + std::to_string(SB)); + buildOptions.emplace("-DSTRM=" + std::to_string(STRM)); + buildOptions.emplace("-DSTRN=" + std::to_string(STRN)); + buildOptions.emplace("-DVWM=" + std::to_string(VWM)); + buildOptions.emplace("-DVWN=" + std::to_string(VWN)); + if(layout >= 4) { + buildOptions.emplace("-DOUTPUTMN"); + } + + int tileM = MWG; + int tileN = NWG; + int localM = MDIMC; + int localN = NDIMC; + + if(mOpenCLBackend->getOpenCLRuntime()->getGpuType() == GpuType::ADRENO) { + buildOptions.emplace("-DUSE_CL_MAD=1"); + buildOptions.emplace("-DRELAX_WORKGROUP_SIZE=1"); + } + buildOptions.emplace("-DONLY_HAVE_ALPHA"); + buildOptions.emplace("-DBIAS_TYPE=" + std::to_string(biasType)); + + buildOptions.emplace("-DPRECISION_COMPUTE=float -DCONVERT_PRECISION_COMPUTE=convert_float"); + buildOptions.emplace("-DPRECISION_COMPUTE2=float2 -DCONVERT_PRECISION_COMPUTE2=convert_float2"); + buildOptions.emplace("-DPRECISION_COMPUTE4=float4 -DCONVERT_PRECISION_COMPUTE4=convert_float4"); + buildOptions.emplace("-DPRECISION_COMPUTE8=float8 -DCONVERT_PRECISION_COMPUTE8=convert_float8"); + buildOptions.emplace("-DPRECISION_COMPUTE16=float16 -DCONVERT_PRECISION_COMPUTE16=convert_float16"); + + mKernel_qk = mOpenCLBackend->getOpenCLRuntime()->buildKernel("matmul_params_buf", "XgemmBatched", buildOptions); + + int out_per_thread_m = tileM / localM; + int out_per_thread_n = tileN / localN; + + mGlobalWorkSizeQk = {static_cast(e_pack/out_per_thread_m), static_cast(h_pack/out_per_thread_n), static_cast(loop)}; + mLocalWorkSizeQk = {static_cast(localM), static_cast(localN), 1}; + + float alpha = scale; + float beta = 0.0f; + int batch_offset_a = e_pack * l_pack; + int batch_offset_b = h_pack * l_pack; + int batch_offset_c = e_pack * h_pack; + + int batch_offset[4] = {batch_offset_a, batch_offset_b, batch_offset_c, 0}; + int stride[4] = {e_pack, h_pack, h_pack, h_pack}; + int group[4] = {1, group_size, 1, numHead}; + + int idx = 0; + cl_int ret = CL_SUCCESS; + ret |= mKernel_qk->get().setArg(idx++, static_cast(e_pack)); + ret |= mKernel_qk->get().setArg(idx++, static_cast(h_pack)); + ret |= mKernel_qk->get().setArg(idx++, static_cast(l_pack)); + ret |= mKernel_qk->get().setArg(idx++, alpha); + ret |= mKernel_qk->get().setArg(idx++, beta); + ret |= mKernel_qk->get().setArg(idx++, openCLBuffer(mTempQ.get())); + ret |= mKernel_qk->get().setArg(idx++, openCLBuffer(mTempK.get())); + ret |= mKernel_qk->get().setArg(idx++, openCLBuffer(mTempMask.get())); + ret |= mKernel_qk->get().setArg(idx++, openCLBuffer(mTempQK.get())); + ret |= mKernel_qk->get().setArg(idx++, batch_offset); + ret |= mKernel_qk->get().setArg(idx++, stride); + ret |= mKernel_qk->get().setArg(idx++, group); + MNN_CHECK_CL_SUCCESS(ret, "setArg Self-Attention batchmatmul qk Kernel"); + mOpenCLBackend->recordKernel3d(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk); } - buildOption.emplace("-DNUMHEAD_GROUP_SIZE=" + std::to_string(group_size)); - mKernel_qk = runtime->buildKernel("attention_buf", "matmul_qk_div_mask", buildOption, inputs[0], outputs[0]); - mGlobalWorkSizeQk = {static_cast(UP_DIV(seq_len, 4)), static_cast(mNumHead), static_cast(UP_DIV(mKv_seq_len, 4))}; - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_qk)); - mGlobalWorkSizeQk2 = UP_DIV(mKv_seq_len, 4); - uint32_t index = 0; - cl_int ret = CL_SUCCESS; - ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[0]); - ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[1]); - ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk2); - ret |= mKernel_qk->get().setArg(index++, openCLBuffer(query)); - ret |= mKernel_qk->get().setArg(index++, openCLBuffer(key)); - ret |= mKernel_qk->get().setArg(index++, openCLBuffer(mTempQK.get())); - ret |= mKernel_qk->get().setArg(index++, *mPastKey.get()); - ret |= mKernel_qk->get().setArg(index++, openCLBuffer(mask)); - ret |= mKernel_qk->get().setArg(index++, mScale); - ret |= mKernel_qk->get().setArg(index++, seq_len); - ret |= mKernel_qk->get().setArg(index++, mKv_seq_len); - ret |= mKernel_qk->get().setArg(index++, mNumHead); - ret |= mKernel_qk->get().setArg(index++, mKvNumHead); - ret |= mKernel_qk->get().setArg(index++, mHeadDim); - MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qk_div_mask"); - - mLocalWorkSizeQk = localWS3DDefault(mGlobalWorkSizeQk, maxWorkGroupSize, runtime, "matmul_qk_div_mask", mKernel_qk).first; - mGlobalWorkSizeQk[0] = ROUND_UP(mGlobalWorkSizeQk[0], std::max((uint32_t)1, mLocalWorkSizeQk[0])); - mGlobalWorkSizeQk[1] = ROUND_UP(mGlobalWorkSizeQk[1], std::max((uint32_t)1, mLocalWorkSizeQk[1])); - mGlobalWorkSizeQk[2] = ROUND_UP(mGlobalWorkSizeQk[2], std::max((uint32_t)1, mLocalWorkSizeQk[2])); - mQkUpdateInfo.update_kernel_args.push_back({0, 2, sizeof(mGlobalWorkSizeQk2), &mGlobalWorkSizeQk2}); - mQkUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(cl_mem), &openCLBuffer(mTempQK.get())()}); - mQkUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(cl_mem), &(*(mPastKey.get()))()}); - mQkUpdateInfo.update_kernel_args.push_back({0, 10, sizeof(mKv_seq_len), &mKv_seq_len}); - mQkGlobal_size[0] = mGlobalWorkSizeQk[0]; - mQkGlobal_size[1] = mGlobalWorkSizeQk[1]; - mQkGlobal_size[2] = mGlobalWorkSizeQk[2]; - mQkUpdateInfo.update_global_size.push_back({0, mQkGlobal_size}); - mOpRecordUpdateInfo.emplace_back(&mQkUpdateInfo); - mOpenCLBackend->recordKernel3d(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, &mQkUpdateInfo); - } - - // softmax - { - auto MaxLocalSize = std::min(std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize), static_cast(512)); - int localSize = getLocalSize(mKv_seq_len, MaxLocalSize); - if(localSize < 4){ - localSize = 1; + // softmax + { + // QV: [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ), ROUND_UP(seqLenKV, mAlignKV)] + // Sotmax: [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ), ROUND_UP(seqLenKV, mAlignKV)] + // axis : 2 (last dim) + int softmaxShape[4]; + softmaxShape[0] = batch*numHead; + softmaxShape[1] = ROUND_UP(seq_len, mAlignQ); + softmaxShape[2] = ROUND_UP(mKv_seq_len, mAlignKV); + + auto MaxLocalSize = std::min(std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize), static_cast(256)); + int localSize = getLocalSize(softmaxShape[2], MaxLocalSize); + if(localSize < 4){ + localSize = 1; + } + + std::set buildOption; + buildOption.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize)); + + mKernel_softmax = runtime->buildKernel("self_attention_buf", "softmax_inside", buildOption, inputs[0], outputs[0]); + mGlobalWorkSizeSoftMax = {static_cast(localSize), static_cast(softmaxShape[1]), static_cast(softmaxShape[0])}; + + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[0]); + ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[1]); + ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[2]); + ret |= mKernel_softmax->get().setArg(index++, openCLBuffer(mTempQK.get())); + ret |= mKernel_softmax->get().setArg(index++, openCLBuffer(mTempSoftMax.get())); + ret |= mKernel_softmax->get().setArg(index++, mKv_seq_len); + ret |= mKernel_softmax->get().setArg(index++, softmaxShape); + MNN_CHECK_CL_SUCCESS(ret, "setArg Attention softmax"); + + mLocalWorkSizeSoftMax = {static_cast(localSize), 1, 1}; + mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax); + } + { + // Sotmax: [Batch * numHead, ROUND_UP(seqLenQ, mAlignQ), ROUND_UP(seqLenKV, mAlignKV)] + // Trans: [Batch * numHead, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(seqLenQ, mAlignQ)] + int loop = batch * numHead; + int transDimW = ROUND_UP(seq_len, mAlignQ); + int transDimH = ROUND_UP(mKv_seq_len, mAlignKV); + + std::set buildOptions; + mKernel_trans = runtime->buildKernel("self_attention_buf", "trans_3d_buf", buildOptions, inputs[0], outputs[0]); + uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mKernel_trans)); + + mGlobalWorkSizeTrans = {(uint32_t)transDimW/8, (uint32_t)transDimH/8, (uint32_t)(loop)}; + + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= mKernel_trans->get().setArg(index++, mGlobalWorkSizeTrans[0]); + ret |= mKernel_trans->get().setArg(index++, mGlobalWorkSizeTrans[1]); + ret |= mKernel_trans->get().setArg(index++, mGlobalWorkSizeTrans[2]); + ret |= mKernel_trans->get().setArg(index++, openCLBuffer(mTempSoftMax.get())); + ret |= mKernel_trans->get().setArg(index++, openCLBuffer(mTempQK.get())); + ret |= mKernel_trans->get().setArg(index++, loop); + ret |= mKernel_trans->get().setArg(index++, transDimW); + ret |= mKernel_trans->get().setArg(index++, transDimH); + MNN_CHECK_CL_SUCCESS(ret, "setArg Attention transpose"); + mLocalWorkSizeTrans = localWS3DDefault(mGlobalWorkSizeTrans, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "trans_3d_buf", mKernel_trans).first; + + mGlobalWorkSizeTrans[0] = ROUND_UP(mGlobalWorkSizeTrans[0], std::max((uint32_t)1, mLocalWorkSizeTrans[0])); + mGlobalWorkSizeTrans[1] = ROUND_UP(mGlobalWorkSizeTrans[1], std::max((uint32_t)1, mLocalWorkSizeTrans[1])); + mGlobalWorkSizeTrans[2] = ROUND_UP(mGlobalWorkSizeTrans[2], std::max((uint32_t)1, mLocalWorkSizeTrans[2])); + + mOpenCLBackend->recordKernel3d(mKernel_trans, mGlobalWorkSizeTrans, mLocalWorkSizeTrans); } - int past_len4 = UP_DIV(mKv_seq_len, 4); - mSoftMaxRemainChannels = past_len4 * 4 - mKv_seq_len; - mSoftmaxShape[0] = mNumHead; - mSoftmaxShape[1] = past_len4; - mSoftmaxShape[2] = 1; - mSoftmaxShape[3] = mPastLength; - std::set buildOption; - buildOption.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize)); - if(!mIsDecode){ - mKernel_softmax = runtime->buildKernel("softmax_buf", "softmax_width", buildOption, inputs[0], outputs[0]); - mGlobalWorkSizeSoftMax = {static_cast(localSize), static_cast(past_len4), static_cast(mNumHead)}; - } else{ - mKernel_softmax = runtime->buildKernel("softmax_buf", "softmax_channel", buildOption, inputs[0], outputs[0]); - mSoftmaxShape[3] = 1; - mGlobalWorkSizeSoftMax = {static_cast(localSize), static_cast(1), static_cast(mNumHead)}; + + // qk * value + { + // Trans: [Batch * numHead, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(seqLenQ, mAlignQ)] -> [B, K, M] + // V : [Batch * numHead / group, ROUND_UP(seqLenKV, mAlignKV), ROUND_UP(headDim, mAlignHDN)] -> [B, K, N] + // QKV : [Batch * numHead, ROUND_UP(headDim, mAlignHDN), ROUND_UP(seqLenQ, mAlignQ)] -> [B, N, M] + + int loop = batch * numHead; + int e_pack = ROUND_UP(seq_len, mAlignQ); + int l_pack = ROUND_UP(mKv_seq_len, mAlignKV); + int h_pack = ROUND_UP(headDim, mAlignHDN); + + std::set buildOptions; + + uint32_t layout = 0; + auto param = getGemmParams({(uint32_t)e_pack, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)loop, (uint32_t)0}, {openCLBuffer(mTempQK.get()), openCLBuffer(mTempV.get()), openCLBuffer(mTempQKV.get())}, mOpenCLBackend->getOpenCLRuntime()); + + int KWG=param[0], KWI=param[1], MDIMA=param[2], MDIMC=param[3], MWG=param[4], NDIMB=param[5], NDIMC=param[6], NWG=param[7], SA=param[8], SB=param[9], STRM=param[10], STRN=param[11], VWM=param[12], VWN=param[13]; + buildOptions.emplace("-DKWG=" + std::to_string(KWG)); + buildOptions.emplace("-DKWI=" + std::to_string(KWI)); + buildOptions.emplace("-DMDIMA=" + std::to_string(MDIMA)); + buildOptions.emplace("-DMDIMC=" + std::to_string(MDIMC)); + buildOptions.emplace("-DMWG=" + std::to_string(MWG)); + buildOptions.emplace("-DNDIMB=" + std::to_string(NDIMB)); + buildOptions.emplace("-DNDIMC=" + std::to_string(NDIMC)); + buildOptions.emplace("-DNWG=" + std::to_string(NWG)); + buildOptions.emplace("-DSA=" + std::to_string(SA)); + buildOptions.emplace("-DSB=" + std::to_string(SB)); + buildOptions.emplace("-DSTRM=" + std::to_string(STRM)); + buildOptions.emplace("-DSTRN=" + std::to_string(STRN)); + buildOptions.emplace("-DVWM=" + std::to_string(VWM)); + buildOptions.emplace("-DVWN=" + std::to_string(VWN)); + if(layout >= 4) { + buildOptions.emplace("-DOUTPUTMN"); + } + + int tileM = MWG; + int tileN = NWG; + int localM = MDIMC; + int localN = NDIMC; + + if(mOpenCLBackend->getOpenCLRuntime()->getGpuType() == GpuType::ADRENO) { + buildOptions.emplace("-DUSE_CL_MAD=1"); + buildOptions.emplace("-DRELAX_WORKGROUP_SIZE=1"); + } + + mKernel_qkv = mOpenCLBackend->getOpenCLRuntime()->buildKernel("matmul_params_buf", "XgemmBatched", buildOptions); + + int out_per_thread_m = tileM / localM; + int out_per_thread_n = tileN / localN; + + mGlobalWorkSizeQkv = {static_cast(e_pack/out_per_thread_m), static_cast(h_pack/out_per_thread_n), static_cast(loop)}; + mLocalWorkSizeQkv = {static_cast(localM), static_cast(localN), 1}; + + float alpha = 1.0f; + float beta = 0.0f; + int batch_offset_a = e_pack * l_pack; + int batch_offset_b = h_pack * l_pack; + int batch_offset_c = e_pack * h_pack; + int batch_offset[4] = {batch_offset_a, batch_offset_b, batch_offset_c, 0}; + int stride[4] = {e_pack, h_pack, e_pack, h_pack}; + int group[4] = {1, group_size, 1, numHead}; + + int idx = 0; + cl_int ret = CL_SUCCESS; + ret |= mKernel_qkv->get().setArg(idx++, static_cast(e_pack)); + ret |= mKernel_qkv->get().setArg(idx++, static_cast(h_pack)); + ret |= mKernel_qkv->get().setArg(idx++, static_cast(l_pack)); + ret |= mKernel_qkv->get().setArg(idx++, alpha); + ret |= mKernel_qkv->get().setArg(idx++, beta); + ret |= mKernel_qkv->get().setArg(idx++, openCLBuffer(mTempQK.get())); + ret |= mKernel_qkv->get().setArg(idx++, openCLBuffer(mTempV.get())); + ret |= mKernel_qkv->get().setArg(idx++, openCLBuffer(mTempQKV.get())); + ret |= mKernel_qkv->get().setArg(idx++, batch_offset); + ret |= mKernel_qkv->get().setArg(idx++, stride); + ret |= mKernel_qkv->get().setArg(idx++, group); + MNN_CHECK_CL_SUCCESS(ret, "setArg Self-Attention batchmatmul qkv Kernel"); + mOpenCLBackend->recordKernel3d(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv); } - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_softmax)); - uint32_t index = 0; - cl_int ret = CL_SUCCESS; - ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[0]); - ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[1]); - ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[2]); - ret |= mKernel_softmax->get().setArg(index++, openCLBuffer(mTempQK.get())); - ret |= mKernel_softmax->get().setArg(index++, openCLBuffer(mTempSoftMax.get())); - ret |= mKernel_softmax->get().setArg(index++, mSoftMaxRemainChannels); - ret |= mKernel_softmax->get().setArg(index++, mSoftmaxShape); - MNN_CHECK_CL_SUCCESS(ret, "setArg softmax"); - - mLocalWorkSizeSoftMax = {static_cast(localSize), 1, 1}; - if(localSize == 1){ - mLocalWorkSizeSoftMax = localWS3DDefault(mGlobalWorkSizeSoftMax, maxWorkGroupSize, runtime, "softmax", mKernel_softmax).first; + // transpose to output + { + // QKV : [Batch * numHead, ROUND_UP(headDim, mAlignHDN), ROUND_UP(seqLenQ, mAlignQ)] -> [B, N, M] + // output: [batch, seqLenQ/4, headNum, headDim, seqLenQ_4] + std::set buildOption; + + mKernel_clip = runtime->buildKernel("attention_buf", "qkv_transpose_output", buildOption, inputs[0], outputs[0]); + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_clip)); + + mGlobalWorkSizeClip = {static_cast(UP_DIV(seq_len, 4)), static_cast(UP_DIV(headDim, 4)), static_cast(batch*numHead)}; + + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= mKernel_clip->get().setArg(index++, mGlobalWorkSizeClip[0]); + ret |= mKernel_clip->get().setArg(index++, mGlobalWorkSizeClip[1]); + ret |= mKernel_clip->get().setArg(index++, mGlobalWorkSizeClip[2]); + ret |= mKernel_clip->get().setArg(index++, openCLBuffer(mTempQKV.get())); + ret |= mKernel_clip->get().setArg(index++, openCLBuffer(outputs[0])); + ret |= mKernel_clip->get().setArg(index++, mAlignQ); + ret |= mKernel_clip->get().setArg(index++, mAlignHDN); + ret |= mKernel_clip->get().setArg(index++, seq_len); + ret |= mKernel_clip->get().setArg(index++, numHead); + ret |= mKernel_clip->get().setArg(index++, headDim); + + mLocalWorkSizeClip = localWS3DDefault(mGlobalWorkSizeClip, maxWorkGroupSize, runtime, "qkv_transpose_output", mKernel_clip).first; + mGlobalWorkSizeClip[0] = ROUND_UP(mGlobalWorkSizeClip[0], std::max((uint32_t)1, mLocalWorkSizeClip[0])); + mGlobalWorkSizeClip[1] = ROUND_UP(mGlobalWorkSizeClip[1], std::max((uint32_t)1, mLocalWorkSizeClip[1])); + mGlobalWorkSizeClip[2] = ROUND_UP(mGlobalWorkSizeClip[2], std::max((uint32_t)1, mLocalWorkSizeClip[2])); + + MNN_CHECK_CL_SUCCESS(ret, "setArg qkv_transpose_output"); + mOpenCLBackend->recordKernel3d(mKernel_clip, mGlobalWorkSizeClip, mLocalWorkSizeClip); } - mGlobalWorkSizeSoftMax[0] = ROUND_UP(mGlobalWorkSizeSoftMax[0], std::max((uint32_t)1, mLocalWorkSizeSoftMax[0])); - mGlobalWorkSizeSoftMax[1] = ROUND_UP(mGlobalWorkSizeSoftMax[1], std::max((uint32_t)1, mLocalWorkSizeSoftMax[1])); - mGlobalWorkSizeSoftMax[2] = ROUND_UP(mGlobalWorkSizeSoftMax[2], std::max((uint32_t)1, mLocalWorkSizeSoftMax[2])); - mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLBuffer(mTempQK.get())()}); - mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLBuffer(mTempSoftMax.get())()}); - mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(mSoftMaxRemainChannels), &mSoftMaxRemainChannels}); - mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(mSoftmaxShape), &mSoftmaxShape}); - mOpRecordUpdateInfo.emplace_back(&mSoftMaxUpdateInfo); - mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, &mSoftMaxUpdateInfo); - } - - // qk * value - { - std::set buildOption; - if(!mIsDecode){ - buildOption.emplace("-DOPENCL_PREFILL_ATTENTION"); + + } else { + // query * key -> div -> select + { + std::set buildOption; + if(!mIsDecode){ + buildOption.emplace("-DOPENCL_PREFILL_ATTENTION"); + } + if((headDim % 4) != 0){ + buildOption.emplace("-DHEADDIM_LEAVE"); + } + if(mask->getType() == halide_type_of()){ + buildOption.emplace("-DADD_MASK"); + } + buildOption.emplace("-DNUMHEAD_GROUP_SIZE=" + std::to_string(group_size)); + mKernel_qk = runtime->buildKernel("attention_buf", "matmul_qk_div_mask", buildOption, inputs[0], outputs[0]); + mGlobalWorkSizeQk = {static_cast(UP_DIV(mKv_seq_len, 4)), static_cast(UP_DIV(seq_len, 4)), static_cast(numHead)}; + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_qk)); + mGlobalWorkSizeQk0 = UP_DIV(mKv_seq_len, 4); + + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk0); + ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[1]); + ret |= mKernel_qk->get().setArg(index++, mGlobalWorkSizeQk[2]); + ret |= mKernel_qk->get().setArg(index++, openCLBuffer(query)); + ret |= mKernel_qk->get().setArg(index++, openCLBuffer(key)); + ret |= mKernel_qk->get().setArg(index++, openCLBuffer(mTempQK.get())); + ret |= mKernel_qk->get().setArg(index++, *mKVCacheCLManager->key()); + ret |= mKernel_qk->get().setArg(index++, openCLBuffer(mask)); + ret |= mKernel_qk->get().setArg(index++, scale); + ret |= mKernel_qk->get().setArg(index++, seq_len); + ret |= mKernel_qk->get().setArg(index++, mKv_seq_len); + ret |= mKernel_qk->get().setArg(index++, numHead); + ret |= mKernel_qk->get().setArg(index++, kvNumHead); + ret |= mKernel_qk->get().setArg(index++, headDim); + MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qk_div_mask"); + + mLocalWorkSizeQk = localWS3DDefault(mGlobalWorkSizeQk, maxWorkGroupSize, runtime, "matmul_qk_div_mask", mKernel_qk).first; + mGlobalWorkSizeQk[0] = ROUND_UP(mGlobalWorkSizeQk[0], std::max((uint32_t)1, mLocalWorkSizeQk[0])); + mGlobalWorkSizeQk[1] = ROUND_UP(mGlobalWorkSizeQk[1], std::max((uint32_t)1, mLocalWorkSizeQk[1])); + mGlobalWorkSizeQk[2] = ROUND_UP(mGlobalWorkSizeQk[2], std::max((uint32_t)1, mLocalWorkSizeQk[2])); + mQkUpdateInfo.update_kernel_args.push_back({0, 0, sizeof(mGlobalWorkSizeQk0), &mGlobalWorkSizeQk0}); + mQkUpdateInfo.update_kernel_args.push_back({0, 5, sizeof(cl_mem), &openCLBuffer(mTempQK.get())()}); + mQkUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(cl_mem), &(*(mKVCacheCLManager->key()))()}); + mQkUpdateInfo.update_kernel_args.push_back({0, 10, sizeof(mKv_seq_len), &mKv_seq_len}); + mQkGlobal_size[0] = mGlobalWorkSizeQk[0]; + mQkGlobal_size[1] = mGlobalWorkSizeQk[1]; + mQkGlobal_size[2] = mGlobalWorkSizeQk[2]; + mQkUpdateInfo.update_global_size.push_back({0, mQkGlobal_size}); + mOpRecordUpdateInfo.emplace_back(&mQkUpdateInfo); + mOpenCLBackend->recordKernel3d(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, &mQkUpdateInfo); } - if((mHeadDim % 4) != 0){ - buildOption.emplace("-DHEADDIM_LEAVE"); + + // softmax + { + int inside = 1; + int outside = numHead * seq_len; + auto MaxLocalSize = std::min(std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize), static_cast(256)); + int localSize = getLocalSize(UP_DIV(mKv_seq_len, 4), MaxLocalSize); + if(localSize < 4){ + localSize = 1; + } + + std::set buildOption; + buildOption.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize)); + mKernel_softmax = runtime->buildKernel("softmax_buf", "softmax_in1_buf", buildOption); + mGlobalWorkSizeSoftMax = {static_cast(localSize), static_cast(inside), static_cast(outside)}; + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_softmax)); + + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[0]); + ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[1]); + ret |= mKernel_softmax->get().setArg(index++, mGlobalWorkSizeSoftMax[2]); + ret |= mKernel_softmax->get().setArg(index++, openCLBuffer(mTempQK.get())); + ret |= mKernel_softmax->get().setArg(index++, openCLBuffer(mTempSoftMax.get())); + ret |= mKernel_softmax->get().setArg(index++, inside); + ret |= mKernel_softmax->get().setArg(index++, outside); + ret |= mKernel_softmax->get().setArg(index++, mKv_seq_len); + MNN_CHECK_CL_SUCCESS(ret, "setArg softmax"); + + mLocalWorkSizeSoftMax = {static_cast(localSize), 1, 1}; + if(localSize == 1){ + mLocalWorkSizeSoftMax = localWS3DDefault(mGlobalWorkSizeSoftMax, maxWorkGroupSize, runtime, "softmax", mKernel_softmax).first; + } + mGlobalWorkSizeSoftMax[0] = ROUND_UP(mGlobalWorkSizeSoftMax[0], std::max((uint32_t)1, mLocalWorkSizeSoftMax[0])); + mGlobalWorkSizeSoftMax[1] = ROUND_UP(mGlobalWorkSizeSoftMax[1], std::max((uint32_t)1, mLocalWorkSizeSoftMax[1])); + mGlobalWorkSizeSoftMax[2] = ROUND_UP(mGlobalWorkSizeSoftMax[2], std::max((uint32_t)1, mLocalWorkSizeSoftMax[2])); + mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLBuffer(mTempQK.get())()}); + mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 4, sizeof(cl_mem), &openCLBuffer(mTempSoftMax.get())()}); + mSoftMaxUpdateInfo.update_kernel_args.push_back({0, 7, sizeof(mKv_seq_len), &mKv_seq_len}); + mOpRecordUpdateInfo.emplace_back(&mSoftMaxUpdateInfo); + mOpenCLBackend->recordKernel3d(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, &mSoftMaxUpdateInfo); } - buildOption.emplace("-DNUMHEAD_GROUP_SIZE=" + std::to_string(group_size)); - mKernel_qkv = runtime->buildKernel("attention_buf", "matmul_qkv", buildOption, inputs[0], outputs[0]); - auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_qkv)); - mGlobalWorkSizeQkv = {static_cast(UP_DIV(seq_len, 4)), static_cast(mNumHead), static_cast(UP_DIV(mHeadDim, 4))}; - uint32_t index = 0; - cl_int ret = CL_SUCCESS; - ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[0]); - ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[1]); - ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[2]); - ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(mTempSoftMax.get())); - ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(value)); - ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(outputs[0])); - ret |= mKernel_qkv->get().setArg(index++, *mPastValue.get()); - ret |= mKernel_qkv->get().setArg(index++, seq_len); - ret |= mKernel_qkv->get().setArg(index++, mKv_seq_len); - ret |= mKernel_qkv->get().setArg(index++, mNumHead); - ret |= mKernel_qkv->get().setArg(index++, mKvNumHead); - ret |= mKernel_qkv->get().setArg(index++, mHeadDim); - MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qkv"); - - mLocalWorkSizeQkv = localWS3DDefault(mGlobalWorkSizeQkv, maxWorkGroupSize, runtime, "matmul_qkv", mKernel_qkv).first; - mGlobalWorkSizeQkv[0] = ROUND_UP(mGlobalWorkSizeQkv[0], std::max((uint32_t)1, mLocalWorkSizeQkv[0])); - mGlobalWorkSizeQkv[1] = ROUND_UP(mGlobalWorkSizeQkv[1], std::max((uint32_t)1, mLocalWorkSizeQkv[1])); - mGlobalWorkSizeQkv[2] = ROUND_UP(mGlobalWorkSizeQkv[2], std::max((uint32_t)1, mLocalWorkSizeQkv[2])); + // qk * value + { + std::set buildOption; + if(!mIsDecode){ + buildOption.emplace("-DOPENCL_PREFILL_ATTENTION"); + } + if((headDim % 4) != 0){ + buildOption.emplace("-DHEADDIM_LEAVE"); + } + buildOption.emplace("-DNUMHEAD_GROUP_SIZE=" + std::to_string(group_size)); + mKernel_qkv = runtime->buildKernel("attention_buf", "matmul_qkv", buildOption, inputs[0], outputs[0]); + auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mKernel_qkv)); + mGlobalWorkSizeQkv = {static_cast(UP_DIV(headDim, 4)), static_cast(numHead), static_cast(UP_DIV(seq_len, 4))}; + + uint32_t index = 0; + cl_int ret = CL_SUCCESS; + ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[0]); + ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[1]); + ret |= mKernel_qkv->get().setArg(index++, mGlobalWorkSizeQkv[2]); + ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(mTempSoftMax.get())); + ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(value)); + ret |= mKernel_qkv->get().setArg(index++, openCLBuffer(outputs[0])); + ret |= mKernel_qkv->get().setArg(index++, *mKVCacheCLManager->value()); + ret |= mKernel_qkv->get().setArg(index++, seq_len); + ret |= mKernel_qkv->get().setArg(index++, mKv_seq_len); + ret |= mKernel_qkv->get().setArg(index++, numHead); + ret |= mKernel_qkv->get().setArg(index++, kvNumHead); + ret |= mKernel_qkv->get().setArg(index++, headDim); + MNN_CHECK_CL_SUCCESS(ret, "setArg matmul_qkv"); + + mLocalWorkSizeQkv = localWS3DDefault(mGlobalWorkSizeQkv, maxWorkGroupSize, runtime, "matmul_qkv", mKernel_qkv).first; + mGlobalWorkSizeQkv[0] = ROUND_UP(mGlobalWorkSizeQkv[0], std::max((uint32_t)1, mLocalWorkSizeQkv[0])); + mGlobalWorkSizeQkv[1] = ROUND_UP(mGlobalWorkSizeQkv[1], std::max((uint32_t)1, mLocalWorkSizeQkv[1])); + mGlobalWorkSizeQkv[2] = ROUND_UP(mGlobalWorkSizeQkv[2], std::max((uint32_t)1, mLocalWorkSizeQkv[2])); + + mQkvUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLBuffer(mTempSoftMax.get())()}); + mQkvUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(cl_mem), &(*(mKVCacheCLManager->value()))()}); + mQkvUpdateInfo.update_kernel_args.push_back({0, 8, sizeof(mKv_seq_len), &mKv_seq_len}); + mOpRecordUpdateInfo.emplace_back(&mQkvUpdateInfo); + mOpenCLBackend->recordKernel3d(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, &mQkvUpdateInfo); + } - mQkvUpdateInfo.update_kernel_args.push_back({0, 3, sizeof(cl_mem), &openCLBuffer(mTempSoftMax.get())()}); - mQkvUpdateInfo.update_kernel_args.push_back({0, 6, sizeof(cl_mem), &(*(mPastValue.get()))()}); - mQkvUpdateInfo.update_kernel_args.push_back({0, 8, sizeof(mKv_seq_len), &mKv_seq_len}); - mOpRecordUpdateInfo.emplace_back(&mQkvUpdateInfo); - mOpenCLBackend->recordKernel3d(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, &mQkvUpdateInfo); + mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC); } - mOpenCLBackend->endRecord(mRecording); - - mOpenCLBackend->onReleaseBuffer(mTempQK.get(), Backend::DYNAMIC); - mOpenCLBackend->onReleaseBuffer(mTempSoftMax.get(), Backend::DYNAMIC); + return NO_ERROR; } -ErrorCode AttentionBufImpl::onExecute(Backend *backend, const std::vector &inputs, const std::vector &outputs) { +ErrorCode AttentionBufExecution::onExecute(const std::vector &inputs, const std::vector &outputs) { #ifdef LOG_VERBOSE MNN_PRINT("start AttentionBufExecution onExecute !\n"); #endif - mOpenCLBackend = static_cast(backend); - reallocKVCache(); + if(mIsDecode){ + if(mKVCacheCLManager->reallocKVCache()){ + reallocKVCache(); + } + mKv_seq_len = mKVCacheCLManager->kvLength() + 1; + mGlobalWorkSizeQk0 = UP_DIV(mKv_seq_len, 4); + mQkGlobal_size[0] = ROUND_UP(mGlobalWorkSizeQk0, std::max((uint32_t)1, mLocalWorkSizeQk[0])); + mGlobalWorkSizeQk[0] = mQkGlobal_size[0]; + mKVCacheCLManager->addKvLength(); + } #ifdef ENABLE_OPENCL_TIME_PROFILER + if(mLongPrefill) { + cl::Event event0, event1; + run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, mOpenCLBackend->getOpenCLRuntime(), &event0); + mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_qkv", event0}); + run3DKernelDefault(mKernel_mask, mGlobalWorkSizeMask, mLocalWorkSizeMask, mOpenCLBackend->getOpenCLRuntime(), &event1); + mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_mask", event1}); + } { cl::Event event; run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, @@ -327,6 +744,12 @@ ErrorCode AttentionBufImpl::onExecute(Backend *backend, const std::vectorgetOpenCLRuntime()->pushEvent({"softmax", event}); } + if(mLongPrefill) { + cl::Event event; + run3DKernelDefault(mKernel_trans, mGlobalWorkSizeTrans, mLocalWorkSizeTrans, mOpenCLBackend->getOpenCLRuntime(), &event); + + mOpenCLBackend->getOpenCLRuntime()->pushEvent({"transpose_softmax", event}); + } { cl::Event event; run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, @@ -334,49 +757,45 @@ ErrorCode AttentionBufImpl::onExecute(Backend *backend, const std::vectorgetOpenCLRuntime()->pushEvent({"matmul_qkv", event}); } + if(mLongPrefill) { + cl::Event event; + run3DKernelDefault(mKernel_clip, mGlobalWorkSizeClip, mLocalWorkSizeClip, mOpenCLBackend->getOpenCLRuntime(), &event); + + mOpenCLBackend->getOpenCLRuntime()->pushEvent({"rearrange_output", event}); + } #else if(mOpenCLBackend->isUseRecordQueue()){ mOpenCLBackend->addRecord(mRecording, mOpRecordUpdateInfo); - if(mIsDecode){ - if(mIsFirstDecode){ - mIsFirstDecode = false; - }else{ - mPastLength += 1; - mKv_seq_len = mPastLength + 1; - int past_len4 = UP_DIV(mKv_seq_len, 4); - mSoftMaxRemainChannels = past_len4 * 4 - mKv_seq_len; - mSoftmaxShape[1] = past_len4; - mGlobalWorkSizeQk2 = past_len4; - mQkGlobal_size[2] = ROUND_UP(mGlobalWorkSizeQk2, std::max((uint32_t)1, mLocalWorkSizeQk[2])); - } - } #ifdef LOG_VERBOSE MNN_PRINT("End AttentionBufExecution onExecute... \n"); #endif return NO_ERROR; } - run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime()); - run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime()); - run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime()); -#endif // decode if(mIsDecode){ - mPastLength += 1; - mKv_seq_len = mPastLength + 1; - int past_len4 = UP_DIV(mKv_seq_len, 4); - mSoftMaxRemainChannels = past_len4 * 4 - mKv_seq_len; - mSoftmaxShape[1] = past_len4; cl_int ret = CL_SUCCESS; - mGlobalWorkSizeQk2 = past_len4; - mGlobalWorkSizeQk[2] = ROUND_UP(mGlobalWorkSizeQk2, std::max((uint32_t)1, mLocalWorkSizeQk[2])); - ret |= mKernel_qk->get().setArg(2, mGlobalWorkSizeQk2); + ret |= mKernel_qk->get().setArg(0, mGlobalWorkSizeQk0); ret |= mKernel_qk->get().setArg(10, mKv_seq_len); - ret |= mKernel_softmax->get().setArg(5, mSoftMaxRemainChannels); - ret |= mKernel_softmax->get().setArg(6, mSoftmaxShape); + ret |= mKernel_softmax->get().setArg(7, mKv_seq_len); ret |= mKernel_qkv->get().setArg(8, mKv_seq_len); MNN_CHECK_CL_SUCCESS(ret, "reset arg for AttentionBufExecution"); } + if(mLongPrefill) { + run3DKernelDefault(mKernel_rearrange, mGlobalWorkSizeRearrg, mLocalWorkSizeRearrg, mOpenCLBackend->getOpenCLRuntime()); + run3DKernelDefault(mKernel_mask, mGlobalWorkSizeMask, mLocalWorkSizeMask, mOpenCLBackend->getOpenCLRuntime()); + } + run3DKernelDefault(mKernel_qk, mGlobalWorkSizeQk, mLocalWorkSizeQk, mOpenCLBackend->getOpenCLRuntime()); + run3DKernelDefault(mKernel_softmax, mGlobalWorkSizeSoftMax, mLocalWorkSizeSoftMax, mOpenCLBackend->getOpenCLRuntime()); + if(mLongPrefill) { + run3DKernelDefault(mKernel_trans, mGlobalWorkSizeTrans, mLocalWorkSizeTrans, mOpenCLBackend->getOpenCLRuntime()); + } + run3DKernelDefault(mKernel_qkv, mGlobalWorkSizeQkv, mLocalWorkSizeQkv, mOpenCLBackend->getOpenCLRuntime()); + if(mLongPrefill) { + run3DKernelDefault(mKernel_clip, mGlobalWorkSizeClip, mLocalWorkSizeClip, mOpenCLBackend->getOpenCLRuntime()); + } +#endif + #ifdef LOG_VERBOSE MNN_PRINT("end AttentionBufExecution onExecute !\n"); #endif @@ -385,24 +804,23 @@ ErrorCode AttentionBufImpl::onExecute(Backend *backend, const std::vector impl, const MNN::Op *op, Backend *backend) : CommonExecution(backend, op), mImpl(impl) {} - -ErrorCode AttentionBufExecution::onResize(const std::vector& inputs, const std::vector& outputs) { - return mImpl->onResize(backend(), inputs, outputs); + mKVCacheCLManager.reset(new KVCacheCLManager(backend, kv_cahce)); + mOpenCLBackend = static_cast(backend); + auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("softmax_buf", "softmax_buf", {"-DSOFTMAX_LOCAL_SIZE=512"}); + mMaxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel)); } -ErrorCode AttentionBufExecution::onExecute(const std::vector& inputs, const std::vector& outputs) { - return mImpl->onExecute(backend(), inputs, outputs); +AttentionBufExecution::AttentionBufExecution(std::shared_ptr manager, const MNN::Op *op, Backend *backend) : CommonExecution(backend, op), mKVCacheCLManager(manager) { + mOpenCLBackend = static_cast(backend); + auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("softmax_buf", "softmax_buf", {"-DSOFTMAX_LOCAL_SIZE=512"}); + mMaxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel)); } bool AttentionBufExecution::onClone(Backend* bn, const Op* op, Execution** dst) { if (nullptr == dst) { return true; } - *dst = new AttentionBufExecution(mImpl, op, bn); + *dst = new AttentionBufExecution(mKVCacheCLManager, op, bn); return true; } diff --git a/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp b/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp index cb33dc05d..1292ace2f 100644 --- a/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp @@ -16,33 +16,63 @@ namespace MNN { namespace OpenCL { -class AttentionBufImpl { +class KVCacheCLManager { public: - AttentionBufImpl(const MNN::Op *op, Backend *backend, bool kv_cache); + KVCacheCLManager(Backend *backend, bool kv_cache); - ~AttentionBufImpl() { - if(mRecording != NULL){ -#ifdef MNN_USE_LIB_WRAPPER - clReleaseRecordingQCOM(mRecording); -#endif - } + ~KVCacheCLManager() = default; + void allocKVCache(); + bool reallocKVCache(); + void setArgs(int pastLength, int numHead, int kvNumHead, int headDim){ + mPastLength = pastLength; + mNumHead = numHead; + mKvNumHead = kvNumHead; + mHeadDim = headDim; + } + int kvLength() { + return mPastLength; + } + void addKvLength(){ + mPastLength += 1; + } + int maxLength() { + return mMaxLength; + } + int numHead() { + return mNumHead; + } + const cl::Buffer * key() { + return mPastKey.get(); + } + const cl::Buffer * value() { + return mPastValue.get(); } - ErrorCode onResize(Backend *backend, const std::vector &inputs, const std::vector &outputs); - ErrorCode onExecute(Backend *backend, const std::vector &inputs, const std::vector &outputs); private: - int getLocalSize(int size, int maxGroupSize); - void allocKVCache(); - void reallocKVCache(); bool mKVCache; - float mScale; const int mExpandChunk = 2048; - bool mIsDecode = false; - bool mIsFirstDecode = true; - int mPastLength = 0, mMaxLength = 0, mKv_seq_len = 0, mSoftMaxRemainChannels = 0; std::shared_ptr mPastKey, mPastValue; - std::shared_ptr mTempQK, mTempSoftMax; - int mNumHead = 0, mKvNumHead = 0, mHeadDim = 0, mValueH = 0; + int mPastLength = 0, mMaxLength = 0, mNumHead = 0, mKvNumHead = 0, mHeadDim = 0; + OpenCLBackend *mOpenCLBackend; + int mByte = 4; +}; + +class AttentionBufExecution : public CommonExecution { +public: + AttentionBufExecution(const MNN::Op *op, Backend *backend, bool kv_cache); + AttentionBufExecution(std::shared_ptr manager, const MNN::Op *op, Backend *backend); + + virtual ~AttentionBufExecution() = default; + virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; + virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; + virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; + +private: + + int getLocalSize(int size, int maxGroupSize); + void reallocKVCache(); + bool mIsDecode = false; + int mKv_seq_len = 0; std::shared_ptr mKernel_qk; std::shared_ptr mKernel_softmax; std::shared_ptr mKernel_qkv; @@ -57,26 +87,28 @@ class AttentionBufImpl { RecordUpdateInfo mQkUpdateInfo; RecordUpdateInfo mSoftMaxUpdateInfo; RecordUpdateInfo mQkvUpdateInfo; - int mGlobalWorkSizeQk2 = 0; + int mGlobalWorkSizeQk0 = 0; size_t mQkGlobal_size[3]; - int mSoftmaxShape[4]; - cl_recording_qcom mRecording{NULL}; std::vector mOpRecordUpdateInfo; - int mByte = 4; -}; - -class AttentionBufExecution : public CommonExecution { -public: - AttentionBufExecution(const MNN::Op *op, Backend *backend, bool kv_cache); - AttentionBufExecution(std::shared_ptr impl, const MNN::Op *op, Backend *backend); - - virtual ~AttentionBufExecution() = default; - virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; - virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; - virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; - + std::shared_ptr mKVCacheCLManager; + std::shared_ptr mTempQK, mTempSoftMax; private: - std::shared_ptr mImpl; + int mAlignQ, mAlignKV, mAlignHDK, mAlignHDN; + bool mLongPrefill = false; + std::shared_ptr mKernel_rearrange; + std::vector mGlobalWorkSizeRearrg{1, 1, 1}; + std::vector mLocalWorkSizeRearrg{1, 1, 1, 1}; + std::shared_ptr mKernel_mask; + std::vector mGlobalWorkSizeMask{1, 1, 1}; + std::vector mLocalWorkSizeMask{1, 1, 1, 1}; + std::shared_ptr mKernel_trans; + std::vector mGlobalWorkSizeTrans{1, 1, 1}; + std::vector mLocalWorkSizeTrans{1, 1, 1, 1}; + std::shared_ptr mKernel_clip; + std::vector mGlobalWorkSizeClip{1, 1, 1}; + std::vector mLocalWorkSizeClip{1, 1, 1, 1}; + std::shared_ptr mTempQ, mTempK, mTempV, mTempMask, mTempQKV; + bool mIsAddMask = false; }; } // namespace OpenCL } // namespace MNN diff --git a/source/backend/opencl/execution/buffer/BinaryBufExecution.cpp b/source/backend/opencl/execution/buffer/BinaryBufExecution.cpp index 47b75864f..94db4128e 100644 --- a/source/backend/opencl/execution/buffer/BinaryBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/BinaryBufExecution.cpp @@ -246,19 +246,22 @@ ErrorCode BinaryBufExecution::onEncode(const std::vector &inputs, cons auto openCLBackend = static_cast(backend()); auto output = outputs[0]; - auto inputShape0 = tensorShapeFormat(inputs[0]); - auto inputShape1 = tensorShapeFormat(inputs[1]); auto outputShape = tensorShapeFormat(output); auto runTime = ((OpenCLBackend *)backend())->getOpenCLRuntime(); #ifdef MNN_SUPPORT_INTEL_SUBGROUP - if (runTime->isSupportedIntelSubgroup()) { + if (runTime->isSupportedIntelSubgroup() && MNN::MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(output)->dimensionFormat) { return SubgroupOnResize(inputs, outputs); } #endif /* MNN_SUPPORT_INTEL_SUBGROUP */ - int shape[4] = {outputShape[0], outputShape[1], outputShape[2], UP_DIV(outputShape[3], 4)}; int fullCount[2] = {1, 1}; fullCount[0] = realSize(inputs[0]) == 1 ? 0 : 1; fullCount[1] = realSize(inputs[1]) == 1 ? 0 : 1; + int totalSize = 0; + if(MNN::MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(output)->dimensionFormat){ + totalSize = outputShape[0] * outputShape[1] * outputShape[2] * ROUND_UP(outputShape[3], 4); + }else{ + totalSize = outputShape[0] * outputShape[1] * outputShape[2] * outputShape[3]; + } int activationType = 0; if(mOp->type() == OpType_BinaryOp) { @@ -267,10 +270,8 @@ ErrorCode BinaryBufExecution::onEncode(const std::vector &inputs, cons auto &unit = mUnits[0]; std::set buildOptions = mBuildOptions; - int wh_pack = 1; - if((outputShape[1]*outputShape[2]) % 4 == 0) { - wh_pack = 4; - buildOptions.emplace("-DWH_PACK4"); + if(totalSize % 4 != 0) { + buildOptions.emplace("-DPACK_LEAVE"); } if(fullCount[0] == 0) { buildOptions.emplace("-DA_SINGLE"); @@ -281,9 +282,7 @@ ErrorCode BinaryBufExecution::onEncode(const std::vector &inputs, cons unit.kernel = runTime->buildKernel("binary_buf", "binary_buf", buildOptions, inputs[0], output); mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); - mGlobalWorkSize = {(uint32_t)UP_DIV(outputShape[3], 4) * outputShape[0], - (uint32_t)UP_DIV(outputShape[1]*outputShape[2], wh_pack)}; - + mGlobalWorkSize = {(uint32_t)UP_DIV(totalSize, 4), (uint32_t)1}; uint32_t index = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); @@ -291,13 +290,12 @@ ErrorCode BinaryBufExecution::onEncode(const std::vector &inputs, cons ret |= unit.kernel->get().setArg(index++, openCLBuffer(inputs[0])); ret |= unit.kernel->get().setArg(index++, openCLBuffer(inputs[1])); ret |= unit.kernel->get().setArg(index++, openCLBuffer(output)); - ret |= unit.kernel->get().setArg(index++, shape); - ret |= unit.kernel->get().setArg(index++, fullCount); + ret |= unit.kernel->get().setArg(index++, totalSize); ret |= unit.kernel->get().setArg(index++, activationType); MNN_CHECK_CL_SUCCESS(ret, "setArg BinaryBufExecution"); std::string name = "binary_buf"; - mLocalWorkSize = {(uint32_t)16, (uint32_t)16}; + mLocalWorkSize = {(uint32_t)16, (uint32_t)1}; unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; @@ -307,13 +305,6 @@ ErrorCode BinaryBufExecution::onEncode(const std::vector &inputs, cons fullCount[1] = realSize(inputs[i]) == 1 ? 0 : 1; auto &unit = mUnits[i-1]; - std::set buildOptions = mBuildOptions; - if((outputShape[1]*outputShape[2]) % 4 == 0) { - buildOptions.emplace("-DWH_PACK4"); - } - if(fullCount[1] == 0) { - buildOptions.emplace("-DB_SINGLE"); - } unit.kernel = runTime->buildKernel("binary_buf", "binary_buf", buildOptions, inputs[i], output); uint32_t index = 0; @@ -322,8 +313,7 @@ ErrorCode BinaryBufExecution::onEncode(const std::vector &inputs, cons ret |= unit.kernel->get().setArg(index++, openCLBuffer(output)); ret |= unit.kernel->get().setArg(index++, openCLBuffer(inputs[i])); ret |= unit.kernel->get().setArg(index++, openCLBuffer(output)); - ret |= unit.kernel->get().setArg(index++, shape); - ret |= unit.kernel->get().setArg(index++, fullCount); + ret |= unit.kernel->get().setArg(index++, totalSize); ret |= unit.kernel->get().setArg(index++, activationType); MNN_CHECK_CL_SUCCESS(ret, "setArg BinaryBufExecution MultiInput"); @@ -341,7 +331,8 @@ class BinaryBufCreator : public OpenCLBackend::Creator { const MNN::Op *op, Backend *backend) const override { for (int i = 0; i < inputs.size(); ++i) { int channel = inputs[i]->channel(); - if (channel >= 16 && static_cast(backend)->getOpenCLRuntime()->isSupportedIntelSubgroup()) { + if (channel >= 16 && static_cast(backend)->getOpenCLRuntime()->isSupportedIntelSubgroup() + && MNN::MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(inputs[i])->dimensionFormat) { TensorUtils::setTensorChannelPack(inputs[i], 16); } } diff --git a/source/backend/opencl/execution/buffer/CastBufExecution.cpp b/source/backend/opencl/execution/buffer/CastBufExecution.cpp index dd4debd80..d4ab150bc 100644 --- a/source/backend/opencl/execution/buffer/CastBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/CastBufExecution.cpp @@ -13,54 +13,49 @@ namespace MNN { namespace OpenCL { CastBufExecution::CastBufExecution(const std::vector &inputs, const std::vector &outputs, const std::string& compute, const MNN::Op* op, Backend* backend) : CommonExecution(backend, op) { - mUnits.resize(1); - auto &unit = mUnits[0]; mBuildOptions.emplace(compute); - auto runtime = static_cast(backend)->getOpenCLRuntime(); - unit.kernel = runtime->buildKernel("cast_buf", "cast_buf", mBuildOptions, inputs[0], outputs[0]); - mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); } ErrorCode CastBufExecution::onEncode(const std::vector& inputs, const std::vector& outputs) { + mUnits.resize(1); auto &unit = mUnits[0]; Tensor* input = inputs[0]; Tensor* output = outputs[0]; auto openCLBackend = static_cast(backend()); auto runtime = openCLBackend->getOpenCLRuntime(); - std::vector inputShape = tensorShapeFormat(input); std::vector outputShape = tensorShapeFormat(output); - - int batch = outputShape.at(0); - int outputHeight = outputShape.at(1); - int outputWidth = outputShape.at(2); - int channels = outputShape.at(3); - - int channelBlocks = (channels + 3) / 4; - + int totalSize = 0; + if(MNN::MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(output)->dimensionFormat){ + totalSize = outputShape[0] * outputShape[1] * outputShape[2] * ROUND_UP(outputShape[3], 4); + }else{ + totalSize = outputShape[0] * outputShape[1] * outputShape[2] * outputShape[3]; + } + std::set buildOptions = mBuildOptions; + if(totalSize % 4 != 0) { + buildOptions.emplace("-DPACK_LEAVE"); + } + unit.kernel = runtime->buildKernel("cast_buf", "cast_buf", mBuildOptions, inputs[0], outputs[0]); + mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); + mGlobalWorkSize = { - static_cast(outputWidth), - static_cast(outputHeight), - static_cast(batch * channelBlocks), + static_cast(UP_DIV(totalSize, 4)), + static_cast(1) }; uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[2]); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); - ret |= unit.kernel->get().setArg(idx++, outputWidth); - ret |= unit.kernel->get().setArg(idx++, outputHeight); - ret |= unit.kernel->get().setArg(idx++, channelBlocks); + ret |= unit.kernel->get().setArg(idx++, totalSize); MNN_CHECK_CL_SUCCESS(ret, "setArg CastBufExecution"); std::string kernelName = "cast_buf"; - mLocalSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, openCLBackend->getOpenCLRuntime(), kernelName, unit.kernel).first; - openCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalSize); - unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalSize[0], mLocalSize[1], mLocalSize[2]}; - + mLocalSize = localWS2DDefault(mGlobalWorkSize, mMaxWorkGroupSize, openCLBackend->getOpenCLRuntime(), kernelName, unit.kernel).first; + openCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalSize); + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; + unit.localWorkSize = {mLocalSize[0], mLocalSize[1]}; return NO_ERROR; } diff --git a/source/backend/opencl/execution/buffer/ConvBufExecution.cpp b/source/backend/opencl/execution/buffer/ConvBufExecution.cpp index ba25bda93..8ba800b26 100644 --- a/source/backend/opencl/execution/buffer/ConvBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/ConvBufExecution.cpp @@ -122,17 +122,26 @@ ConvBufExecution::ConvBufExecution(const std::vector &inputs, const st mPaddings[1] == 0 && mResource->mStrides[0] == 1 && mResource->mStrides[1] == 1); mResource->mConv1x1Opt = isConv1x1; - mResource->mConv1x1C8Opt = mResource->mConv1x1Opt && mResource->mOutputChannel >= 16; + if(mResource->mConv1x1Opt) { + mResource->mAlignK = 4; + mResource->mAlignN = 8; + } bool useConvGemm = isConv1x1 && mResource->mInputChannel > 32 && mResource->mOutputChannel > 64; if (useConvGemm) { - mResource->mConvGemmOptLevel = 2; + mResource->mAlignK = 4; + mResource->mAlignN = 16; + mResource->mConvGemmOptLevel = 1; + if(mResource->mOutputChannel > 1024) { + mResource->mAlignN = 128; + } else if(mResource->mOutputChannel > 512) { + mResource->mAlignN = 64; + } else if(mResource->mOutputChannel > 96) { + mResource->mAlignN = 32; + } } } if (mResource->mConv1x1Opt) { - // Tile Match with mConvGemmOptLevel == 2 - int tileK = 4; - int tileN = 32; - int buffer_size = ROUND_UP(mResource->mOutputChannel, tileN) * ROUND_UP(mResource->mInputChannel, tileK); + int buffer_size = ROUND_UP(mResource->mOutputChannel, mResource->mAlignN) * ROUND_UP(mResource->mInputChannel, mResource->mAlignK); mResource->mFilter.reset( Tensor::createDevice({buffer_size})); mOpenCLBackend->onAcquireBuffer(mResource->mFilter.get(), Backend::STATIC); @@ -153,13 +162,13 @@ ConvBufExecution::ConvBufExecution(const std::vector &inputs, const st // [Ci, Co] ( [K, N] ) for (int o = 0; o < mResource->mOutputChannel; o++) { for (int i = 0; i < mResource->mInputChannel; i++) { - ((half_float::half *)ptrCL)[i * ROUND_UP(mResource->mOutputChannel, tileN) + o] = (half_float::half)(mFilterDataPtr[o * mResource->mInputChannel + i]); + ((half_float::half *)ptrCL)[i * ROUND_UP(mResource->mOutputChannel, mResource->mAlignN) + o] = (half_float::half)(mFilterDataPtr[o * mResource->mInputChannel + i]); } } } else { for (int o = 0; o < mResource->mOutputChannel; o++) { for (int i = 0; i < mResource->mInputChannel; i++) { - ((float *)ptrCL)[i * ROUND_UP(mResource->mOutputChannel, tileN) + o] = (mFilterDataPtr[o * mResource->mInputChannel + i]); + ((float *)ptrCL)[i * ROUND_UP(mResource->mOutputChannel, mResource->mAlignN) + o] = (mFilterDataPtr[o * mResource->mInputChannel + i]); } } } @@ -257,6 +266,7 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const mOpenCLBackend->startRecord(mRecording); std::vector inputShape = tensorShapeFormat(input); std::vector outputShape = tensorShapeFormat(output); + const int batch = outputShape.at(0); const int height = outputShape.at(1); const int width = outputShape.at(2); const int outChannel = outputShape.at(3); @@ -279,50 +289,48 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const int M = outputShape.at(0) * area; int N = outputShape.at(3); int K = inputShape.at(3); - - bool isAlign = (K % 8 == 0 && area == 1 && N % 64 == 0 && M % 64 == 0); - bool isLimitSize = (M * 1.0 / 512 * N / 512 * K / 512 <= 1.0) && (1.0 * M * K / N / N >= 16.0); - if(isAlign && isLimitSize) { - mResource->mConvGemmOptLevel = 1; - } else if(M < 128 || 1.0 * M / 512 * N / 512 * K / 256 < 1.0) { + + if(M < 128 || 1.0 * M / 512 * N / 512 * K / 256 < 1.0) { + mResource->mConvGemmOptLevel = 0; + } + if(1.0 * M * N / K / K > 100.0 || 1.0 * M * K / N / N > 100.0) { mResource->mConvGemmOptLevel = 0; } } - - if (mResource->mConvGemmOptLevel == 2) { - // set large tile - int tileM = 16; - int tileN = 32; - int tileK = 4; - + + if (mResource->mConvGemmOptLevel == 1) { int area = height * width; int M = outputShape.at(0) * area; int N = outputShape.at(3); int K = inputShape.at(3); + // set M Align + float ratio = 1.0 * M / 1024.0 * N / 1024.0 * K / 1024.0; + if(M > 1024 && ratio >= 1.0) { + mAlignM = 128; + } else if(M > 512 && ratio >= 0.1) { + mAlignM = 64; + } else if(M > 96){ + mAlignM = 32; + } else { + mAlignM = 16; + } - int alignM = ROUND_UP(M, tileM); - int alignN = ROUND_UP(N, tileN); - int alignK = ROUND_UP(K, tileK); + int alignM = ROUND_UP(M, mAlignM); + int alignN = ROUND_UP(N, mResource->mAlignN); + int alignK = ROUND_UP(K, mResource->mAlignK); // ReArrange input mConvGemmInpTensor.reset(Tensor::createDevice({alignK * alignM})); mOpenCLBackend->onAcquireBuffer(mConvGemmInpTensor.get(), Backend::DYNAMIC); - - mNeedOutTempTensor = true; mConvGemmOutTensor.reset(Tensor::createDevice({alignN * alignM})); mOpenCLBackend->onAcquireBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); - + mOpenCLBackend->onReleaseBuffer(mConvGemmInpTensor.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); + { std::set buildOptions; - - int m_pack = 1; - if(area == 1) { - m_pack = 4; - buildOptions.emplace("-DAREA_EQUAL_1"); - } else if(outputShape.at(0) == 1) { - m_pack = 4; - buildOptions.emplace("-DBATCH_EQUAL_1"); - } + + int m_pack = 4; mPreKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_buf", "transpose_pad", buildOptions); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mPreKernel)); mPreGlobalWorkSize = {static_cast(alignM/m_pack), static_cast(alignK/4)}; @@ -339,14 +347,14 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const ret |= mPreKernel->get().setArg(idx++, static_cast(area)); ret |= mPreKernel->get().setArg(idx++, openCLBuffer(input)); ret |= mPreKernel->get().setArg(idx++, openCLBuffer(mConvGemmInpTensor.get())); - MNN_CHECK_CL_SUCCESS(ret, "setArg mConvgemmOptLevel==2 PreKernel"); + MNN_CHECK_CL_SUCCESS(ret, "setArg mConvgemmOptLevel==1 PreKernel"); mPreLocalWorkSize = localWS2DDefault(mPreGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "transpose_pad", mPreKernel).first; mOpenCLBackend->recordKernel2d(mPreKernel, mPreGlobalWorkSize, mPreLocalWorkSize); mPreGlobalWorkSize[0] = ROUND_UP(mPreGlobalWorkSize[0], std::max((uint32_t)1, mPreLocalWorkSize[0])); mPreGlobalWorkSize[1] = ROUND_UP(mPreGlobalWorkSize[1], std::max((uint32_t)1, mPreLocalWorkSize[1])); } - + // call gemm strassen { mStrassenComputor.reset(new StrassenMatrixComputor(backend(), 3)); @@ -355,15 +363,19 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const } // call output transpose - if(mNeedOutTempTensor) { + { std::set buildOptions = mResource->mBuildOptions; - if(area == 1) { - buildOptions.emplace("-DAREA_EQUAL_1"); + int pack_m = 1; + if(M % 8 == 0) { + pack_m = 8; + } else if(M % 4 == 0) { + pack_m = 4; } + buildOptions.emplace("-DM_VEC=" + std::to_string(pack_m)); mPostKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_buf", "transpose_bias", buildOptions); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mPostKernel)); - mPostGlobalWorkSize = {static_cast(M), static_cast(UP_DIV(N, 16))}; + mPostGlobalWorkSize = {static_cast(UP_DIV(M, pack_m)), static_cast(UP_DIV(N, 4))}; int offset = 0; int idx = 0; @@ -379,7 +391,7 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const ret |= mPostKernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= mPostKernel->get().setArg(idx++, openCLBuffer(output)); - MNN_CHECK_CL_SUCCESS(ret, "setArg mConvgemmOptLevel==2 PostKernel"); + MNN_CHECK_CL_SUCCESS(ret, "setArg mConvgemmOptLevel==1 PostKernel"); mPostLocalWorkSize = localWS2DDefault(mPostGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "transpose_bias", mPostKernel).first; mOpenCLBackend->recordKernel2d(mPostKernel, mPostGlobalWorkSize, mPostLocalWorkSize); mPostGlobalWorkSize[0] = ROUND_UP(mPostGlobalWorkSize[0], std::max((uint32_t)1, mPostLocalWorkSize[0])); @@ -388,146 +400,132 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const mOpenCLBackend->endRecord(mRecording); } - mOpenCLBackend->onReleaseBuffer(mConvGemmInpTensor.get(), Backend::DYNAMIC); - if(mNeedOutTempTensor) { - mOpenCLBackend->onReleaseBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); - } - return NO_ERROR; - } else if (mResource->mConvGemmOptLevel == 1) { - // set small tile - int tileM = 64; - int tileN = 64; - int tileK = 8; - int localM = 16; - int localN = 16; - int M = outputShape.at(0); - int N = outputShape.at(3); - int K = inputShape.at(3); - - std::set buildOptions = mResource->mBuildOptions;; - buildOptions.emplace(" -DBIAS"); + } else if (mResource->mConv1x1Opt) { + if(inputChannels >= 128 && outputShape[0] * outChannel * width * height <= 64){ + mResource->mConv1x1Local = true; + int local_size = 1; + while(local_size * 2 <= 256 && local_size * 2 <= inputChannelBlocks){ + local_size *= 2; + } + mGlobalWorkSize = {static_cast(local_size), static_cast(UP_DIV(outChannel, 4) * width), static_cast(outputShape[0] * height)}; + mLocalWorkSize = {static_cast(local_size), 1, 1}; + + std::set buildOption = mResource->mBuildOptions; + buildOption.emplace("-DCONV_LOCAL_SIZE=" + std::to_string(local_size)); + mKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", "conv_2d_1x1_local", buildOption); + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; - if(N % 128 == 0) { - tileN = 128; - buildOptions.emplace(" -DOPWM=64 -DOPWN=128 -DCPWK=8 -DOPTM=4 -DOPTN=8"); + ret |= mKernel->get().setArg(idx++, UP_DIV(width, 1)); + ret |= mKernel->get().setArg(idx++, openCLBuffer(input)); + ret |= mKernel->get().setArg(idx++, openCLBuffer(mResource->mFilter.get())); + ret |= mKernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); + ret |= mKernel->get().setArg(idx++, openCLBuffer(output)); + ret |= mKernel->get().setArg(idx++, static_cast(inputChannelBlocks)); + ret |= mKernel->get().setArg(idx++, batch); + ret |= mKernel->get().setArg(idx++, height); + ret |= mKernel->get().setArg(idx++, width); + ret |= mKernel->get().setArg(idx++, UP_DIV(outChannel, 4)); + ret |= mKernel->get().setArg(idx++, ROUND_UP(outChannel, mResource->mAlignN)); + MNN_CHECK_CL_SUCCESS(ret, "setArg Conv1x1Buf"); } else { - buildOptions.emplace(" -DOPWM=64 -DOPWN=64 -DCPWK=8 -DOPTM=4 -DOPTN=4"); - } - - - mKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("matmul_local_buf", "matmul_local_buf", buildOptions); - int out_per_thread_m = tileM / localM; - int out_per_thread_n = tileN / localN; - - mGlobalWorkSize = {static_cast(M/out_per_thread_m), static_cast(N/out_per_thread_n)}; - mLocalWorkSize = {static_cast(localM), static_cast(localN)}; - - int idx = 0; - cl_int ret = CL_SUCCESS; - ret |= mKernel->get().setArg(idx++, static_cast(M)); - ret |= mKernel->get().setArg(idx++, static_cast(N)); - ret |= mKernel->get().setArg(idx++, static_cast(K)); - ret |= mKernel->get().setArg(idx++, openCLBuffer(input)); - ret |= mKernel->get().setArg(idx++, openCLBuffer(mResource->mFilter.get())); - ret |= mKernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); - ret |= mKernel->get().setArg(idx++, openCLBuffer(output)); - - MNN_CHECK_CL_SUCCESS(ret, "setArg Conv1x1Buf mConvgemmOptLevel==1 Kernel Select"); - } else if (mResource->mConv1x1Opt) { + mResource->mConv1x1Local = false; + // {"conv_2d_1x1_c4h1w4", "conv_2d_1x1_c4h1w2", "conv_2d_1x1_c4h1w1", "conv_2d_1x1_c8h1w4"}; + const int total_kernel = 3; + std::string kernelName[total_kernel] = {"conv_2d_1x1_c4h1w4", "conv_2d_1x1_c4h1w2", "conv_2d_1x1_c4h1w1"}; + int itemC[total_kernel] = {4, 4, 4}; + int itemW[total_kernel] = {4, 2, 1}; + + int M = outputShape.at(0) * outputShape.at(1) * outputShape.at(2); + mResource->mConv1x1C8Opt = (mResource->mOutputChannel >= 16 && M >= 16 && M * mResource->mOutputChannel >= 65536); + + int actual_kernel = total_kernel; + if(mResource->mConv1x1C8Opt) { + actual_kernel = 2; + kernelName[0] = "conv_2d_1x1_c8h1w4"; + itemC[0] = 8; + itemW[0] = 4; + + kernelName[1] = "conv_2d_1x1_c8h1w2"; + itemC[1] = 8; + itemW[1] = 2; + } - int tileN = 32; - // {"conv_2d_1x1_c4h1w4", "conv_2d_1x1_c4h1w2", "conv_2d_1x1_c4h1w1", "conv_2d_1x1_c8h1w4"}; - const int total_kernel = 3; - std::string kernelName[total_kernel] = {"conv_2d_1x1_c4h1w4", "conv_2d_1x1_c4h1w2", "conv_2d_1x1_c4h1w1"}; - int itemC[total_kernel] = {4, 4, 4}; - int itemW[total_kernel] = {4, 2, 1}; + std::shared_ptr kernel[total_kernel]; + std::vector globalWorkSize[total_kernel]; + std::vector localWorkSize[total_kernel]; + std::pair min_cost(INT_MAX, 0);//(min_time, min_index) + for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) { + std::set buildOption = mResource->mBuildOptions; + if(outputShape.at(3) % itemC[knl_idx] != 0){ + buildOption.emplace("-DCHANNEL_LEAVE"); + } + if((outputShape.at(2) % itemW[knl_idx]) != 0){ + buildOption.emplace("-DBLOCK_LEAVE"); + } + kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[knl_idx], buildOption); + uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); + + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; + globalWorkSize[knl_idx] = {static_cast(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast(outputShape.at(0) * outputShape.at(1))}; + + ret |= kernel[knl_idx]->get().setArg(idx++, globalWorkSize[knl_idx][0]); + ret |= kernel[knl_idx]->get().setArg(idx++, globalWorkSize[knl_idx][1]); + ret |= kernel[knl_idx]->get().setArg(idx++, UP_DIV(width, itemW[knl_idx])); + ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(input)); + ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->mFilter.get())); + ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); + ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(output)); + ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(inputChannelBlocks)); + ret |= kernel[knl_idx]->get().setArg(idx++, height); + ret |= kernel[knl_idx]->get().setArg(idx++, width); + ret |= kernel[knl_idx]->get().setArg(idx++, batch); + ret |= kernel[knl_idx]->get().setArg(idx++, UP_DIV(outChannel, 4)); + ret |= kernel[knl_idx]->get().setArg(idx++, ROUND_UP(outChannel, mResource->mAlignN)); + + MNN_CHECK_CL_SUCCESS(ret, "setArg Conv1x1Buf Kernel Select"); + + std::pair, int> retTune; + retTune = localWS2DDefault(globalWorkSize[knl_idx], maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName[knl_idx] + info, kernel[knl_idx]); + if(min_cost.first > retTune.second) { + min_cost.first = retTune.second; + min_cost.second = knl_idx; + mLocalWorkSize = {retTune.first[0], retTune.first[1]}; + } + } - int actual_kernel = total_kernel; - if(mResource->mConv1x1C8Opt) { - actual_kernel = 2; - kernelName[0] = "conv_2d_1x1_c8h1w4"; - itemC[0] = 8; - itemW[0] = 4; - - kernelName[1] = "conv_2d_1x1_c8h1w2"; - itemC[1] = 8; - itemW[1] = 2; - } + std::shared_ptr quanCommon; + int min_index = min_cost.second; + mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]}; - std::shared_ptr kernel[total_kernel]; - std::vector globalWorkSize[total_kernel]; - std::vector localWorkSize[total_kernel]; - std::pair min_cost(INT_MAX, 0);//(min_time, min_index) - for(int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) { std::set buildOption = mResource->mBuildOptions; - if(outputShape.at(3) % itemC[knl_idx] != 0){ + if(outputShape.at(3) % itemC[min_index] != 0){ buildOption.emplace("-DCHANNEL_LEAVE"); } - if((outputShape.at(2) % itemW[knl_idx]) != 0){ + if((outputShape.at(2) % itemW[min_index]) != 0){ buildOption.emplace("-DBLOCK_LEAVE"); } - kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[knl_idx], buildOption); - uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); - - uint32_t idx = 0; + mKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[min_index], buildOption); + uint32_t idx = 0; cl_int ret = CL_SUCCESS; - globalWorkSize[knl_idx] = {static_cast(UP_DIV(outputShape.at(3), itemC[knl_idx]) * UP_DIV(outputShape.at(2), itemW[knl_idx])), static_cast(outputShape.at(0) * outputShape.at(1))}; - - ret |= kernel[knl_idx]->get().setArg(idx++, globalWorkSize[knl_idx][0]); - ret |= kernel[knl_idx]->get().setArg(idx++, globalWorkSize[knl_idx][1]); - ret |= kernel[knl_idx]->get().setArg(idx++, UP_DIV(width, itemW[knl_idx])); - ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(input)); - ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->mFilter.get())); - ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); - ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(output)); - ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(inputChannelBlocks)); - ret |= kernel[knl_idx]->get().setArg(idx++, height); - ret |= kernel[knl_idx]->get().setArg(idx++, width); - ret |= kernel[knl_idx]->get().setArg(idx++, UP_DIV(outChannel, 4)); - ret |= kernel[knl_idx]->get().setArg(idx++, ROUND_UP(outChannel, tileN)); - - MNN_CHECK_CL_SUCCESS(ret, "setArg Conv1x1Buf Kernel Select"); - - std::pair, int> retTune; - retTune = localWS2DDefault(globalWorkSize[knl_idx], maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), kernelName[knl_idx] + info, kernel[knl_idx]); - if(min_cost.first > retTune.second) { - min_cost.first = retTune.second; - min_cost.second = knl_idx; - mLocalWorkSize = {retTune.first[0], retTune.first[1]}; - } - } - - std::shared_ptr quanCommon; - int min_index = min_cost.second; - mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]}; - std::set buildOption = mResource->mBuildOptions; - if(outputShape.at(3) % itemC[min_index] != 0){ - buildOption.emplace("-DCHANNEL_LEAVE"); - } - if((outputShape.at(2) % itemW[min_index]) != 0){ - buildOption.emplace("-DBLOCK_LEAVE"); + ret |= mKernel->get().setArg(idx++, mGlobalWorkSize[0]); + ret |= mKernel->get().setArg(idx++, mGlobalWorkSize[1]); + ret |= mKernel->get().setArg(idx++, UP_DIV(width, itemW[min_index])); + ret |= mKernel->get().setArg(idx++, openCLBuffer(input)); + ret |= mKernel->get().setArg(idx++, openCLBuffer(mResource->mFilter.get())); + ret |= mKernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); + ret |= mKernel->get().setArg(idx++, openCLBuffer(output)); + ret |= mKernel->get().setArg(idx++, static_cast(inputChannelBlocks)); + ret |= mKernel->get().setArg(idx++, height); + ret |= mKernel->get().setArg(idx++, width); + ret |= mKernel->get().setArg(idx++, batch); + ret |= mKernel->get().setArg(idx++, UP_DIV(outChannel, 4)); + ret |= mKernel->get().setArg(idx++, ROUND_UP(outChannel, mResource->mAlignN)); + MNN_CHECK_CL_SUCCESS(ret, "setArg Conv1x1Buf"); } - mKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("conv_2d_buf", kernelName[min_index], buildOption); - uint32_t idx = 0; - cl_int ret = CL_SUCCESS; - - ret |= mKernel->get().setArg(idx++, mGlobalWorkSize[0]); - ret |= mKernel->get().setArg(idx++, mGlobalWorkSize[1]); - ret |= mKernel->get().setArg(idx++, UP_DIV(width, itemW[min_index])); - ret |= mKernel->get().setArg(idx++, openCLBuffer(input)); - ret |= mKernel->get().setArg(idx++, openCLBuffer(mResource->mFilter.get())); - ret |= mKernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); - ret |= mKernel->get().setArg(idx++, openCLBuffer(output)); - ret |= mKernel->get().setArg(idx++, static_cast(inputChannelBlocks)); - ret |= mKernel->get().setArg(idx++, height); - ret |= mKernel->get().setArg(idx++, width); - ret |= mKernel->get().setArg(idx++, UP_DIV(outChannel, 4)); - ret |= mKernel->get().setArg(idx++, ROUND_UP(outChannel, tileN)); - MNN_CHECK_CL_SUCCESS(ret, "setArg Conv1x1Buf"); - - //printf("conv1x1 %d, %d %d, %d %d, %d %d\n", min_index, mGlobalWorkSize[0], mGlobalWorkSize[1], mLocalWorkSize[0], mLocalWorkSize[1], outChannel, width); } else { int inputImageShape[2] = {inputHeight, inputWidth}; int outputImageShape[2] = {height, width}; @@ -574,6 +572,7 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(inputImageShape), inputImageShape); ret |= kernel[knl_idx]->get().setArg(idx++, inputChannels); ret |= kernel[knl_idx]->get().setArg(idx++, inputChannelBlocks); + ret |= kernel[knl_idx]->get().setArg(idx++, batch); ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(outputImageShape), outputImageShape); ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(kernelShape), kernelShape); ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(strideShape), strideShape); @@ -617,6 +616,7 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const ret |= mKernel->get().setArg(idx++, sizeof(inputImageShape), inputImageShape); ret |= mKernel->get().setArg(idx++, inputChannels); ret |= mKernel->get().setArg(idx++, inputChannelBlocks); + ret |= mKernel->get().setArg(idx++, batch); ret |= mKernel->get().setArg(idx++, sizeof(outputImageShape), outputImageShape); ret |= mKernel->get().setArg(idx++, sizeof(kernelShape), kernelShape); ret |= mKernel->get().setArg(idx++, sizeof(strideShape), strideShape); @@ -630,9 +630,13 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const if (inputs.size() > 1) { backend()->onReleaseBuffer(mResource->mFilter.get(), Backend::DYNAMIC); } - mOpenCLBackend->recordKernel2d(mKernel, mGlobalWorkSize, mLocalWorkSize); - mGlobalWorkSize[0] = ROUND_UP(mGlobalWorkSize[0], std::max((uint32_t)1, mLocalWorkSize[0])); - mGlobalWorkSize[1] = ROUND_UP(mGlobalWorkSize[1], std::max((uint32_t)1, mLocalWorkSize[1])); + if (mResource->mConv1x1Opt && mResource->mConv1x1Local){ + mOpenCLBackend->recordKernel3d(mKernel, mGlobalWorkSize, mLocalWorkSize); + }else{ + mOpenCLBackend->recordKernel2d(mKernel, mGlobalWorkSize, mLocalWorkSize); + mGlobalWorkSize[0] = ROUND_UP(mGlobalWorkSize[0], std::max((uint32_t)1, mLocalWorkSize[0])); + mGlobalWorkSize[1] = ROUND_UP(mGlobalWorkSize[1], std::max((uint32_t)1, mLocalWorkSize[1])); + } mOpenCLBackend->endRecord(mRecording); #ifdef LOG_VERBOSE MNN_PRINT("end ConvExecution onResize !\n"); @@ -663,11 +667,15 @@ ErrorCode ConvBufExecution::onExecute(const std::vector &inputs, const mOpenCLBackend->getOpenCLRuntime()->pushEvent({"ConvBuf2D-gemm2-0", event0}); } - if(mResource->mConvGemmOptLevel == 2) { + if(mResource->mConvGemmOptLevel == 1) { mStrassenComputor->onExecute(); } else { cl::Event event; - runKernel2D(mKernel, mGlobalWorkSize, mLocalWorkSize, mOpenCLBackend->getOpenCLRuntime(), &event); + if (mResource->mConv1x1Opt && mResource->mConv1x1Local){ + run3DKernelDefault(mKernel, mGlobalWorkSize, mLocalWorkSize, mOpenCLBackend->getOpenCLRuntime(), &event); + } else{ + runKernel2D(mKernel, mGlobalWorkSize, mLocalWorkSize, mOpenCLBackend->getOpenCLRuntime(), &event); + } std::string name = "ConvBuf2D"; std::string b = std::to_string(inputs[0]->batch()); std::string ci = std::to_string(inputs[0]->channel()); @@ -708,11 +716,14 @@ ErrorCode ConvBufExecution::onExecute(const std::vector &inputs, const if (mPreKernel) { runKernel2D(mPreKernel, mPreGlobalWorkSize, mPreLocalWorkSize, mOpenCLBackend->getOpenCLRuntime()); } - - if(mResource->mConvGemmOptLevel == 2) { + if(mResource->mConvGemmOptLevel == 1) { mStrassenComputor->onExecute(); } else { - runKernel2D(mKernel, mGlobalWorkSize, mLocalWorkSize, mOpenCLBackend->getOpenCLRuntime()); + if (mResource->mConv1x1Opt && mResource->mConv1x1Local){ + run3DKernelDefault(mKernel, mGlobalWorkSize, mLocalWorkSize, mOpenCLBackend->getOpenCLRuntime()); + } else{ + runKernel2D(mKernel, mGlobalWorkSize, mLocalWorkSize, mOpenCLBackend->getOpenCLRuntime()); + } } if (mPostKernel) { runKernel2D(mPostKernel, mPostGlobalWorkSize, mPostLocalWorkSize, mOpenCLBackend->getOpenCLRuntime()); @@ -739,7 +750,7 @@ class ConvolutionBufCreator : public OpenCLBackend::Creator { const int outputChannel = outputShape.at(3); const int inputChannels = inputShape.at(3); #ifdef MNN_LOW_MEMORY - { + if (static_cast(backend)->getMemory() == BackendConfig::Memory_Low){ auto conv2dParams = op->main_as_Convolution2D(); if (conv2dParams->quanParameter() != nullptr) { if (((conv2dParams->quanParameter()->type() == 4) || @@ -749,6 +760,12 @@ class ConvolutionBufCreator : public OpenCLBackend::Creator { // Don't support IDST-int8 because of error return nullptr; } + for (int i = 0; i < inputs.size(); ++i) { + TensorUtils::setTensorSupportPack(inputs[i], false); + } + for (int i = 0; i < outputs.size(); ++i) { + TensorUtils::setTensorSupportPack(outputs[i], false); + } return new ConvBufLowMemoryExecution(inputs, outputs, op, backend); } else { //MNN_ERROR("OpenCL Conv buf low memory init error. For Opencl Backend, only support low memory mode of int8 or int4 dequantization currently.\n"); diff --git a/source/backend/opencl/execution/buffer/ConvBufExecution.hpp b/source/backend/opencl/execution/buffer/ConvBufExecution.hpp index 96e1ec5aa..b8edea9ef 100644 --- a/source/backend/opencl/execution/buffer/ConvBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/ConvBufExecution.hpp @@ -35,6 +35,7 @@ struct ConvBufResource { std::set mBuildOptions; bool mConv1x1Opt = false; bool mConv1x1C8Opt = false; + bool mConv1x1Local = false; /* 0 -> not use 1 -> use small tile @@ -44,6 +45,8 @@ struct ConvBufResource { std::shared_ptr mRasterExe; bool mUseImage = false; int mNumQuantBit = 0; + int mAlignK = 1; + int mAlignN = 1; }; class ConvBufCommonExecution { @@ -76,7 +79,6 @@ class ConvBufExecution : public ConvBufCommonExecution, public CommonExecution { std::shared_ptr mKernel; std::shared_ptr mConvGemmInpTensor; std::shared_ptr mConvGemmOutTensor; - bool mNeedOutTempTensor = false; std::shared_ptr mPreKernel = nullptr; std::vector mPreGlobalWorkSize{1, 1, 1}; std::vector mPreLocalWorkSize{1, 1, 1, 1}; @@ -84,8 +86,9 @@ class ConvBufExecution : public ConvBufCommonExecution, public CommonExecution { std::vector mPostGlobalWorkSize{1, 1, 1}; std::vector mPostLocalWorkSize{1, 1, 1, 1}; const float* mFilterDataPtr = nullptr; + private: - + int mAlignM = 1; std::shared_ptr mStrassenComputor; }; diff --git a/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp b/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp index c932c0a6c..d31462301 100644 --- a/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp +++ b/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.cpp @@ -166,13 +166,12 @@ void ConvBufLowMemoryExecution::set1x1WeightLowMemory(int packCout, int packCin, mResource->mUseImage = true; } if(mResource->mUseImage) { - size_t w = ROUND_UP(mResource->mOutputChannel, packCout); - size_t h = UP_DIV(mResource->mInputChannel, packCin); if(mResource->mNumQuantBit == 4){ - mResource->mKernelImage.reset(new cl::Image2D(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE, cl::ImageFormat(CL_RGBA, CL_UNSIGNED_INT16), w, h, 0, nullptr, &res)); - }else{ - mResource->mKernelImage.reset(new cl::Image2D(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE, cl::ImageFormat(CL_RGBA, CL_SIGNED_INT32), w, h, 0, nullptr, &res)); + packCin *= 2; } + size_t w = ROUND_UP(mResource->mOutputChannel, packCout); + size_t h = UP_DIV(mResource->mInputChannel, packCin); + mResource->mKernelImage.reset(new cl::Image2D(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE, cl::ImageFormat(CL_RGBA, CL_SIGNED_INT32), w, h, 0, nullptr, &res)); if (nullptr == mResource->mKernelImage.get() || res != CL_SUCCESS) { MNN_ERROR("Alloc Image %d x %d error, code:%d \n", (int)w, (int)h, (int)res); } @@ -229,9 +228,11 @@ void ConvBufLowMemoryExecution::setGeneralWeightLowMemory(void* filterDataPtr, s } // select the fastest kernel for the general cases by tuning void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor * output) { + mUnits.resize(1); auto &unit = mUnits[0]; std::vector inputShape = tensorShapeFormat(input); std::vector outputShape = tensorShapeFormat(output); + const int batch = outputShape.at(0); const int height = outputShape.at(1); const int width = outputShape.at(2); const int outChannel = outputShape.at(3); @@ -286,6 +287,7 @@ void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(inputImageShape), inputImageShape); ret |= kernel[knl_idx]->get().setArg(idx++, inputChannels); ret |= kernel[knl_idx]->get().setArg(idx++, inputChannelBlocks); + ret |= kernel[knl_idx]->get().setArg(idx++, batch); ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(outputImageShape), outputImageShape); ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(kernelShape), kernelShape); ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(strideShape), strideShape); @@ -331,6 +333,7 @@ void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor ret |= unit.kernel->get().setArg(idx++, sizeof(inputImageShape), inputImageShape); ret |= unit.kernel->get().setArg(idx++, inputChannels); ret |= unit.kernel->get().setArg(idx++, inputChannelBlocks); + ret |= unit.kernel->get().setArg(idx++, batch); ret |= unit.kernel->get().setArg(idx++, sizeof(outputImageShape), outputImageShape); ret |= unit.kernel->get().setArg(idx++, sizeof(kernelShape), kernelShape); ret |= unit.kernel->get().setArg(idx++, sizeof(strideShape), strideShape); @@ -346,9 +349,171 @@ void ConvBufLowMemoryExecution::tuneGeneralCaseLowMemory(Tensor * input, Tensor unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; return; } -unsigned int ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * output) { + +// weight inverse quantization, use xgemm opt +void ConvBufLowMemoryExecution::useFPWeightGemmLowMemory(Tensor * input, Tensor * output) { + mUnits.resize(3); + auto runtime = mOpenCLBackend->getOpenCLRuntime(); + std::vector inputShape = tensorShapeFormat(input); + std::vector outputShape = tensorShapeFormat(output); + int channelPack = 16; + if(mResource->mUseImage && mResource->mNumQuantBit == 4){ + channelPack = 32; + } + int area = inputShape.at(1) * inputShape.at(2); + int M = outputShape.at(0) * area; + int N = mResource->mOutputChannel; + int K = mResource->mInputChannel; + int mAlignK = 4; + int mAlignN = 16; + int mAlignM = 64; + + // set M Align and N Align + if(mResource->mOutputChannel > 1024) { + mAlignN = 128; + } else if(mResource->mOutputChannel > 512) { + mAlignN = 64; + } else if(mResource->mOutputChannel > 96) { + mAlignN = 32; + } + float ratio = 1.0 * M / 1024.0 * N / 1024.0 * K / 1024.0; + if(M > 1024 && ratio >= 1.0) { + mAlignM = 128; + } else if(M > 512 && ratio >= 0.1) { + mAlignM = 64; + } else if(M > 96){ + mAlignM = 32; + } else { + mAlignM = 16; + } + int alignM = ROUND_UP(M, mAlignM); + int alignN = ROUND_UP(N, mAlignN); + int alignK = ROUND_UP(K, mAlignK); + int blockDim = mResource->mInputChannel / mResource->mBlockSize; + + // alloc temp bufer + mConvGemmWeightTensor.reset(Tensor::createDevice({ROUND_UP(mResource->mOutputChannel, mAlignN) * ROUND_UP(mResource->mInputChannel, std::max(mAlignK, channelPack))})); + mConvGemmInpTensor.reset(Tensor::createDevice({alignK * alignM})); + mConvGemmOutTensor.reset(Tensor::createDevice({alignN * alignM})); + mOpenCLBackend->onAcquireBuffer(mConvGemmWeightTensor.get(), Backend::DYNAMIC); + mOpenCLBackend->onAcquireBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); + mOpenCLBackend->onAcquireBuffer(mConvGemmInpTensor.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mConvGemmWeightTensor.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mConvGemmInpTensor.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); + + //weight inverse quantization and rearrange + { + auto &unit = mUnits[0]; + int outputChannelAlign = ROUND_UP(mResource->mOutputChannel, alignN); + int outputChannel4Align = ROUND_UP(mResource->mOutputChannel, 4); + std::set buildOption = mResource->mBuildOptions; + if(mResource->mUseImage){ + buildOption.emplace("-DUSE_IMAGE"); + } + mGlobalWorkSize = {static_cast(UP_DIV(mResource->mInputChannel, channelPack)), static_cast(UP_DIV(mResource->mOutputChannel, 4))}; + unit.kernel = runtime->buildKernel("gemm_conv1x1_buf", "inverse_quant_weight", buildOption); + uint32_t maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); + if(mResource->mUseImage){ + ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelImage.get()); + }else{ + ret |= unit.kernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); + } + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->dequantScaleOffset.get())); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmWeightTensor.get())); + ret |= unit.kernel->get().setArg(idx++, static_cast(outputChannelAlign)); + ret |= unit.kernel->get().setArg(idx++, static_cast(outputChannel4Align)); + ret |= unit.kernel->get().setArg(idx++, static_cast(blockDim)); + MNN_CHECK_CL_SUCCESS(ret, "setArg inverse_quant_weight"); + + mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, runtime, "inverse_quant_weight", unit.kernel).first; + mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; + } + + // rearange input + { + auto &unit = mUnits[1]; + std::set buildOptions = mResource->mBuildOptions; + + int m_pack = 4; + mGlobalWorkSize = {static_cast(alignM/m_pack), static_cast(alignK/4)}; + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_buf", "transpose_pad", buildOptions); + uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); + + int offset = 0; + int idx = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[0])); + ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[1])); + ret |= unit.kernel->get().setArg(idx++, static_cast(alignM)); + ret |= unit.kernel->get().setArg(idx++, static_cast(alignK)); + ret |= unit.kernel->get().setArg(idx++, static_cast(M)); + ret |= unit.kernel->get().setArg(idx++, static_cast(K)); + ret |= unit.kernel->get().setArg(idx++, static_cast(area)); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmInpTensor.get())); + MNN_CHECK_CL_SUCCESS(ret, "setArg transpose_pad"); + mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, runtime, "transpose_pad", unit.kernel).first; + + mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; + } + + // call gemm strassen + { + mStrassenComputor.reset(new StrassenMatrixComputor(backend(), 3)); + mStrassenComputor->onEncode(alignM, alignK, alignN, alignM, alignN, alignN, openCLBuffer(mConvGemmInpTensor.get()), openCLBuffer(mConvGemmWeightTensor.get()), openCLBuffer(mConvGemmOutTensor.get()), false, openCLBuffer(mResource->mBias.get())); + } + + // call output transpose + { + auto &unit = mUnits[2]; + std::set buildOptions = mResource->mBuildOptions; + int pack_m = 1; + if(M % 8 == 0) { + pack_m = 8; + } else if(M % 4 == 0) { + pack_m = 4; + } + buildOptions.emplace("-DM_VEC=" + std::to_string(pack_m)); + unit.kernel = runtime->buildKernel("gemm_buf", "transpose_bias", buildOptions); + uint32_t maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); + + mGlobalWorkSize = {static_cast(UP_DIV(M, pack_m)), static_cast(UP_DIV(N, 4))}; + + int offset = 0; + int idx = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[0])); + ret |= unit.kernel->get().setArg(idx++, static_cast(mGlobalWorkSize[1])); + ret |= unit.kernel->get().setArg(idx++, static_cast(alignM)); + ret |= unit.kernel->get().setArg(idx++, static_cast(alignN)); + ret |= unit.kernel->get().setArg(idx++, static_cast(M)); + ret |= unit.kernel->get().setArg(idx++, static_cast(N)); + ret |= unit.kernel->get().setArg(idx++, static_cast(area)); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmOutTensor.get())); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); + + MNN_CHECK_CL_SUCCESS(ret, "setArg transpose_bias"); + mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, runtime, "transpose_bias", unit.kernel).first; + mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; + } + + return; +} +void ConvBufLowMemoryExecution::tuneGemvLowMemory(Tensor * input, Tensor * output) { + mUnits.resize(1); auto &unit = mUnits[0]; - unsigned int total_time = 0; std::vector inputShape = tensorShapeFormat(input); std::vector outputShape = tensorShapeFormat(output); const int outChannel = outputShape.at(3); @@ -361,20 +526,17 @@ unsigned int ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor const int blockNum = mResource->mBlockSize; const int blockDim = mResource->mInputChannel / mResource->mBlockSize; - int global_y = batch * height; - const int total_kernel = 5; - std::string kernelName[total_kernel] = {"gemm_conv_c1_buf", "gemm_conv_c2_buf", "gemm_conv_c4_buf", "gemm_conv_c1_image", "gemm_conv_c2_image"}; - int itemC[total_kernel] = {1, 2, 4, 1, 2}; + int global_y = batch * height * width; + const int total_kernel = 3; + std::string kernelName[total_kernel] = {"gemv_conv_c1_buf", "gemv_conv_c2_buf", "gemv_conv_c4_buf"}; + int itemC[total_kernel] = {1, 2, 4}; int actual_kernel = total_kernel; std::shared_ptr kernel[total_kernel]; std::vector globalWorkSize[total_kernel]; std::vector localWorkSize[total_kernel]; std::pair min_cost(INT_MAX, 0);//(min_time, min_index) std::set buildOption = mResource->mBuildOptions; - if(width == 1 && height == 1){ - buildOption.emplace("-DWIDTH_HEIGHT_1"); - } - + if(blockDim % 16 != 0){ buildOption.emplace("-DINPUT_CHANNEL_LEAVE"); } else if (mResource->mUseImage && mResource->mNumQuantBit == 4 && blockDim % 32 != 0) { @@ -382,22 +544,15 @@ unsigned int ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor buildOption.emplace("-DINPUT_CHANNEL_LEAVE"); } std::string info = std::to_string(inputChannels) + "_" + std::to_string(outChannel); - if(batch > 1){ - global_y = UP_DIV(batch, 4) * height; - buildOption.emplace("-DBACTH_BLOCK4"); - info += "_BATCH_BLOCK4"; - } - int knl_idx = 0; - actual_kernel = 3; if(mResource->mUseImage){ - knl_idx = 3; - actual_kernel = total_kernel; + buildOption.emplace("-DUSE_IMAGE"); } - for (; knl_idx < actual_kernel; knl_idx++) { - kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", kernelName[knl_idx], buildOption); + for (int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) { + auto option = buildOption; + option.emplace("-DTILE_N=" + std::to_string(itemC[knl_idx])); + kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", kernelName[knl_idx], option); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); - - globalWorkSize[knl_idx] = {static_cast(UP_DIV(outChannel, itemC[knl_idx]) * width), static_cast(global_y)}; + globalWorkSize[knl_idx] = {static_cast(UP_DIV(outChannel, itemC[knl_idx])), static_cast(global_y)}; uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= kernel[knl_idx]->get().setArg(idx++, globalWorkSize[knl_idx][0]); @@ -414,9 +569,7 @@ unsigned int ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(outputChannelBlocks)); ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(inputChannelBlocks)); ret |= kernel[knl_idx]->get().setArg(idx++, inputChannels); - ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(batch)); - ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(height)); - ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(width)); + ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(global_y)); ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(blockNum)); ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(blockDim)); MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv1x1_buf Kernel Select"); @@ -428,13 +581,11 @@ unsigned int ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor mLocalWorkSize = {retTune.first[0], retTune.first[1]}; } } - total_time += min_cost.first; int min_index = min_cost.second; mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]}; - - + + buildOption.emplace("-DTILE_N=" + std::to_string(itemC[min_index])); unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemv_conv1x1_buf", kernelName[min_index], buildOption); - //MNN_PRINT("Kernel is %d.\n", min_index); uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); @@ -451,35 +602,37 @@ unsigned int ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor ret |= unit.kernel->get().setArg(idx++, static_cast(outputChannelBlocks)); ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannelBlocks)); ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannels)); - ret |= unit.kernel->get().setArg(idx++, static_cast(batch)); - ret |= unit.kernel->get().setArg(idx++, static_cast(height)); - ret |= unit.kernel->get().setArg(idx++, static_cast(width)); + ret |= unit.kernel->get().setArg(idx++, static_cast(global_y)); ret |= unit.kernel->get().setArg(idx++, static_cast(blockNum)); ret |= unit.kernel->get().setArg(idx++, static_cast(blockDim)); MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv1x1_buf"); mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; - return total_time; + return; } -unsigned int ConvBufLowMemoryExecution::tuneGemvBatchLowMemory(Tensor * input, Tensor * output) { +unsigned int ConvBufLowMemoryExecution::tuneGemmLowMemory(Tensor * input, Tensor * output, std::string option, bool onlyGetTime) { mUnits.resize(3); unsigned int total_time = 0; std::vector inputShape = tensorShapeFormat(input); std::vector outputShape = tensorShapeFormat(output); + int channelPack = 16; + if(mResource->mUseImage && mResource->mNumQuantBit == 4){ + channelPack = 32; + } const int outChannel = outputShape.at(3); const int inputChannels = inputShape.at(3); const int batch = outputShape.at(0); const int width_height = outputShape.at(1) * outputShape.at(2); - const int inputChannelBlocks = UP_DIV(inputChannels, 4); - const int outputChannelBlocks = UP_DIV(outChannel, 4); + const int inputChannelAlign = ROUND_UP(inputChannels, channelPack); + const int outputChannelAlign = ROUND_UP(outChannel, 4); const int blockNum = mResource->mBlockSize; const int blockDim = mResource->mInputChannel / mResource->mBlockSize; - - int global_y = UP_DIV(batch, 4) * width_height; - const int total_kernel = 6; - std::string kernelName[total_kernel] = {"gemm_b4_c1_buf", "gemm_b4_c2_buf", "gemm_b4_c4_buf", "gemm_b4_c1_image", "gemm_b4_c2_image", "gemm_b4_c4_image"}; - int itemC[total_kernel] = {1, 2, 4, 1, 2, 4}; + + int global_y = batch * width_height; + const int total_kernel = 3; + std::string kernelName[total_kernel] = {"gemm_b4_c1_buf", "gemm_b4_c2_buf", "gemm_b4_c4_buf"}; + int itemC[total_kernel] = {1, 2, 4}; int actual_kernel = total_kernel; std::shared_ptr kernel[total_kernel]; std::vector globalWorkSize[total_kernel]; @@ -492,9 +645,13 @@ unsigned int ConvBufLowMemoryExecution::tuneGemvBatchLowMemory(Tensor * input, T // Image weight-int4 use load32 buildOption.emplace("-DINPUT_CHANNEL_LEAVE"); } - std::string info = std::to_string(inputChannels) + "_" + std::to_string(outChannel); + buildOption.emplace(option); + if(mResource->mUseImage){ + buildOption.emplace("-DUSE_IMAGE"); + } + std::string info = std::to_string(inputChannels) + "_" + std::to_string(outChannel) + option; // mResource->mInputChannel ROUND_UP to blockDim, avoid gemm overstep - mConvGemmInpTensor.reset(Tensor::createDevice({ROUND_UP(batch, 4) * ROUND_UP(ROUND_UP(mResource->mInputChannel, 4), blockDim) * width_height})); + mConvGemmInpTensor.reset(Tensor::createDevice({ROUND_UP(batch, 4) * inputChannelAlign * width_height})); mConvGemmOutTensor.reset(Tensor::createDevice({ROUND_UP(batch, 4) * ROUND_UP(mResource->mOutputChannel, 4) * width_height})); mOpenCLBackend->onAcquireBuffer(mConvGemmInpTensor.get(), Backend::DYNAMIC); mOpenCLBackend->onAcquireBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); @@ -504,43 +661,37 @@ unsigned int ConvBufLowMemoryExecution::tuneGemvBatchLowMemory(Tensor * input, T // reshape n*c/4*4*hw -> n/4*hw*c*4 { auto &unit = mUnits[0]; - mGlobalWorkSize = {static_cast(UP_DIV(mResource->mInputChannel, 4)), static_cast(UP_DIV(batch, 4)), static_cast(width_height)}; - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_quant_batch_buf", "reshape_nchw4_nhwc4", buildOption); + mGlobalWorkSize = {static_cast(UP_DIV(inputChannelAlign, 4)), static_cast(UP_DIV(global_y, 4))}; + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", "reshape_nchw4_nhwc4", buildOption); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[2]); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmInpTensor.get())); - ret |= unit.kernel->get().setArg(idx++, static_cast(width_height)); - ret |= unit.kernel->get().setArg(idx++, static_cast(batch)); + ret |= unit.kernel->get().setArg(idx++, static_cast(global_y)); ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannels)); - ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannelBlocks)); + ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannelAlign)); MNN_CHECK_CL_SUCCESS(ret, "setArg reshape_nc4_cn4"); - std::pair, unsigned int> retTune = localWS3DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "reshape_nchw4_nhwc4", unit.kernel); + std::pair, unsigned int> retTune = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "reshape_nchw4_nhwc4", unit.kernel); total_time += retTune.second; mLocalWorkSize = retTune.first; - mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); - unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + if(false == onlyGetTime){ + mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + } + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; } // gemm { auto &unit = mUnits[1]; - int knl_idx = 0; - actual_kernel = 3; - if(mResource->mUseImage){ - knl_idx = 3; - actual_kernel = total_kernel; - } - for (; knl_idx < actual_kernel; knl_idx++) { - kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_quant_batch_buf", kernelName[knl_idx], buildOption); + for (int knl_idx = 0; knl_idx < actual_kernel; knl_idx++) { + kernel[knl_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", kernelName[knl_idx], buildOption); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel[knl_idx])); - - globalWorkSize[knl_idx] = {static_cast(UP_DIV(outChannel, itemC[knl_idx])), static_cast(global_y)}; + + globalWorkSize[knl_idx] = {static_cast(UP_DIV(outChannel, itemC[knl_idx])), static_cast(UP_DIV(global_y, 4))}; uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= kernel[knl_idx]->get().setArg(idx++, globalWorkSize[knl_idx][0]); @@ -554,8 +705,9 @@ unsigned int ConvBufLowMemoryExecution::tuneGemvBatchLowMemory(Tensor * input, T ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->dequantScaleOffset.get())); ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mConvGemmOutTensor.get())); - ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(outputChannelBlocks)); - ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(inputChannelBlocks)); + ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(UP_DIV(global_y, 4) * 4)); + ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(outputChannelAlign)); + ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(inputChannelAlign)); ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(blockNum)); ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(blockDim)); MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv1x1_buf Kernel Select"); @@ -572,8 +724,7 @@ unsigned int ConvBufLowMemoryExecution::tuneGemvBatchLowMemory(Tensor * input, T mGlobalWorkSize = {globalWorkSize[min_index][0], globalWorkSize[min_index][1]}; - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_quant_batch_buf", kernelName[min_index], buildOption); - //MNN_PRINT("Kernel is %d.\n", min_index); + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", kernelName[min_index], buildOption); uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); @@ -587,38 +738,41 @@ unsigned int ConvBufLowMemoryExecution::tuneGemvBatchLowMemory(Tensor * input, T ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->dequantScaleOffset.get())); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmOutTensor.get())); - ret |= unit.kernel->get().setArg(idx++, static_cast(outputChannelBlocks)); - ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannelBlocks)); + ret |= unit.kernel->get().setArg(idx++, static_cast(UP_DIV(global_y, 4) * 4)); + ret |= unit.kernel->get().setArg(idx++, static_cast(outputChannelAlign)); + ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannelAlign)); ret |= unit.kernel->get().setArg(idx++, static_cast(blockNum)); ret |= unit.kernel->get().setArg(idx++, static_cast(blockDim)); - MNN_CHECK_CL_SUCCESS(ret, "setArg gemv_conv1x1_buf"); - mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + MNN_CHECK_CL_SUCCESS(ret, "setArg gemm_conv1x1_buf"); + if(false == onlyGetTime){ + mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + } unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; } // reshape n/4*hw*c*4 -> n*c/4*4*hw { auto &unit = mUnits[2]; - mGlobalWorkSize = {static_cast(UP_DIV(mResource->mOutputChannel, 4)), static_cast(UP_DIV(batch, 4)), static_cast(width_height)}; - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_quant_batch_buf", "reshape_nhwc4_nchw4", buildOption); + mGlobalWorkSize = {static_cast(UP_DIV(mResource->mOutputChannel, 4)), static_cast(UP_DIV(global_y, 4))}; + unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_conv1x1_buf", "reshape_nhwc4_nchw4", buildOption); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(unit.kernel)); uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[2]); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mConvGemmOutTensor.get())); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); - ret |= unit.kernel->get().setArg(idx++, static_cast(width_height)); - ret |= unit.kernel->get().setArg(idx++, static_cast(batch)); - ret |= unit.kernel->get().setArg(idx++, static_cast(outputChannelBlocks)); + ret |= unit.kernel->get().setArg(idx++, static_cast(global_y)); + ret |= unit.kernel->get().setArg(idx++, static_cast(outputChannelAlign)); MNN_CHECK_CL_SUCCESS(ret, "setArg reshape_cn4_nc4"); - std::pair, unsigned int> retTune = localWS3DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "reshape_nhwc4_nchw4", unit.kernel); + std::pair, unsigned int> retTune = localWS2DDefault(mGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "reshape_nhwc4_nchw4", unit.kernel); mLocalWorkSize = retTune.first; total_time += retTune.second; - mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); - unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + if(false == onlyGetTime){ + mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + } + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1]}; } return total_time; } @@ -695,10 +849,12 @@ bool ConvBufLowMemoryExecution::onClone(Backend* bn, const Op* op, Execution** d return true; } -ErrorCode ConvBufLowMemoryExecution::onEncode(const std::vector &inputs, const std::vector &outputs) { +ErrorCode ConvBufLowMemoryExecution::onResize(const std::vector &inputs, const std::vector &outputs) { #ifdef LOG_VERBOSE MNN_PRINT("Start ConvBufLowMemoryExecution onResize !\n"); #endif + auto runTime = mOpenCLBackend->getOpenCLRuntime(); + mOpenCLBackend->startRecord(mRecording); mUnits.resize(1); auto input = inputs[0]; auto output = outputs[0]; @@ -707,30 +863,138 @@ ErrorCode ConvBufLowMemoryExecution::onEncode(const std::vector &input mPaddings[1] = padding.first;//padX // onclone default use conv1x1Opt, need reset std::vector outputShape = tensorShapeFormat(output); - const int batch = outputShape.at(0); - auto runTime = mOpenCLBackend->getOpenCLRuntime(); + const int batch = outputShape.at(0) * outputShape.at(1) * outputShape.at(2); + mUseFPWeight = false; if (mResource->mConv1x1Opt) { - if(batch > 1 && false == getPreParamInfo("ConvBufLowMemoryPreArrangeMode", &batchConvMode, runTime)){ - if(tuneGemvBatchLowMemory(input, output) < tuneGemmLowMemory(input, output)){ - batchConvMode = 1; - } else{ - batchConvMode = 2; + if(batch == 1){ + tuneGemvLowMemory(input, output); + } else { + if(batch > 512){ + useFPWeightGemmLowMemory(input, output); + mUseFPWeight = true; + } + else if(false == getPreParamInfo("ConvBufLowMemoryPreArrangeMode", &batchConvMode, runTime)){ + if(tuneGemmLowMemory(input, output, "-DFORMAT_CNHW", true) < tuneGemmLowMemory(input, output, "", true)){ + batchConvMode = 1; + } else{ + batchConvMode = 2; + } + setPreParamInfo("ConvBufLowMemoryPreArrangeMode", batchConvMode, runTime); + } else { + std::string option = ""; + if(1 == batchConvMode){ + option = "-DFORMAT_CNHW"; + } + tuneGemmLowMemory(input, output, option); } - setPreParamInfo("ConvBufLowMemoryPreArrangeMode", batchConvMode, runTime); - } - if(batch > 1 && batchConvMode == 1){ - tuneGemvBatchLowMemory(input, output); - }else{ - tuneGemmLowMemory(input, output); } } else { tuneGeneralCaseLowMemory(input, output); } + for (auto &unit : mUnits) { + bool lws_null = true; + for (size_t i = 0; i < unit.globalWorkSize.dimensions(); ++i) { + unit.globalWorkSize.get()[i] = ROUND_UP(unit.globalWorkSize.get()[i], std::max((size_t)1, unit.localWorkSize.get()[i])); + if(unit.localWorkSize.get()[i] != 0) { + lws_null = false; + } + } + if(lws_null){ + unit.localWorkSize = cl::NullRange; + } + } + mOpenCLBackend->endRecord(mRecording); #ifdef LOG_VERBOSE MNN_PRINT("end ConvBufLowMemoryExecution onResize !\n"); #endif return NO_ERROR; } + +ErrorCode ConvBufLowMemoryExecution::onExecute(const std::vector &inputs, const std::vector &outputs) { +#ifdef LOG_VERBOSE + MNN_PRINT("Start ConvBufLowMemoryExecution onExecute !\n"); +#endif + auto runtime = mOpenCLBackend->getOpenCLRuntime(); +#ifdef ENABLE_OPENCL_TIME_PROFILER + int idx = 0; +#else + if(mOpenCLBackend->isUseRecordQueue()){ + mOpenCLBackend->addRecord(mRecording, mOpRecordUpdateInfo); + return NO_ERROR; + } +#endif + auto res = CL_SUCCESS; + if(mUseFPWeight){ + // arrange input and weight + int i = 0; + for (; i < 2; ++i){ + auto unit = mUnits[i]; + #ifdef ENABLE_OPENCL_TIME_PROFILER + cl::Event event; + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), + cl::NullRange, + unit.globalWorkSize, + unit.localWorkSize, + nullptr, + &event); + runtime->pushEvent({EnumNameOpType(mOpType) + std::to_string(idx++), event}); + #else + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), + cl::NullRange, + unit.globalWorkSize, + unit.localWorkSize); + #endif + MNN_CHECK_CL_SUCCESS(res, EnumNameOpType(mOp->type())); + } + // call gemm execute + mStrassenComputor->onExecute(); + + // rearrange output + for (; i < mUnits.size(); ++i){ + auto unit = mUnits[i]; + #ifdef ENABLE_OPENCL_TIME_PROFILER + cl::Event event; + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), + cl::NullRange, + unit.globalWorkSize, + unit.localWorkSize, + nullptr, + &event); + runtime->pushEvent({EnumNameOpType(mOpType) + std::to_string(idx++), event}); + #else + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), + cl::NullRange, + unit.globalWorkSize, + unit.localWorkSize); + #endif + MNN_CHECK_CL_SUCCESS(res, EnumNameOpType(mOp->type())); + } + }else{ + for (auto &unit : mUnits) { + #ifdef ENABLE_OPENCL_TIME_PROFILER + cl::Event event; + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), + cl::NullRange, + unit.globalWorkSize, + unit.localWorkSize, + nullptr, + &event); + runtime->pushEvent({EnumNameOpType(mOpType) + std::to_string(idx++), event}); + #else + res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), + cl::NullRange, + unit.globalWorkSize, + unit.localWorkSize); + #endif + MNN_CHECK_CL_SUCCESS(res, EnumNameOpType(mOp->type())); + } + } +#ifdef LOG_VERBOSE + MNN_PRINT("end ConvBufLowMemoryExecution onExecute !\n"); +#endif + return NO_ERROR; +} + } // namespace OpenCL } // namespace MNN #endif /* MNN_OPENCL_BUFFER_CLOSED */ diff --git a/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.hpp b/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.hpp index 8488f461b..5e04ac1aa 100644 --- a/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.hpp +++ b/source/backend/opencl/execution/buffer/ConvBufLowMemoryExecution.hpp @@ -21,25 +21,30 @@ class ConvBufLowMemoryExecution : public ConvBufCommonExecution, public CommonEx ConvBufLowMemoryExecution(const std::vector &inputs, const std::vector &outputs, const MNN::Op *op, Backend *backend); ConvBufLowMemoryExecution(std::shared_ptr resource, const MNN::Op* op, Backend* backend); virtual ~ConvBufLowMemoryExecution(); - virtual ErrorCode onEncode(const std::vector &inputs, const std::vector &outputs) override; + virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; + virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; private: void getInfoFromOpLowMemory(std::shared_ptr & quanCommon); void set1x1WeightLowMemory(int packCout, int packCin, void * filterDataPtr, std::shared_ptr & quanCommon); void setGeneralWeightLowMemory(void * filterDataPtr, std::shared_ptr & quanCommon); void tuneGeneralCaseLowMemory(Tensor * input, Tensor * output); - unsigned int tuneGemmLowMemory(Tensor * input, Tensor * output); - unsigned int tuneGemvBatchLowMemory(Tensor * input, Tensor * output); + void useFPWeightGemmLowMemory(Tensor * input, Tensor * output); + void tuneGemvLowMemory(Tensor * input, Tensor * output); + unsigned int tuneGemmLowMemory(Tensor * input, Tensor * output, std::string option, bool onlyGetTime = false); bool convertToQuantWeight1x1Buffer(cl::Buffer input, int pack); std::vector mPaddings{0, 0}; std::vector mGlobalWorkSize{1, 1, 1}; std::vector mLocalWorkSize{1, 1, 1, 1}; void *mFilterDataPtr = nullptr; bool mLowMemoryFlag = false; + bool mUseFPWeight = false; std::shared_ptr mConvGemmInpTensor; std::shared_ptr mConvGemmOutTensor; + std::shared_ptr mConvGemmWeightTensor; std::shared_ptr mBufferToConv1x1Kernel = nullptr; uint32_t batchConvMode = 0; // batch > 1 convolution input arrage mode. 0 is need tune; 1 arrage to n/4chw4; 2 arrage to c/4hwn4 + std::shared_ptr mStrassenComputor; }; } // namespace OpenCL diff --git a/source/backend/opencl/execution/buffer/ConvBufWinograd.cpp b/source/backend/opencl/execution/buffer/ConvBufWinograd.cpp index c7b7fc644..bcabf60f6 100644 --- a/source/backend/opencl/execution/buffer/ConvBufWinograd.cpp +++ b/source/backend/opencl/execution/buffer/ConvBufWinograd.cpp @@ -35,11 +35,11 @@ bool ConvBufWinograd::valid(const Convolution2DCommon* common, const Tensor* inp return valid; } -void ConvBufWinograd::convertWeightFormat(cl::Buffer& buffer, const int tileK, const int tileN) { +void ConvBufWinograd::convertWeightFormat(cl::Buffer& buffer, const int alignK, const int alignN) { auto runtime = mOpenCLBackend->getOpenCLRuntime(); - auto icPad = ROUND_UP(mCi, tileK); - auto ocPad = ROUND_UP(mCo, tileN); + auto icPad = ROUND_UP(mCi, alignK); + auto ocPad = ROUND_UP(mCo, alignN); auto kernel = runtime->buildKernel("winogradTransform_buf", "winoTransWeightBuf2_3_1", {}); uint32_t gws[2] = {static_cast(icPad), static_cast(ocPad)}; @@ -205,15 +205,22 @@ ConvBufWinograd::ConvBufWinograd(const MNN::Op* op, Backend* backend) : CommonEx int kernelSize = kx; int alpha = unit + kernelSize - 1; - int tileK = 4; - int tileN = 32; + mResource->mAlignK = 4; + mResource->mAlignN = 16; + if(mCo > 1024) { + mResource->mAlignN = 128; + } else if(mCo > 256) { + mResource->mAlignN = 64; + } else if(mCo > 64) { + mResource->mAlignN = 32; + } std::shared_ptr tmpFilterTensor; tmpFilterTensor.reset(Tensor::createDevice({mCo * mCi * ky * kx})); mOpenCLBackend->onAcquireBuffer(tmpFilterTensor.get(), Backend::DYNAMIC); mOpenCLBackend->onReleaseBuffer(tmpFilterTensor.get(), Backend::DYNAMIC); - mResource->mWeight.reset(Tensor::createDevice({alpha * alpha * ROUND_UP(mCo, tileN) * ROUND_UP(mCi, tileK)}));//NHWC + mResource->mWeight.reset(Tensor::createDevice({alpha * alpha * ROUND_UP(mCo, mResource->mAlignN) * ROUND_UP(mCi, mResource->mAlignK)}));//NHWC mOpenCLBackend->onAcquireBuffer(mResource->mWeight.get(), Backend::STATIC); buffer_size = mCo * mCi * ky * kx * sizeof(float); @@ -228,7 +235,7 @@ ConvBufWinograd::ConvBufWinograd(const MNN::Op* op, Backend* backend) : CommonEx } mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(weightBufferCL, ptrCL); - convertWeightFormat(weightBufferCL, tileK, tileN); + convertWeightFormat(weightBufferCL, mResource->mAlignK, mResource->mAlignN); } } @@ -277,7 +284,8 @@ ErrorCode ConvBufWinograd::SubgroupOnResize(const std::vector &inputs, auto icC4 = UP_DIV(input->channel(), 4); auto icC16 = UP_DIV(input->channel(), 16); auto ocC4 = UP_DIV(output->channel(), 4); - auto ocC16 = UP_DIV(output->channel(), 16); + auto ocC16 = UP_DIV(output->channel(), 16); + auto batch = output->batch(); auto inputpad = TensorUtils::getDescribe(input)->mPads; auto outputpad = TensorUtils::getDescribe(output)->mPads; int in_c_pack = TensorUtils::getTensorChannelPack(input); @@ -316,7 +324,7 @@ ErrorCode ConvBufWinograd::SubgroupOnResize(const std::vector &inputs, } } - for (int b = 0; b < input->batch(); ++b) { + for (int b = 0; b < batch; ++b) { int hCount = hUnit; int wCount = wUnit; int width_pack = ROUND_UP(hCount * wCount, 8); @@ -340,6 +348,7 @@ ErrorCode ConvBufWinograd::SubgroupOnResize(const std::vector &inputs, ret |= mUnits[b * 3].kernel->get().setArg(index++, icC16); ret |= mUnits[b * 3].kernel->get().setArg(index++, width_pack); ret |= mUnits[b * 3].kernel->get().setArg(index++, b); + ret |= mUnits[b * 3].kernel->get().setArg(index++, batch); ret |= mUnits[b * 3].kernel->get().setArg(index++, static_cast(inputpad.left)); ret |= mUnits[b * 3].kernel->get().setArg(index++, static_cast(inputpad.right)); MNN_CHECK_CL_SUCCESS(ret, "setArg ConvWinogradBuf Source Trans"); @@ -400,6 +409,7 @@ ErrorCode ConvBufWinograd::SubgroupOnResize(const std::vector &inputs, ret |= mUnits[b * 3 + 2].kernel->get().setArg(index++, ocC16); ret |= mUnits[b * 3 + 2].kernel->get().setArg(index++, width_pack); ret |= mUnits[b * 3 + 2].kernel->get().setArg(index++, b); + ret |= mUnits[b * 3 + 2].kernel->get().setArg(index++, batch); ret |= mUnits[b * 3 + 2].kernel->get().setArg(index++, static_cast(outputpad.left)); ret |= mUnits[b * 3 + 2].kernel->get().setArg(index++, static_cast(outputpad.right)); MNN_CHECK_CL_SUCCESS(ret, "setArg ConvWinogradBuf Dest Trans"); @@ -458,13 +468,21 @@ ErrorCode ConvBufWinograd::onEncode(const std::vector& inputs, const st } else #endif /* MNN_SUPPORT_INTEL_SUBGROUP */ { - int tileM = 16; - int tileN = 32; - int tileK = 4; + mAlignM = 16; + float ratio = 1.0 * alpha * alpha * wUnit * hUnit / 1024.0 * input->channel() / 1024.0 * output->channel() / 1024.0; + if (wUnit * hUnit > 512 && ratio > 1.0) { + mAlignM = 128; + } else if (wUnit * hUnit > 256 && ratio > 0.1) { + mAlignM = 64; + } else if (wUnit * hUnit > 64) { + mAlignM = 32; + } + int mAlignK = mResource->mAlignK; + int mAlignN = mResource->mAlignN; mSource.reset(Tensor::createDevice( - std::vector{alpha * alpha * ROUND_UP(input->channel(), tileK) * ROUND_UP(wUnit * hUnit, tileM)})); + std::vector{alpha * alpha * ROUND_UP(input->channel(), mAlignK) * ROUND_UP(wUnit * hUnit, mAlignM)})); mDest.reset(Tensor::createDevice( - std::vector{alpha * alpha * ROUND_UP(wUnit * hUnit, tileM) * ROUND_UP(output->channel(), tileN)})); + std::vector{alpha * alpha * ROUND_UP(wUnit * hUnit, mAlignM) * ROUND_UP(output->channel(), mAlignN)})); mOpenCLBackend->onAcquireBuffer(mSource.get(), Backend::DYNAMIC); mOpenCLBackend->onAcquireBuffer(mDest.get(), Backend::DYNAMIC); @@ -498,9 +516,9 @@ ErrorCode ConvBufWinograd::onEncode(const std::vector& inputs, const st int hCount = hUnit; int wCount = wUnit; - int M_pack = ROUND_UP(wCount * hCount, tileM); - int K_pack = ROUND_UP(input->channel(), tileK); - int N_pack = ROUND_UP(output->channel(), tileN); + int M_pack = ROUND_UP(wCount * hCount, mAlignM); + int K_pack = ROUND_UP(input->channel(), mAlignK); + int N_pack = ROUND_UP(output->channel(), mAlignN); for (int b = 0; b < input->batch(); ++b) { // Source Transform @@ -521,6 +539,7 @@ ErrorCode ConvBufWinograd::onEncode(const std::vector& inputs, const st ret |= mUnits[b * 3].kernel->get().setArg(index++, icC4); ret |= mUnits[b * 3].kernel->get().setArg(index++, M_pack); ret |= mUnits[b * 3].kernel->get().setArg(index++, K_pack); + ret |= mUnits[b * 3].kernel->get().setArg(index++, input->batch()); ret |= mUnits[b * 3].kernel->get().setArg(index++, b); MNN_CHECK_CL_SUCCESS(ret, "setArg ConvWinogradBuf Source Trans"); @@ -535,9 +554,9 @@ ErrorCode ConvBufWinograd::onEncode(const std::vector& inputs, const st { int loop = alpha * alpha; - int e_pack = ROUND_UP(wCount * hCount, tileM); - int l_pack = ROUND_UP(input->channel(), tileK); - int h_pack = ROUND_UP(output->channel(), tileN); + int e_pack = ROUND_UP(wCount * hCount, mAlignM); + int l_pack = ROUND_UP(input->channel(), mAlignK); + int h_pack = ROUND_UP(output->channel(), mAlignN); std::set buildOptions; uint32_t layout = 4; @@ -586,6 +605,10 @@ ErrorCode ConvBufWinograd::onEncode(const std::vector& inputs, const st int batch_offset_b = h_pack * l_pack; int batch_offset_c = e_pack * h_pack; + int batch_offset[4] = {batch_offset_a, batch_offset_b, batch_offset_c, 0}; + int stride[4] = {e_pack, h_pack, h_pack, h_pack}; + int group[4] = {1, 1, 1, loop}; + int idx = 0; cl_int ret = CL_SUCCESS; ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, static_cast(e_pack)); @@ -594,11 +617,11 @@ ErrorCode ConvBufWinograd::onEncode(const std::vector& inputs, const st ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, alpha); ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, beta); ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, openCLBuffer(mSource.get())); - ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, batch_offset_a); ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, openCLBuffer(mResource->mWeight.get())); - ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, batch_offset_b); ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, openCLBuffer(mDest.get())); - ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, batch_offset_c); + ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, batch_offset); + ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, stride); + ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, group); MNN_CHECK_CL_SUCCESS(ret, "setArg Winograd batchmatmul Kernel"); mOpenCLBackend->recordKernel3d(mUnits[b * 3 + 1].kernel, mGWS_M[b], mLWS_M[b]); @@ -624,6 +647,7 @@ ErrorCode ConvBufWinograd::onEncode(const std::vector& inputs, const st ret |= mUnits[b * 3 + 2].kernel->get().setArg(index++, ocC4); ret |= mUnits[b * 3 + 2].kernel->get().setArg(index++, M_pack); ret |= mUnits[b * 3 + 2].kernel->get().setArg(index++, N_pack); + ret |= mUnits[b * 3 + 2].kernel->get().setArg(index++, input->batch()); ret |= mUnits[b * 3 + 2].kernel->get().setArg(index++, b); MNN_CHECK_CL_SUCCESS(ret, "setArg ConvWinogradBuf Dest Trans"); diff --git a/source/backend/opencl/execution/buffer/ConvBufWinograd.hpp b/source/backend/opencl/execution/buffer/ConvBufWinograd.hpp index e200fc2ef..cec80e347 100644 --- a/source/backend/opencl/execution/buffer/ConvBufWinograd.hpp +++ b/source/backend/opencl/execution/buffer/ConvBufWinograd.hpp @@ -22,6 +22,8 @@ struct ConvBufWinoResource { bool mUseSubgroup{false}; std::shared_ptr mWeight; std::shared_ptr mBias; + int mAlignN; + int mAlignK; }; class ConvBufWinograd : public CommonExecution { @@ -41,7 +43,7 @@ class ConvBufWinograd : public CommonExecution { #endif /* MNN_SUPPORT_INTEL_SUBGROUP */ private: - void convertWeightFormat(cl::Buffer& buffer, const int tileK, const int tileN); + void convertWeightFormat(cl::Buffer& buffer, const int alignK, const int alignN); private: OpenCLBackend* mOpenCLBackend; std::shared_ptr mResource; @@ -66,6 +68,8 @@ class ConvBufWinograd : public CommonExecution { std::vector > mLWS_S; std::vector > mLWS_D; std::vector > mLWS_M; +private: + int mAlignM; }; } // namespace OpenCL diff --git a/source/backend/opencl/execution/buffer/ConvSubgroupBufExecution.cpp b/source/backend/opencl/execution/buffer/ConvSubgroupBufExecution.cpp index 70685e7bb..dd4201f12 100644 --- a/source/backend/opencl/execution/buffer/ConvSubgroupBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/ConvSubgroupBufExecution.cpp @@ -258,6 +258,7 @@ ErrorCode ConvSubgroupBuf::onEncode(const std::vector &inputs, const s std::vector outputShape = tensorShapeFormat(output); int in_c_pack = TensorUtils::getTensorChannelPack(input); int out_c_pack = TensorUtils::getTensorChannelPack(output); + const int batch = outputShape.at(0); const int height = outputShape.at(1); const int width = outputShape.at(2); const int outChannel = outputShape.at(3); @@ -266,8 +267,6 @@ ErrorCode ConvSubgroupBuf::onEncode(const std::vector &inputs, const s const int inputWidth = inputShape.at(2); const int inputChannels = inputShape.at(3); - int input_width_pad = mResource->mStrides[1] * (8 - 1) + (mResource->mKernelWidth - 1) * mResource->mDilations[1] + 1 + width * mResource->mStrides[1] + mPaddings[1]; - int input_height_pad = inputHeight + 2 * mPaddings[0]; uint32_t MaxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->MaxWorkGroupSize()); uint32_t MaxThreadsPerDevice = static_cast(mOpenCLBackend->getOpenCLRuntime()->MaxThreadsPerDevice()); bool isSupportedFP16 = mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16(); @@ -280,6 +279,8 @@ ErrorCode ConvSubgroupBuf::onEncode(const std::vector &inputs, const s int strideShape[2] = {mResource->mStrides[0], mResource->mStrides[1]}; int paddingShape[2] = {mPaddings[0], mPaddings[1]}; int dilationShape[2] = {mResource->mDilations[0], mResource->mDilations[1]}; + int trans_pad_x = inputpad.left; + int trans_pad_y = inputpad.right; auto tune_param = GetTuningParams(inputs, outputs, MaxWorkGroupSize, isSupportedFP16, MaxThreadsPerDevice); uint32_t blockWidth = tune_param.first; uint32_t sub_group_size = 16; @@ -318,14 +319,17 @@ ErrorCode ConvSubgroupBuf::onEncode(const std::vector &inputs, const s unit.kernel->get().setArg(idx++, static_cast(inputWidth)); unit.kernel->get().setArg(idx++, static_cast(inputHeight)); unit.kernel->get().setArg(idx++, static_cast(inputChannels)); + unit.kernel->get().setArg(idx++, static_cast(batch)); unit.kernel->get().setArg(idx++, UP_DIV(inputShape.at(3), 4)); - unit.kernel->get().setArg(idx++, static_cast(inputpad.left)); - unit.kernel->get().setArg(idx++, static_cast(inputpad.right)); + unit.kernel->get().setArg(idx++, static_cast(trans_pad_x)); + unit.kernel->get().setArg(idx++, static_cast(trans_pad_y)); mTranseLocalWorkSize = localWS3DDefault(mTranseGlobalWorkSize, mMaxWGS_S, mOpenCLBackend->getOpenCLRuntime(), "conv_transe_c4_c1", unit.kernel).first; mOpenCLBackend->recordKernel3d(unit.kernel, mTranseGlobalWorkSize, mTranseLocalWorkSize); } else { - mSource.reset(Tensor::createDevice(std::vector{inputShape.at(0), UP_DIV(input->channel(), 16),inputHeight * inputWidth, 16}, Tensor::CAFFE_C4)); + trans_pad_x = std::max(inputpad.left, mPaddings[1]); + trans_pad_y = std::max(inputpad.right, mPaddings[1]); + mSource.reset(Tensor::createDevice(std::vector{inputShape.at(0), UP_DIV(input->channel(), 16),inputHeight * (inputWidth + trans_pad_x + trans_pad_y), 16}, Tensor::CAFFE_C4)); mOpenCLBackend->onAcquireBuffer(mSource.get(), Backend::DYNAMIC); mOpenCLBackend->onReleaseBuffer(mSource.get(), Backend::DYNAMIC); unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("input_transe_buf", "conv_transe_c4_c16", {}); @@ -344,9 +348,10 @@ ErrorCode ConvSubgroupBuf::onEncode(const std::vector &inputs, const s unit.kernel->get().setArg(idx++, static_cast(inputWidth)); unit.kernel->get().setArg(idx++, static_cast(inputHeight)); unit.kernel->get().setArg(idx++, static_cast(inputChannels)); + unit.kernel->get().setArg(idx++, static_cast(batch)); unit.kernel->get().setArg(idx++, UP_DIV(inputShape.at(3), 4)); - unit.kernel->get().setArg(idx++, static_cast(inputpad.left)); - unit.kernel->get().setArg(idx++, static_cast(inputpad.right)); + unit.kernel->get().setArg(idx++, static_cast(trans_pad_x)); + unit.kernel->get().setArg(idx++, static_cast(trans_pad_y)); mTranseLocalWorkSize = localWS3DDefault(mTranseGlobalWorkSize, mMaxWGS_S, mOpenCLBackend->getOpenCLRuntime(), "conv_transe_c4_c16", unit.kernel).first; mOpenCLBackend->recordKernel3d(unit.kernel, mTranseGlobalWorkSize, mTranseLocalWorkSize); @@ -402,9 +407,10 @@ ErrorCode ConvSubgroupBuf::onEncode(const std::vector &inputs, const s unit.kernel->get().setArg(idx++, static_cast(width)); unit.kernel->get().setArg(idx++, static_cast(height)); unit.kernel->get().setArg(idx++, static_cast(outChannel)); + unit.kernel->get().setArg(idx++, static_cast(batch)); unit.kernel->get().setArg(idx++, static_cast(x_blocks)); - unit.kernel->get().setArg(idx++, static_cast(inputpad.left)); - unit.kernel->get().setArg(idx++, static_cast(inputpad.right)); + unit.kernel->get().setArg(idx++, static_cast(trans_pad_x)); + unit.kernel->get().setArg(idx++, static_cast(trans_pad_y)); unit.kernel->get().setArg(idx++, static_cast(outputpad.left)); unit.kernel->get().setArg(idx++, static_cast(outputpad.right)); #ifdef LOG_VERBOSE diff --git a/source/backend/opencl/execution/buffer/DeconvBufExecution.cpp b/source/backend/opencl/execution/buffer/DeconvBufExecution.cpp index 096594ebc..4fd445067 100644 --- a/source/backend/opencl/execution/buffer/DeconvBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/DeconvBufExecution.cpp @@ -153,6 +153,7 @@ ErrorCode DeconvBufExecution::onEncode(const std::vector &inputs, cons unit.kernel->get().setArg(idx++, openCLBuffer(mResource->mFilter.get())); unit.kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); unit.kernel->get().setArg(idx++, openCLBuffer(output)); + unit.kernel->get().setArg(idx++, static_cast(outputBatch)); unit.kernel->get().setArg(idx++, sizeof(inputImageShape), inputImageShape); unit.kernel->get().setArg(idx++, sizeof(outputImageShape), outputImageShape); unit.kernel->get().setArg(idx++, sizeof(strideShape), strideShape); diff --git a/source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp b/source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp index 5bc18f9ff..44af28f35 100644 --- a/source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/DepthwiseConvBufExecution.cpp @@ -108,7 +108,8 @@ ErrorCode DepthwiseConvBufExecution::onEncode(const std::vector &input const int outputHeight = outputShape.at(1); const int outputWidth = outputShape.at(2); const int outputChannel = outputShape.at(3); - + + const int batch = inputShape.at(0); const int inputHeight = inputShape.at(1); const int inputWidth = inputShape.at(2); const int inputChannels = inputShape.at(3); @@ -173,7 +174,7 @@ ErrorCode DepthwiseConvBufExecution::onEncode(const std::vector &input ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(output)); ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(inputImageShape), inputImageShape); - ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(inputChannels)); + ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(batch)); ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(outputImageShape), outputImageShape); ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(kernelShape), kernelShape); ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(paddingShape), paddingShape); @@ -206,7 +207,7 @@ ErrorCode DepthwiseConvBufExecution::onEncode(const std::vector &input ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); ret |= unit.kernel->get().setArg(idx++, sizeof(inputImageShape), inputImageShape); - ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannels)); + ret |= unit.kernel->get().setArg(idx++, static_cast(batch)); ret |= unit.kernel->get().setArg(idx++, sizeof(outputImageShape), outputImageShape); ret |= unit.kernel->get().setArg(idx++, sizeof(kernelShape), kernelShape); ret |= unit.kernel->get().setArg(idx++, sizeof(paddingShape), paddingShape); @@ -249,7 +250,7 @@ ErrorCode DepthwiseConvBufExecution::onEncode(const std::vector &input ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(output)); ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(inputImageShape), inputImageShape); - ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(inputChannels)); + ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(batch)); ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(outputImageShape), outputImageShape); ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(kernelShape), kernelShape); ret |= kernel[knl_idx]->get().setArg(idx++, sizeof(paddingShape), paddingShape); @@ -283,7 +284,7 @@ ErrorCode DepthwiseConvBufExecution::onEncode(const std::vector &input ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); ret |= unit.kernel->get().setArg(idx++, sizeof(inputImageShape), inputImageShape); - ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannels)); + ret |= unit.kernel->get().setArg(idx++, static_cast(batch)); ret |= unit.kernel->get().setArg(idx++, sizeof(outputImageShape), outputImageShape); ret |= unit.kernel->get().setArg(idx++, sizeof(kernelShape), kernelShape); ret |= unit.kernel->get().setArg(idx++, sizeof(paddingShape), paddingShape); diff --git a/source/backend/opencl/execution/buffer/DepthwiseConvSubgroupBufExecution.cpp b/source/backend/opencl/execution/buffer/DepthwiseConvSubgroupBufExecution.cpp index 90bcb5c36..c7db48719 100644 --- a/source/backend/opencl/execution/buffer/DepthwiseConvSubgroupBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/DepthwiseConvSubgroupBufExecution.cpp @@ -178,7 +178,8 @@ ErrorCode DepthwiseConvSubgroupBufExecution::onEncode(const std::vectormConv2dCommonParams); mPaddings[0] = padding.second;//padY mPaddings[1] = padding.first;//padX - + + const int batch = outputShape.at(0); const int outputHeight = outputShape.at(1); const int outputWidth = outputShape.at(2); const int outputChannel = outputShape.at(3); @@ -201,6 +202,8 @@ ErrorCode DepthwiseConvSubgroupBufExecution::onEncode(const std::vectormPads; int input_c_pack = TensorUtils::getTensorChannelPack(input); int output_c_pack = TensorUtils::getTensorChannelPack(output); + int trans_pad_x = inputpad.left; + int trans_pad_y = inputpad.right; std::set buildOptions = mResource->mBuildOptions; buildOptions.emplace("-DFILTER_HEIGHT=" + std::to_string(kernelShape[0])); @@ -210,9 +213,11 @@ ErrorCode DepthwiseConvSubgroupBufExecution::onEncode(const std::vectorrecordKernel3d(unit.kernel, mTranseGlobalWorkSize, mTranseLocalWorkSize); @@ -265,8 +271,9 @@ ErrorCode DepthwiseConvSubgroupBufExecution::onEncode(const std::vectorget().setArg(idx++, static_cast(inputHeight)); unit.kernel->get().setArg(idx++, static_cast(inputWidth)); unit.kernel->get().setArg(idx++, static_cast(inputChannels)); - unit.kernel->get().setArg(idx++, static_cast(inputpad.left)); - unit.kernel->get().setArg(idx++, static_cast(inputpad.right)); + unit.kernel->get().setArg(idx++, static_cast(batch)); + unit.kernel->get().setArg(idx++, static_cast(trans_pad_x)); + unit.kernel->get().setArg(idx++, static_cast(trans_pad_y)); unit.kernel->get().setArg(idx++, static_cast(outputHeight)); unit.kernel->get().setArg(idx++, static_cast(outputWidth)); unit.kernel->get().setArg(idx++, static_cast(outputpad.left)); diff --git a/source/backend/opencl/execution/buffer/GridSampleBufExecution.cpp b/source/backend/opencl/execution/buffer/GridSampleBufExecution.cpp index e4091b0ea..50b7fd25e 100644 --- a/source/backend/opencl/execution/buffer/GridSampleBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/GridSampleBufExecution.cpp @@ -76,7 +76,7 @@ ErrorCode GridSampleBufExecution::onEncode(const std::vector &inputs, ret |= unit.kernel->get().setArg(idx++, static_cast(inW)); ret |= unit.kernel->get().setArg(idx++, static_cast(outH)); ret |= unit.kernel->get().setArg(idx++, static_cast(outW)); - ret |= unit.kernel->get().setArg(idx++, static_cast(channelC4)); + ret |= unit.kernel->get().setArg(idx++, static_cast(batches)); ret |= unit.kernel->get().setArg(idx++, mPaddingMode); ret |= unit.kernel->get().setArg(idx++, mAlignCorners); MNN_CHECK_CL_SUCCESS(ret, "setArg GridSampleBufExecution"); diff --git a/source/backend/opencl/execution/buffer/GroupNormBufExecution.cpp b/source/backend/opencl/execution/buffer/GroupNormBufExecution.cpp index 92485742b..1865696ea 100644 --- a/source/backend/opencl/execution/buffer/GroupNormBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/GroupNormBufExecution.cpp @@ -117,47 +117,12 @@ ErrorCode GroupNormBufExecution::onEncode(const std::vector& inputs, co inner_size /= mGroup; mUnits.clear(); - mUnits.resize(3); + mUnits.resize(1); std::vector inputShape = tensorShapeFormat(inputs[0]); int inputWH[] = {inputShape[2], inputShape[1]}; int region[] = {inputShape[0], UP_DIV(inputShape[3], 4), inputShape[1], inputShape[2]}; - mInputPlain = std::make_shared(Tensor::createDevice(std::vector{inputShape[0] * inputShape[3] * ROUND_UP(inputShape[1] * inputShape[2], 4)})); - mOpenCLBackend->onAcquireBuffer(mInputPlain.get(), Backend::DYNAMIC); - mOutputPlain = std::make_shared(Tensor::createDevice(std::vector{inputShape[0] * inputShape[3] * ROUND_UP(inputShape[1] * inputShape[2], 4)})); - mOpenCLBackend->onAcquireBuffer(mOutputPlain.get(), Backend::DYNAMIC); - - mOpenCLBackend->onReleaseBuffer(mInputPlain.get(), Backend::DYNAMIC); - mOpenCLBackend->onReleaseBuffer(mOutputPlain.get(), Backend::DYNAMIC); std::set buildOptions; - // convert nc4hw4 to nchw - { - auto &unit = mUnits[0]; - unit.kernel = runtime->buildKernel("buffer_convert_buf", "nc4hw4_buffer_to_nchw_buffer", {}, inputs[0], outputs[0]); - - mGWS = {(uint32_t)(UP_DIV(region[3] * region[1], 16) * 16), - (uint32_t)(UP_DIV(region[2] * region[0], 16) * 16)}; - mLWS = {16, 16}; - unit.globalWorkSize = {mGWS[0], mGWS[1]}; - unit.localWorkSize = {mLWS[0], mLWS[1]}; - - int global_dim0 = region[3] * region[1]; - int global_dim1 = region[2] * region[0]; - - //MNN_CHECK_CL_SUCCESS - uint32_t idx = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(idx++, global_dim0); - ret |= unit.kernel->get().setArg(idx++, global_dim1); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mInputPlain.get())); - ret |= unit.kernel->get().setArg(idx++, inputWH[1]); - ret |= unit.kernel->get().setArg(idx++, inputWH[0]); - ret |= unit.kernel->get().setArg(idx++, inputShape[3]); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); - MNN_CHECK_CL_SUCCESS(ret, "setArg GroupNormBufExecution with group, convert nc4hw4 to nchw"); - - mOpenCLBackend->recordKernel2d(unit.kernel, mGWS, mLWS); - } // do groupnorm { int area = inputWH[1] * inputWH[0]; @@ -175,7 +140,7 @@ ErrorCode GroupNormBufExecution::onEncode(const std::vector& inputs, co } auto MaxLocalSize = std::min(runtime->getMaxWorkItemSizes()[0], (uint32_t)256); - auto &unit = mUnits[1]; + auto &unit = mUnits[0]; std::string kernelName = "groupnorm_plain_buf"; int local_size = getLocalSize(UP_DIV(inner_size, 4), MaxLocalSize); buildOptions.emplace("-DLOCAL_SIZE=" + std::to_string(local_size)); @@ -195,11 +160,11 @@ ErrorCode GroupNormBufExecution::onEncode(const std::vector& inputs, co ret |= unit.kernel->get().setArg(idx++, mGWS[0]); ret |= unit.kernel->get().setArg(idx++, mGWS[1]); ret |= unit.kernel->get().setArg(idx++, mGWS[2]); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mInputPlain.get())); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); if(inputs.size() > 1) { ret |= unit.kernel->get().setArg(idx++, openCLBuffer(inputs[1])); } - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mOutputPlain.get())); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); ret |= unit.kernel->get().setArg(idx++, static_cast(area)); ret |= unit.kernel->get().setArg(idx++, static_cast(mGroup)); ret |= unit.kernel->get().setArg(idx++, static_cast(inner_size)); @@ -212,33 +177,6 @@ ErrorCode GroupNormBufExecution::onEncode(const std::vector& inputs, co MNN_CHECK_CL_SUCCESS(ret, "setArg GroupNormBufExecution with group, do group layernorm"); mOpenCLBackend->recordKernel3d(unit.kernel, mGWS, mLWS); } - // convert nchw to nc4hw4 - { - auto &unit = mUnits[2]; - - unit.kernel = runtime->buildKernel("buffer_convert_buf", "nchw_buffer_to_nc4hw4_buffer", {}, inputs[0], outputs[0]); - mLWS = {16, 16}; - mGWS = {(uint32_t)UP_DIV(region[3] * region[1], 16) * 16, - (uint32_t)UP_DIV(region[2] * region[0], 16) * 16}; - - unit.globalWorkSize = {mGWS[0], mGWS[1]}; - unit.localWorkSize = {mLWS[0], mLWS[1]}; - - int global_dim0 = region[3] * region[1]; - int global_dim1 = region[2] * region[0]; - - uint32_t idx = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(idx++, global_dim0); - ret |= unit.kernel->get().setArg(idx++, global_dim1); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mOutputPlain.get())); - ret |= unit.kernel->get().setArg(idx++, inputWH[1]); - ret |= unit.kernel->get().setArg(idx++, inputWH[0]); - ret |= unit.kernel->get().setArg(idx++, inputShape[3]); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); - MNN_CHECK_CL_SUCCESS(ret, "setArg GroupNormBufExecution with group, convert nchw to nc4hw4"); - mOpenCLBackend->recordKernel2d(unit.kernel, mGWS, mLWS); - } mOpenCLBackend->endRecord(mRecording); return NO_ERROR; diff --git a/source/backend/opencl/execution/buffer/GroupNormBufExecution.hpp b/source/backend/opencl/execution/buffer/GroupNormBufExecution.hpp index bf569f983..2076c2780 100644 --- a/source/backend/opencl/execution/buffer/GroupNormBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/GroupNormBufExecution.hpp @@ -31,7 +31,6 @@ class GroupNormBufExecution : public CommonExecution { int32_t mBatch; std::unique_ptr mGammaTensor; std::unique_ptr mBetaTensor; - std::shared_ptr mInputPlain, mOutputPlain; bool mHasGammaBeta = false; std::vector mLWS{0, 0, 0, 0}; std::vector mGWS{0, 0, 0, 0}; diff --git a/source/backend/opencl/execution/buffer/Interp3DBufExecution.cpp b/source/backend/opencl/execution/buffer/Interp3DBufExecution.cpp index 191c2fabe..fff8ef02f 100644 --- a/source/backend/opencl/execution/buffer/Interp3DBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/Interp3DBufExecution.cpp @@ -86,7 +86,7 @@ ErrorCode Interp3DBufExecution::onEncode(const std::vector &inputs, co ret |= unit.kernel->get().setArg(idx++, static_cast(outputDepth)); ret |= unit.kernel->get().setArg(idx++, static_cast(outputHeight)); ret |= unit.kernel->get().setArg(idx++, static_cast(outputWidth)); - ret |= unit.kernel->get().setArg(idx++, static_cast(channelBlocks)); + ret |= unit.kernel->get().setArg(idx++, static_cast(inputBatch)); MNN_CHECK_CL_SUCCESS(ret, "setArg Interp3DBufExecution"); mLWS = localWS3DDefault(mGWS, mMaxWorkGroupSize, runtime, mKernelName, unit.kernel).first; diff --git a/source/backend/opencl/execution/buffer/InterpBufExecution.cpp b/source/backend/opencl/execution/buffer/InterpBufExecution.cpp index 00ab7de08..061dacbd0 100644 --- a/source/backend/opencl/execution/buffer/InterpBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/InterpBufExecution.cpp @@ -80,7 +80,7 @@ ErrorCode InterpBufExecution::onEncode(const std::vector &inputs, cons ret |= unit.kernel->get().setArg(idx++, static_cast(inputWidth)); ret |= unit.kernel->get().setArg(idx++, static_cast(outputHeight)); ret |= unit.kernel->get().setArg(idx++, static_cast(outputWidth)); - ret |= unit.kernel->get().setArg(idx++, static_cast(channelBlocks)); + ret |= unit.kernel->get().setArg(idx++, static_cast(inputBatch)); MNN_CHECK_CL_SUCCESS(ret, "setArg InterpBufExecution"); mLWS = localWS3DDefault(mGWS, mMaxWorkGroupSize, runtime, mKernelName, unit.kernel).first; diff --git a/source/backend/opencl/execution/buffer/LayerNormBufExecution.cpp b/source/backend/opencl/execution/buffer/LayerNormBufExecution.cpp index 0f6b3f629..100fe2db2 100644 --- a/source/backend/opencl/execution/buffer/LayerNormBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/LayerNormBufExecution.cpp @@ -24,7 +24,7 @@ LayerNormBufExecution::LayerNormBufExecution(const std::vector &inputs group_ = layer_norm_param->group(); RMSNorm = layer_norm_param->useRMSNorm(); auto bufferUnitSize = runtime->isSupportedFP16() ? sizeof(half_float::half) : sizeof(float); - auto kernel = runtime->buildKernel("layernorm_buf", "layernorm_w_buf", {"-DLOCAL_SIZE=512"}); + auto kernel = runtime->buildKernel("layernorm_buf", "layernorm_buf", {"-DLOCAL_SIZE=512"}); mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(kernel)); if(layer_norm_param->gamma() && layer_norm_param->beta()){ @@ -99,11 +99,6 @@ ErrorCode LayerNormBufExecution::onEncode(const std::vector &inputs, c std::vector inputShape = tensorShapeFormat(input); std::vector outputShape = tensorShapeFormat(output); - const int inputBatch = inputShape[0]; - const int inputHeight = inputShape[1]; - const int inputWidth = inputShape[2]; - const int inputChannels = inputShape[3]; - int local_size; int rank = inputs.at(0)->dimensions(); int outter_size = 1; int inner_size = 1; @@ -122,169 +117,40 @@ ErrorCode LayerNormBufExecution::onEncode(const std::vector &inputs, c } inner_size /= group_; } -// printf("out:%d in:%d, %d %d %d %d, %d\n", outter_size, inner_size, inputBatch, inputHeight, inputWidth, inputChannels, group_); + + int local_size = getLocalSize(inner_size / 4, MaxLocalSize); std::set buildOptions; + buildOptions.emplace("-DLOCAL_SIZE=" + std::to_string(local_size)); if(RMSNorm){ buildOptions.emplace("-DRMSNORM"); } if(has_gamma_beta_){ buildOptions.emplace("-DGAMMA_BETA"); } - std::string kernelName; - if (inner_size == inputWidth && outter_size == inputBatch * inputHeight * inputChannels) { - kernelName = "layernorm_w_buf"; - local_size = getLocalSize(inputWidth, MaxLocalSize); - buildOptions.emplace("-DLOCAL_SIZE=" + std::to_string(local_size)); - unit.kernel = runtime->buildKernel("layernorm_buf", kernelName, buildOptions); - - mGWS = {static_cast(local_size), - static_cast(inputHeight * UP_DIV(inputChannels, 4)), - static_cast(inputBatch)}; - }else if(inner_size == inputWidth * inputHeight && outter_size == inputBatch * inputChannels){ - kernelName = "layernorm_hw_buf"; - local_size = getLocalSize(inputWidth * inputHeight, MaxLocalSize); - buildOptions.emplace("-DLOCAL_SIZE=" + std::to_string(local_size)); - unit.kernel = runtime->buildKernel("layernorm_buf", kernelName, buildOptions); - - mGWS = {static_cast(local_size), - static_cast(UP_DIV(inputChannels, 4)), - static_cast(inputBatch)}; - }else if(inner_size == inputWidth * inputHeight * inputChannels && outter_size == inputBatch){ - kernelName = "layernorm_chw_buf"; - local_size = getLocalSize(inputWidth * inputHeight, MaxLocalSize); - buildOptions.emplace("-DLOCAL_SIZE=" + std::to_string(local_size)); - unit.kernel = runtime->buildKernel("layernorm_buf", kernelName, buildOptions); - - mGWS = {static_cast(local_size), - static_cast(1), - static_cast(inputBatch)}; - } else if(inner_size == inputWidth * inputHeight * inputChannels / group_ && outter_size == inputBatch * group_){ - mUnits.clear(); - mUnits.resize(3); - std::vector inputShape = tensorShapeFormat(inputs[0]); - int inputWH[] = {inputShape[2], inputShape[1]}; - int region[] = {inputShape[0], UP_DIV(inputShape[3], 4), inputShape[1], inputShape[2]}; - - mInputPlain = std::make_shared(Tensor::createDevice(std::vector{inputShape[0], inputShape[3], ROUND_UP(inputShape[1] * inputShape[2], 4), 1}, Tensor::CAFFE)); - mOpenCLBackend->onAcquireBuffer(mInputPlain.get(), Backend::DYNAMIC); - mOutputPlain = std::make_shared(Tensor::createDevice(std::vector{inputShape[0], inputShape[3], ROUND_UP(inputShape[1] * inputShape[2], 4), 1}, Tensor::CAFFE)); - mOpenCLBackend->onAcquireBuffer(mOutputPlain.get(), Backend::DYNAMIC); - - // convert nc4hw4 to nchw - { - auto &unit = mUnits[0]; - unit.kernel = runtime->buildKernel("buffer_convert_buf", "nc4hw4_buffer_to_nchw_buffer", {}, inputs[0], outputs[0]); - - mGWS = {(uint32_t)(UP_DIV(region[3] * region[1], 16) * 16), - (uint32_t)(UP_DIV(region[2] * region[0], 16) * 16)}; - mLWS = {16, 16}; - unit.globalWorkSize = {mGWS[0], mGWS[1]}; - unit.localWorkSize = {mLWS[0], mLWS[1]}; - - int global_dim0 = region[3] * region[1]; - int global_dim1 = region[2] * region[0]; - - //MNN_CHECK_CL_SUCCESS - uint32_t idx = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(idx++, global_dim0); - ret |= unit.kernel->get().setArg(idx++, global_dim1); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mInputPlain.get())); - ret |= unit.kernel->get().setArg(idx++, inputWH[1]); - ret |= unit.kernel->get().setArg(idx++, inputWH[0]); - ret |= unit.kernel->get().setArg(idx++, inputShape[3]); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); - MNN_CHECK_CL_SUCCESS(ret, "setArg LayerNormBufExecution with group, convert nc4hw4 to nchw"); - - mOpenCLBackend->recordKernel2d(unit.kernel, mGWS, mLWS); - } - // do group layernorm - { - auto &unit = mUnits[1]; - kernelName = "layernorm_plain_buf"; - local_size = getLocalSize(UP_DIV(inner_size, 4), MaxLocalSize); - buildOptions.emplace("-DLOCAL_SIZE=" + std::to_string(local_size)); - unit.kernel = runtime->buildKernel("layernorm_buf", kernelName, buildOptions); - - mGWS = {static_cast(local_size), - static_cast(1), - static_cast(outter_size)}; - - mLWS = {static_cast(local_size), 1, 1}; - - unit.globalWorkSize = {mGWS[0], mGWS[1], mGWS[2]}; - unit.localWorkSize = {mLWS[0], mLWS[1], mLWS[2]}; - - uint32_t idx = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(idx++, mGWS[0]); - ret |= unit.kernel->get().setArg(idx++, mGWS[1]); - ret |= unit.kernel->get().setArg(idx++, mGWS[2]); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mInputPlain.get())); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mOutputPlain.get())); - ret |= unit.kernel->get().setArg(idx++, static_cast(inner_size)); - ret |= unit.kernel->get().setArg(idx++, static_cast(outter_size)); - if(has_gamma_beta_){ - ret |= unit.kernel->get().setArg(idx++, *mGammaBuffer.get()); - ret |= unit.kernel->get().setArg(idx++, *mBetaBuffer.get()); - } - ret |= unit.kernel->get().setArg(idx++, epsilon_); - MNN_CHECK_CL_SUCCESS(ret, "setArg LayerNormBufExecution with group, do group layernorm"); - mOpenCLBackend->recordKernel3d(unit.kernel, mGWS, mLWS); - } - // convert nchw to nc4hw4 - { - auto &unit = mUnits[2]; - - unit.kernel = runtime->buildKernel("buffer_convert_buf", "nchw_buffer_to_nc4hw4_buffer", {}, inputs[0], outputs[0]); - mLWS = {16, 16}; - mGWS = {(uint32_t)UP_DIV(region[3] * region[1], 16) * 16, - (uint32_t)UP_DIV(region[2] * region[0], 16) * 16}; - - unit.globalWorkSize = {mGWS[0], mGWS[1]}; - unit.localWorkSize = {mLWS[0], mLWS[1]}; - - int global_dim0 = region[3] * region[1]; - int global_dim1 = region[2] * region[0]; - - uint32_t idx = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(idx++, global_dim0); - ret |= unit.kernel->get().setArg(idx++, global_dim1); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mOutputPlain.get())); - ret |= unit.kernel->get().setArg(idx++, inputWH[1]); - ret |= unit.kernel->get().setArg(idx++, inputWH[0]); - ret |= unit.kernel->get().setArg(idx++, inputShape[3]); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); - MNN_CHECK_CL_SUCCESS(ret, "setArg LayerNormBufExecution with group, convert nchw to nc4hw4"); - mOpenCLBackend->recordKernel2d(unit.kernel, mGWS, mLWS); - } - - mOpenCLBackend->onReleaseBuffer(mInputPlain.get(), Backend::DYNAMIC); - mOpenCLBackend->onReleaseBuffer(mOutputPlain.get(), Backend::DYNAMIC); - return NO_ERROR; + if(inner_size % 4 != 0){ + buildOptions.emplace("-DPACK_LEAVE"); } - mLWS = {static_cast(local_size), 1, 1}; + + unit.kernel = runtime->buildKernel("layernorm_buf", "layernorm_buf", buildOptions); + mGWS = {static_cast(local_size), static_cast(outter_size)}; + mLWS = {static_cast(local_size), 1}; uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, mGWS[0]); ret |= unit.kernel->get().setArg(idx++, mGWS[1]); - ret |= unit.kernel->get().setArg(idx++, mGWS[2]); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); - ret |= unit.kernel->get().setArg(idx++, static_cast(inputWidth)); - ret |= unit.kernel->get().setArg(idx++, static_cast(inputHeight)); - ret |= unit.kernel->get().setArg(idx++, static_cast(inputChannels)); + ret |= unit.kernel->get().setArg(idx++, static_cast(inner_size)); if(has_gamma_beta_){ ret |= unit.kernel->get().setArg(idx++, *mGammaBuffer.get()); ret |= unit.kernel->get().setArg(idx++, *mBetaBuffer.get()); } ret |= unit.kernel->get().setArg(idx++, epsilon_); MNN_CHECK_CL_SUCCESS(ret, "setArg LayerNormBufExecution"); - mOpenCLBackend->recordKernel3d(unit.kernel, mGWS, mLWS); - unit.globalWorkSize = {mGWS[0], mGWS[1], mGWS[2]}; - unit.localWorkSize = {mLWS[0], mLWS[1], mLWS[2]}; + mOpenCLBackend->recordKernel2d(unit.kernel, mGWS, mLWS); + unit.globalWorkSize = {mGWS[0], mGWS[1]}; + unit.localWorkSize = {mLWS[0], mLWS[1]}; return NO_ERROR; diff --git a/source/backend/opencl/execution/buffer/LoopBufExecution.cpp b/source/backend/opencl/execution/buffer/LoopBufExecution.cpp index bf8dfc463..56476b59e 100644 --- a/source/backend/opencl/execution/buffer/LoopBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/LoopBufExecution.cpp @@ -12,137 +12,6 @@ namespace MNN { namespace OpenCL { - -static void _TileOrPackTensor(Tensor *input, Tensor *output, std::shared_ptr& kernelW, cl::NDRange &globalWorkSize, - cl::NDRange &localWorkSize, const int Width, const int Height, const int Channel, - const int Batch, OpenCLBackend *bn, const std::string& KernelName, std::set buildOptions, - const int WidthPad, const int HeightPad, const int ChannelPad, OpenCLRuntime* runtime) { - bool fastTileTranspose = false; - if (TensorUtils::getDescribe(output)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC || TensorUtils::getDescribe(input)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC){ - buildOptions.emplace("-DMNN_NHWC"); - } else { - if (KernelName == "tile_buf" && buildOptions.find("-DTRANSPOSE") != buildOptions.end() && (buildOptions.find("-DDIMENSION_3") != buildOptions.end() || buildOptions.find("-DDIMENSION_4") != buildOptions.end())) { - fastTileTranspose = true; - } - } - - std::string runKernelName = KernelName; - unsigned int tileW = 32; - unsigned int tileC = 32; - unsigned int tileH = 32; - - unsigned int localW = 8; - unsigned int localC = 8; - unsigned int localH = 8; - if(fastTileTranspose) { - // local memory limit - uint32_t local_mem_size = 4; - if(runtime->isSupportedFP16()) { - local_mem_size = 2; - } - - if(buildOptions.find("-DDIMENSION_4") != buildOptions.end()) { - local_mem_size *= (64 * 64 * 4); - if(local_mem_size <= runtime->getMaxLocalMem()) { - if((WidthPad & 63) == 0) { - tileW = 64; - } - if((HeightPad & 63) == 0) { - tileH = 64; - } - } - - runKernelName = "tile_trans_4d_buf"; - // match with tileW tileH tileW/localW tileH/localH - buildOptions.emplace("-DWGSW=" + std::to_string(tileW)); - buildOptions.emplace("-DWGSH=" + std::to_string(tileH)); - buildOptions.emplace("-DTSW=" + std::to_string(tileW/localW)); - buildOptions.emplace("-DTSH=" + std::to_string(tileH/localH)); - } else { - local_mem_size *= (64 * 64); - if(local_mem_size <= runtime->getMaxLocalMem()) { - if((ChannelPad & 63) == 0) { - tileC = 64; - } - if((HeightPad & 63) == 0) { - tileH = 64; - } - } - runKernelName = "tile_trans_3d_buf"; - // match with tileW tileH tileW/localW tileH/localH - buildOptions.emplace("-DWGSC=" + std::to_string(tileC)); - buildOptions.emplace("-DWGSH=" + std::to_string(tileH)); - buildOptions.emplace("-DTSC=" + std::to_string(tileC/localC)); - buildOptions.emplace("-DTSH=" + std::to_string(tileH/localH)); - } - - } - if(input->getType().code == halide_type_int){ - kernelW = bn->getOpenCLRuntime()->buildKernel("loop_buf", runKernelName, buildOptions, input, input); - }else if (output->getType().code == halide_type_int){ - kernelW = bn->getOpenCLRuntime()->buildKernel("loop_buf", runKernelName, buildOptions, output, output); - }else { - kernelW = bn->getOpenCLRuntime()->buildKernel("loop_buf", runKernelName, buildOptions, input, output); - } - auto kernel = kernelW->get(); - - uint32_t mMaxWorkGroupSize = static_cast(bn->getOpenCLRuntime()->getMaxWorkGroupSize(kernelW)); - - if(fastTileTranspose) { - int w_per_thread = tileW / localW; - int h_per_thread = tileH / localH; - std::vector mGlobalWorkSize = {(uint32_t)WidthPad/w_per_thread, (uint32_t)HeightPad/h_per_thread, (uint32_t)(UP_DIV(ChannelPad, 4)*Batch)}; - std::vector mLocalWorkSize = {localW, localH, 1}; - - if(buildOptions.find("-DDIMENSION_3") != buildOptions.end()) { - int c_per_thread = tileC / localC; - int h_per_thread = tileH / localH; - mGlobalWorkSize = {(uint32_t)ChannelPad/c_per_thread, (uint32_t)HeightPad/h_per_thread, (uint32_t)Batch}; - mLocalWorkSize = {localC, localH, 1}; - } - - uint32_t index = 0; - cl_int ret = CL_SUCCESS; - ret |= kernel.setArg(index++, openCLBuffer(input)); - ret |= kernel.setArg(index++, openCLBuffer(output)); - ret |= kernel.setArg(index++, WidthPad); - ret |= kernel.setArg(index++, HeightPad); - ret |= kernel.setArg(index++, ChannelPad); - ret |= kernel.setArg(index++, Batch); - ret |= kernel.setArg(index++, Width); - ret |= kernel.setArg(index++, Height); - ret |= kernel.setArg(index++, Channel); - MNN_CHECK_CL_SUCCESS(ret, "setArg LoopBuf _TileOrPackTensor tile_transpose_fast_buf"); - - globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; - bn->recordKernel3d(kernelW, mGlobalWorkSize, mLocalWorkSize); - } else { - std::vector mGlobalWorkSize = {(uint32_t)WidthPad, (uint32_t)HeightPad, (uint32_t)(UP_DIV(ChannelPad, 4)*Batch)}; - - uint32_t index = 0; - cl_int ret = CL_SUCCESS; - ret |= kernel.setArg(index++, mGlobalWorkSize[0]); - ret |= kernel.setArg(index++, mGlobalWorkSize[1]); - ret |= kernel.setArg(index++, mGlobalWorkSize[2]); - ret |= kernel.setArg(index++, openCLBuffer(input)); - ret |= kernel.setArg(index++, openCLBuffer(output)); - ret |= kernel.setArg(index++, WidthPad); - ret |= kernel.setArg(index++, HeightPad); - ret |= kernel.setArg(index++, ChannelPad); - ret |= kernel.setArg(index++, Batch); - ret |= kernel.setArg(index++, Width); - ret |= kernel.setArg(index++, Height); - ret |= kernel.setArg(index++, Channel); - MNN_CHECK_CL_SUCCESS(ret, "setArg LoopBuf _TileOrPackTensor"); - - std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, bn->getOpenCLRuntime(), KernelName, kernelW).first; - - globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; - bn->recordKernel3d(kernelW, mGlobalWorkSize, mLocalWorkSize); - } -} static void _setTensorStack(std::vector &result, const std::vector &inputs, const std::vector &outputs, const LoopParam *loop) { @@ -190,23 +59,10 @@ ErrorCode LoopGatherBufExecution::onEncode(const std::vector &inputs, // gather { + Unit unit; auto input = mTensors[cmd->indexes()->data()[1]]; auto output = mTensors[cmd->indexes()->data()[0]]; - std::vector inputShape = tensorShapeFormat(input); - std::vector outputShape = tensorShapeFormat(output); - int inputShapeVec[4] = {inputShape[2], inputShape[1], inputShape[3], inputShape[0]}; - int outputShapeVec[4] = {outputShape[2], outputShape[1], outputShape[3], outputShape[0]}; - int offset_index = 0; - - Unit unit; std::set buildOptions; - if (TensorUtils::getDescribe(output)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC){ - buildOptions.emplace("-DGATHER_OUTPUT_NHWC"); - } - if (TensorUtils::getDescribe(input)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC){ - buildOptions.emplace("-DGATHER_INPUT_NHWC"); - } - if (mIter[0] >= 0) { buildOptions.emplace("-DOFFSET_DST"); } @@ -239,8 +95,6 @@ ErrorCode LoopGatherBufExecution::onEncode(const std::vector &inputs, ret |= unit.kernel->get().setArg(index++, sizeof(mStride_dst), mStride_dst); ret |= unit.kernel->get().setArg(index++, sizeof(mStep), mStep); ret |= unit.kernel->get().setArg(index++, sizeof(mIter), mIter); - ret |= unit.kernel->get().setArg(index++, sizeof(outputShapeVec), outputShapeVec); - ret |= unit.kernel->get().setArg(index++, sizeof(inputShapeVec), inputShapeVec); ret |= unit.kernel->get().setArg(index++, inputSize); MNN_CHECK_CL_SUCCESS(ret, "setArg LoopGatherBufExecution"); @@ -261,142 +115,6 @@ LoopBatchMatMulBufExecution::LoopBatchMatMulBufExecution(const LoopParam *loop, mTensors.resize(mLoop->tensorNumber()); } -static std::tuple getTileDimensionSize(std::tuple shape, std::tuple tile, MNN_DATA_FORMAT format, int dimension, bool transpose, int index) { - if(index > 2 || index < 0) { - MNN_ERROR("Error getTileDimensionSize index, only support 1 for input_1, 2 for input_2, 0 for output!\n"); - return shape; - } - // tile: {e, l, h} - int tile_e = std::get<0>(tile); - int tile_l = std::get<1>(tile); - int tile_h = std::get<2>(tile); - // shape: {w, h, c} - int pad_w = std::get<0>(shape); - int pad_h = std::get<1>(shape); - int pad_c = std::get<2>(shape); - - // output - if(index == 0) { - if (format == MNN::MNN_DATA_FORMAT_NHWC) { - if(dimension == 3) { - // [N, H, W] -> (n, e, h) - pad_h = ROUND_UP(pad_h, tile_e); - pad_w = ROUND_UP(pad_w, tile_h); - } else { - // [N*H, W, C] -> [n, e, h] - pad_w = ROUND_UP(pad_w, tile_e); - pad_c = ROUND_UP(pad_c, tile_h); - } - } else { - if(dimension == 3) { - // [N, C, H] -> (n, e, h) - pad_c = ROUND_UP(pad_c, tile_e); - pad_h = ROUND_UP(pad_h, tile_h); - } else { - // [N*C, H, W] -> [n, e, h] - pad_h = ROUND_UP(pad_h, tile_e); - pad_w = ROUND_UP(pad_w, tile_h); - } - } - return std::make_tuple(pad_w, pad_h, pad_c); - } - - if (format == MNN::MNN_DATA_FORMAT_NHWC) { - if(dimension == 3) { - if(transpose) { - if(index == 1) { - // [N, H, W] -> (n, l, e) - pad_h = ROUND_UP(pad_h, tile_l); - pad_w = ROUND_UP(pad_w, tile_e); - } else { - // [N, H, W] -> (n, h, l) - pad_h = ROUND_UP(pad_h, tile_h); - pad_w = ROUND_UP(pad_w, tile_l); - } - } else { - if(index == 1) { - // [N, H, W] -> (n, e, l) - pad_h = ROUND_UP(pad_h, tile_e); - pad_w = ROUND_UP(pad_w, tile_l); - } else { - // [N, H, W] -> (n, l, h) - pad_h = ROUND_UP(pad_h, tile_l); - pad_w = ROUND_UP(pad_w, tile_h); - } - } - } else { - if(transpose) { - if(index == 1) { - // [N*H, W, C] -> (n, l, e) - pad_w = ROUND_UP(pad_w, tile_l); - pad_c = ROUND_UP(pad_c, tile_e); - } else { - // [N*H, W, C] -> (n, h, l) - pad_w = ROUND_UP(pad_w, tile_h); - pad_c = ROUND_UP(pad_c, tile_l); - } - } else { - if(index == 1) { - // [N*H, W, C] -> [n, e, l] - pad_w = ROUND_UP(pad_w, tile_e); - pad_c = ROUND_UP(pad_c, tile_l); - } else { - // [N*H, W, C] -> [n, l, h] - pad_w = ROUND_UP(pad_w, tile_l); - pad_c = ROUND_UP(pad_c, tile_h); - } - } - } - } else { - if(dimension == 3) { - if(transpose) { - if(index == 1) { - // [N, C, H] -> (n, l, e) - pad_c = ROUND_UP(pad_c, tile_l); - pad_h = ROUND_UP(pad_h, tile_e); - } else { - // [N, C, H] -> (n, h, l) - pad_c = ROUND_UP(pad_c, tile_h); - pad_h = ROUND_UP(pad_h, tile_l); - } - } else { - if(index == 1) { - // [N, C, H] -> (n, e, l) - pad_c = ROUND_UP(pad_c, tile_e); - pad_h = ROUND_UP(pad_h, tile_l); - } else { - // [N, C, H] -> (n, l, h) - pad_c = ROUND_UP(pad_c, tile_l); - pad_h = ROUND_UP(pad_h, tile_h); - } - } - } else { - if(transpose) { - if(index == 1) { - // [N*C, H, W] -> (n, l, e) - pad_h = ROUND_UP(pad_h, tile_l); - pad_w = ROUND_UP(pad_w, tile_e); - } else { - // [N*C, H, W] -> (n, h, l) - pad_h = ROUND_UP(pad_h, tile_h); - pad_w = ROUND_UP(pad_w, tile_l); - } - } else { - if(index == 1) { - // [N*C, H, W] -> [n, e, l] - pad_h = ROUND_UP(pad_h, tile_e); - pad_w = ROUND_UP(pad_w, tile_l); - } else { - // [N*C, H, W] -> [n, l, h] - pad_h = ROUND_UP(pad_h, tile_l); - pad_w = ROUND_UP(pad_w, tile_h); - } - } - } - } - return std::make_tuple(pad_w, pad_h, pad_c); -} - ErrorCode LoopBatchMatMulBufExecution::onEncode(const std::vector &inputs, const std::vector &outputs) { auto cmd = mLoop->commands()->GetAs(0); mHasBias = cmd->indexes()->size() > 3; @@ -410,10 +128,7 @@ ErrorCode LoopBatchMatMulBufExecution::onEncode(const std::vector &inp mOffset[1] = cmd->view()->GetAs(1)->offset(); mOffset[2] = cmd->view()->GetAs(2)->offset(); mUnits.clear(); - mOffsetTensors.clear(); - mTmpTensors.resize(3); if (mHasBias) { - mTmpTensors.resize(4); mOffset[3] = cmd->view()->GetAs(3)->offset(); } @@ -424,190 +139,8 @@ ErrorCode LoopBatchMatMulBufExecution::onEncode(const std::vector &inp int h = cmd->size()->data()[2]; int n = mLoop->loopNumber(); - int tileM = 32; - int tileN = 32; - int tileK = 4; - bool isTotalLarge = (e * 1.0 / 512 * l / 512 * h / 512 > 0.5); - bool isDimLarge = (e > 256 && l > 256 && h > 256); - int max_eh = std::max(e, h); - int min_eh = std::min(e, h); - isDimLarge = isDimLarge || (l >= 512 && (max_eh > 1024 || min_eh > 32)); - - mBatchGemmOpt = isTotalLarge && isDimLarge; - for(int i = 0; i < cmd->iterIndexes()->size(); ++i){ - if (mIter[i] >= 0) { - mBatchGemmOpt = false; - break; - } - } - - if(mHasBias) { - mBatchGemmOpt = false; - } - - bool needRearrangeA = false; - if(mBatchGemmOpt && !mTransposeA) { - // rearrange to [n, l, e] - needRearrangeA = true; - } - bool needRearrangeB = false; - if(mBatchGemmOpt && mTransposeB) { - // rearrange to [n, l, h] - needRearrangeB = true; - } - - // tile input - for (int i = 1; i < cmd->indexes()->size(); ++i) { - auto input = mTensors[cmd->indexes()->data()[i]]; - std::vector Shape = tensorShapeFormat(input); - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - bool needTranspose = false; - if(i == 1) { - needTranspose = needRearrangeA; - } else if(i == 2) { - needTranspose = needRearrangeB; - } - - Unit unit; - std::set buildOptions = mBuildOptions; - if(needTranspose) { - buildOptions.emplace("-DTRANSPOSE"); - } - if(input->buffer().dimensions == 3) { - buildOptions.emplace("-DDIMENSION_3"); - } - if(input->buffer().dimensions == 4) { - buildOptions.emplace("-DDIMENSION_4"); - } - - int WidthPad = Width; - int HeightPad = Height; - int ChannelPad = Channel; - - if(mBatchGemmOpt) { - auto shape = getTileDimensionSize(std::make_tuple(Width, Height, Channel), std::make_tuple(tileM, tileK, tileN), TensorUtils::getDescribe(input)->dimensionFormat, input->buffer().dimensions, needTranspose, i); - WidthPad = std::get<0>(shape); - HeightPad = std::get<1>(shape); - ChannelPad = std::get<2>(shape); - } - - mTmpTensors[i] = std::make_shared(Tensor::createDevice(std::vector{Batch, ChannelPad, HeightPad, WidthPad}, Tensor::CAFFE)); - // MNN_PRINT("input%d, %d %d %d %d\n", i, Batch, ChannelPad, HeightPad, WidthPad); - - mOpenCLBackend->onAcquireBuffer(mTmpTensors[i].get(), Backend::DYNAMIC); - _TileOrPackTensor(input, mTmpTensors[i].get(), unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, "tile_buf", buildOptions, WidthPad, HeightPad, ChannelPad, runTime); - mUnits.emplace_back(unit); - } - - for(int i = 0; i < cmd->iterIndexes()->size(); ++i){ - if (mIter[i] >= 0) { - auto input = mTensors[cmd->iterIndexes()->data()[i]]; - std::vector Shape = tensorShapeFormat(input); - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - mOffsetTensors.emplace_back(std::make_shared(Tensor::createDevice(std::vector{Batch, Channel, Height, Width}, Tensor::CAFFE))); - mOpenCLBackend->onAcquireBuffer(mOffsetTensors.back().get(), Backend::DYNAMIC); - // MNN_PRINT("input%d offset, %d %d %d %d\n", i, Batch, Channel, Height, Width); - - Unit unit; - _TileOrPackTensor(input, mOffsetTensors.back().get(), unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, "tile_buf", mBuildOptions, Width, Height, Channel, runTime); - mUnits.emplace_back(unit); - } - } - - mBatch = n; - mM = e; - mN = h; - mK = l; - if(mBatchGemmOpt) { - // matmul - int e_pack = ROUND_UP(e, tileM); - int l_pack = ROUND_UP(l, tileK); - int h_pack = ROUND_UP(h, tileN); - mTmpTensors[0] = std::make_shared(Tensor::createDevice(std::vector{n * e_pack * h_pack}, Tensor::CAFFE)); - mOpenCLBackend->onAcquireBuffer(mTmpTensors[0].get(), Backend::DYNAMIC); - - - std::set buildOptions; - - uint32_t layout = 0; - auto param = getGemmParams({(uint32_t)e_pack, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)n, (uint32_t)0}, {openCLBuffer(mTmpTensors[1].get()), openCLBuffer(mTmpTensors[2].get()), openCLBuffer(mTmpTensors[0].get())}, mOpenCLBackend->getOpenCLRuntime()); - - int KWG=param[0], KWI=param[1], MDIMA=param[2], MDIMC=param[3], MWG=param[4], NDIMB=param[5], NDIMC=param[6], NWG=param[7], SA=param[8], SB=param[9], STRM=param[10], STRN=param[11], VWM=param[12], VWN=param[13]; - buildOptions.emplace("-DKWG=" + std::to_string(KWG)); - buildOptions.emplace("-DKWI=" + std::to_string(KWI)); - buildOptions.emplace("-DMDIMA=" + std::to_string(MDIMA)); - buildOptions.emplace("-DMDIMC=" + std::to_string(MDIMC)); - buildOptions.emplace("-DMWG=" + std::to_string(MWG)); - buildOptions.emplace("-DNDIMB=" + std::to_string(NDIMB)); - buildOptions.emplace("-DNDIMC=" + std::to_string(NDIMC)); - buildOptions.emplace("-DNWG=" + std::to_string(NWG)); - buildOptions.emplace("-DSA=" + std::to_string(SA)); - buildOptions.emplace("-DSB=" + std::to_string(SB)); - buildOptions.emplace("-DSTRM=" + std::to_string(STRM)); - buildOptions.emplace("-DSTRN=" + std::to_string(STRN)); - buildOptions.emplace("-DVWM=" + std::to_string(VWM)); - buildOptions.emplace("-DVWN=" + std::to_string(VWN)); - if(layout >= 4) { - buildOptions.emplace("-DOUTPUTMN"); - } - - tileM = MWG; - tileN = NWG; - int localM = MDIMC; - int localN = NDIMC; - - if(mOpenCLBackend->getOpenCLRuntime()->getGpuType() == GpuType::ADRENO) { - buildOptions.emplace("-DUSE_CL_MAD=1"); - buildOptions.emplace("-DRELAX_WORKGROUP_SIZE=1"); - } - - Unit unit; - unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("matmul_params_buf", "XgemmBatched", buildOptions); - - int out_per_thread_m = tileM / localM; - int out_per_thread_n = tileN / localN; - - std::vector globalWorkSize = {static_cast(e_pack/out_per_thread_m), static_cast(h_pack/out_per_thread_n), static_cast(n)}; - std::vector localWorkSize = {static_cast(localM), static_cast(localN), 1}; - - float alpha = 1.0; - float beta = 0.0f; - int batch_offset_a = e_pack * l_pack; - int batch_offset_b = h_pack * l_pack; - int batch_offset_c = e_pack * h_pack; - int idx = 0; - cl_int ret = CL_SUCCESS; - ret |= unit.kernel->get().setArg(idx++, static_cast(e_pack)); - ret |= unit.kernel->get().setArg(idx++, static_cast(h_pack)); - ret |= unit.kernel->get().setArg(idx++, static_cast(l_pack)); - ret |= unit.kernel->get().setArg(idx++, alpha); - ret |= unit.kernel->get().setArg(idx++, beta); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mTmpTensors[1].get())); - ret |= unit.kernel->get().setArg(idx++, batch_offset_a); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mTmpTensors[2].get())); - ret |= unit.kernel->get().setArg(idx++, batch_offset_b); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mTmpTensors[0].get())); - ret |= unit.kernel->get().setArg(idx++, batch_offset_c); - MNN_CHECK_CL_SUCCESS(ret, "setArg LoopBuf GemmTile Kernel"); - - unit.globalWorkSize = {globalWorkSize[0], globalWorkSize[1], globalWorkSize[2]}; - unit.localWorkSize = {localWorkSize[0], localWorkSize[1], localWorkSize[2]}; - mUnits.emplace_back(unit); - mOpenCLBackend->recordKernel3d(unit.kernel, globalWorkSize, localWorkSize); - - } else { + { // matmul - mTmpTensors[0] = std::make_shared(Tensor::createDevice(std::vector{1, n, e, h}, Tensor::CAFFE)); - mOpenCLBackend->onAcquireBuffer(mTmpTensors[0].get(), Backend::DYNAMIC); - int offset_index = 0; - - // MNN_PRINT("batchgemm:%d, %d %d %d, transAB %d %d, bias:%d, inputsize:%d\n", n, e, h, l, mTransposeA, mTransposeB, mHasBias, cmd->indexes()->size()); Unit unit; std::string KernelName = "batch_matmul"; std::set buildOptions = mBuildOptions; @@ -630,15 +163,15 @@ ErrorCode LoopBatchMatMulBufExecution::onEncode(const std::vector &inp ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[0]); ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[1]); ret |= unit.kernel->get().setArg(index++, mGlobalWorkSize[2]); - ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTmpTensors[0].get())); - ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTmpTensors[1].get())); - ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTmpTensors[2].get())); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->indexes()->data()[0]])); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->indexes()->data()[1]])); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->indexes()->data()[2]])); if (mHasBias) { - ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTmpTensors[3].get())); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->indexes()->data()[3]])); } for (int i = 0; i < cmd->iterIndexes()->size(); ++i) { if (mIter[i] >= 0) { - ret |= unit.kernel->get().setArg(index++, openCLBuffer(mOffsetTensors[offset_index++].get())); + ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->iterIndexes()->data()[i]])); } else { ret |= unit.kernel->get().setArg(index++, openCLBuffer(mTensors[cmd->indexes()->data()[1]])); } @@ -659,116 +192,9 @@ ErrorCode LoopBatchMatMulBufExecution::onEncode(const std::vector &inp mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); } - //pack output - { - auto output = mTensors[cmd->indexes()->data()[0]]; - std::vector Shape = tensorShapeFormat(output); - const int Channel = Shape.at(3); - const int Width = Shape.at(2); - const int Height = Shape.at(1); - const int Batch = Shape.at(0); - // MNN_PRINT("output, %d %d %d %d\n", Batch, Channel, Height, Width); - - Unit unit; - std::set buildOptions = mBuildOptions; - if(mBatchGemmOpt) { - buildOptions.emplace("-DTRANSPOSE"); - if (mHasBias) { - buildOptions.emplace("-DBIAS"); - } - if(output->buffer().dimensions == 3) { - buildOptions.emplace("-DDIMENSION_3"); - } - if(output->buffer().dimensions == 4) { - buildOptions.emplace("-DDIMENSION_4"); - } - } - - int WidthPad = Width; - int HeightPad = Height; - int ChannelPad = Channel; - if(mBatchGemmOpt) { - auto shape = getTileDimensionSize(std::make_tuple(Width, Height, Channel), std::make_tuple(tileM, tileK, tileN), TensorUtils::getDescribe(output)->dimensionFormat, output->buffer().dimensions, false, 0); - WidthPad = std::get<0>(shape); - HeightPad = std::get<1>(shape); - ChannelPad = std::get<2>(shape); - } - _TileOrPackTensor(mTmpTensors[0].get(), output, unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, "pack_buf", buildOptions, WidthPad, HeightPad, ChannelPad, runTime); - mUnits.emplace_back(unit); - } - - for (int i = 0; i < cmd->indexes()->size(); ++i) { - mOpenCLBackend->onReleaseBuffer(mTmpTensors[i].get(), Backend::DYNAMIC); - } - for (int i = 0; i < mOffsetTensors.size(); ++i) { - mOpenCLBackend->onReleaseBuffer(mOffsetTensors[i].get(), Backend::DYNAMIC); - } - return NO_ERROR; } -ErrorCode LoopBatchMatMulBufExecution::onExecute(const std::vector &inputs, const std::vector &outputs) { - auto openCLBackend = static_cast(backend()); - auto runtime = openCLBackend->getOpenCLRuntime(); -#ifdef ENABLE_OPENCL_TIME_PROFILER - int idx = 0; -#else - if(openCLBackend->isUseRecordQueue()){ - openCLBackend->addRecord(mRecording, mOpRecordUpdateInfo); - return NO_ERROR; - } -#endif - auto res = CL_SUCCESS; - for (auto &unit : mUnits) { - #ifdef ENABLE_OPENCL_TIME_PROFILER - cl::Event event; - res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), - cl::NullRange, - unit.globalWorkSize, - unit.localWorkSize, - nullptr, - &event); - std::string name = "While-gemm"; - - if(mBatchGemmOpt) { - if(idx == 2) { - name += "-batchgemm"; - } else if(idx == 0) { - name += "-rearrangeA"; - } else if(idx == 1) { - name += "-rearrangeB"; - } else { - name += "-rearrangeC"; - } - } else { - if(idx == mUnits.size()-2) { - name += "-batchgemm"; - } else if(idx == 0) { - name += "-rearrangeA"; - } else if(idx == 1) { - name += "-rearrangeB"; - } else { - name += "-rearrangeC"; - } - } - std::string b = std::to_string(mBatch); - std::string m = std::to_string(mM); - std::string n = std::to_string(mN); - std::string k = std::to_string(mK); - std::string total = std::to_string(1.0 / 1000000 * mBatch * mM * mN * mK); - name += "-b" + b + "m" + m + "n" + n + "k" + k + "-total:" + total + "*10^6"; - runtime->pushEvent({name.c_str(), event}); - idx++; - #else - res = runtime->commandQueue().enqueueNDRangeKernel(unit.kernel->get(), - cl::NullRange, - unit.globalWorkSize, - unit.localWorkSize); - #endif - MNN_CHECK_CL_SUCCESS(res, "While-gemm execute"); - } - return NO_ERROR; -} LoopBinaryBufExecution::LoopBinaryBufExecution(const LoopParam *loop, const std::string &compute, const MNN::Op *op, Backend *bn) : CommonExecution(bn, op) { mLoop = loop; @@ -784,115 +210,28 @@ ErrorCode LoopBinaryBufExecution::onEncode(const std::vector &inputs, mUnits.clear(); Unit unit; + int z = cmd->size()->data()[0]; + int y = cmd->size()->data()[1]; + int x = cmd->size()->data()[2]; + int n = mLoop->loopNumber(); + int inputSize = mTensors[cmd->indexes()->data()[1]]->elementSize(); + + auto src0Stride = cmd->view()->GetAs(1)->stride()->data(); + auto src1Stride = cmd->view()->GetAs(2)->stride()->data(); + auto dstStride = cmd->view()->GetAs(0)->stride()->data(); + for (int i = 0; i < 3; ++i) { + mStride_src0[i] = src0Stride[i]; + mStride_src1[i] = src1Stride[i]; + mStride_dst[i] = dstStride[i]; + } + auto input0 = mTensors[cmd->indexes()->data()[1]]; - std::vector input0C4Shape = tensorShapeFormat(input0); - int input0C4Size[4] = {input0C4Shape.at(0), input0C4Shape.at(3),input0C4Shape.at(1),input0C4Shape.at(2)}; - auto input1 = mTensors[cmd->indexes()->data()[2]]; - std::vector input1C4Shape = tensorShapeFormat(input1); - int input1C4Size[4] = {input1C4Shape.at(0), input1C4Shape.at(3),input1C4Shape.at(1),input1C4Shape.at(2)}; - auto output = mTensors[cmd->indexes()->data()[0]]; - std::vector outputC4Shape = tensorShapeFormat(output); - - int input0Shape[8] = {1, 1, 1, 1, 1, 1, 1, 1}; - int input1Shape[8] = {1, 1, 1, 1, 1, 1, 1, 1}; - int outputShape[8] = {1, 1, 1, 1, 1, 1, 1, 1}; - - int offset0 = output->dimensions() - input0->dimensions(); - int offset1 = output->dimensions() - input1->dimensions(); - for (int i = 0; i < input0->dimensions(); ++i) { - input0Shape[i + offset0] = input0->length(i); - } - for (int i = 0; i < input1->dimensions(); ++i) { - input1Shape[i + offset1] = input1->length(i); - } - for(int i =0;idimensions();++i){ - outputShape[i] = output->length(i); - } - if (TensorUtils::getDescribe(input0)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC) - { - int iN = input0Shape[0]; - int iH = input0Shape[1]; - int iW = input0Shape[2]; - int iC = input0Shape[3]; - - if(input0->dimensions() > 4) - { - for(int i = 4; i < input0->dimensions(); i++) - { - iC *= input0Shape[i]; - } - } - input0Shape[0] = iN; - input0Shape[1] = iC; - input0Shape[2] = iH; - input0Shape[3] = iW; - input0Shape[4] = 1; - } - if (TensorUtils::getDescribe(input1)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC) - { - int iN = input1Shape[0]; - int iH = input1Shape[1]; - int iW = input1Shape[2]; - int iC = input1Shape[3]; - - if(input1->dimensions() > 4) - { - for(int i = 4; i < input1->dimensions(); i++) - { - iC *= input1Shape[i]; - } - } - input1Shape[0] = iN; - input1Shape[1] = iC; - input1Shape[2] = iH; - input1Shape[3] = iW; - input1Shape[4] = 1; - } - if (TensorUtils::getDescribe(output)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC) - { - int iN = outputShape[0]; - int iH = outputShape[1]; - int iW = outputShape[2]; - int iC = outputShape[3]; - - if(input1->dimensions() > 4) - { - for(int i = 4; i < output->dimensions(); i++) - { - iC *= outputShape[i]; - } - } - outputShape[0] = iN; - outputShape[1] = iC; - outputShape[2] = iH; - outputShape[3] = iW; - outputShape[4] = 1; - } - auto BuildOptions = mBuildOptions; - for(int i = 0; i < 4; ++i){ - if(input1C4Shape[i] != outputC4Shape[i]){ - BuildOptions.emplace("-DBROADCAST_INPUT1"); - break; - } - } - - const int Channel = outputC4Shape.at(3); - const int Width = outputC4Shape.at(2); - const int Height = outputC4Shape.at(1); - const int Batch = outputC4Shape.at(0); - const int ChannelBlock = UP_DIV(Channel, 4); - std::string KernelName = "broadcast_binary_buf"; - if(input0Shape[1] == input1Shape[1] && input0C4Size[1] == input1C4Size[1]){ - KernelName = "broadcast_binary_channel_equall_buf"; - } else if((input0->dimensions() == 1 && input0Shape[1] == 1) || (input1->dimensions() == 1 && input1Shape[1] == 1)){ - KernelName = "broadcast_binary_dimmision1_channel1_buf"; - } - unit.kernel = runTime->buildKernel("loop_buf", KernelName, BuildOptions, input0, output); + unit.kernel = runTime->buildKernel("loop_buf", "loop_binary_buf", mBuildOptions, input0, output); uint32_t mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(unit.kernel)); - - std::vector mGlobalWorkSize = {(uint32_t)(Width), (uint32_t)(Height), (uint32_t)(Batch * ChannelBlock)}; + + std::vector mGlobalWorkSize = {(uint32_t)(x), (uint32_t)(y), (uint32_t)(z)}; uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -902,18 +241,18 @@ ErrorCode LoopBinaryBufExecution::onEncode(const std::vector &inputs, ret |= unit.kernel->get().setArg(index++, openCLBuffer(output)); ret |= unit.kernel->get().setArg(index++, openCLBuffer(input0)); ret |= unit.kernel->get().setArg(index++, openCLBuffer(input1)); - ret |= unit.kernel->get().setArg(index++, sizeof(input0Shape), input0Shape); - ret |= unit.kernel->get().setArg(index++, sizeof(input0C4Size), input0C4Size); - ret |= unit.kernel->get().setArg(index++, sizeof(input1Shape), input1Shape); - ret |= unit.kernel->get().setArg(index++, sizeof(input1C4Size), input1C4Size); - ret |= unit.kernel->get().setArg(index++, sizeof(outputShape), outputShape); - ret |= unit.kernel->get().setArg(index++, Width); - ret |= unit.kernel->get().setArg(index++, Height); - ret |= unit.kernel->get().setArg(index++, Channel); - ret |= unit.kernel->get().setArg(index++, ChannelBlock); + ret |= unit.kernel->get().setArg(index++, mStride_src0[0]); + ret |= unit.kernel->get().setArg(index++, mStride_src0[1]); + ret |= unit.kernel->get().setArg(index++, mStride_src0[2]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[0]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[1]); + ret |= unit.kernel->get().setArg(index++, mStride_src1[2]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[0]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[1]); + ret |= unit.kernel->get().setArg(index++, mStride_dst[2]); MNN_CHECK_CL_SUCCESS(ret, "setArg LoopBinaryBufExecution"); - std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, KernelName, unit.kernel).first; + std::vector mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, runTime, "loop_binary_buf", unit.kernel).first; unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; @@ -958,35 +297,35 @@ class LoopBufCreator : public OpenCLBackend::Creator { case BinaryOpOperation_SUB: return new LoopBinaryBufExecution(loop, "in0-in1", op, backend); case BinaryOpOperation_REALDIV: - return new LoopBinaryBufExecution(loop, "sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001))", op, backend); + return new LoopBinaryBufExecution(loop, "sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001))", op, backend); case BinaryOpOperation_MINIMUM: return new LoopBinaryBufExecution(loop, "in0>in1?in1:in0", op, backend); case BinaryOpOperation_MAXIMUM: return new LoopBinaryBufExecution(loop, "in0>in1?in0:in1", op, backend); case BinaryOpOperation_GREATER: - return new LoopBinaryBufExecution(loop, "convert_float4(-isgreater(in0,in1))", op, backend); + return new LoopBinaryBufExecution(loop, "(float)(isgreater(in0,in1))", op, backend); case BinaryOpOperation_LESS: - return new LoopBinaryBufExecution(loop, "convert_float4(-isless(in0,in1))", op, backend); + return new LoopBinaryBufExecution(loop, "(float)(isless(in0,in1))", op, backend); case BinaryOpOperation_LESS_EQUAL: - return new LoopBinaryBufExecution(loop, "convert_float4(-islessequal(in0,in1))", op, backend); + return new LoopBinaryBufExecution(loop, "(float)(islessequal(in0,in1))", op, backend); case BinaryOpOperation_GREATER_EQUAL: - return new LoopBinaryBufExecution(loop, "convert_float4(-isgreaterequal(in0,in1))", op, backend); + return new LoopBinaryBufExecution(loop, "(float)(isgreaterequal(in0,in1))", op, backend); case BinaryOpOperation_EQUAL: - return new LoopBinaryBufExecution(loop, "convert_float4(-isequal(in0,in1))", op, backend); + return new LoopBinaryBufExecution(loop, "(float)(isequal(in0,in1))", op, backend); case BinaryOpOperation_FLOORDIV: - return new LoopBinaryBufExecution(loop, "floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))", op, backend); + return new LoopBinaryBufExecution(loop, "floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))", op, backend); case BinaryOpOperation_FLOORMOD: - return new LoopBinaryBufExecution(loop, "in0-floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1", op, backend); + return new LoopBinaryBufExecution(loop, "in0-floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))*in1", op, backend); case BinaryOpOperation_POW: return new LoopBinaryBufExecution(loop, "pow(in0,in1)", op, backend); case BinaryOpOperation_SquaredDifference: return new LoopBinaryBufExecution(loop, "(in0-in1)*(in0-in1)", op, backend); case BinaryOpOperation_ATAN2: - return new LoopBinaryBufExecution(loop, "(in1==(float4)0?(sign(in0)*(float4)(PI/2)):(atan(in0/in1)+(in1>(float4)0?(float4)0:sign(in0)*(float4)PI)))", op, backend); + return new LoopBinaryBufExecution(loop, "(in1==(float)0?(sign(in0)*(float)(PI/2)):(atan(in0/in1)+(in1>(float)0?(float)0:sign(in0)*(float)PI)))", op, backend); case BinaryOpOperation_NOTEQUAL: - return new LoopBinaryBufExecution(loop, "convert_float4(-isnotequal(in0,in1))", op, backend); + return new LoopBinaryBufExecution(loop, "(float)(isnotequal(in0,in1))", op, backend); case BinaryOpOperation_MOD: - return new LoopBinaryBufExecution(loop, "in0-floor(sign(in1)*in0/(fabs(in1)>(float4)((float)0.0000001)?fabs(in1):(float4)((float)0.0000001)))*in1", op, backend); + return new LoopBinaryBufExecution(loop, "in0-floor(sign(in1)*in0/(fabs(in1)>(float)((float)0.0000001)?fabs(in1):(float)((float)0.0000001)))*in1", op, backend); default: break; } diff --git a/source/backend/opencl/execution/buffer/LoopBufExecution.hpp b/source/backend/opencl/execution/buffer/LoopBufExecution.hpp index aba7848ff..6bad208af 100644 --- a/source/backend/opencl/execution/buffer/LoopBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/LoopBufExecution.hpp @@ -39,14 +39,11 @@ class LoopBatchMatMulBufExecution : public CommonExecution { LoopBatchMatMulBufExecution(const LoopParam *loop, const MNN::Op *op, Backend *bn); virtual ~LoopBatchMatMulBufExecution() = default; virtual ErrorCode onEncode(const std::vector &inputs, const std::vector &outputs) override; - virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; private: const LoopParam *mLoop; std::vector mTensors; - std::vector> mTmpTensors; - std::vector> mOffsetTensors; int mOffset[4]; int mStep[4]; int mIter[4]; @@ -54,8 +51,6 @@ class LoopBatchMatMulBufExecution : public CommonExecution { bool mTransposeA = false; bool mTransposeB = false; std::set mBuildOptions; - bool mBatchGemmOpt = false; - int mBatch, mM, mN, mK; }; @@ -69,6 +64,9 @@ class LoopBinaryBufExecution : public CommonExecution { const LoopParam *mLoop; std::vector mTensors; std::set mBuildOptions; + int mStride_src0[3]; + int mStride_src1[3]; + int mStride_dst[3]; }; } // namespace OpenCL diff --git a/source/backend/opencl/execution/buffer/MatmulBufExecution.cpp b/source/backend/opencl/execution/buffer/MatmulBufExecution.cpp index ea055eb37..4062220fb 100644 --- a/source/backend/opencl/execution/buffer/MatmulBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/MatmulBufExecution.cpp @@ -122,11 +122,21 @@ ErrorCode MatMulBufExecution::onEncode(const std::vector &inputs, cons unit.kernel = runtime->buildKernel("matmul_local_buf", "matmul_local_buf", buildOptions); } else { if(mTransposeA) { - mKernelName = mTransposeB ? "matmul_transA_transB_buf":"matmul_transA_buf"; - } else { - mKernelName = mTransposeB ? "matmul_transB_buf":"matmul_buf"; + buildOptions.emplace(" -DTRANSPOSE_A"); + } + if(mTransposeB) { + buildOptions.emplace(" -DTRANSPOSE_B"); } - unit.kernel = runtime->buildKernel("matmul_buf", mKernelName, buildOptions); + if(M % 4 != 0) { + buildOptions.emplace(" -DM_LEAVE"); + } + if(N % 4 != 0) { + buildOptions.emplace(" -DN_LEAVE"); + } + if(K % 4 != 0) { + buildOptions.emplace(" -DK_LEAVE"); + } + unit.kernel = runtime->buildKernel("matmul_buf", "matmul_buf", buildOptions); } mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); @@ -183,46 +193,22 @@ ErrorCode MatMulBufExecution::onEncode(const std::vector &inputs, cons MNN_CHECK_CL_SUCCESS(ret, "setArg MatMulBufExecution use tile opt"); } else { - if(mTransposeA) { - mGlobalWorkSize = {static_cast(N_4), static_cast(M_4)}; - int idx = 0; - ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); - ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input0)); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input1)); - if(inputs.size() > 2) { - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(inputs[2])); - } - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); - ret |= unit.kernel->get().setArg(idx++, static_cast(K)); - ret |= unit.kernel->get().setArg(idx++, static_cast(K_4)); - ret |= unit.kernel->get().setArg(idx++, static_cast(M)); - ret |= unit.kernel->get().setArg(idx++, static_cast(M_4)); - ret |= unit.kernel->get().setArg(idx++, static_cast(N_4)); - ret |= unit.kernel->get().setArg(idx++, static_cast(N)); - MNN_CHECK_CL_SUCCESS(ret, "setArg MatMulBufExecution mTransposeA"); - - mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), mKernelName, unit.kernel).first; + mGlobalWorkSize = {static_cast(N_4), static_cast(M_4)}; + int idx = 0; + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input0)); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input1)); + if(inputs.size() > 2) { + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(inputs[2])); } - else { + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); + ret |= unit.kernel->get().setArg(idx++, static_cast(M)); + ret |= unit.kernel->get().setArg(idx++, static_cast(N)); + ret |= unit.kernel->get().setArg(idx++, static_cast(K)); + MNN_CHECK_CL_SUCCESS(ret, "setArg MatMulBufExecution mTransposeA"); - mGlobalWorkSize = {static_cast(N_4), static_cast(M)}; - int idx = 0; - ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); - ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input0)); - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input1)); - if(inputs.size() > 2) { - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(inputs[2])); - } - ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); - ret |= unit.kernel->get().setArg(idx++, static_cast(K)); - ret |= unit.kernel->get().setArg(idx++, static_cast(K_4)); - ret |= unit.kernel->get().setArg(idx++, static_cast(N_4)); - ret |= unit.kernel->get().setArg(idx++, static_cast(N)); - MNN_CHECK_CL_SUCCESS(ret, "setArg MatMulBufExecution"); - mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), mKernelName, unit.kernel).first; - } + mLocalWorkSize = localWS2DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), mKernelName, unit.kernel).first; } mOpenCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; diff --git a/source/backend/opencl/execution/buffer/PoolBufExecution.cpp b/source/backend/opencl/execution/buffer/PoolBufExecution.cpp index 8b61b4d77..66e29d1b7 100644 --- a/source/backend/opencl/execution/buffer/PoolBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/PoolBufExecution.cpp @@ -159,7 +159,7 @@ ErrorCode PoolBufExecution::onEncode(const std::vector &inputs, const ret |= unit.kernel->get().setArg(idx++, sizeof(kernelShape), kernelShape); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(redice)); - ret |= unit.kernel->get().setArg(idx++, channelBlocks); + ret |= unit.kernel->get().setArg(idx++, batch); MNN_CHECK_CL_SUCCESS(ret, "setArg PoolBufExecution"); std::string kernelNameTune = "pooling_buf"; @@ -296,6 +296,7 @@ ErrorCode PoolBufExecution::SubgrouponResize(const std::vector &inputs ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(redice)); ret |= unit.kernel->get().setArg(idx++, channels); + ret |= unit.kernel->get().setArg(idx++, batch); ret |= unit.kernel->get().setArg(idx++, in_channel_block); ret |= unit.kernel->get().setArg(idx++, out_channel_block); ret |= unit.kernel->get().setArg(idx++, static_cast(inputpad.left)); diff --git a/source/backend/opencl/execution/buffer/RangeBufExecution.cpp b/source/backend/opencl/execution/buffer/RangeBufExecution.cpp index 913e841e8..cc42c3624 100644 --- a/source/backend/opencl/execution/buffer/RangeBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/RangeBufExecution.cpp @@ -20,43 +20,35 @@ ErrorCode RangeBufExecution::onEncode(const std::vector& inputs, const mUnits.resize(1); auto &unit = mUnits[0]; auto openCLBackend = static_cast(backend()); - auto runtime = openCLBackend->getOpenCLRuntime(); - unit.kernel = runtime->buildKernel("range_buf", "range_buf", mBuildOptions, inputs[0], outputs[0]); - mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); - + auto runtime = openCLBackend->getOpenCLRuntime(); std::vector outputShape = tensorShapeFormat(outputs[0]); - - int batch = outputShape.at(0); - int outputHeight = outputShape.at(1); - int outputWidth = outputShape.at(2); - int channels = outputShape.at(3); - int channelBlocks = (channels + 3) / 4; - + int totalSize = outputShape[0] * outputShape[1] * outputShape[2] * outputShape[3]; mGlobalWorkSize = { - static_cast(outputWidth), - static_cast(outputHeight), - static_cast(batch * channelBlocks) + static_cast(UP_DIV(totalSize, 4)), + static_cast(1) }; + std::set buildOption = mBuildOptions; + if((totalSize % 4) != 0){ + buildOption.emplace("-DPACK_LEAVE"); + } + unit.kernel = runtime->buildKernel("range_buf", "range_buf", buildOption, inputs[0], outputs[0]); + mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[2]); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(inputs[0])); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(inputs[2])); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(outputs[0])); - ret |= unit.kernel->get().setArg(idx++, outputWidth); - ret |= unit.kernel->get().setArg(idx++, outputHeight); - ret |= unit.kernel->get().setArg(idx++, channels); - ret |= unit.kernel->get().setArg(idx++, channelBlocks); + ret |= unit.kernel->get().setArg(idx++, totalSize); MNN_CHECK_CL_SUCCESS(ret, "setArg RangeBufExecution"); std::string kernelName = "range_buf"; - mLocalSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, openCLBackend->getOpenCLRuntime(), kernelName, unit.kernel).first; - openCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalSize); - unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalSize[0], mLocalSize[1], mLocalSize[2]}; + mLocalSize = localWS2DDefault(mGlobalWorkSize, mMaxWorkGroupSize, openCLBackend->getOpenCLRuntime(), kernelName, unit.kernel).first; + openCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalSize); + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; + unit.localWorkSize = {mLocalSize[0], mLocalSize[1]}; return NO_ERROR; } diff --git a/source/backend/opencl/execution/buffer/RasterBufExecution.cpp b/source/backend/opencl/execution/buffer/RasterBufExecution.cpp index d663a6c9f..8db39af02 100644 --- a/source/backend/opencl/execution/buffer/RasterBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/RasterBufExecution.cpp @@ -36,24 +36,42 @@ ErrorCode RasterBufExecution::onEncode(const std::vector &____inputs, } auto des = TensorUtils::getDescribe(output); auto outputDes = TensorUtils::getDescribe(output); - mNeedZero = !TensorUtils::regionIsFull(output); auto regionNum = des->regions.size(); auto mOpenCLBackend = static_cast(backend()); auto runtime = mOpenCLBackend->getOpenCLRuntime(); - - bool cancombine = CanCombine(outputs); + int kernel_idx = 0; + auto outputShape = tensorShapeFormat(output); + mFast = false; + if (outputDes->dimensionFormat == MNN_DATA_FORMAT_NC4HW4) { + mFast = true; + for (int i=0; i< des->regions.size(); ++i) { + auto& slice = des->regions[i]; + if (TensorUtils::getDescribe(slice.origin)->dimensionFormat != MNN_DATA_FORMAT_NC4HW4) { + mFast = false; + break; + } + if (!OpCommonUtils::canBlitFast(slice, output, 4, true)) { + mFast = false; + break; + } + } + } + mNeedZero = !TensorUtils::regionIsFull(output); + mNeedZero = mNeedZero || ((outputShape[3] % 4) != 0 && MNN_DATA_FORMAT_NC4HW4 == outputDes->dimensionFormat && !mFast); + bool cancombine = CanCombine(outputs) && (!mFast); if(cancombine){ regionNum = 1; } - int kernel_idx = 0; mUnits.resize(regionNum); - auto outputShape = tensorShapeFormat(output); - if(mNeedZero || (outputShape[3] % 4) != 0) + if(mNeedZero) { mUnits.resize(regionNum + 1); - int region[] = {outputShape[0], ROUND_UP(outputShape[3], 4), outputShape[1], outputShape[2]};//nhwc + int region[] = {outputShape[0], outputShape[3], outputShape[1], outputShape[2]};//nchw + if(MNN_DATA_FORMAT_NC4HW4 == outputDes->dimensionFormat){ + region[1] = ROUND_UP(outputShape[3], 4); + } Unit &unit = mUnits[kernel_idx++]; - unit.kernel = runtime->buildKernel("raster", "buffer_set_zero", {}, output, output); + unit.kernel = runtime->buildKernel("raster_buf", "buffer_set_zero", {}, output, output); unit.localWorkSize = {8, 8}; unit.globalWorkSize = {(uint32_t)UP_DIV((region[2] * region[3]), 8)*8, (uint32_t)UP_DIV((region[0] * region[1]), 8)*8}; @@ -73,6 +91,64 @@ ErrorCode RasterBufExecution::onEncode(const std::vector &____inputs, mOpenCLBackend->recordKernel2d(unit.kernel, {(uint32_t)UP_DIV((region[2] * region[3]), 8)*8, (uint32_t)UP_DIV((region[0] * region[1]), 8)*8}, {8, 8}); } + if(mFast) + { + // nc4hw4 buffer raster + for (auto& slice : des->regions) + { + auto origin = slice.origin; + auto inputShape = tensorShapeFormat(origin); + Tensor::InsideDescribe::Region C4Region; + OpCommonUtils::turnToPackRegion(slice, C4Region, output, 4, true); + Unit &unit = mUnits[kernel_idx++]; + unit.kernel = runtime->buildKernel("raster_buf", "raster_nc4hw4_buffer", {}, origin, output); + + const std::vector gws = {(uint32_t)C4Region.size[2], + (uint32_t)C4Region.size[1], + (uint32_t)C4Region.size[0]}; + uint32_t mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); + + auto outputShape = tensorShapeFormat(output); + auto sliceShape = tensorShapeFormat(slice.origin); + + uint32_t idx = 0; + cl_int ret = CL_SUCCESS; + ret |= unit.kernel->get().setArg(idx++, gws[0]); + ret |= unit.kernel->get().setArg(idx++, gws[1]); + ret |= unit.kernel->get().setArg(idx++, gws[2]); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(slice.origin)); + ret |= unit.kernel->get().setArg(idx++, C4Region.src.offset); + ret |= unit.kernel->get().setArg(idx++, C4Region.src.stride[0]); + ret |= unit.kernel->get().setArg(idx++, C4Region.src.stride[1]); + ret |= unit.kernel->get().setArg(idx++, C4Region.src.stride[2]); + ret |= unit.kernel->get().setArg(idx++, sliceShape[1]); + ret |= unit.kernel->get().setArg(idx++, sliceShape[2]); + ret |= unit.kernel->get().setArg(idx++, sliceShape[3]); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); + ret |= unit.kernel->get().setArg(idx++, C4Region.dst.offset); + ret |= unit.kernel->get().setArg(idx++, C4Region.dst.stride[0]); + ret |= unit.kernel->get().setArg(idx++, C4Region.dst.stride[1]); + ret |= unit.kernel->get().setArg(idx++, C4Region.dst.stride[2]); + ret |= unit.kernel->get().setArg(idx++, outputShape[1]); + ret |= unit.kernel->get().setArg(idx++, outputShape[2]); + ret |= unit.kernel->get().setArg(idx++, outputShape[3]); + if(ret != CL_SUCCESS) + { + MNN_PRINT("setArg err %d\n", (int)ret); + } + std::string name = "raster_nc4hw4_buffer"; + const std::vector lws = localWS3DDefault(gws, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), name, unit.kernel).first; + + unit.localWorkSize = {lws[0], lws[1], lws[2]}; + + unit.globalWorkSize = {ROUND_UP(gws[0], std::max((uint32_t)1, lws[0])), + ROUND_UP(gws[1], std::max((uint32_t)1, lws[1])), + ROUND_UP(gws[2], std::max((uint32_t)1, lws[2]))}; + mOpenCLBackend->recordKernel3d(unit.kernel, gws, lws); + } + return NO_ERROR; + } + if(cancombine){ auto regions = des->regions; auto slice = regions[0]; @@ -82,17 +158,11 @@ ErrorCode RasterBufExecution::onEncode(const std::vector &____inputs, std::set buildOptions; auto origin = slice.origin; auto inputShape = tensorShapeFormat(origin); - if(TensorUtils::getDescribe(origin)->dimensionFormat == MNN_DATA_FORMAT_NHWC) - { - buildOptions.emplace(" -DINPUT_DATA_FORMAT_NHWC"); - } - if(outputDes->dimensionFormat == MNN_DATA_FORMAT_NHWC)//nhwc buffer to Image - { - buildOptions.emplace(" -DOUTPUT_DATA_FORMAT_NHWC"); - } + buildOptions.emplace("-DINPUT_FORMAT=" + std::to_string(TensorUtils::getDescribe(origin)->dimensionFormat)); + buildOptions.emplace("-DOUTPUT_FORMAT=" + std::to_string(outputDes->dimensionFormat)); Unit &unit = mUnits[kernel_idx++]; - unit.kernel = runtime->buildKernel("raster_buf", "raster_direct_buffer", buildOptions, output, output); + unit.kernel = runtime->buildKernel("raster_buf", "raster_direct_buffer", buildOptions, origin, output); const std::vector gws = {(uint32_t)slice.size[2] * nums, (uint32_t)slice.size[1], @@ -114,6 +184,7 @@ ErrorCode RasterBufExecution::onEncode(const std::vector &____inputs, ret |= unit.kernel->get().setArg(idx++, inputShape[2]); ret |= unit.kernel->get().setArg(idx++, inputShape[1]); ret |= unit.kernel->get().setArg(idx++, inputShape[3]); + ret |= unit.kernel->get().setArg(idx++, inputShape[0]); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); ret |= unit.kernel->get().setArg(idx++, slice.dst.offset); ret |= unit.kernel->get().setArg(idx++, dst_offset); @@ -123,6 +194,7 @@ ErrorCode RasterBufExecution::onEncode(const std::vector &____inputs, ret |= unit.kernel->get().setArg(idx++, outputShape[2]); ret |= unit.kernel->get().setArg(idx++, outputShape[1]); ret |= unit.kernel->get().setArg(idx++, outputShape[3]); + ret |= unit.kernel->get().setArg(idx++, outputShape[0]); if(ret != CL_SUCCESS) { MNN_PRINT("setArg err %d\n", (int)ret); @@ -141,18 +213,11 @@ ErrorCode RasterBufExecution::onEncode(const std::vector &____inputs, auto inputShape = tensorShapeFormat(origin); int src_offset = 0; int dst_offset = 0; - if(TensorUtils::getDescribe(origin)->dimensionFormat == MNN_DATA_FORMAT_NHWC) - { - buildOptions.emplace(" -DINPUT_DATA_FORMAT_NHWC"); - } - if(outputDes->dimensionFormat == MNN_DATA_FORMAT_NHWC)//nhwc buffer to Image - { - buildOptions.emplace(" -DOUTPUT_DATA_FORMAT_NHWC"); - } + buildOptions.emplace("-DINPUT_FORMAT=" + std::to_string(TensorUtils::getDescribe(origin)->dimensionFormat)); + buildOptions.emplace("-DOUTPUT_FORMAT=" + std::to_string(outputDes->dimensionFormat)); Unit &unit = mUnits[kernel_idx++]; - unit.kernel = runtime->buildKernel("raster_buf", "raster_direct_buffer", buildOptions, output, output); - + unit.kernel = runtime->buildKernel("raster_buf", "raster_direct_buffer", buildOptions, origin, output); const std::vector gws = {(uint32_t)slice.size[2], (uint32_t)slice.size[1], (uint32_t)slice.size[0]}; @@ -173,6 +238,7 @@ ErrorCode RasterBufExecution::onEncode(const std::vector &____inputs, ret |= unit.kernel->get().setArg(idx++, inputShape[2]); ret |= unit.kernel->get().setArg(idx++, inputShape[1]); ret |= unit.kernel->get().setArg(idx++, inputShape[3]); + ret |= unit.kernel->get().setArg(idx++, inputShape[0]); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); ret |= unit.kernel->get().setArg(idx++, slice.dst.offset); ret |= unit.kernel->get().setArg(idx++, dst_offset); @@ -182,6 +248,7 @@ ErrorCode RasterBufExecution::onEncode(const std::vector &____inputs, ret |= unit.kernel->get().setArg(idx++, outputShape[2]); ret |= unit.kernel->get().setArg(idx++, outputShape[1]); ret |= unit.kernel->get().setArg(idx++, outputShape[3]); + ret |= unit.kernel->get().setArg(idx++, outputShape[0]); if(ret != CL_SUCCESS) { MNN_PRINT("setArg err %d\n", (int)ret); diff --git a/source/backend/opencl/execution/buffer/ReductionBufExecution.cpp b/source/backend/opencl/execution/buffer/ReductionBufExecution.cpp index 83fc56474..bc1760d34 100644 --- a/source/backend/opencl/execution/buffer/ReductionBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/ReductionBufExecution.cpp @@ -23,25 +23,31 @@ ReductionBufExecution::ReductionBufExecution(const std::vector &inputs mAxis = op->main_as_ReductionParam()->dim()->data()[0]; switch (op->main_as_ReductionParam()->operation()) { case ReductionType_MEAN: - mReductType = 0; + mBuildOptions.emplace("-DOPERATE(a,b)=(a+b)"); + mBuildOptions.emplace("-DGET_AVG"); + mBuildOptions.emplace("-DVALUE=0"); break; case ReductionType_MAXIMUM: - mReductType = 1; + mBuildOptions.emplace("-DOPERATE(a,b)=max(a,b)"); + mBuildOptions.emplace("-DVALUE=-FLT_MAX"); break; case ReductionType_MINIMUM: - mReductType = 2; + mBuildOptions.emplace("-DOPERATE(a,b)=min(a,b)"); + mBuildOptions.emplace("-DVALUE=FLT_MAX"); break; case ReductionType_PROD: - mReductType = 3; + mBuildOptions.emplace("-DOPERATE(a,b)=(a*b)"); + mBuildOptions.emplace("-DVALUE=1"); break; case ReductionType_SUM: - mReductType = 4; + mBuildOptions.emplace("-DOPERATE(a,b)=(a+b)"); + mBuildOptions.emplace("-DVALUE=0"); break; default: MNN_ASSERT(false); break; } - auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("reduction_buf", "reduct_width_buf", {"-DOPERATE(a,b)=(a+b)","-DVALUE=0","-DLOCAL_SIZE=512"}, inputs[0], outputs[0]); + auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("reduction_buf", "reduct_buf", {"-DOPERATE(a,b)=(a+b)","-DVALUE=0","-DLOCAL_SIZE=512"}, inputs[0], outputs[0]); mMaxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel)); #ifdef LOG_VERBOSE MNN_PRINT("end ReductionBufExecution init !\n"); @@ -76,102 +82,24 @@ ErrorCode ReductionBufExecution::onEncode(const std::vector &inputs, c inside *= input->length(i); } int dim = input->length(mAxis); - int local_size = 0; - if(dim >= 16){ - mUseLocal = true; - } - - std::vector inputShape = tensorShapeFormat(input); - std::vector outputShape = tensorShapeFormat(output); - - int batch = inputShape.at(0); - int inputHeight = inputShape.at(1); - int inputWidth = inputShape.at(2); - int inputChannels = inputShape.at(3); - int inputChannelBlocks = (inputChannels + 3) / 4; - int outputBatch = outputShape.at(0); - int outputHeight = outputShape.at(1); - int outputWidth = outputShape.at(2); - int outputChannels = outputShape.at(3); - int outputChannelBlocks = (outputChannels + 3) / 4; - - std::set buildOption; - switch (mReductType) { - case 0: - buildOption.emplace("-DOPERATE(a,b)=(a+b)"); - buildOption.emplace("-DGET_AVG"); - buildOption.emplace("-DVALUE=0"); - break; - case 1: - buildOption.emplace("-DOPERATE(a,b)=max(a,b)"); - buildOption.emplace("-DVALUE=-FLT_MAX"); - break; - case 2: - buildOption.emplace("-DOPERATE(a,b)=min(a,b)"); - buildOption.emplace("-DVALUE=FLT_MAX"); - break; - case 3: - buildOption.emplace("-DOPERATE(a,b)=(a*b)"); - buildOption.emplace("-DVALUE=1"); - break; - case 4: - buildOption.emplace("-DOPERATE(a,b)=(a+b)"); - buildOption.emplace("-DVALUE=0"); - break; - default: - MNN_ASSERT(false); - break; + int localSize = getLocalSize(dim, MaxLocalSize); + if(localSize < 4){ + localSize = 1; } - mGlobalWorkSize = { - static_cast(outputWidth), - static_cast(outputHeight), - static_cast(outputBatch * outputChannelBlocks) - }; - - if(mUseLocal){ - if(batch * inputHeight * inputChannels == outside && 1 == inside && dim == inputWidth){ - local_size = getLocalSize(inputWidth, MaxLocalSize); - buildOption.emplace("-DLOCAL_SIZE=" + std::to_string(local_size)); - unit.kernel = runtime->buildKernel("reduction_buf", "reduct_width_buf", buildOption, input, output); - }else if(batch * inputChannels == outside && inputWidth == inside && dim == inputHeight){ - local_size = getLocalSize(inputHeight, MaxLocalSize); - buildOption.emplace("-DLOCAL_SIZE=" + std::to_string(local_size)); - unit.kernel = runtime->buildKernel("reduction_buf", "reduct_height_buf", buildOption, input, output); - }else if(batch == outside && inputWidth * inputHeight == inside && dim == inputChannels){ - local_size = getLocalSize(inputChannelBlocks - 1, MaxLocalSize); - buildOption.emplace("-DLOCAL_SIZE=" + std::to_string(local_size)); - if(output->buffer().dimensions == 1){ - unit.kernel = runtime->buildKernel("reduction_buf", "reduct_channel_dim1_buf", buildOption, input, output); - }else{ - unit.kernel = runtime->buildKernel("reduction_buf", "reduct_channel_buf", buildOption, input, output); - } - mGlobalWorkSize[2] = static_cast(outputBatch * outputChannels); - }else if(1 == outside && inputWidth * inputHeight * inputChannels == inside && dim == batch){ - local_size = getLocalSize(batch, MaxLocalSize); - buildOption.emplace("-DLOCAL_SIZE=" + std::to_string(local_size)); - unit.kernel = runtime->buildKernel("reduction_buf", "reduct_batch_buf", buildOption, input, output); - } - mGlobalWorkSize[0] *= local_size; - }else{ - buildOption.emplace("-DLOCAL_SIZE=0"); - if(batch * inputHeight * inputChannels == outside && 1 == inside && dim == inputWidth){ - unit.kernel = runtime->buildKernel("reduction_buf", "reduct_width_buf", buildOption, input, output); - }else if(batch * inputChannels == outside && inputWidth == inside && dim == inputHeight){ - unit.kernel = runtime->buildKernel("reduction_buf", "reduct_height_buf", buildOption, input, output); - }else if(batch == outside && inputWidth * inputHeight == inside && dim == inputChannels){ - if(output->buffer().dimensions == 1){ - unit.kernel = runtime->buildKernel("reduction_buf", "reduct_channel_dim1_buf", buildOption, input, output); - }else{ - unit.kernel = runtime->buildKernel("reduction_buf", "reduct_channel_buf", buildOption, input, output); - } - mGlobalWorkSize[2] = static_cast(outputBatch * outputChannels); - }else if(1 == outside && inputWidth * inputHeight * inputChannels == inside && dim == batch){ - unit.kernel = runtime->buildKernel("reduction_buf", "reduct_batch_buf", buildOption, input, output); - } + std::set buildOptions = mBuildOptions; + buildOptions.emplace("-DREDUCT_LOCAL_SIZE=" + std::to_string(localSize)); + std::string kernelName; + if(inside % 4 == 0){ + unit.kernel = runtime->buildKernel("reduction_buf", "reduct_v4_buf", buildOptions, input, output); + mGlobalWorkSize = {static_cast(localSize), static_cast(UP_DIV(inside, 4)), static_cast(outside)}; + }else { + unit.kernel = runtime->buildKernel("reduction_buf", "reduct_buf", buildOptions, input, output); + mGlobalWorkSize = {static_cast(localSize), static_cast(inside), static_cast(outside)}; } - //printf("reduce axis:%d , %d %d %d %d, useLocal:%d\n", mAxis[0], inputShape[0], inputShape[1], inputShape[2], inputShape[3], mUseLocal); + mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); + mLocalWorkSize = {(uint32_t)(localSize), 1, 1}; mUnits.resize(1); uint32_t idx = 0; @@ -181,20 +109,12 @@ ErrorCode ReductionBufExecution::onEncode(const std::vector &inputs, c ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[2]); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); - ret |= unit.kernel->get().setArg(idx++, inputWidth); - ret |= unit.kernel->get().setArg(idx++, inputHeight); - ret |= unit.kernel->get().setArg(idx++, inputChannels); - ret |= unit.kernel->get().setArg(idx++, batch); - ret |= unit.kernel->get().setArg(idx++, inputChannelBlocks); - ret |= unit.kernel->get().setArg(idx++, outputWidth); - ret |= unit.kernel->get().setArg(idx++, outputHeight); - ret |= unit.kernel->get().setArg(idx++, outputChannels); - ret |= unit.kernel->get().setArg(idx++, outputChannelBlocks); + ret |= unit.kernel->get().setArg(idx++, inside); + ret |= unit.kernel->get().setArg(idx++, outside); + ret |= unit.kernel->get().setArg(idx++, dim); MNN_CHECK_CL_SUCCESS(ret, "setArg ReductionBufExecution"); - if(mUseLocal){ - mLocalWorkSize = {static_cast(local_size), 1, 1}; - }else{ + if(localSize == 1){ mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); std::string kernelName = "reduct_buf"; mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, openCLBackend->getOpenCLRuntime(), kernelName, unit.kernel).first; diff --git a/source/backend/opencl/execution/buffer/ReductionBufExecution.hpp b/source/backend/opencl/execution/buffer/ReductionBufExecution.hpp index fb1d78172..091617b82 100644 --- a/source/backend/opencl/execution/buffer/ReductionBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/ReductionBufExecution.hpp @@ -26,12 +26,11 @@ class ReductionBufExecution : public CommonExecution { int getLocalSize(int size, int maxGroupSize); OpenCLBackend *mOpenCLBackend; MNN::DataType mdataType; - int mReductType; int mAxis; std::vector mGlobalWorkSize = {1, 1, 1}; std::vector mLocalWorkSize{1, 1, 1}; - bool mUseLocal = false; uint32_t mMaxWorkGroupSize; + std::set mBuildOptions; }; } // namespace OpenCL diff --git a/source/backend/opencl/execution/buffer/ReluBufExecution.cpp b/source/backend/opencl/execution/buffer/ReluBufExecution.cpp index 6d1b9ee3d..b268f8dc4 100644 --- a/source/backend/opencl/execution/buffer/ReluBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/ReluBufExecution.cpp @@ -61,7 +61,7 @@ ErrorCode ReluBufExecution::onEncode(const std::vector &inputs, const int nhwcArray[4] = {nhwc[0], nhwc[1], nhwc[2], UP_DIV(nhwc[3], 4)}; auto imageWidth = nhwc[0] * UP_DIV(nhwc[3], 4); auto imageHeight = nhwc[1] * nhwc[2]; - + std::vector localSize = {1, 1}; std::vector globalSize = {(uint32_t)imageWidth, (uint32_t)imageHeight}; @@ -71,7 +71,10 @@ ErrorCode ReluBufExecution::onEncode(const std::vector &inputs, const return SubgrouponResize(inputs, outputs); } #endif /* MNN_SUPPORT_INTEL_SUBGROUP */ - mUnits[0].kernel = runTime->buildKernel("binary_buf", "prelu_buf", {"-DOPERATOR=select(in0*in1,in0,in0>=(float4)0)"}, inputs[0], outputs[0]); + + std::set buildOption; + buildOption.emplace("-DOPERATOR=select(in0*in1,in0,in0>=(float4)0)"); + mUnits[0].kernel = runTime->buildKernel("binary_buf", "prelu_buf", buildOption, inputs[0], outputs[0]); mMaxWorkGroupSize = static_cast(runTime->getMaxWorkGroupSize(mUnits[0].kernel)); int fullCount[2] = {1, 1}; diff --git a/source/backend/opencl/execution/buffer/ScaleBufExecution.cpp b/source/backend/opencl/execution/buffer/ScaleBufExecution.cpp index 43ce99a58..764ea8c95 100644 --- a/source/backend/opencl/execution/buffer/ScaleBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/ScaleBufExecution.cpp @@ -93,14 +93,10 @@ ScaleBufExecution::ScaleBufExecution(const std::vector &inputs, const } openclBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(biasBuffer, biasPtrCL); - buildOptions.emplace("-DBIAS"); + mBuildOptions.emplace("-DBIAS"); mHasBias = true; } - auto runtime = mOpenCLBackend->getOpenCLRuntime(); - unit.kernel = runtime->buildKernel("scale_buf", "scale_buf", buildOptions); - mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); - #ifdef LOG_VERBOSE MNN_PRINT("end ScaleBufExecution init !\n"); #endif @@ -122,13 +118,15 @@ ErrorCode ScaleBufExecution::onEncode(const std::vector &inputs, const const int height = inputShape.at(1); const int width = inputShape.at(2); const int channels = inputShape.at(3); - + const int inside = width * height; const int channelBlocks = UP_DIV(channels, 4); - mGlobalWorkSize = {static_cast(width * channelBlocks), - static_cast(height * batch)}; + std::set buildOptions = mBuildOptions; + unit.kernel = runtime->buildKernel("scale_buf", "scale_buf", buildOptions); + mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); - int shape[4] = {batch, height, width, channelBlocks}; + mGlobalWorkSize = {static_cast(inside), + static_cast(channelBlocks * batch)}; uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); @@ -139,7 +137,9 @@ ErrorCode ScaleBufExecution::onEncode(const std::vector &inputs, const ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mBias.get())); } ret |= unit.kernel->get().setArg(idx++, openCLBuffer(outputs[0])); - ret |= unit.kernel->get().setArg(idx++, shape); + ret |= unit.kernel->get().setArg(idx++, channelBlocks); + ret |= unit.kernel->get().setArg(idx++, batch); + ret |= unit.kernel->get().setArg(idx++, inside); MNN_CHECK_CL_SUCCESS(ret, "setArg ScaleBufExecution"); std::string name = "scale_buf"; diff --git a/source/backend/opencl/execution/buffer/ScaleBufExecution.hpp b/source/backend/opencl/execution/buffer/ScaleBufExecution.hpp index b01897bf3..f288a6a73 100644 --- a/source/backend/opencl/execution/buffer/ScaleBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/ScaleBufExecution.hpp @@ -31,6 +31,7 @@ class ScaleBufExecution : public CommonExecution { std::vector mLocalWorkSize{1, 1, 1}; OpenCLBackend *mOpenCLBackend; bool mHasBias = false; + std::set mBuildOptions; }; } // namespace OpenCL diff --git a/source/backend/opencl/execution/buffer/SelectBufExecution.cpp b/source/backend/opencl/execution/buffer/SelectBufExecution.cpp index 385c853a4..94a2f934a 100644 --- a/source/backend/opencl/execution/buffer/SelectBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/SelectBufExecution.cpp @@ -34,13 +34,12 @@ ErrorCode SelectBufExecution::onEncode(const std::vector& inputs, const mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); std::vector outputShape = tensorShapeFormat(outputs[0]); - - int batch = outputShape.at(0); - int outputHeight = outputShape.at(1); - int outputWidth = outputShape.at(2); - int channels = outputShape.at(3); - int channelBlocks = (channels + 3) / 4; - int outSize = batch * channelBlocks * outputWidth * outputHeight * 4; + int outSize = 0; + if(MNN::MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(outputs[0])->dimensionFormat){ + outSize = outputShape[0] * outputShape[1] * outputShape[2] * ROUND_UP(outputShape[3], 4); + }else{ + outSize = outputShape[0] * outputShape[1] * outputShape[2] * outputShape[3]; + } mGlobalWorkSize = { static_cast(outSize), diff --git a/source/backend/opencl/execution/buffer/SelfAttentionBufExecution.cpp b/source/backend/opencl/execution/buffer/SelfAttentionBufExecution.cpp index bc7aba4ef..6f4a3a97f 100644 --- a/source/backend/opencl/execution/buffer/SelfAttentionBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/SelfAttentionBufExecution.cpp @@ -150,6 +150,7 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorget().setArg(index++, seq_len); ret |= mKernel_split[seq_idx]->get().setArg(index++, mNumHead); ret |= mKernel_split[seq_idx]->get().setArg(index++, mHeadDim); + ret |= mKernel_split[seq_idx]->get().setArg(index++, batch); ret |= mKernel_split[seq_idx]->get().setArg(index++, seq_idx); MNN_CHECK_CL_SUCCESS(ret, "setArg split_transpose_qkv"); mLocalWorkSizeSplit[seq_idx] = localWS3DDefault(mGlobalWorkSizeSplit[seq_idx], maxWorkGroupSize, runtime, "split_transpose_qkv", mKernel_split[seq_idx]).first; @@ -216,6 +217,10 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorget().setArg(idx++, static_cast(e_pack)); @@ -224,11 +229,11 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorget().setArg(idx++, alpha); ret |= mKernel_qk[seq_idx]->get().setArg(idx++, beta); ret |= mKernel_qk[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQ.get())); - ret |= mKernel_qk[seq_idx]->get().setArg(idx++, batch_offset_a); ret |= mKernel_qk[seq_idx]->get().setArg(idx++, openCLBuffer(mTempK.get())); - ret |= mKernel_qk[seq_idx]->get().setArg(idx++, batch_offset_b); ret |= mKernel_qk[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQK.get())); - ret |= mKernel_qk[seq_idx]->get().setArg(idx++, batch_offset_c); + ret |= mKernel_qk[seq_idx]->get().setArg(idx++, batch_offset); + ret |= mKernel_qk[seq_idx]->get().setArg(idx++, stride); + ret |= mKernel_qk[seq_idx]->get().setArg(idx++, group); MNN_CHECK_CL_SUCCESS(ret, "setArg Self-Attention batchmatmul qk Kernel"); mOpenCLBackend->recordKernel3d(mKernel_qk[seq_idx], mGlobalWorkSizeQk[seq_idx], mLocalWorkSizeQk[seq_idx]); @@ -283,6 +288,9 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorget().setArg(index++, mGlobalWorkSizeTrans[seq_idx][0]); + ret |= mKernel_trans[seq_idx]->get().setArg(index++, mGlobalWorkSizeTrans[seq_idx][1]); + ret |= mKernel_trans[seq_idx]->get().setArg(index++, mGlobalWorkSizeTrans[seq_idx][2]); ret |= mKernel_trans[seq_idx]->get().setArg(index++, openCLBuffer(mTempSoftMax.get())); ret |= mKernel_trans[seq_idx]->get().setArg(index++, openCLBuffer(mTempTrans.get())); ret |= mKernel_trans[seq_idx]->get().setArg(index++, loop); @@ -291,6 +299,10 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorgetOpenCLRuntime(), "trans_3d_buf", mKernel_trans[seq_idx]).first; + mGlobalWorkSizeTrans[seq_idx][0] = ROUND_UP(mGlobalWorkSizeTrans[seq_idx][0], std::max((uint32_t)1, mLocalWorkSizeTrans[seq_idx][0])); + mGlobalWorkSizeTrans[seq_idx][1] = ROUND_UP(mGlobalWorkSizeTrans[seq_idx][1], std::max((uint32_t)1, mLocalWorkSizeTrans[seq_idx][1])); + mGlobalWorkSizeTrans[seq_idx][2] = ROUND_UP(mGlobalWorkSizeTrans[seq_idx][2], std::max((uint32_t)1, mLocalWorkSizeTrans[seq_idx][2])); + mOpenCLBackend->recordKernel3d(mKernel_trans[seq_idx], mGlobalWorkSizeTrans[seq_idx], mLocalWorkSizeTrans[seq_idx]); } @@ -361,6 +373,10 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorget().setArg(idx++, static_cast(e_pack)); @@ -369,11 +385,11 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorget().setArg(idx++, alpha); ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, beta); ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, openCLBuffer(mTempTrans.get())); - ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, batch_offset_a); ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, openCLBuffer(mTempV.get())); - ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, batch_offset_b); ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQKV.get())); - ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, batch_offset_c); + ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, batch_offset); + ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, stride); + ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, group); MNN_CHECK_CL_SUCCESS(ret, "setArg Self-Attention batchmatmul qkv Kernel"); mOpenCLBackend->recordKernel3d(mKernel_qkv[seq_idx], mGlobalWorkSizeQkv[seq_idx], mLocalWorkSizeQkv[seq_idx]); } @@ -403,6 +419,7 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorget().setArg(index++, seq_len_piece); ret |= mKernel_clip[seq_idx]->get().setArg(index++, mNumHead); ret |= mKernel_clip[seq_idx]->get().setArg(index++, mHeadDim); + ret |= mKernel_clip[seq_idx]->get().setArg(index++, batch); ret |= mKernel_clip[seq_idx]->get().setArg(index++, seq_idx); mLocalWorkSizeClip[seq_idx] = localWS3DDefault(mGlobalWorkSizeClip[seq_idx], maxWorkGroupSize, runtime, "clip_transpose_qkv", mKernel_clip[seq_idx]).first; diff --git a/source/backend/opencl/execution/buffer/SoftmaxBufExecution.cpp b/source/backend/opencl/execution/buffer/SoftmaxBufExecution.cpp index fa2dc9216..a01c91832 100644 --- a/source/backend/opencl/execution/buffer/SoftmaxBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/SoftmaxBufExecution.cpp @@ -17,27 +17,10 @@ SoftmaxBufExecution::SoftmaxBufExecution(const std::vector &inputs, in : CommonExecution(backend, Op) { mAxis = axis; mOpenCLBackend = static_cast(backend); - auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("softmax_buf", "softmax_channel", {"-DSOFTMAX_LOCAL_SIZE=512"}); + auto kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("softmax_buf", "softmax_buf", {"-DSOFTMAX_LOCAL_SIZE=512"}); mMaxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(kernel)); } -bool SoftmaxBufExecution::buildSoftmaxKernel(int localSize) { - auto runtime = mOpenCLBackend->getOpenCLRuntime(); - std::set buildOptions; - buildOptions.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize)); - std::string kernelName; - if (mAxis == 1) { - mUnits[0].kernel = runtime->buildKernel("softmax_buf", "softmax_channel", buildOptions); - } else if (mAxis == 2) { - mUnits[0].kernel = runtime->buildKernel("softmax_buf", "softmax_height", buildOptions); - } else { - MNN_ASSERT(mAxis == 3); - mUnits[0].kernel = runtime->buildKernel("softmax_buf", "softmax_width", buildOptions); - } - mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(mUnits[0].kernel)); - return true; -} - int SoftmaxBufExecution::getLocalSize(int size, int maxGroupSize){ int local_size = 1; while(local_size * 2 <= maxGroupSize && local_size * 2 <= size){ @@ -47,8 +30,7 @@ int SoftmaxBufExecution::getLocalSize(int size, int maxGroupSize){ } ErrorCode SoftmaxBufExecution::onEncode(const std::vector &inputs, const std::vector &outputs) { - mUnits.resize(1); - auto &unit = mUnits[0]; + mUnits.clear(); Tensor *input = inputs[0]; Tensor *output = outputs[0]; @@ -57,6 +39,18 @@ ErrorCode SoftmaxBufExecution::onEncode(const std::vector &inputs, con auto MaxLocalSize = std::min(std::min(runtime->getMaxWorkItemSizes()[0], mMaxWorkGroupSize), static_cast(256)); + const auto layout = TensorUtils::getDescribe(input)->dimensionFormat; + mNeedUnpackC4 = layout == MNN_DATA_FORMAT_NC4HW4; + if (mNeedUnpackC4) { + int totalSize = 1; + for (int i = 1; i < dims; ++i) { + totalSize *= input->length(i); + } + mTempTensor.reset(Tensor::createDevice({totalSize})); + mOpenCLBackend->onAcquireBuffer(mTempTensor.get(), Backend::DYNAMIC); + mOpenCLBackend->onReleaseBuffer(mTempTensor.get(), Backend::DYNAMIC); + } + int inside = 1; int outside = 1; int channel = 1; @@ -67,62 +61,123 @@ ErrorCode SoftmaxBufExecution::onEncode(const std::vector &inputs, con for (int i = mAxis + 1; i < dims; ++i) { inside *= input->length(i); } - - std::vector inputShape = tensorShapeFormat(input); - std::vector outputShape = tensorShapeFormat(output); - - const int inputBatch = inputShape.at(0); - const int inputHeight = inputShape.at(1); - const int inputWidth = inputShape.at(2); - const int inputChannels = inputShape.at(3); - const int outputBatch = outputShape.at(0); - const int outputHeight = outputShape.at(1); - const int outputWidth = outputShape.at(2); - const int outputChannels = outputShape.at(3); - - const int channelBlocks = UP_DIV(outputChannels, 4); - const int remainChannels = channelBlocks * 4 - outputChannels; - int shape[] = {outputBatch, channelBlocks, outputHeight, outputWidth}; - int localSize = getLocalSize(channel, MaxLocalSize); - if(localSize < 4){ - localSize = 1; - } - if(inputBatch == outside && channel == inputChannels && inside == inputWidth * inputHeight){ - mAxis = 1; - mGlobalWorkSize = {(uint32_t)(localSize), (uint32_t)outputWidth, (uint32_t)outputHeight * outputBatch}; - localSize = getLocalSize(channelBlocks, MaxLocalSize); - }else if(inputBatch * inputChannels == outside && channel == inputHeight && inside == inputWidth){ - mAxis = 2; - mGlobalWorkSize = {(uint32_t)(localSize), (uint32_t)channelBlocks*outputWidth, (uint32_t)outputBatch}; - }else if(inputBatch * inputChannels * inputHeight == outside && channel == inputWidth && inside == 1){ - mAxis = 3; - mGlobalWorkSize = {(uint32_t)(localSize), (uint32_t)channelBlocks, (uint32_t)outputBatch*outputHeight}; + // NC4HW4 -> NCHW + if(mNeedUnpackC4){ + Unit unit; + std::vector outputShape = tensorShapeFormat(input); + int shape[4] = {outputShape[0], outputShape[3], outputShape[1], outputShape[2]};//N C H W + std::set buildOptions; + buildOptions.emplace("-DINPUT_FORMAT=MNN_DATA_FORMAT_NC4HW4"); + buildOptions.emplace("-DOUTPUT_FORMAT=MNN_DATA_FORMAT_NCHW"); + unit.kernel = runtime->buildKernel("buffer_convert_buf", "buffer_convert_to_buffer", buildOptions, input, output); + mGlobalWorkSize = {static_cast(shape[2] * shape[3]), static_cast(shape[1]), static_cast(shape[0])}; + cl_int ret = CL_SUCCESS; + uint32_t idx = 0; + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[2]); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); + ret |= unit.kernel->get().setArg(idx++, sizeof(shape), shape); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); + MNN_CHECK_CL_SUCCESS(ret, "setArg buffer_convert_to_buffer"); + + const uint32_t maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); + mLocalWorkSize = {16, std::max((uint32_t)1, maxWorkGroupSize / 16), 1}; + + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + mUnits.emplace_back(unit); } -// printf("softmax: %d %d %d %d, %d\n", inputBatch, inputChannels, inputHeight, inputWidth, mAxis); - buildSoftmaxKernel(localSize); - - cl_int ret = CL_SUCCESS; - mLocalWorkSize = {(uint32_t)(localSize), 1, 1}; + // softmax + { + Unit unit; + int localSize = getLocalSize(channel, MaxLocalSize); + if(localSize < 4){ + localSize = 1; + } + std::set buildOptions = mBuildOptions; + buildOptions.emplace("-DARGMAX_LOCAL_SIZE=" + std::to_string(localSize)); + std::string kernelName; + if(inside == 1){ + buildOptions.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize)); + unit.kernel = runtime->buildKernel("self_attention_buf", "softmax_inside", buildOptions, inputs[0], outputs[0]); + mGlobalWorkSize = {static_cast(localSize), static_cast(outside), static_cast(1)}; + } + else if(inside % 4 == 0){ + unit.kernel = runtime->buildKernel("softmax_buf", "softmax_v4_buf", buildOptions); + mGlobalWorkSize = {static_cast(localSize), static_cast(UP_DIV(inside, 4)), static_cast(outside)}; + }else { + unit.kernel = runtime->buildKernel("softmax_buf", "softmax_buf", buildOptions); + mGlobalWorkSize = {static_cast(localSize), static_cast(inside), static_cast(outside)}; + } + mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); + mLocalWorkSize = {(uint32_t)(localSize), 1, 1}; + + cl_int ret = CL_SUCCESS; + + uint32_t idx = 0; + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[2]); + if(mNeedUnpackC4){ + ret |= unit.kernel->get().setArg(idx++, openCLImage(output)); + ret |= unit.kernel->get().setArg(idx++, openCLImage(mTempTensor.get())); + }else{ + ret |= unit.kernel->get().setArg(idx++, openCLImage(input)); + ret |= unit.kernel->get().setArg(idx++, openCLImage(output)); + } + if(inside == 1){ + ret |= unit.kernel->get().setArg(idx++, channel); + int shape[4] = {1, outside, channel, 1}; + ret |= unit.kernel->get().setArg(idx++, shape); + } else { + ret |= unit.kernel->get().setArg(idx++, inside); + ret |= unit.kernel->get().setArg(idx++, outside); + ret |= unit.kernel->get().setArg(idx++, channel); + } + MNN_CHECK_CL_SUCCESS(ret, "setArg SoftmaxBufExecution"); + if(localSize == 1){ + mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "softmax_buf", unit.kernel).first; + } + + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + mUnits.emplace_back(unit); + } - uint32_t idx = 0; - ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); - ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[2]); - - ret |= unit.kernel->get().setArg(idx++, openCLImage(input)); - ret |= unit.kernel->get().setArg(idx++, openCLImage(output)); - ret |= unit.kernel->get().setArg(idx++, remainChannels); - ret |= unit.kernel->get().setArg(idx++, shape); - MNN_CHECK_CL_SUCCESS(ret, "setArg SoftmaxBufExecution"); - if(localSize == 1){ - mLocalWorkSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "softmax_buf", unit.kernel).first; + // NCHW -> NC4HW4 + if(mNeedUnpackC4){ + Unit unit; + std::vector outputShape = tensorShapeFormat(output); + int shape[4] = {outputShape[0], outputShape[3], outputShape[1], outputShape[2]};//N C H W + std::set buildOptions; + buildOptions.emplace("-DINPUT_FORMAT=MNN_DATA_FORMAT_NCHW"); + buildOptions.emplace("-DOUTPUT_FORMAT=MNN_DATA_FORMAT_NC4HW4"); + unit.kernel = runtime->buildKernel("buffer_convert_buf", "buffer_convert_to_buffer", buildOptions, input, output); + mGlobalWorkSize = {static_cast(shape[2] * shape[3]), static_cast(shape[1]), static_cast(shape[0])}; + cl_int ret = CL_SUCCESS; + uint32_t idx = 0; + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); + ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[2]); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mTempTensor.get())); + ret |= unit.kernel->get().setArg(idx++, sizeof(shape), shape); + ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); + MNN_CHECK_CL_SUCCESS(ret, "setArg buffer_convert_to_buffer"); + + const uint32_t maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); + mLocalWorkSize = {16, std::max((uint32_t)1, maxWorkGroupSize / 16), 1}; + + mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; + unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; + mUnits.emplace_back(unit); } - mOpenCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalWorkSize); - unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalWorkSize[0], mLocalWorkSize[1], mLocalWorkSize[2]}; return NO_ERROR; } diff --git a/source/backend/opencl/execution/buffer/SoftmaxBufExecution.hpp b/source/backend/opencl/execution/buffer/SoftmaxBufExecution.hpp index 4385bae7d..e6d154ffa 100644 --- a/source/backend/opencl/execution/buffer/SoftmaxBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/SoftmaxBufExecution.hpp @@ -22,8 +22,6 @@ class SoftmaxBufExecution : public CommonExecution { virtual ~SoftmaxBufExecution() = default; virtual ErrorCode onEncode(const std::vector &inputs, const std::vector &outputs) override; - - bool buildSoftmaxKernel(int localSize); private: int getLocalSize(int size, int maxGroupSize); uint32_t mMaxWorkGroupSize; @@ -31,6 +29,9 @@ class SoftmaxBufExecution : public CommonExecution { std::vector mGlobalWorkSize{1, 1, 1}; std::vector mLocalWorkSize{1, 1, 1, 1}; int mAxis; + std::set mBuildOptions; + std::shared_ptr mTempTensor; + bool mNeedUnpackC4; }; } // namespace OpenCL } // namespace MNN diff --git a/source/backend/opencl/execution/buffer/SplitGeluBufExecution.cpp b/source/backend/opencl/execution/buffer/SplitGeluBufExecution.cpp index 0baee6428..78d72de88 100644 --- a/source/backend/opencl/execution/buffer/SplitGeluBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/SplitGeluBufExecution.cpp @@ -39,7 +39,10 @@ ErrorCode SplitGeluBufExecution::onEncode(const std::vector& inputs, co buildOptions.emplace("-DDOUBLE_INPUTS"); } int pack_wh = 1; - if(shape[2] % 4 == 0) { + if(shape[2] % 16 == 0) { + pack_wh = 16; + buildOptions.emplace("-DWH_16"); + } else if(shape[2] % 4 == 0) { pack_wh = 4; buildOptions.emplace("-DWH_4"); } @@ -49,15 +52,13 @@ ErrorCode SplitGeluBufExecution::onEncode(const std::vector& inputs, co auto maxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); - mGWS = {static_cast(shape[0]), - static_cast(UP_DIV(shape[1], 4)), - static_cast(UP_DIV(shape[2],pack_wh))}; + mGWS = {static_cast(UP_DIV(shape[2], pack_wh)), + static_cast(shape[0] * shape[1])}; uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, mGWS[0]); ret |= unit.kernel->get().setArg(idx++, mGWS[1]); - ret |= unit.kernel->get().setArg(idx++, mGWS[2]); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); if(inputs.size() > 1) { ret |= unit.kernel->get().setArg(idx++, openCLBuffer(inputs[1])); @@ -67,12 +68,12 @@ ErrorCode SplitGeluBufExecution::onEncode(const std::vector& inputs, co MNN_CHECK_CL_SUCCESS(ret, "setArg SplitGeluBufExecution"); - mLWS = localWS3DDefault(mGWS, maxWorkGroupSize, runtime, "splitgelu_buf", unit.kernel).first; + mLWS = localWS2DDefault(mGWS, maxWorkGroupSize, runtime, "splitgelu_buf", unit.kernel).first; - unit.globalWorkSize = {mGWS[0], mGWS[1], mGWS[2]}; - unit.localWorkSize = {mLWS[0], mLWS[1], mLWS[2]}; + unit.globalWorkSize = {mGWS[0], mGWS[1]}; + unit.localWorkSize = {mLWS[0], mLWS[1]}; - mOpenCLBackend->recordKernel3d(unit.kernel, mGWS, mLWS); + mOpenCLBackend->recordKernel2d(unit.kernel, mGWS, mLWS); mOpenCLBackend->endRecord(mRecording); return NO_ERROR; diff --git a/source/backend/opencl/execution/buffer/StrassenMatmulOpenCLComputor.cpp b/source/backend/opencl/execution/buffer/StrassenMatmulOpenCLComputor.cpp index ff1bddda1..501e6f1ae 100644 --- a/source/backend/opencl/execution/buffer/StrassenMatmulOpenCLComputor.cpp +++ b/source/backend/opencl/execution/buffer/StrassenMatmulOpenCLComputor.cpp @@ -210,6 +210,19 @@ ErrorCode StrassenMatrixComputor::_generateBasicMatMul(int e, int l, int h, cons return NO_ERROR; } + +static int getMaxMultiple(int number) { + if(number % 128 == 0) { + return 128; + } else if(number % 64 == 0) { + return 64; + } else if(number % 32 == 0) { + return 32; + } else if(number % 16 == 0) { + return 16; + } + return 1; +} ErrorCode StrassenMatrixComputor::_generateMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, int currentDepth, int postType) { @@ -244,6 +257,14 @@ ErrorCode StrassenMatrixComputor::_generateMatMul(int e, int l, int h, const Mat return res; } + // sub_matrix cannot own sufficient tile + if(getMaxMultiple(e) != getMaxMultiple(eSub) || getMaxMultiple(h) != getMaxMultiple(eSub) || (lSub % 4 != 0)) { + Unit unit; + auto res = _generateBasicMatMul(e, l, h, AT, BT, CT, COT, postType, unit); + mUnits.emplace_back(unit); + return res; + } + // Strassen Construct currentDepth += 1; diff --git a/source/backend/opencl/execution/buffer/UnaryBufExecution.cpp b/source/backend/opencl/execution/buffer/UnaryBufExecution.cpp index 75897d2fe..56a1c9027 100644 --- a/source/backend/opencl/execution/buffer/UnaryBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/UnaryBufExecution.cpp @@ -22,54 +22,47 @@ ErrorCode UnaryBufExecution::onEncode(const std::vector& inputs, const Tensor* output = outputs[0]; auto openCLBackend = static_cast(backend()); auto runtime = openCLBackend->getOpenCLRuntime(); - - auto dataType = inputs[0]->getType(); std::set buildOptions = mBuildOptions; - if (dataType.code == halide_type_int){ - buildOptions.emplace("-DOPENCL_INPUT_INT"); - } #ifdef MNN_SUPPORT_INTEL_SUBGROUP - if (runtime->isSupportedIntelSubgroup()) { + if (runtime->isSupportedIntelSubgroup() && MNN::MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(output)->dimensionFormat) { return SubgrouponResize(inputs, outputs); } #endif /* MNN_SUPPORT_INTEL_SUBGROUP */ - unit.kernel = runtime->buildKernel("unary_buf", "unary_buf", buildOptions, input, output); - mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); - std::vector inputShape = tensorShapeFormat(input); std::vector outputShape = tensorShapeFormat(output); - - int batch = outputShape.at(0); - int outputHeight = outputShape.at(1); - int outputWidth = outputShape.at(2); - int channels = outputShape.at(3); - - int channelBlocks = (channels + 3) / 4; + int totalSize = 0; + if(MNN::MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(output)->dimensionFormat){ + totalSize = outputShape[0] * outputShape[1] * outputShape[2] * ROUND_UP(outputShape[3], 4); + }else{ + totalSize = outputShape[0] * outputShape[1] * outputShape[2] * outputShape[3]; + } + if(totalSize % 4 != 0) { + buildOptions.emplace("-DPACK_LEAVE"); + } + unit.kernel = runtime->buildKernel("unary_buf", "unary_buf", buildOptions, input, output); + mMaxWorkGroupSize = static_cast(runtime->getMaxWorkGroupSize(unit.kernel)); mGlobalWorkSize = { - static_cast(channelBlocks), - static_cast(outputWidth), - static_cast(batch * outputHeight), + static_cast(UP_DIV(totalSize, 4)), + static_cast(1) }; uint32_t idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[0]); ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[1]); - ret |= unit.kernel->get().setArg(idx++, mGlobalWorkSize[2]); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(input)); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(output)); - ret |= unit.kernel->get().setArg(idx++, outputHeight); + ret |= unit.kernel->get().setArg(idx++, totalSize); MNN_CHECK_CL_SUCCESS(ret, "setArg UnaryBufExecution"); std::string kernelName = "unary_buf"; - mLocalSize = localWS3DDefault(mGlobalWorkSize, mMaxWorkGroupSize, openCLBackend->getOpenCLRuntime(), kernelName, unit.kernel).first; - openCLBackend->recordKernel3d(unit.kernel, mGlobalWorkSize, mLocalSize); - unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1], mGlobalWorkSize[2]}; - unit.localWorkSize = {mLocalSize[0], mLocalSize[1], mLocalSize[2]}; + mLocalSize = localWS2DDefault(mGlobalWorkSize, mMaxWorkGroupSize, openCLBackend->getOpenCLRuntime(), kernelName, unit.kernel).first; + openCLBackend->recordKernel2d(unit.kernel, mGlobalWorkSize, mLocalSize); + unit.globalWorkSize = {mGlobalWorkSize[0], mGlobalWorkSize[1]}; + unit.localWorkSize = {mLocalSize[0], mLocalSize[1]}; return NO_ERROR; } - #ifdef MNN_SUPPORT_INTEL_SUBGROUP ErrorCode UnaryBufExecution::SubgrouponResize(const std::vector& inputs, const std::vector& outputs) { auto &unit = mUnits[0]; @@ -162,6 +155,7 @@ ErrorCode UnaryBufExecution::SubgrouponResize(const std::vector& inputs ret |= unit.kernel->get().setArg(idx++, outputWidth); ret |= unit.kernel->get().setArg(idx++, outputHeight); ret |= unit.kernel->get().setArg(idx++, channels); + ret |= unit.kernel->get().setArg(idx++, batch); ret |= unit.kernel->get().setArg(idx++, static_cast(inputpad.left)); ret |= unit.kernel->get().setArg(idx++, static_cast(inputpad.right)); ret |= unit.kernel->get().setArg(idx++, static_cast(outputpad.left)); @@ -187,7 +181,8 @@ class UnaryBufCreator : public OpenCLBackend::Creator { const MNN::Op* op, Backend* backend) const override { for (int i = 0; i < inputs.size(); ++i) { int channel = inputs[i]->channel(); - if (channel >= 16 && static_cast(backend)->getOpenCLRuntime()->isSupportedIntelSubgroup()) { + if (channel >= 16 && static_cast(backend)->getOpenCLRuntime()->isSupportedIntelSubgroup() + && MNN::MNN_DATA_FORMAT_NC4HW4 == TensorUtils::getDescribe(inputs[i])->dimensionFormat) { TensorUtils::setTensorChannelPack(inputs[i], 16); } } diff --git a/source/backend/opencl/execution/cl/argmax_buf.cl b/source/backend/opencl/execution/cl/argmax_buf.cl index 5240eea4e..f6a675070 100644 --- a/source/backend/opencl/execution/cl/argmax_buf.cl +++ b/source/backend/opencl/execution/cl/argmax_buf.cl @@ -22,219 +22,37 @@ __private const int global_size_dim0, __private const int global_size_dim1, __pr if(A.z > B.z){ A.z = B.z; C.z = D; } \ if(A.w > B.w){ A.w = B.w; C.w = D; } -__kernel void argmax_width_buf(GLOBAL_SIZE_3_DIMS - __global const FLOAT* input, - __global int* output, - __private const int inputWidth, - __private const int inputHeight, - __private const int inputChannel, - __private const int inputBatch, - __private const int inputChannelBlock, - __private const int oututWidth, - __private const int outputHeight, - __private const int outputChannel, - __private const int outputChannelBlock - ) { - const int x = get_global_id(0); - const int height_idx = get_global_id(1); - const int batch_channel_idx = get_global_id(2); - DEAL_NON_UNIFORM_DIM3(x, height_idx, batch_channel_idx); - - const int batch_idx = batch_channel_idx / outputChannelBlock; - const int channel_idx = batch_channel_idx % outputChannelBlock; - - const int offset = ((((batch_idx * inputChannelBlock) + channel_idx) * inputHeight + height_idx) * inputWidth + 0)*4; - const int outputOffset = ((((batch_idx * outputChannelBlock) + channel_idx) * outputHeight + height_idx) * oututWidth + 0)*4; - int4 index = 0; -#ifdef ARGMAX - FLOAT4 maxValue = (FLOAT4)-FLT_MAX; -#else - FLOAT4 maxValue = (FLOAT4)FLT_MAX; -#endif -#if ARGMAX_LOCAL_SIZE >= 4 - int lid = get_local_id(0); - FLOAT4 local reduce[ARGMAX_LOCAL_SIZE]; - int4 local index_reduce[ARGMAX_LOCAL_SIZE]; - - for (int i=lid; i < inputWidth; i+=ARGMAX_LOCAL_SIZE) { - FLOAT4 value = vload4(i, input + offset); -#ifdef ARGMAX - ARGMAX_SELECT(maxValue, value, index, i); -#else - ARGMIN_SELECT(maxValue, value, index, i); -#endif - } - reduce[lid] = maxValue; - index_reduce[lid] = index; - barrier(CLK_LOCAL_MEM_FENCE); - for(int i = ARGMAX_LOCAL_SIZE/2; i > 0; i /= 2){ - if (lid < i){ -#ifdef ARGMAX - if(reduce[lid].x < reduce[lid + i].x){reduce[lid].x = reduce[lid + i].x; index_reduce[lid].x = index_reduce[lid + i].x;} - if(reduce[lid].y < reduce[lid + i].y){reduce[lid].y = reduce[lid + i].y; index_reduce[lid].y = index_reduce[lid + i].y;} - if(reduce[lid].z < reduce[lid + i].z){reduce[lid].z = reduce[lid + i].z; index_reduce[lid].z = index_reduce[lid + i].z;} - if(reduce[lid].w < reduce[lid + i].w){reduce[lid].w = reduce[lid + i].w; index_reduce[lid].w = index_reduce[lid + i].w;} -#else - if(reduce[lid].x > reduce[lid + i].x){reduce[lid].x = reduce[lid + i].x; index_reduce[lid].x = index_reduce[lid + i].x;} - if(reduce[lid].y > reduce[lid + i].y){reduce[lid].y = reduce[lid + i].y; index_reduce[lid].y = index_reduce[lid + i].y;} - if(reduce[lid].z > reduce[lid + i].z){reduce[lid].z = reduce[lid + i].z; index_reduce[lid].z = index_reduce[lid + i].z;} - if(reduce[lid].w > reduce[lid + i].w){reduce[lid].w = reduce[lid + i].w; index_reduce[lid].w = index_reduce[lid + i].w;} -#endif - } - barrier(CLK_LOCAL_MEM_FENCE); - } - if(lid == 0){ - vstore4(index_reduce[0], 0, output + outputOffset); - } -#else - for(int i = 0; i < inputWidth; ++i){ - FLOAT4 value = vload4(i, input + offset); -#ifdef ARGMAX - ARGMAX_SELECT(maxValue, value, index, i); -#else - ARGMIN_SELECT(maxValue, value, index, i); -#endif - } - vstore4(index, 0, output + outputOffset); -#endif -} - - -__kernel void argmax_height_buf(GLOBAL_SIZE_3_DIMS - __global const FLOAT* input, - __global int* output, - __private const int inputWidth, - __private const int inputHeight, - __private const int inputChannel, - __private const int inputBatch, - __private const int inputChannelBlock, - __private const int oututWidth, - __private const int outputHeight, - __private const int outputChannel, - __private const int outputChannelBlock - ) { +__kernel void argmax_buf(GLOBAL_SIZE_3_DIMS + __global const FLOAT* input, + __global int* output, + __private const int inside, + __private const int outside, + __private const int dim){ const int x = get_global_id(0); - const int width_idx = get_global_id(1); - const int batch_channel_idx = get_global_id(2); - - DEAL_NON_UNIFORM_DIM3(x, width_idx, batch_channel_idx); - - const int batch_idx = batch_channel_idx / outputChannelBlock; - const int channel_idx = batch_channel_idx % outputChannelBlock; - - const int offset = ((((batch_idx * inputChannelBlock) + channel_idx) * inputHeight + 0) * inputWidth + width_idx)*4; - const int outputOffset = ((((batch_idx * outputChannelBlock) + channel_idx) * outputHeight + 0) * oututWidth + width_idx)*4; - int4 index = 0; -#ifdef ARGMAX - FLOAT4 maxValue = (FLOAT4)-FLT_MAX; -#else - FLOAT4 maxValue = (FLOAT4)FLT_MAX; -#endif -#if ARGMAX_LOCAL_SIZE >= 4 - int lid = get_local_id(0); - FLOAT4 local reduce[ARGMAX_LOCAL_SIZE]; - int4 local index_reduce[ARGMAX_LOCAL_SIZE]; + const int y = get_global_id(1); // inside + const int z = get_global_id(2); // outside - for (int i=lid; i < inputHeight; i+=ARGMAX_LOCAL_SIZE) { - FLOAT4 value = vload4(i * inputWidth, input + offset); -#ifdef ARGMAX - ARGMAX_SELECT(maxValue, value, index, i); -#else - ARGMIN_SELECT(maxValue, value, index, i); -#endif - } - reduce[lid] = maxValue; - index_reduce[lid] = index; - barrier(CLK_LOCAL_MEM_FENCE); - for(int i = ARGMAX_LOCAL_SIZE/2; i > 0; i /= 2){ - if (lid < i){ -#ifdef ARGMAX - if(reduce[lid].x < reduce[lid + i].x){reduce[lid].x = reduce[lid + i].x; index_reduce[lid].x = index_reduce[lid + i].x;} - if(reduce[lid].y < reduce[lid + i].y){reduce[lid].y = reduce[lid + i].y; index_reduce[lid].y = index_reduce[lid + i].y;} - if(reduce[lid].z < reduce[lid + i].z){reduce[lid].z = reduce[lid + i].z; index_reduce[lid].z = index_reduce[lid + i].z;} - if(reduce[lid].w < reduce[lid + i].w){reduce[lid].w = reduce[lid + i].w; index_reduce[lid].w = index_reduce[lid + i].w;} -#else - if(reduce[lid].x > reduce[lid + i].x){reduce[lid].x = reduce[lid + i].x; index_reduce[lid].x = index_reduce[lid + i].x;} - if(reduce[lid].y > reduce[lid + i].y){reduce[lid].y = reduce[lid + i].y; index_reduce[lid].y = index_reduce[lid + i].y;} - if(reduce[lid].z > reduce[lid + i].z){reduce[lid].z = reduce[lid + i].z; index_reduce[lid].z = index_reduce[lid + i].z;} - if(reduce[lid].w > reduce[lid + i].w){reduce[lid].w = reduce[lid + i].w; index_reduce[lid].w = index_reduce[lid + i].w;} -#endif - } - barrier(CLK_LOCAL_MEM_FENCE); - } - if(lid == 0){ - vstore4(index_reduce[0], 0, output + outputOffset); - } -#else - for(int i = 0; i < inputHeight; ++i){ - FLOAT4 value = vload4(i * inputWidth, input + offset); -#ifdef ARGMAX - ARGMAX_SELECT(maxValue, value, index, i); -#else - ARGMIN_SELECT(maxValue, value, index, i); -#endif - } - vstore4(index, 0, output + outputOffset); -#endif -} - -__kernel void argmax_channel_buf(GLOBAL_SIZE_3_DIMS - __global const FLOAT* input, - __global int* output, - __private const int inputWidth, - __private const int inputHeight, - __private const int inputChannel, - __private const int inputBatch, - __private const int inputChannelBlock, - __private const int oututWidth, - __private const int outputHeight, - __private const int outputChannel, - __private const int outputChannelBlock - ) { - const int x = get_global_id(0); - const int wh = get_global_id(1); - const int batch_idx = get_global_id(2); - - DEAL_NON_UNIFORM_DIM3(x, wh, batch_idx); - - const int width_idx = wh % oututWidth; - const int height_idx = wh / oututWidth; - const int offset = ((((batch_idx * inputChannelBlock) + 0) * inputHeight + height_idx) * inputWidth + width_idx)*4; -#ifdef ARGMAX_CHANNEL_DIM1 - const int outputOffset = ((batch_idx * outputHeight + height_idx) * oututWidth + width_idx); -#else - const int outputOffset = ((((batch_idx * outputChannelBlock) + 0) * outputHeight + height_idx) * oututWidth + width_idx)*4; -#endif - int remain = inputChannel - (inputChannelBlock - 1) * 4; + DEAL_NON_UNIFORM_DIM3(x, y, z); + int index = 0; #ifdef ARGMAX FLOAT maxValue = (FLOAT)-FLT_MAX; #else - FLOAT maxValue = (FLOAT)FLT_MAX; +FLOAT maxValue = (FLOAT)FLT_MAX; #endif - int index = 0; - FLOAT4 value; - FLOAT *valuePtr = (FLOAT*)&value; + const int offset = z * dim * inside + y; #if ARGMAX_LOCAL_SIZE >= 4 int lid = get_local_id(0); FLOAT local reduce[ARGMAX_LOCAL_SIZE]; int local index_reduce[ARGMAX_LOCAL_SIZE]; - - for (int i=lid; i < inputChannelBlock - 1; i+=ARGMAX_LOCAL_SIZE) { - value = vload4(i * inputWidth * inputHeight, input + offset); - for(int j = 0; j < 4; ++j){ + + for (int i=lid; i < dim; i+=ARGMAX_LOCAL_SIZE) { + FLOAT value = input[offset + i * inside]; #ifdef ARGMAX - if(maxValue < valuePtr[j]){ - index = i * 4 + j; - maxValue = valuePtr[j]; - } + if(maxValue < value){ maxValue = value; index = i; } #else - if(maxValue > valuePtr[j]){ - index = i * 4 + j; - maxValue = valuePtr[j]; - } + if(maxValue > value){ maxValue = value; index = i; } #endif - } } reduce[lid] = maxValue; index_reduce[lid] = index; @@ -250,96 +68,47 @@ __kernel void argmax_channel_buf(GLOBAL_SIZE_3_DIMS barrier(CLK_LOCAL_MEM_FENCE); } if(lid == 0){ - maxValue = reduce[lid]; - index = index_reduce[lid]; - value = vload4((inputChannelBlock - 1) * inputWidth * inputHeight, input + offset); - for(int j = 0; j < remain; ++j){ -#ifdef ARGMAX - if(maxValue < valuePtr[j]){ - index = (inputChannelBlock - 1) * 4 + j; - maxValue = valuePtr[j]; - } -#else - if(maxValue > valuePtr[j]){ - index = (inputChannelBlock - 1) * 4 + j; - maxValue = valuePtr[j]; - } -#endif - } - output[outputOffset] = index; + output[z * inside + y] = index_reduce[0]; } #else - for(int i = 0; i < inputChannelBlock - 1; ++i){ - value = vload4(i * inputWidth * inputHeight, input + offset); - for(int j = 0; j < 4; ++j){ -#ifdef ARGMAX - if(maxValue < valuePtr[j]){ - index = i * 4 + j; - maxValue = valuePtr[j]; - } -#else - if(maxValue > valuePtr[j]){ - index = i * 4 + j; - maxValue = valuePtr[j]; - } -#endif - } - } - value = vload4((inputChannelBlock - 1) * inputWidth * inputHeight, input + offset); - for(int j = 0; j < remain; ++j){ + for(int i = 0; i < dim; ++i){ + FLOAT value = input[ + offset + i * inside]; #ifdef ARGMAX - if(maxValue < valuePtr[j]){ - index = (inputChannelBlock - 1) * 4 + j; - maxValue = valuePtr[j]; - } + if(maxValue < value){ maxValue = value; index = i; } #else - if(maxValue > valuePtr[j]){ - index = (inputChannelBlock - 1) * 4 + j; - maxValue = valuePtr[j]; - } + if(maxValue > value){ maxValue = value; index = i; } #endif } - output[outputOffset] = index; + output[z * inside + y] = index; #endif } -__kernel void argmax_batch_buf(GLOBAL_SIZE_3_DIMS - __global const FLOAT* input, - __global int* output, - __private const int inputWidth, - __private const int inputHeight, - __private const int inputChannel, - __private const int inputBatch, - __private const int inputChannelBlock, - __private const int oututWidth, - __private const int outputHeight, - __private const int outputChannel, - __private const int outputChannelBlock - ) { - const int x = get_global_id(0); - const int wh = get_global_id(1); - const int channel_idx = get_global_id(2); - DEAL_NON_UNIFORM_DIM3(x, wh, channel_idx); +__kernel void argmax_v4_buf(GLOBAL_SIZE_3_DIMS + __global const FLOAT* input, + __global int* output, + __private const int inside, + __private const int outside, + __private const int dim){ + const int x = get_global_id(0); + const int y = get_global_id(1) << 2; // inside + const int z = get_global_id(2); // outside - const int width_idx = wh % oututWidth; - const int height_idx = wh / oututWidth; - const int offset = ((((0 * inputChannelBlock) + channel_idx) * inputHeight + height_idx) * inputWidth + width_idx)*4; - const int outputOffset = ((((0 * outputChannelBlock) + channel_idx) * outputHeight + height_idx) * oututWidth + width_idx)*4; + DEAL_NON_UNIFORM_DIM3(x, y, z); int4 index = 0; - int batchOffset = inputChannelBlock * inputHeight * inputWidth; #ifdef ARGMAX FLOAT4 maxValue = (FLOAT4)-FLT_MAX; #else FLOAT4 maxValue = (FLOAT4)FLT_MAX; #endif + const int offset = z * dim * inside + y; #if ARGMAX_LOCAL_SIZE >= 4 int lid = get_local_id(0); FLOAT4 local reduce[ARGMAX_LOCAL_SIZE]; int4 local index_reduce[ARGMAX_LOCAL_SIZE]; - - for (int i=lid; i < inputBatch; i+=ARGMAX_LOCAL_SIZE) { - FLOAT4 value = vload4(i * batchOffset, input + offset); + + for (int i=lid; i < dim; i+=ARGMAX_LOCAL_SIZE) { + FLOAT4 value = vload4(0, input + offset + i * inside); #ifdef ARGMAX ARGMAX_SELECT(maxValue, value, index, i); #else @@ -366,17 +135,17 @@ __kernel void argmax_batch_buf(GLOBAL_SIZE_3_DIMS barrier(CLK_LOCAL_MEM_FENCE); } if(lid == 0){ - vstore4(index_reduce[0], 0, output + outputOffset); + vstore4(index_reduce[0], 0, output + z * inside + y); } #else - for(int i = 0; i < inputBatch; ++i){ - FLOAT4 value = vload4(i * batchOffset, input + offset); + for(int i = 0; i < dim; ++i){ + FLOAT4 value = vload4(0, input + offset + i * inside); #ifdef ARGMAX ARGMAX_SELECT(maxValue, value, index, i); #else ARGMIN_SELECT(maxValue, value, index, i); #endif } - vstore4(index, 0, output + outputOffset); + vstore4(index, 0, output + z * inside + y); #endif } diff --git a/source/backend/opencl/execution/cl/attention_buf.cl b/source/backend/opencl/execution/cl/attention_buf.cl index c17be5a4f..074956902 100644 --- a/source/backend/opencl/execution/cl/attention_buf.cl +++ b/source/backend/opencl/execution/cl/attention_buf.cl @@ -10,359 +10,784 @@ return; \ } +#define DEAL_OUTER_SEQLEN_NOT_ALIGN(length) \ + if(4 * sl + 3 >= length) {\ + temp_3 = (FLOAT4)0;\ + }\ + if(4 * sl + 2 >= length) {\ + temp_2 = (FLOAT4)0;\ + }\ + if(4 * sl + 1 >= length) {\ + temp_1 = (FLOAT4)0;\ + } + +#define DEAL_INNER_HEADDIM_NOT_ALIGN(length) \ + if(hd * 4 + 3 >= length) {\ + temp_0.w = (FLOAT)0;\ + temp_1.w = (FLOAT)0;\ + temp_2.w = (FLOAT)0;\ + temp_3.w = (FLOAT)0;\ + }\ + if(hd * 4 + 2 >= length) {\ + temp_0.z = (FLOAT)0;\ + temp_1.z = (FLOAT)0;\ + temp_2.z = (FLOAT)0;\ + temp_3.z = (FLOAT)0;\ + }\ + if(hd * 4 + 1 >= length) {\ + temp_0.y = (FLOAT)0;\ + temp_1.y = (FLOAT)0;\ + temp_2.y = (FLOAT)0;\ + temp_3.y = (FLOAT)0;\ + } + + + +__kernel void rearrange_qkv(GLOBAL_SIZE_3_DIMS + __global const FLOAT *input_q, //[batch, seqLenQ/4, headNum, headDim, seqLenQ_4] + __global const FLOAT *input_k, // [batch, seqLenKV/4, headNum/group, headDim, seqLenKV_4] + __global const FLOAT *input_v, // [batch, seqLenKV/4, headNum/group, headDim, seqLenKV_4] + __global FLOAT *output_q, // [batch*headNum, ROUND_UP(headDim, mTileHDK), ROUND_UP(seqLenQ, mTileQ)] + __global FLOAT *output_k, // [batch*headNum/group, ROUND_UP(headDim, mTileHDK), ROUND_UP(seqLenKV, mTileKV)] + __global FLOAT *output_v, // [batch*headNum/group, ROUND_UP(seqLenKV, mTileKV), ROUND_UP(headDim, mTileHDN)] + __global FLOAT *past_k, // [batch, seqLenKV/4, headNum/group, headDim, seqLenKV_4] + __global FLOAT *past_v, // [batch, seqLenKV/4, headNum/group, headDim, seqLenKV_4] + __private const int4 tile, // [mTileQ, mTileKV, mTileHDK, mTileHDN] + __private const int4 shape,// [seqLenQ, seqLenKV, headNum, headDim] + __private const int4 param // [group, batch] +) { + const int sl = get_global_id(0); // seqLen/4 : max(seqLenPackQ/4, seqLenPackKV/4) + const int hd = get_global_id(1); // headDim/4 : max(headDimPackQK/4, headDimPackV/4) + const int z = get_global_id(2); // batch * headNum + DEAL_NON_UNIFORM_DIM3(sl, hd, z); + + const int seqLenQ = shape.x; + const int seqLenKV = shape.y; + const int headNum = shape.z; + const int headDim = shape.w; + const int group = param.x; + const int batch = param.y; + + const int b = z % batch; + const int hn = z / batch; + + const int seqLenQ_4 = (seqLenQ + 3) / 4; + //const int in_offset_q = (((b * seqLenQ_4 + sl) * headNum + hn) * headDim + 4 * hd) * 4; + const int in_offset_q = (((b * seqLenQ + sl * 4) * headNum + hn) * headDim + 4 * hd); + + const int seqLenPackQ = ((seqLenQ + tile.x - 1) / tile.x) * tile.x; + const int headDimPackQK = ((headDim + tile.z - 1) / tile.z) * tile.z; + const int out_offset_q = (((b * headNum + hn) * headDimPackQK + hd * 4) * seqLenPackQ + sl * 4); + + if(sl * 4 < seqLenPackQ && hd * 4 < headDimPackQK) { + if(sl * 4 >= seqLenQ || hd * 4 >= headDim) { + vstore4((FLOAT4)0, 0, output_q + out_offset_q); + vstore4((FLOAT4)0, 0, output_q + out_offset_q + seqLenPackQ); + vstore4((FLOAT4)0, 0, output_q + out_offset_q + 2 * seqLenPackQ); + vstore4((FLOAT4)0, 0, output_q + out_offset_q + 3 * seqLenPackQ); + } else { + FLOAT4 temp_0 = vload4(0, input_q + in_offset_q); + FLOAT4 temp_1 = (sl * 4 + 1 >= seqLenQ) ? (FLOAT4)0 : vload4(0, input_q + in_offset_q + headNum*headDim); + FLOAT4 temp_2 = (sl * 4 + 2 >= seqLenQ) ? (FLOAT4)0 : vload4(0, input_q + in_offset_q + 2*headNum*headDim); + FLOAT4 temp_3 = (sl * 4 + 3 >= seqLenQ) ? (FLOAT4)0 : vload4(0, input_q + in_offset_q + 3*headNum*headDim); + #ifdef HEADDIM_LEAVE + DEAL_INNER_HEADDIM_NOT_ALIGN(headDim) + #endif + #ifdef SEQLEN_LEAVE + DEAL_OUTER_SEQLEN_NOT_ALIGN(seqLenQ) + #endif + vstore4((FLOAT4)(temp_0.s0, temp_1.s0, temp_2.s0, temp_3.s0), 0, output_q + out_offset_q); + vstore4((FLOAT4)(temp_0.s1, temp_1.s1, temp_2.s1, temp_3.s1), 0, output_q + out_offset_q + seqLenPackQ); + vstore4((FLOAT4)(temp_0.s2, temp_1.s2, temp_2.s2, temp_3.s2), 0, output_q + out_offset_q + 2 * seqLenPackQ); + vstore4((FLOAT4)(temp_0.s3, temp_1.s3, temp_2.s3, temp_3.s3), 0, output_q + out_offset_q + 3 * seqLenPackQ); + } + } + + if(hn >= headNum / group) { + return; + } + + + const int seqLenPackKV = ((seqLenKV + tile.y - 1) / tile.y) * tile.y; + const int headDimPackV = ((headDim + tile.w - 1) / tile.w) * tile.w; + const int seqLenKV_4 = (seqLenKV + 3) / 4; + const int in_offset_kv = (((b * seqLenKV + sl*4) * headNum/group + hn) * headDim + 4 * hd); + + if(sl * 4 < seqLenPackKV && hd * 4 < headDimPackQK) { + const int out_offset_k = (((b * headNum/group + hn) * headDimPackQK + hd * 4) * seqLenPackKV + sl * 4); + + if(sl * 4 >= seqLenKV || hd * 4 >= headDim) { + vstore4((FLOAT4)0, 0, output_k + out_offset_k); + vstore4((FLOAT4)0, 0, output_k + out_offset_k + seqLenPackKV); + vstore4((FLOAT4)0, 0, output_k + out_offset_k + 2 * seqLenPackKV); + vstore4((FLOAT4)0, 0, output_k + out_offset_k + 3 * seqLenPackKV); + } else { + FLOAT4 temp_0 = vload4(0, input_k + in_offset_kv); + FLOAT4 temp_1 = (sl * 4 + 1 >= seqLenKV) ? (FLOAT4)0 : vload4(0, input_k + in_offset_kv + headNum*headDim/group); + FLOAT4 temp_2 = (sl * 4 + 2 >= seqLenKV) ? (FLOAT4)0 : vload4(0, input_k + in_offset_kv + 2*headNum*headDim/group); + FLOAT4 temp_3 = (sl * 4 + 3 >= seqLenKV) ? (FLOAT4)0 : vload4(0, input_k + in_offset_kv + 3*headNum*headDim/group); + #ifdef HEADDIM_LEAVE + DEAL_INNER_HEADDIM_NOT_ALIGN(headDim) + #endif + #ifdef SEQLEN_LEAVE + DEAL_OUTER_SEQLEN_NOT_ALIGN(seqLenKV) + #endif + vstore4((FLOAT4)(temp_0.s0, temp_1.s0, temp_2.s0, temp_3.s0), 0, output_k + out_offset_k); + vstore4((FLOAT4)(temp_0.s1, temp_1.s1, temp_2.s1, temp_3.s1), 0, output_k + out_offset_k + seqLenPackKV); + vstore4((FLOAT4)(temp_0.s2, temp_1.s2, temp_2.s2, temp_3.s2), 0, output_k + out_offset_k + 2 * seqLenPackKV); + vstore4((FLOAT4)(temp_0.s3, temp_1.s3, temp_2.s3, temp_3.s3), 0, output_k + out_offset_k + 3 * seqLenPackKV); + + // pastK + vstore4(temp_0, 0, past_k + in_offset_kv); + if(sl * 4 + 1 < seqLenKV) { + vstore4(temp_1, 0, past_k + in_offset_kv + headNum*headDim/group); + } + if(sl * 4 + 2 < seqLenKV) { + vstore4(temp_2, 0, past_k + in_offset_kv + 2*headNum*headDim/group); + } + if(sl * 4 + 3 < seqLenKV) { + vstore4(temp_3, 0, past_k + in_offset_kv + 3*headNum*headDim/group); + } + } + + } + + if(sl * 4 < seqLenPackKV && hd * 4 < headDimPackV) { + const int out_offset_v = (((b * headNum/group + hn) * seqLenPackKV + sl * 4) * headDimPackV + hd * 4); + + if(sl * 4 >= seqLenKV || hd * 4 >= headDim) { + vstore4((FLOAT4)0, 0, output_v + out_offset_v); + vstore4((FLOAT4)0, 0, output_v + out_offset_v + headDimPackV); + vstore4((FLOAT4)0, 0, output_v + out_offset_v + 2 * headDimPackV); + vstore4((FLOAT4)0, 0, output_v + out_offset_v + 3 * headDimPackV); + } else { + FLOAT4 temp_0 = vload4(0, input_v + in_offset_kv); + FLOAT4 temp_1 = (sl * 4 + 1 >= seqLenKV) ? (FLOAT4)0 : vload4(0, input_v + in_offset_kv + headNum*headDim/group); + FLOAT4 temp_2 = (sl * 4 + 2 >= seqLenKV) ? (FLOAT4)0 : vload4(0, input_v + in_offset_kv + 2*headNum*headDim/group); + FLOAT4 temp_3 = (sl * 4 + 3 >= seqLenKV) ? (FLOAT4)0 : vload4(0, input_v + in_offset_kv + 3*headNum*headDim/group); + #ifdef HEADDIM_LEAVE + DEAL_INNER_HEADDIM_NOT_ALIGN(headDim) + #endif + #ifdef SEQLEN_LEAVE + DEAL_OUTER_SEQLEN_NOT_ALIGN(seqLenKV) + #endif + vstore4(temp_0, 0, output_v + out_offset_v); + vstore4(temp_1, 0, output_v + out_offset_v + headDimPackV); + vstore4(temp_2, 0, output_v + out_offset_v + 2 * headDimPackV); + vstore4(temp_3, 0, output_v + out_offset_v + 3 * headDimPackV); + + // pastV + vstore4(temp_0, 0, past_v + in_offset_kv); + if(sl * 4 + 1 < seqLenKV) { + vstore4(temp_1, 0, past_v + in_offset_kv + headNum*headDim/group); + } + if(sl * 4 + 2 < seqLenKV) { + vstore4(temp_2, 0, past_v + in_offset_kv + 2*headNum*headDim/group); + } + if(sl * 4 + 3 < seqLenKV) { + vstore4(temp_3, 0, past_v + in_offset_kv + 3*headNum*headDim/group); + } + } + + } +} + +#ifndef MASK_DTYPE +#define MASK_DTYPE FLOAT +#define MASK_DTYPE4 FLOAT4 +#endif +__kernel void rearrange_mask(GLOBAL_SIZE_3_DIMS + __global const MASK_DTYPE *input_mask, // [batch, 1, seqLenQ, seqLenKV, 4] + __global MASK_DTYPE *output_mask, // [batch, ROUND_UP(seqLenQ, mTileQ), ROUND_UP(seqLenKV, mTileKV)] + const int4 shape // [seqLenQ, seqLenKV, mTileQ, mTileKV] +) { + const int sl = get_global_id(0); // seqLen_4 + const int sl_kv = get_global_id(1); // seqLenKV_4 + const int b = get_global_id(2); // Batch + DEAL_NON_UNIFORM_DIM3(sl, sl_kv, b); + + const int seq_len_pack = ((shape.x + shape.z - 1) / shape.z) * shape.z; + const int seq_len_kv_pack = ((shape.y + shape.w - 1) / shape.w) * shape.w; + + int in_offset = ((b * shape.x + sl * 4) * shape.y + sl_kv * 4); + int out_offset = (b * seq_len_pack + sl * 4) * seq_len_kv_pack + sl_kv * 4; + + if(sl * 4 >= shape.x || sl_kv * 4 >= shape.y) { + vstore4((MASK_DTYPE4)0, 0, output_mask + out_offset); + vstore4((MASK_DTYPE4)0, 0, output_mask + out_offset + seq_len_kv_pack); + vstore4((MASK_DTYPE4)0, 0, output_mask + out_offset + seq_len_kv_pack * 2); + vstore4((MASK_DTYPE4)0, 0, output_mask + out_offset + seq_len_kv_pack * 3); + } else { + int y_down_align4 = (shape.y / 4 * 4); + MASK_DTYPE4 temp_0, temp_1, temp_2, temp_3; + + if(sl_kv * 4 < y_down_align4) { + temp_0 = vload4(0, input_mask + in_offset); + temp_1 = (sl * 4 + 1 >= shape.x) ? (MASK_DTYPE4)0 : vload4(0, input_mask + in_offset + shape.y); + temp_2 = (sl * 4 + 2 >= shape.x) ? (MASK_DTYPE4)0 : vload4(0, input_mask + in_offset + shape.y * 2); + temp_3 = (sl * 4 + 3 >= shape.x) ? (MASK_DTYPE4)0 : vload4(0, input_mask + in_offset + shape.y * 3); + } else if(sl_kv * 4 + 1 == shape.y){ + temp_0 = (MASK_DTYPE4)(input_mask[in_offset], 0, 0, 0); + temp_1 = (sl * 4 + 1 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset + shape.y], 0, 0, 0);//vload4(0, input_mask + in_offset + shape.y); + temp_2 = (sl * 4 + 2 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset + shape.y*2], 0, 0, 0);//vload4(0, input_mask + in_offset + shape.y * 2); + temp_3 = (sl * 4 + 3 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset + shape.y*3], 0, 0, 0);//vload4(0, input_mask + in_offset + shape.y * 3); + } else if(sl_kv * 4 + 2 == shape.y){ + temp_0 = (MASK_DTYPE4)(input_mask[in_offset], input_mask[in_offset+1], 0, 0); + temp_1 = (sl * 4 + 1 >= shape.x) ? (MASK_DTYPE4)0 : (FLOAT4)(input_mask[in_offset + shape.y], input_mask[in_offset + shape.y + 1], 0, 0);//vload4(0, input_mask + in_offset + shape.y); + temp_2 = (sl * 4 + 2 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset + shape.y*2], input_mask[in_offset + shape.y*2 + 1], 0, 0);//vload4(0, input_mask + in_offset + shape.y * 2); + temp_3 = (sl * 4 + 3 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset + shape.y*3], input_mask[in_offset + shape.y*3 + 1], 0, 0);//vload4(0, input_mask + in_offset + shape.y * 3); + } else if(sl_kv * 4 + 3 == shape.y){ + temp_0 = (MASK_DTYPE4)(input_mask[in_offset], input_mask[in_offset+1], input_mask[in_offset+2], 0); + temp_1 = (sl * 4 + 1 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset + shape.y], input_mask[in_offset + shape.y + 1], input_mask[in_offset + shape.y + 2], 0);//vload4(0, input_mask + in_offset + shape.y); + temp_2 = (sl * 4 + 2 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset + shape.y*2], input_mask[in_offset + shape.y*2 + 1], input_mask[in_offset + shape.y*2 + 2], 0);//vload4(0, input_mask + in_offset + shape.y * 2); + temp_3 = (sl * 4 + 3 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset + shape.y*3], input_mask[in_offset + shape.y*3 + 1], input_mask[in_offset + shape.y*3 + 2], 0);//vload4(0, input_mask + in_offset + shape.y * 3); + } + + vstore4(temp_0, 0, output_mask + out_offset); + vstore4(temp_1, 0, output_mask + out_offset + seq_len_kv_pack); + vstore4(temp_2, 0, output_mask + out_offset + 2 * seq_len_kv_pack); + vstore4(temp_3, 0, output_mask + out_offset + 3 * seq_len_kv_pack); + } + +} + +__kernel void qkv_transpose_output(GLOBAL_SIZE_3_DIMS + __global const FLOAT *input, // [Batch * mNumHead, ROUND_UP(mHeadDim, mTileHDN), ROUND_UP(seqLen, mTileQ)] + __global FLOAT *output, // [Batch, seqLen/4, mNumHead, mHeadDim, 4] + __private const int tile_q, + __private const int tile_hdn, + __private const int seq_len, + __private const int head_num, + __private const int head_dim +) { + + const int sl = get_global_id(0); // seqLen_4 + const int hd = get_global_id(1); // mHeadDim_4 + const int z = get_global_id(2); // Batch * mNumHead + DEAL_NON_UNIFORM_DIM3(sl, hd, z); + + const int b = z / head_num; + const int hn = z % head_num; + + const int seq_len_pack = ((seq_len + tile_q - 1) / tile_q) * tile_q; + const int head_dim_pack = ((head_dim + tile_hdn - 1) / tile_hdn) * tile_hdn; + + const int offset_inp = ((b * head_num + hn) * head_dim_pack + 4 * hd) * seq_len_pack + 4 * sl; + + const int offset_out = (((b * seq_len + sl*4) * head_num + hn) * head_dim + 4 * hd); + + // Q + FLOAT4 temp_0 = vload4(0, input + offset_inp); + FLOAT4 temp_1 = vload4(0, input + offset_inp + seq_len_pack); + FLOAT4 temp_2 = vload4(0, input + offset_inp + 2 * seq_len_pack); + FLOAT4 temp_3 = vload4(0, input + offset_inp + 3 * seq_len_pack); + + vstore4((FLOAT4)(temp_0.s0, temp_1.s0, temp_2.s0, temp_3.s0), 0, output + offset_out); + if(4 * sl + 1 >= seq_len) return; + vstore4((FLOAT4)(temp_0.s1, temp_1.s1, temp_2.s1, temp_3.s1), 0, output + offset_out + head_num*head_dim); + if(4 * sl + 2 >= seq_len) return; + vstore4((FLOAT4)(temp_0.s2, temp_1.s2, temp_2.s2, temp_3.s2), 0, output + offset_out + 2*head_num*head_dim); + if(4 * sl + 3 >= seq_len) return; + vstore4((FLOAT4)(temp_0.s3, temp_1.s3, temp_2.s3, temp_3.s3), 0, output + offset_out + 3*head_num*head_dim); + +} + +#ifndef NUMHEAD_GROUP_SIZE +#define NUMHEAD_GROUP_SIZE 1 +#endif __kernel void matmul_qk_div_mask(GLOBAL_SIZE_3_DIMS - __global const FLOAT *input0, // query [1 query_seq_len/4 head_num head_dim 4] - __global const FLOAT *input1, // key [1 key_seq_len/4 head_num head_dim 4] - __global FLOAT *output, // prefill [1 head_num query_seq_len/4 key_seq_len 4] decode[1 head_num key_seq_len/4 4] - __global FLOAT *past_key, // [1 head_num max_length/4 head_dim 4] -#ifdef ADD_MASK + __global const FLOAT *input0, // query [1 query_seq_len head_num head_dim] + __global const FLOAT *input1, // key [1 key_seq_len head_num head_dim] + __global FLOAT *output, // prefill [1 head_num query_seq_len key_seq_len] decode[1 head_num key_seq_len/4 4] + __global FLOAT *past_key, // [1 max_length head_num head_dim] + #ifdef ADD_MASK __global const FLOAT* mask, -#else - __global const int* mask, // [1 1 query_seq_len key_seq_len 4] -#endif + #else + __global const int* mask, // [1 1 query_seq_len key_seq_len] + #endif __private const float scale, __private const int query_seq_len, __private const int key_seq_len, __private const int head_num, __private const int kv_head_num, __private const int head_dim) { - - const int x = get_global_id(0); // query_seq_len / 4 for prefill 1 for decode - const int y = get_global_id(1); // head_num - const int z = get_global_id(2); // key_seq_len / 4 + + const int x = get_global_id(0); // key_seq_len + const int y = get_global_id(1); // query_seq_len for prefill 1 for decode + const int z = get_global_id(2); // head_num DEAL_NON_UNIFORM_DIM3(x, y, z); - int yin = y / NUMHEAD_GROUP_SIZE; - const int offset = head_num * head_dim * 4; - const int offset_head = y * head_dim * 4; - __global const FLOAT *A_offset = input0 + x * offset + offset_head; - __global FLOAT *Pastkey_offset = past_key + (z * kv_head_num + yin) * head_dim * 4; - const int z4 = z << 2; - float4 Vscale = (float4)scale; + int x4 = x << 2; + int y4 = y << 2; + int zin = z / NUMHEAD_GROUP_SIZE; + __global const FLOAT *A_offset = input0 + (y4 * head_num + z) * head_dim; + __global FLOAT *Pastkey_offset = past_key + (x4 * kv_head_num + zin) * head_dim; + int strideA = head_num * head_dim; + int strideB = kv_head_num * head_dim; #ifdef OPENCL_PREFILL_ATTENTION - __global const FLOAT *B_offset = input1 + (z * kv_head_num + yin) * head_dim * 4; - const int x4 = x << 2; - const int query_seq_len4 = (query_seq_len + 3) / 4; - const int output_offset = y * query_seq_len4 * key_seq_len * 4; + __global const FLOAT *B_offset = input1 + (x4 * kv_head_num + zin) * head_dim; + int output_offset = (z * query_seq_len + y4) * key_seq_len + x4; float4 out0 = 0; float4 out1 = 0; float4 out2 = 0; float4 out3 = 0; + bool A1_enable = y4 + 1 < query_seq_len; + bool A2_enable = y4 + 2 < query_seq_len; + bool A3_enable = y4 + 3 < query_seq_len; + + bool B1_enable = x4 + 1 < key_seq_len; + bool B2_enable = x4 + 2 < key_seq_len; + bool B3_enable = x4 + 3 < key_seq_len; + const int head_dim4 = (head_dim + 3) / 4; -#ifdef HEADDIM_LEAVE + #ifdef HEADDIM_LEAVE for(int i = 0; i < head_dim4 - 1; ++i){ - float16 A = convert_float16(vload16(i, A_offset)); - float16 B = convert_float16(vload16(i, B_offset)); + float4 A0 = convert_float4(vload4(i, A_offset)); + float4 A1 = A1_enable ? convert_float4(vload4(i, A_offset + strideA)) : (float4)0; + float4 A2 = A2_enable ? convert_float4(vload4(i, A_offset + strideA + strideA)) : (float4)0; + float4 A3 = A3_enable ? convert_float4(vload4(i, A_offset + strideA + strideA + strideA)) : (float4)0; + float4 B0 = convert_float4(vload4(i, B_offset)); + float4 B1 = B1_enable ? convert_float4(vload4(i, B_offset + strideB)) : (float4)0; + float4 B2 = B2_enable ? convert_float4(vload4(i, B_offset + strideB + strideB)) : (float4)0; + float4 B3 = B3_enable ? convert_float4(vload4(i, B_offset + strideB + strideB + strideB)) : (float4)0; - out0 = mad(A.s0123, (float4)B.s0, out0); - out1 = mad(A.s0123, (float4)B.s1, out1); - out2 = mad(A.s0123, (float4)B.s2, out2); - out3 = mad(A.s0123, (float4)B.s3, out3); + out0.x += dot(A0, B0); + out0.y += dot(A0, B1); + out0.z += dot(A0, B2); + out0.w += dot(A0, B3); - out0 = mad(A.s4567, (float4)B.s4, out0); - out1 = mad(A.s4567, (float4)B.s5, out1); - out2 = mad(A.s4567, (float4)B.s6, out2); - out3 = mad(A.s4567, (float4)B.s7, out3); + out1.x += dot(A1, B0); + out1.y += dot(A1, B1); + out1.z += dot(A1, B2); + out1.w += dot(A1, B3); - out0 = mad(A.s89ab, (float4)B.s8, out0); - out1 = mad(A.s89ab, (float4)B.s9, out1); - out2 = mad(A.s89ab, (float4)B.sa, out2); - out3 = mad(A.s89ab, (float4)B.sb, out3); + out2.x += dot(A2, B0); + out2.y += dot(A2, B1); + out2.z += dot(A2, B2); + out2.w += dot(A2, B3); - out0 = mad(A.scdef, (float4)B.sc, out0); - out1 = mad(A.scdef, (float4)B.sd, out1); - out2 = mad(A.scdef, (float4)B.se, out2); - out3 = mad(A.scdef, (float4)B.sf, out3); + out3.x += dot(A3, B0); + out3.y += dot(A3, B1); + out3.z += dot(A3, B2); + out3.w += dot(A3, B3); - vstore16(CONVERT_FLOAT16(B), i, Pastkey_offset); + vstore4(CONVERT_FLOAT4(B0), i, Pastkey_offset); + vstore4(CONVERT_FLOAT4(B1), i, Pastkey_offset + strideB); + vstore4(CONVERT_FLOAT4(B2), i, Pastkey_offset + strideB + strideB); + vstore4(CONVERT_FLOAT4(B3), i, Pastkey_offset + strideB + strideB + strideB); } for(int i = (head_dim4 - 1) * 4; i < head_dim; ++i){ - float4 A = convert_float4(vload4(i, A_offset)); - float4 B = convert_float4(vload4(i, B_offset)); + float A0 = A_offset[i]; + float A1 = A1_enable ? A_offset[i + strideA] : 0; + float A2 = A2_enable ? A_offset[i + strideA + strideA] : 0; + float A3 = A3_enable ? A_offset[i + strideA + strideA + strideA] : 0; + float B0 = B_offset[i]; + float B1 = B1_enable ? B_offset[i + strideB] : 0; + float B2 = B2_enable ? B_offset[i + strideB + strideB] : 0; + float B3 = B3_enable ? B_offset[i + strideB + strideB + strideB] : 0; + + out0.x += A0 * B0; + out0.y += A0 * B1; + out0.z += A0 * B2; + out0.w += A0 * B3; - out0 = mad(A, (float4)B.s0, out0); - out1 = mad(A, (float4)B.s1, out1); - out2 = mad(A, (float4)B.s2, out2); - out3 = mad(A, (float4)B.s3, out3); + out1.x += A1 * B0; + out1.y += A1 * B1; + out1.z += A1 * B2; + out1.w += A1 * B3 - vstore4(CONVERT_FLOAT4(B), i, Pastkey_offset); + out2.x += A2 * B0; + out2.y += A2 * B1; + out2.z += A2 * B2; + out2.w += A2 * B3; + + out3.x += A3 * B0; + out3.y += A3 * B1; + out3.z += A3 * B2; + out3.w += A3 * B3; + + Pastkey_offset[i] = (FLOAT)B0; + Pastkey_offset[i + strideB] = (FLOAT)B1; + Pastkey_offset[i + strideB + strideB] = (FLOAT)B2; + Pastkey_offset[i + strideB + strideB + strideB] = (FLOAT)B3; } -#else + #else for(int i = 0; i < head_dim4; ++i){ - float16 A = convert_float16(vload16(i, A_offset)); - float16 B = convert_float16(vload16(i, B_offset)); + float4 A0 = convert_float4(vload4(i, A_offset)); + float4 A1 = A1_enable ? convert_float4(vload4(i, A_offset + strideA)) : (float4)0; + float4 A2 = A2_enable ? convert_float4(vload4(i, A_offset + strideA + strideA)) : (float4)0; + float4 A3 = A3_enable ? convert_float4(vload4(i, A_offset + strideA + strideA + strideA)) : (float4)0; + float4 B0 = convert_float4(vload4(i, B_offset)); + float4 B1 = B1_enable ? convert_float4(vload4(i, B_offset + strideB)) : (float4)0; + float4 B2 = B2_enable ? convert_float4(vload4(i, B_offset + strideB + strideB)) : (float4)0; + float4 B3 = B3_enable ? convert_float4(vload4(i, B_offset + strideB + strideB + strideB)) : (float4)0; - out0 = mad(A.s0123, (float4)B.s0, out0); - out1 = mad(A.s0123, (float4)B.s1, out1); - out2 = mad(A.s0123, (float4)B.s2, out2); - out3 = mad(A.s0123, (float4)B.s3, out3); + out0.x += dot(A0, B0); + out0.y += dot(A0, B1); + out0.z += dot(A0, B2); + out0.w += dot(A0, B3); - out0 = mad(A.s4567, (float4)B.s4, out0); - out1 = mad(A.s4567, (float4)B.s5, out1); - out2 = mad(A.s4567, (float4)B.s6, out2); - out3 = mad(A.s4567, (float4)B.s7, out3); + out1.x += dot(A1, B0); + out1.y += dot(A1, B1); + out1.z += dot(A1, B2); + out1.w += dot(A1, B3); - out0 = mad(A.s89ab, (float4)B.s8, out0); - out1 = mad(A.s89ab, (float4)B.s9, out1); - out2 = mad(A.s89ab, (float4)B.sa, out2); - out3 = mad(A.s89ab, (float4)B.sb, out3); + out2.x += dot(A2, B0); + out2.y += dot(A2, B1); + out2.z += dot(A2, B2); + out2.w += dot(A2, B3); - out0 = mad(A.scdef, (float4)B.sc, out0); - out1 = mad(A.scdef, (float4)B.sd, out1); - out2 = mad(A.scdef, (float4)B.se, out2); - out3 = mad(A.scdef, (float4)B.sf, out3); - - vstore16(CONVERT_FLOAT16(B), i, Pastkey_offset); + out3.x += dot(A3, B0); + out3.y += dot(A3, B1); + out3.z += dot(A3, B2); + out3.w += dot(A3, B3); + + vstore4(CONVERT_FLOAT4(B0), i, Pastkey_offset); + vstore4(CONVERT_FLOAT4(B1), i, Pastkey_offset + strideB); + vstore4(CONVERT_FLOAT4(B2), i, Pastkey_offset + strideB + strideB); + vstore4(CONVERT_FLOAT4(B3), i, Pastkey_offset + strideB + strideB + strideB); } -#endif - - out0 *= Vscale; - out1 *= Vscale; - out2 *= Vscale; - out3 *= Vscale; - - float4 mask0, mask1, mask2, mask3; - mask = mask + (x4 * key_seq_len + z4) * 4; - mask0.s0 = mask[0]; mask1.s0 = mask[4]; mask2.s0 = mask[8]; mask3.s0 = mask[12]; mask += key_seq_len * 4; - mask0.s1 = mask[0]; mask1.s1 = mask[4]; mask2.s1 = mask[8]; mask3.s1 = mask[12]; mask += key_seq_len * 4; - mask0.s2 = mask[0]; mask1.s2 = mask[4]; mask2.s2 = mask[8]; mask3.s2 = mask[12]; mask += key_seq_len * 4; - mask0.s3 = mask[0]; mask1.s3 = mask[4]; mask2.s3 = mask[8]; mask3.s3 = mask[12]; -#ifdef ADD_MASK + #endif + out0 *= (float4)scale; + out1 *= (float4)scale; + out2 *= (float4)scale; + out3 *= (float4)scale; + float4 mask0 = convert_float4(vload4(0, mask + y4 * key_seq_len + x4)); + float4 mask1 = convert_float4(vload4(0, mask + (y4 + 1) * key_seq_len + x4)); + float4 mask2 = convert_float4(vload4(0, mask + (y4 + 2) * key_seq_len + x4)); + float4 mask3 = convert_float4(vload4(0, mask + (y4 + 3) * key_seq_len + x4)); + #ifdef ADD_MASK out0 += mask0; out1 += mask1; out2 += mask2; out3 += mask3; -#else + #else out0 = (mask0 == (float4)0) ? (float4)(-FLT_MAX) : out0; out1 = (mask1 == (float4)0) ? (float4)(-FLT_MAX) : out1; out2 = (mask2 == (float4)0) ? (float4)(-FLT_MAX) : out2; out3 = (mask3 == (float4)0) ? (float4)(-FLT_MAX) : out3; -#endif - - vstore4(CONVERT_FLOAT4(out0), 0, output + output_offset + x * key_seq_len * 4 + z4 * 4); - if(z4 + 1 >= key_seq_len) return; - vstore4(CONVERT_FLOAT4(out1), 0, output + output_offset + x * key_seq_len * 4 + (z4 + 1) * 4); - if(z4 + 2 >= key_seq_len) return; - vstore4(CONVERT_FLOAT4(out2), 0, output + output_offset + x * key_seq_len * 4 + (z4 + 2) * 4); - if(z4 + 3 >= key_seq_len) return; - vstore4(CONVERT_FLOAT4(out3), 0, output + output_offset + x * key_seq_len * 4 + (z4 + 3) * 4); + #endif + if(B3_enable){ + vstore4(CONVERT_FLOAT4(out0), 0, output + output_offset); + if(!A1_enable) return; + output_offset += key_seq_len; + vstore4(CONVERT_FLOAT4(out1), 0, output + output_offset); + if(!A2_enable) return; + output_offset += key_seq_len; + vstore4(CONVERT_FLOAT4(out2), 0, output + output_offset); + if(!A3_enable) return; + output_offset += key_seq_len; + vstore4(CONVERT_FLOAT4(out3), 0, output + output_offset); + } else if(B2_enable){ + vstore3(CONVERT_FLOAT3((float3)(out0.x, out0.y, out0.z)), 0, output + output_offset); + if(!A1_enable) return; + output_offset += key_seq_len; + vstore3(CONVERT_FLOAT3((float3)(out1.x, out1.y, out1.z)), 0, output + output_offset); + if(!A2_enable) return; + output_offset += key_seq_len; + vstore3(CONVERT_FLOAT3((float3)(out2.x, out2.y, out2.z)), 0, output + output_offset); + if(!A3_enable) return; + output_offset += key_seq_len; + vstore3(CONVERT_FLOAT3((float3)(out3.x, out3.y, out3.z)), 0, output + output_offset); + } else if(B1_enable){ + vstore2(CONVERT_FLOAT2((float2)(out0.x, out0.y)), 0, output + output_offset); + if(!A1_enable) return; + output_offset += key_seq_len; + vstore2(CONVERT_FLOAT2((float2)(out1.x, out1.y)), 0, output + output_offset); + if(!A2_enable) return; + output_offset += key_seq_len; + vstore2(CONVERT_FLOAT2((float2)(out2.x, out2.y)), 0, output + output_offset); + if(!A3_enable) return; + output_offset += key_seq_len; + vstore2(CONVERT_FLOAT2((float2)(out3.x, out3.y)), 0, output + output_offset); + } else { + output[output_offset] = out0.x; + if(!A1_enable) return; + output[output_offset + key_seq_len] = out1.x; + if(!A2_enable) return; + output[output_offset + key_seq_len + key_seq_len] = out2.x; + if(!A3_enable) return; + output[output_offset + key_seq_len + key_seq_len + key_seq_len] = out3.x; + } #else - __global const FLOAT *B_offset = input1 + yin * head_dim * 4; - const int key_seq_len4 = (key_seq_len + 3) / 4; float4 out = 0; const int head_dim4 = (head_dim + 3) / 4; - -#ifdef HEADDIM_LEAVE + int key_seq_len4 = (key_seq_len + 3) / 4; + #ifdef HEADDIM_LEAVE for(int i = 0; i < head_dim4 - 1; ++i){ - float16 A = convert_float16(vload16(i, A_offset)); - float16 B = convert_float16(vload16(i, Pastkey_offset)); - - out = mad((float4)A.s0, B.s0123, out); - out = mad((float4)A.s4, B.s4567, out); - out = mad((float4)A.s8, B.s89ab, out); - out = mad((float4)A.sc, B.scdef, out); + float4 A = convert_float4(vload4(i, A_offset)); + float4 B0 = convert_float4(vload4(i, Pastkey_offset)); + float4 B1 = convert_float4(vload4(i, Pastkey_offset + strideB)); + float4 B2 = convert_float4(vload4(i, Pastkey_offset + strideB + strideB)); + float4 B3 = convert_float4(vload4(i, Pastkey_offset + strideB + strideB + strideB)); + + out.x += dot(A, B0); + out.y += dot(A, B1); + out.z += dot(A, B2); + out.w += dot(A, B3); } for(int i = (head_dim4 - 1) * 4; i < head_dim; ++i){ - float4 A = convert_float4(vload4(i, A_offset)); - float4 B = convert_float4(vload4(i, Pastkey_offset)); - - out = mad((float4)A.s0, B, out); + float A = A_offset[i]; + float B0 = Pastkey_offset[i]; + float B1 = Pastkey_offset[i + strideB]; + float B2 = Pastkey_offset[i + strideB + strideB]; + float B3 = Pastkey_offset[i + strideB + strideB]; + out.x += A * B0; + out.y += A * B1; + out.z += A * B2; + out.w += A * B3; } -#else + #else for(int i = 0; i < head_dim4; ++i){ - float16 A = convert_float16(vload16(i, A_offset)); - float16 B = convert_float16(vload16(i, Pastkey_offset)); + float4 A = convert_float4(vload4(i, A_offset)); + float4 B0 = convert_float4(vload4(i, Pastkey_offset)); + float4 B1 = convert_float4(vload4(i, Pastkey_offset + strideB)); + float4 B2 = convert_float4(vload4(i, Pastkey_offset + strideB + strideB)); + float4 B3 = convert_float4(vload4(i, Pastkey_offset + strideB + strideB + strideB)); - out = mad((float4)A.s0, B.s0123, out); - out = mad((float4)A.s4, B.s4567, out); - out = mad((float4)A.s8, B.s89ab, out); - out = mad((float4)A.sc, B.scdef, out); + out.x += dot(A, B0); + out.y += dot(A, B1); + out.z += dot(A, B2); + out.w += dot(A, B3); } -#endif - if(z == key_seq_len4 - 1){ - int remain = key_seq_len - z * 4 - 1; - Pastkey_offset += remain; + #endif + int remain = key_seq_len - x4; + if(x == key_seq_len4 - 1){ + __global const FLOAT *B_offset = input1 + zin * head_dim; + Pastkey_offset += (remain - 1) * strideB; float tmp = 0; - for(int i = 0; i < head_dim; ++i){ - float A = A_offset[i*4]; - float B = B_offset[i*4]; - Pastkey_offset[i * 4] = B; + #ifdef HEADDIM_LEAVE + for(int i = 0; i < head_dim4 - 1; ++i){ + float4 A = convert_float4(vload4(i, A_offset)); + float4 B = convert_float4(vload4(i, B_offset)); + + tmp += dot(A, B); + vstore4(CONVERT_FLOAT4(B), i, Pastkey_offset); + } + for(int i = (head_dim4 - 1) * 4; i < head_dim; ++i){ + float A = A_offset[i]; + float B = B_offset[i]; tmp += A * B; + Pastkey_offset[i] = B; } + #else + for(int i = 0; i < head_dim4; ++i){ + float4 A = convert_float4(vload4(i, A_offset)); + float4 B = convert_float4(vload4(i, B_offset)); + + tmp += dot(A, B); + vstore4(CONVERT_FLOAT4(B), i, Pastkey_offset); + } + #endif float *out_ptr = (float*)&out; - out_ptr[remain] = tmp; + out_ptr[remain - 1] = tmp; + } + out *= (float4)scale; + if(remain >= 4){ + vstore4(CONVERT_FLOAT4(out), 0, output + z * key_seq_len + x4); + } else if (remain >= 3){ + vstore3(CONVERT_FLOAT3((float3)(out.x, out.y, out.z)), 0, output + z * key_seq_len + x4); + } else if (remain >= 2){ + vstore2(CONVERT_FLOAT2((float2)(out.x, out.y)), 0, output + z * key_seq_len + x4); + } else { + output[z * key_seq_len + x4] = out.x; } - out *= Vscale; - vstore4(CONVERT_FLOAT4(out), 0, output + y * key_seq_len4 * 4 + z4); #endif } __kernel void matmul_qkv(GLOBAL_SIZE_3_DIMS - __global const FLOAT *input0, // qk prefill [1 head_num qk_seq_len/4 value_seq_len 4] decode[1 head_num value_seq_len/4 4] - __global const FLOAT *input1, // [1 value_seq_len/4 head_num head_dim 4] - __global FLOAT *output, // [1 qk_seq_len head_num*head_dim 1 4] - __global FLOAT *past_value, // [1 value_seq_len/4 head_num head_dim 4] + __global const FLOAT *input0, // qk prefill [1 head_num qk_seq_len value_seq_len] decode[1 head_num value_seq_len] + __global const FLOAT *input1, // [1 value_seq_len head_num head_dim] + __global FLOAT *output, // [1 qk_seq_len head_num head_dim] + __global FLOAT *past_value, // [1 value_seq_len head_num head_dim] __private const int qk_seq_len, __private const int value_seq_len, __private const int head_num, __private const int kv_head_num, __private const int head_dim) { - - const int x = get_global_id(0); // prefill qk_seq_len / 4 decode 1 + + const int x = get_global_id(0); // head_dim << 2 const int y = get_global_id(1); // head_num - const int z = get_global_id(2); // head_dim << 2 - const int z4 = z << 2; + const int z = get_global_id(2); // prefill qk_seq_len decode 1 + + const int x4 = x << 2; DEAL_NON_UNIFORM_DIM3(x, y, z); const int yin = y / NUMHEAD_GROUP_SIZE; #ifdef OPENCL_PREFILL_ATTENTION - const int offset = head_num * head_dim * 4; - const int stride = kv_head_num * head_dim * 4; - const int offset_head = y * head_dim * 4 + z4 * 4; - const int value_seq_len4 = (value_seq_len + 3) / 4; - const int qk_seq_len4 = (qk_seq_len + 3) / 4; - __global const FLOAT *A_offset = input0 + (y * qk_seq_len4 + x) * value_seq_len * 4; - __global const FLOAT *B_offset = input1 + yin * head_dim * 4 + z4 * 4; - __global FLOAT *Pastvalue_offset = past_value + yin * head_dim * 4 + z4 * 4; + int z4 = z << 2; + int value_seq_len4 = (value_seq_len + 3) / 4; + int loop_end = max(value_seq_len4 - 1, 0); + const int stride = kv_head_num * head_dim; + __global const FLOAT *A_offset = input0 + (y * qk_seq_len + z4) * value_seq_len; + __global const FLOAT *B_offset = input1 + yin * head_dim + x4; + __global FLOAT *Pastvalue_offset = past_value + yin * head_dim + x4; COMPUTE_FLOAT4 out0 = 0; COMPUTE_FLOAT4 out1 = 0; COMPUTE_FLOAT4 out2 = 0; COMPUTE_FLOAT4 out3 = 0; - for(int i = 0; i < value_seq_len4 - 1; ++i){ + for(int i = 0; i < loop_end; ++i){ int index = i << 2; - COMPUTE_FLOAT4 A0 = CONVERT_COMPUTE_FLOAT4(vload4(index, A_offset)); - COMPUTE_FLOAT4 A1 = CONVERT_COMPUTE_FLOAT4(vload4(index + 1, A_offset)); - COMPUTE_FLOAT4 A2 = CONVERT_COMPUTE_FLOAT4(vload4(index + 2, A_offset)); - COMPUTE_FLOAT4 A3 = CONVERT_COMPUTE_FLOAT4(vload4(index + 3, A_offset)); - COMPUTE_FLOAT16 B = CONVERT_COMPUTE_FLOAT16(vload16(0, B_offset + i * stride)); + COMPUTE_FLOAT4 A0 = CONVERT_COMPUTE_FLOAT4(vload4(i, A_offset)); + COMPUTE_FLOAT4 A1 = CONVERT_COMPUTE_FLOAT4(vload4(i, A_offset + value_seq_len)); + COMPUTE_FLOAT4 A2 = CONVERT_COMPUTE_FLOAT4(vload4(i, A_offset + value_seq_len + value_seq_len)); + COMPUTE_FLOAT4 A3 = CONVERT_COMPUTE_FLOAT4(vload4(i, A_offset + value_seq_len + value_seq_len + value_seq_len)); + COMPUTE_FLOAT4 B0 = CONVERT_COMPUTE_FLOAT4(vload4(0, B_offset + (index + 0) * stride)); + COMPUTE_FLOAT4 B1 = CONVERT_COMPUTE_FLOAT4(vload4(0, B_offset + (index + 1) * stride)); + COMPUTE_FLOAT4 B2 = CONVERT_COMPUTE_FLOAT4(vload4(0, B_offset + (index + 2) * stride)); + COMPUTE_FLOAT4 B3 = CONVERT_COMPUTE_FLOAT4(vload4(0, B_offset + (index + 3) * stride)); - out0 = mad(A0, (COMPUTE_FLOAT4)B.s0, out0); - out0 = mad(A1, (COMPUTE_FLOAT4)B.s1, out0); - out0 = mad(A2, (COMPUTE_FLOAT4)B.s2, out0); - out0 = mad(A3, (COMPUTE_FLOAT4)B.s3, out0); + out0 = mad(B0, (COMPUTE_FLOAT4)A0.x, out0); + out0 = mad(B1, (COMPUTE_FLOAT4)A0.y, out0); + out0 = mad(B2, (COMPUTE_FLOAT4)A0.z, out0); + out0 = mad(B3, (COMPUTE_FLOAT4)A0.w, out0); - out1 = mad(A0, (COMPUTE_FLOAT4)B.s4, out1); - out1 = mad(A1, (COMPUTE_FLOAT4)B.s5, out1); - out1 = mad(A2, (COMPUTE_FLOAT4)B.s6, out1); - out1 = mad(A3, (COMPUTE_FLOAT4)B.s7, out1); + out1 = mad(B0, (COMPUTE_FLOAT4)A1.x, out1); + out1 = mad(B1, (COMPUTE_FLOAT4)A1.y, out1); + out1 = mad(B2, (COMPUTE_FLOAT4)A1.z, out1); + out1 = mad(B3, (COMPUTE_FLOAT4)A1.w, out1); - out2 = mad(A0, (COMPUTE_FLOAT4)B.s8, out2); - out2 = mad(A1, (COMPUTE_FLOAT4)B.s9, out2); - out2 = mad(A2, (COMPUTE_FLOAT4)B.sa, out2); - out2 = mad(A3, (COMPUTE_FLOAT4)B.sb, out2); + out2 = mad(B0, (COMPUTE_FLOAT4)A2.x, out2); + out2 = mad(B1, (COMPUTE_FLOAT4)A2.y, out2); + out2 = mad(B2, (COMPUTE_FLOAT4)A2.z, out2); + out2 = mad(B3, (COMPUTE_FLOAT4)A2.w, out2); - out3 = mad(A0, (COMPUTE_FLOAT4)B.sc, out3); - out3 = mad(A1, (COMPUTE_FLOAT4)B.sd, out3); - out3 = mad(A2, (COMPUTE_FLOAT4)B.se, out3); - out3 = mad(A3, (COMPUTE_FLOAT4)B.sf, out3); - - vstore16(CONVERT_FLOAT16(B), 0, Pastvalue_offset + i * stride); + out3 = mad(B0, (COMPUTE_FLOAT4)A3.x, out3); + out3 = mad(B1, (COMPUTE_FLOAT4)A3.y, out3); + out3 = mad(B2, (COMPUTE_FLOAT4)A3.z, out3); + out3 = mad(B3, (COMPUTE_FLOAT4)A3.w, out3); + vstore4(CONVERT_FLOAT4(B0), 0, Pastvalue_offset + (index + 0) * stride); + vstore4(CONVERT_FLOAT4(B1), 0, Pastvalue_offset + (index + 1) * stride); + vstore4(CONVERT_FLOAT4(B2), 0, Pastvalue_offset + (index + 2) * stride); + vstore4(CONVERT_FLOAT4(B3), 0, Pastvalue_offset + (index + 3) * stride); } - -#ifdef HEADDIM_LEAVE - COMPUTE_FLOAT16 B = CONVERT_COMPUTE_FLOAT16(vload16(0, B_offset + (value_seq_len4 - 1) * stride)); - COMPUTE_FLOAT *B_ptr = (COMPUTE_FLOAT*)&B; - for(int i = (value_seq_len4 - 1) * 4, j = 0; i < value_seq_len; ++i, ++j){ - COMPUTE_FLOAT4 A0 = CONVERT_COMPUTE_FLOAT4(vload4(i, A_offset)); - out0 = mad(A0, (COMPUTE_FLOAT4)B_ptr[j], out0); - out1 = mad(A0, (COMPUTE_FLOAT4)B_ptr[j + 4], out1); - out2 = mad(A0, (COMPUTE_FLOAT4)B_ptr[j + 8], out2); - out3 = mad(A0, (COMPUTE_FLOAT4)B_ptr[j + 12], out3); + for(int i = loop_end << 2; i < value_seq_len; ++i){ + COMPUTE_FLOAT A0 = A_offset[i]; + COMPUTE_FLOAT A1 = A_offset[i + value_seq_len]; + COMPUTE_FLOAT A2 = A_offset[i + value_seq_len + value_seq_len]; + COMPUTE_FLOAT A3 = A_offset[i + value_seq_len + value_seq_len + value_seq_len]; + COMPUTE_FLOAT4 B = CONVERT_COMPUTE_FLOAT4(vload4(0, B_offset + i * stride)); + + out0 = mad(B, (COMPUTE_FLOAT4)A0, out0); + out1 = mad(B, (COMPUTE_FLOAT4)A1, out1); + out2 = mad(B, (COMPUTE_FLOAT4)A2, out2); + out3 = mad(B, (COMPUTE_FLOAT4)A3, out3); + vstore4(CONVERT_FLOAT4(B), 0, Pastvalue_offset + i * stride); } - vstore4(CONVERT_FLOAT4(out0), 0, output + x * offset + (y * head_dim + z4) * 4); - vstore4(CONVERT_FLOAT4(B.s0123), 0, Pastvalue_offset + (value_seq_len4 - 1) * stride); - if(z4 + 1 >= head_dim) return; - vstore4(CONVERT_FLOAT4(out1), 0, output + x * offset + (y * head_dim + z4 + 1) * 4); - vstore4(CONVERT_FLOAT4(B.s4567), 1, Pastvalue_offset + (value_seq_len4 - 1) * stride); - if(z4 + 2 >= head_dim) return; - vstore4(CONVERT_FLOAT4(out2), 0, output + x * offset + (y * head_dim + z4 + 2) * 4); - vstore4(CONVERT_FLOAT4(B.s89ab), 2, Pastvalue_offset + (value_seq_len4 - 1) * stride); - if(z4 + 3 >= head_dim) return; - vstore4(CONVERT_FLOAT4(out3), 0, output + x * offset + (y * head_dim + z4 + 3) * 4); - vstore4(CONVERT_FLOAT4(B.scdef), 3, Pastvalue_offset + (value_seq_len4 - 1) * stride); -#else - COMPUTE_FLOAT16 B = CONVERT_COMPUTE_FLOAT16(vload16(0, B_offset + (value_seq_len4 - 1) * stride)); - vstore16(CONVERT_FLOAT16(B), 0, Pastvalue_offset + (value_seq_len4 - 1) * stride); - COMPUTE_FLOAT *B_ptr = (COMPUTE_FLOAT*)&B; - for(int i = (value_seq_len4 - 1) * 4, j = 0; i < value_seq_len; ++i, ++j){ - COMPUTE_FLOAT4 A0 = CONVERT_COMPUTE_FLOAT4(vload4(i, A_offset)); - out0 = mad(A0, (COMPUTE_FLOAT4)B_ptr[j], out0); - out1 = mad(A0, (COMPUTE_FLOAT4)B_ptr[j + 4], out1); - out2 = mad(A0, (COMPUTE_FLOAT4)B_ptr[j + 8], out2); - out3 = mad(A0, (COMPUTE_FLOAT4)B_ptr[j + 12], out3); + + #ifdef HEADDIM_LEAVE + int remain = head_dim - x4; + int output_offset = (z4 * head_num + y) * head_dim + x4; + if(remain >= 4){ + vstore4(CONVERT_FLOAT4(out0), 0, output + output_offset); + } else if(remain == 3){ + vstore3(CONVERT_FLOAT3((COMPUTE_FLOAT3)(out0.x, out0.y, out0.z)), 0, output + output_offset); + } else if(remain == 2){ + vstore2(CONVERT_FLOAT2((COMPUTE_FLOAT3)(out0.x, out0.y)), 0, output + output_offset); + } else{ + output[output_offset] = out0.x; } - vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0, out1, out2, out3)), 0, output + x * offset + (y * head_dim + z4) * 4); -#endif + if(z4 + 1 >= qk_seq_len) return; + output_offset += head_num * head_dim; + if(remain >= 4){ + vstore4(CONVERT_FLOAT4(out1), 0, output + output_offset); + } else if(remain == 3){ + vstore3(CONVERT_FLOAT3((COMPUTE_FLOAT3)(out1.x, out1.y, out1.z)), 0, output + output_offset); + } else if(remain == 2){ + vstore2(CONVERT_FLOAT2((COMPUTE_FLOAT3)(out1.x, out1.y)), 0, output + output_offset); + } else{ + output[output_offset] = out1.x; + } + if(z4 + 2 >= qk_seq_len) return; + output_offset += head_num * head_dim; + if(remain >= 4){ + vstore4(CONVERT_FLOAT4(out2), 0, output + output_offset); + } else if(remain == 3){ + vstore3(CONVERT_FLOAT3((COMPUTE_FLOAT3)(out2.x, out2.y, out2.z)), 0, output + output_offset); + } else if(remain == 2){ + vstore2(CONVERT_FLOAT2((COMPUTE_FLOAT3)(out2.x, out2.y)), 0, output + output_offset); + } else{ + output[output_offset] = out2.x; + } + if(z4 + 3 >= qk_seq_len) return; + output_offset += head_num * head_dim; + if(remain >= 4){ + vstore4(CONVERT_FLOAT4(out3), 0, output + output_offset); + } else if(remain == 3){ + vstore3(CONVERT_FLOAT3((COMPUTE_FLOAT3)(out3.x, out3.y, out3.z)), 0, output + output_offset); + } else if(remain == 2){ + vstore2(CONVERT_FLOAT2((COMPUTE_FLOAT3)(out3.x, out3.y)), 0, output + output_offset); + } else{ + output[(x * head_num + y) * head_dim + z4] = out3.x; + } + #else + int output_offset = (z4 * head_num + y) * head_dim + x4; + vstore4(CONVERT_FLOAT4(out0), 0, output + output_offset); + if(z4 + 1 >= qk_seq_len) return; + output_offset += head_num * head_dim; + vstore4(CONVERT_FLOAT4(out1), 0, output + output_offset); + if(z4 + 2 >= qk_seq_len) return; + output_offset += head_num * head_dim; + vstore4(CONVERT_FLOAT4(out2), 0, output + output_offset); + if(z4 + 3 >= qk_seq_len) return; + output_offset += head_num * head_dim; + vstore4(CONVERT_FLOAT4(out3), 0, output + output_offset); + #endif #else - const int value_seq_len4 = (value_seq_len + 3) / 4; - const int stride = kv_head_num * head_dim * 4; - const int offset = head_num * head_dim * 4; - const int offset_head = y * head_dim * 4 + z4 * 4; - const int loop = (value_seq_len + 2) / 4; - __global const FLOAT *A_offset = input0 + y * value_seq_len4 * 4; - __global const FLOAT *B_offset = input1 + yin * head_dim * 4 + z4 * 4; - __global FLOAT *Pastvalue_offset = past_value + yin * head_dim * 4 + z4 * 4; + int value_seq_len4 = (value_seq_len - 1 + 3) / 4; + int loop_end = max(value_seq_len4 - 1, 0); + const int stride = kv_head_num * head_dim; + __global const FLOAT *A_offset = input0 + y * value_seq_len; + __global const FLOAT *B_offset = input1 + yin * head_dim + x4; + __global FLOAT *Pastvalue_offset = past_value + yin * head_dim + x4; COMPUTE_FLOAT4 out = 0; - for(int i = 0; i < loop - 1; i++){ + for(int i = 0; i < loop_end; i++){ + int index = i << 2; COMPUTE_FLOAT4 A = CONVERT_COMPUTE_FLOAT4(vload4(i, A_offset)); - COMPUTE_FLOAT16 B = CONVERT_COMPUTE_FLOAT16(vload16(0, Pastvalue_offset + i * stride)); + COMPUTE_FLOAT4 B0 = CONVERT_COMPUTE_FLOAT4(vload4(0, Pastvalue_offset + (index + 0) * stride)); + COMPUTE_FLOAT4 B1 = CONVERT_COMPUTE_FLOAT4(vload4(0, Pastvalue_offset + (index + 1) * stride)); + COMPUTE_FLOAT4 B2 = CONVERT_COMPUTE_FLOAT4(vload4(0, Pastvalue_offset + (index + 2) * stride)); + COMPUTE_FLOAT4 B3 = CONVERT_COMPUTE_FLOAT4(vload4(0, Pastvalue_offset + (index + 3) * stride)); - out.s0 += dot(A, B.s0123); - out.s1 += dot(A, B.s4567); - out.s2 += dot(A, B.s89ab); - out.s3 += dot(A, B.scdef); + out = mad(B0, (COMPUTE_FLOAT4)A.x, out); + out = mad(B1, (COMPUTE_FLOAT4)A.y, out); + out = mad(B2, (COMPUTE_FLOAT4)A.z, out); + out = mad(B3, (COMPUTE_FLOAT4)A.w, out); } - int start = (loop - 1) < 0 ? 0 : (loop - 1); - COMPUTE_FLOAT16 B_Vec = CONVERT_COMPUTE_FLOAT16(vload16(0, Pastvalue_offset + start * stride)); - COMPUTE_FLOAT *B_ptr = (COMPUTE_FLOAT *)&B_Vec; - for(int i = start * 4; i < value_seq_len - 1; ++i){ + for(int i = loop_end << 2; i < value_seq_len - 1; i++){ COMPUTE_FLOAT A = A_offset[i]; + COMPUTE_FLOAT4 B = CONVERT_COMPUTE_FLOAT4(vload4(0, Pastvalue_offset + i * stride)); - int index = i % 4; - out.s0 += A * B_ptr[index]; - out.s1 += A * B_ptr[index+4]; - out.s2 += A * B_ptr[index+8]; - out.s3 += A * B_ptr[index+12]; + out = mad(B, (COMPUTE_FLOAT4)A, out); } COMPUTE_FLOAT A = A_offset[value_seq_len - 1]; - COMPUTE_FLOAT B0 = B_offset[0]; - COMPUTE_FLOAT B1 = B_offset[4]; - COMPUTE_FLOAT B2 = B_offset[8]; - COMPUTE_FLOAT B3 = B_offset[12]; - out.s0 += A * B0; - out.s1 += A * B1; - out.s2 += A * B2; - out.s3 += A * B3; - int index = ((value_seq_len - 1) >> 2) * stride + ((value_seq_len - 1) % 4); - -#ifdef HEADDIM_LEAVE - Pastvalue_offset[index] = B0; - output[(y * head_dim + z4) * 4] = out.s0; - if(z4 + 1 >= head_dim) return; - Pastvalue_offset[index + 4] = B1; - output[(y * head_dim + z4 + 1) * 4] = out.s1; - if(z4 + 2 >= head_dim) return; - Pastvalue_offset[index + 8] = B2; - output[(y * head_dim + z4 + 2) * 4] = out.s2; - if(z4 + 3 >= head_dim) return; - Pastvalue_offset[index + 12] = B3; - output[(y * head_dim + z4 + 3) * 4] = out.s3; -#else - Pastvalue_offset[index] = B0; - Pastvalue_offset[index + 4] = B1; - Pastvalue_offset[index + 8] = B2; - Pastvalue_offset[index + 12] = B3; + COMPUTE_FLOAT4 B = CONVERT_COMPUTE_FLOAT4(vload4(0, B_offset)); + out = mad(B, (COMPUTE_FLOAT4)A, out); - output[(y * head_dim + z4) * 4] = out.s0; - output[(y * head_dim + z4 + 1) * 4] = out.s1; - output[(y * head_dim + z4 + 2) * 4] = out.s2; - output[(y * head_dim + z4 + 3) * 4] = out.s3; -#endif + #ifdef HEADDIM_LEAVE + int remain = head_dim - x4; + if(remain >= 4){ + vstore4(CONVERT_FLOAT4(out), 0, output + y * head_dim + x4); + vstore4(CONVERT_FLOAT4(B), 0, Pastvalue_offset + (value_seq_len - 1) * stride); + } else if(remain == 3){ + vstore3(CONVERT_FLOAT3((COMPUTE_FLOAT3)(out.x, out.y, out.z)), 0, output + y * head_dim + x4); + vstore3(CONVERT_FLOAT4((COMPUTE_FLOAT3)(B.x, B.y, B.z)), 0, Pastvalue_offset + (value_seq_len - 1) * stride); + } else if(remain == 2){ + vstore2(CONVERT_FLOAT2((COMPUTE_FLOAT3)(out.x, out.y)), 0, output + y * head_dim + x4); + vstore2(CONVERT_FLOAT4((COMPUTE_FLOAT3)(B.x, B.y)), 0, Pastvalue_offset + (value_seq_len - 1) * stride); + } else{ + output[(x * head_num + y) * head_dim + x4] = out.x; + Pastvalue_offset[(value_seq_len - 1) * stride] = B.x; + } + #else + vstore4(CONVERT_FLOAT4(B), 0, Pastvalue_offset + (value_seq_len - 1) * stride); + vstore4(CONVERT_FLOAT4(out), 0, output + y * head_dim + x4); + #endif #endif } diff --git a/source/backend/opencl/execution/cl/binary_buf.cl b/source/backend/opencl/execution/cl/binary_buf.cl index 3528882ab..4ad1bad8d 100644 --- a/source/backend/opencl/execution/cl/binary_buf.cl +++ b/source/backend/opencl/execution/cl/binary_buf.cl @@ -5,84 +5,79 @@ __kernel void binary_buf(__private int global_dim0, __private int global_dim1, __global INPUT_TYPE* input0, __global INPUT_TYPE* input1, __global OUTPUT_TYPE* output, - __private const int4 shape,//[N,H,W,C4] - __private const int2 isFull, + __private const int size, __private const int activationType) { - int2 pos = (int2)(get_global_id(0), get_global_id(1));//NC4, HW + int2 pos = (int2)(get_global_id(0), get_global_id(1));//NCHW, 1 if (pos.x < global_dim0 && pos.y < global_dim1) { - #ifdef WH_PACK4 - int offset = pos.x * (shape.y*shape.z/4) + pos.y; - #ifdef A_SINGLE - float data0 = input0[0]; - float16 in0_16 = (float16)data0; - #else - float16 in0_16 = convert_float16(vload16(offset, input0)); - #endif - - #ifdef B_SINGLE - float data1 = input1[0]; - float16 in1_16 = (float16)data1; - #else - float16 in1_16 = convert_float16(vload16(offset, input1)); - #endif - - float16 out; - float4 in0 = in0_16.s0123; - float4 in1 = in1_16.s0123; - out.s0123 = OPERATOR; - - in0 = in0_16.s4567; - in1 = in1_16.s4567; - out.s4567 = OPERATOR; - - in0 = in0_16.s89ab; - in1 = in1_16.s89ab; - out.s89ab = OPERATOR; - - in0 = in0_16.scdef; - in1 = in1_16.scdef; - out.scdef = OPERATOR; - - if(activationType == 1) { - out = fmax(out, (float16)0); - } - vstore16(CONVERT_OUTPUT16(out), offset, output); - #else - int offset = pos.x * (shape.y*shape.z) + pos.y; - #ifdef A_SINGLE - float data0 = input0[0]; - float4 in0 = (float4)(data0, data0, data0, data0); - #else - float4 in0 = convert_float4(vload4(offset, input0)); - #endif + int offset = pos.x << 2; +#ifdef PACK_LEAVE + if(offset + 3 >= size){ + int remain = size - offset; + float4 in0, in1; + float* in0_ptr = (float*)&in0; + float* in1_ptr = (float*)&in1; + + for(int i = 0; i < remain; ++i){ + #ifdef A_SINGLE + in0_ptr[i] = (float)input0[0]; + #else + in0_ptr[i] = (float)input0[offset + i]; + #endif - #ifdef B_SINGLE - float data1 = input1[0]; - float4 in1 = (float4)(data1, data1, data1, data1); - #else - float4 in1 = convert_float4(vload4(offset, input1)); - #endif + #ifdef B_SINGLE + in1_ptr[i] = (float)input1[0]; + #else + in1_ptr[i] = (float)input1[offset + i]; + #endif + } + float4 out = OPERATOR; + if(activationType == 1) { + out = fmax(out, (float4)0); + } + float* out_ptr = (float*)&out; + for(int i = 0; i < remain; ++i){ + output[offset + i] = (OUTPUT_TYPE)out_ptr[i]; + } + }else { +#endif + #ifdef A_SINGLE + float data0 = input0[0]; + float4 in0 = (float4)(data0, data0, data0, data0); + #else + float4 in0 = convert_float4(vload4(0, input0 + offset)); + #endif - float4 out = OPERATOR; + #ifdef B_SINGLE + float data1 = input1[0]; + float4 in1 = (float4)(data1, data1, data1, data1); + #else + float4 in1 = convert_float4(vload4(0, input1 + offset)); + #endif + + float4 out = OPERATOR; - if(activationType == 1) { - out = fmax(out, (float4)0); + if(activationType == 1) { + out = fmax(out, (float4)0); + } + vstore4(CONVERT_OUTPUT4(out), 0, output + offset); +#ifdef PACK_LEAVE } - vstore4(CONVERT_OUTPUT4(out), offset, output); - #endif +#endif } } __kernel void prelu_buf(__private int global_dim0, __private int global_dim1, __global INPUT_TYPE* input0, __global INPUT_TYPE* input1, __global OUTPUT_TYPE* output, - __private const int4 shape//[N,H,W,C4] + __private const int4 shape ) { int2 pos = (int2)(get_global_id(0), get_global_id(1));//NC4, HW - + if (pos.x < global_dim0 && pos.y < global_dim1) { - int offset = pos.x * (shape.y*shape.z) + pos.y; + int b = pos.x / shape.w; + int c = pos.x % shape.w; + int offset = (b + c * shape.x) * (shape.y*shape.z) + pos.y; float4 in0 = convert_float4(vload4(offset, input0)); float4 in1 = convert_float4(vload4(pos.x % shape.w, input1)); float4 out = OPERATOR; diff --git a/source/backend/opencl/execution/cl/binary_subgroup_buf.cl b/source/backend/opencl/execution/cl/binary_subgroup_buf.cl index e3362a21b..54c162120 100644 --- a/source/backend/opencl/execution/cl/binary_subgroup_buf.cl +++ b/source/backend/opencl/execution/cl/binary_subgroup_buf.cl @@ -19,7 +19,7 @@ __kernel void binary_buf_c4_c4_c4(__private int global_dim0, __private int globa const int batch_idx = get_global_id(2); const int channel_idx = get_global_id(1); - const int offset = (((batch_idx*channel4+channel_idx)*shape.y+h_idx)*shape.z+w_idx) * 4; + const int offset = (((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx) * 4; float4 in0 = convert_float4(vload4(0, input0 + offset*isFull.x)); float4 in1 = convert_float4(vload4(0, input1 + offset*isFull.y)); @@ -57,7 +57,7 @@ __kernel void binary_buf_c4_c4_c16(__private int global_dim0, __private int glob const int dst_width = shape.z + output_pad_left + output_pad_right; const int channe_out_idx = channel_idx >> 2; - const int offset = (((batch_idx*channel4+channel_idx)*shape.y+h_idx)*shape.z+w_idx) * 4; + const int offset = (((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx) * 4; const int dst_offset = (((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width+w_idx+output_pad_left) * 16 + (channel_idx % 4) * 4; float4 in0 = convert_float4(vload4(0, input0 + offset*isFull.x)); @@ -105,7 +105,7 @@ __kernel void binary_buf_c4_c16_c4(__private int global_dim0, __private int glob const int src_width = shape.z + input1_pad_left + input1_pad_right; const int channe_out_idx = channel_idx >> 2; - const int offset0 = (((batch_idx*channel4+channel_idx)*shape.y+h_idx)*shape.z+w_idx) * 4; + const int offset0 = (((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx) * 4; const int offset1 = (((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*src_width+w_idx+input1_pad_left) * 16 + (channel_idx % 4) * 4; float4 in0 = convert_float4(vload4(0, input0 + offset0*isFull.x)); @@ -142,7 +142,7 @@ __kernel void binary_buf_c16_c4_c4(__private int global_dim0, __private int glob const int src_width = shape.z + input0_pad_left + input0_pad_right; const int channe_out_idx = channel_idx >> 2; - const int offset1 = (((batch_idx*channel4+channel_idx)*shape.y+h_idx)*shape.z+w_idx) * 4; + const int offset1 = (((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx) * 4; const int offset0 = (((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*src_width+w_idx+input0_pad_left) * 16 + (channel_idx % 4) * 4; float4 in0 = convert_float4(vload4(0, input0 + offset0*isFull.x)); @@ -181,7 +181,7 @@ __kernel void binary_buf_c4_c16_c16(__private int global_dim0, __private int glo const int dst_width = shape.z + output_pad_left + output_pad_right; const int channe_out_idx = channel_idx >> 2; - const int offset0 = (((batch_idx*channel4+channel_idx)*shape.y+h_idx)*shape.z+w_idx) * 4; + const int offset0 = (((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx) * 4; const int offset1 = (((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*src_width+w_idx+input1_pad_left) * 16 + (channel_idx % 4) * 4; const int dst_offset = (((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width+w_idx+output_pad_left) * 16 + (channel_idx % 4) * 4; @@ -231,7 +231,7 @@ __kernel void binary_buf_c16_c4_c16(__private int global_dim0, __private int glo const int dst_width = shape.z + output_pad_left + output_pad_right; const int channe_out_idx = channel_idx >> 2; - const int offset1 = (((batch_idx*channel4+channel_idx)*shape.y+h_idx)*shape.z+w_idx) * 4; + const int offset1 = (((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx) * 4; const int offset0 = (((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*src_width+w_idx+input0_pad_left) * 16 + (channel_idx % 4) * 4; const int dst_offset = (((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width+w_idx+output_pad_left) * 16 + (channel_idx % 4) * 4; @@ -277,7 +277,7 @@ __kernel void prelu_buf_c4_c4(__private int global_dim0, __private int global_di const int batch_idx = get_global_id(2); const int channel_idx = get_global_id(1); - const int offset0 = (((batch_idx*channel4+channel_idx)*shape.y+h_idx)*shape.z+w_idx) * 4; + const int offset0 = (((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx) * 4; const int offset1 = channel_idx * 4; float4 in0 = convert_float4(vload4(0, input0 + offset0)); @@ -304,7 +304,7 @@ __kernel void prelu_buf_c4_c16(__private int global_dim0, __private int global_d const int dst_width = shape.z + output_pad_left + output_pad_right; const int channe_out_idx = channel_idx >> 2; - const int offset0 = (((batch_idx*channel4+channel_idx)*shape.y+h_idx)*shape.z+w_idx) * 4; + const int offset0 = (((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx) * 4; const int offset1 = channel_idx * 4; const int offset = (((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width+w_idx+output_pad_left) * 16 + (channel_idx % 4) * 4; @@ -385,11 +385,11 @@ __kernel void prelu_buf_c16_c4(__private int global_dim0, __private int global_d const int channel_idx = get_group_id(1); const int sglid = get_sub_group_local_id(); const int src_width = shape.z + input0_pad_left + input0_pad_right; - const int width_height = shape.z * shape.y * 4; + const int batch_width_height = shape.x * shape.z * shape.y * 4; const int offset0 = (((batch_idx*channel16+channel_idx)*shape.y+h_idx)*src_width+w_idx+input0_pad_left) * 16; const int offset1 = channel_idx * 16; - const int offset = (((batch_idx*channel4+(channel_idx<<2))*shape.y+h_idx)*shape.z+w_idx) * 4; + const int offset = (((batch_idx+(channel_idx<<2)*shape.x)*shape.y+h_idx)*shape.z+w_idx) * 4; float4 in0 = convert_float4(AS_INPUT_DATA4(INTEL_SUB_GROUP_READ4((__global INTEL_DATA*)(input0 + offset0)))); float4 in1 = (float4)(AS_INPUT_DATA(INTEL_SUB_GROUP_READ((__global INTEL_DATA*)(input1 + offset1)))); @@ -400,7 +400,7 @@ __kernel void prelu_buf_c16_c4(__private int global_dim0, __private int global_d const int lid_y = sglid / 4; int block_size = w_idx + 4 > shape.z ? (shape.z % 4) : 4; for (int i = 0; i < block_size; i++) { - output[offset + i * 4 + lid_y * width_height + lid_x] = (OUTPUT_TYPE)out[i]; + output[offset + i * 4 + lid_y * batch_width_height + lid_x] = (OUTPUT_TYPE)out[i]; } } @@ -478,11 +478,11 @@ __kernel void binary_buf_c16_c16_c4(__private int global_dim0, __private int glo const int sglid = get_sub_group_local_id(); const int src0_width = shape.z + input0_pad_left + input0_pad_right; const int src1_width = shape.z + input1_pad_left + input1_pad_right; - const int width_height = shape.z * shape.y * 4; + const int batch_width_height = shape.x * shape.z * shape.y * 4; const int offset0 = (((batch_idx*channel16+channel_idx)*shape.y+h_idx)*src0_width+w_idx+input0_pad_left) * 16; const int offset1 = (((batch_idx*channel16+channel_idx)*shape.y+h_idx)*src1_width+w_idx+input1_pad_left) * 16; - const int offset = (((batch_idx*channel4+(channel_idx << 2))*shape.y+h_idx)*shape.z+w_idx) * 4; + const int offset = (((batch_idx+(channel_idx << 2)*shape.x)*shape.y+h_idx)*shape.z+w_idx) * 4; float4 in0 = isFull.x ? convert_float4(AS_INPUT_DATA4(INTEL_SUB_GROUP_READ4((__global INTEL_DATA*)(input0 + offset0)))) : (float4)(input0[0]); float4 in1 = isFull.y ? convert_float4(AS_INPUT_DATA4(INTEL_SUB_GROUP_READ4((__global INTEL_DATA*)(input1 + offset1)))) : (float4)(input1[0]); @@ -496,6 +496,6 @@ __kernel void binary_buf_c16_c16_c4(__private int global_dim0, __private int glo const int lid_y = sglid / 4; int block_size = w_idx + 4 > shape.z ? (shape.z % 4) : 4; for (int i = 0; i < block_size; i++) { - output[offset + i * 4 + lid_y * width_height + lid_x] = (OUTPUT_TYPE)out[i]; + output[offset + i * 4 + lid_y * batch_width_height + lid_x] = (OUTPUT_TYPE)out[i]; } } diff --git a/source/backend/opencl/execution/cl/buffer_convert_buf.cl b/source/backend/opencl/execution/cl/buffer_convert_buf.cl index 1563c65ea..6a4b4e220 100644 --- a/source/backend/opencl/execution/cl/buffer_convert_buf.cl +++ b/source/backend/opencl/execution/cl/buffer_convert_buf.cl @@ -7,231 +7,71 @@ if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { \ return; \ } -// convert data from buffer(nhwc) to buffer(nc4hw4) -__kernel void nhwc_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS - __global const INPUT_TYPE *input_ptr, - __private const int height, - __private const int width, __private const int channels, - __global OUTPUT_TYPE *output) { - int image_width_idx = get_global_id(0); - int image_height_idx = get_global_id(1); - - DEAL_NON_UNIFORM_DIM2(image_width_idx, image_height_idx); - - const int batch_idx = image_height_idx / height; - const int height_idx = image_height_idx % height; - const int width_idx = image_width_idx % width; - const int channel_4_idx = (image_width_idx / width) << 2; - const int buffer_offset = ((batch_idx * height + height_idx) * width + width_idx) * channels + channel_4_idx; - - const int remain_channel = channels - channel_4_idx; - float4 values = convert_float4(vload4(0, input_ptr + buffer_offset)); - - if (remain_channel == 3) { - values.w = 0; - } else if (remain_channel == 2) { - values.z = 0; - values.w = 0; - } else if (remain_channel == 1) { - values.y = 0; - values.z = 0; - values.w = 0; - } - const int out_offset = (((batch_idx * ((channels+3)/4) + channel_4_idx/4) * height + height_idx) * width + width_idx)*4; - vstore4(CONVERT_OUTPUT4(values), 0, output+out_offset); -} - -// convert data from buffer(nchw) to buffer(nc4hw4) -__kernel void nchw_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS - __global const INPUT_TYPE *input_ptr, - __private const int height, __private const int width, __private const int channels, - __global OUTPUT_TYPE *output) { - int image_width_idx = get_global_id(0); - int image_height_idx = get_global_id(1); - - DEAL_NON_UNIFORM_DIM2(image_width_idx, image_height_idx); - - const int batch_idx = image_height_idx / height; - const int height_idx = image_height_idx % height; - const int width_idx = image_width_idx % width; - const int channel_4_idx = image_width_idx / width << 2; - const int buffer_offset = ((batch_idx * channels + channel_4_idx) * height + height_idx) * width + width_idx; - - const int remain_channel = channels - channel_4_idx; - const int height_width_size = height * width; - float4 output_values = 0; - - if (remain_channel >= 4) { - int offset = buffer_offset; - output_values.x = (float)*(input_ptr + offset); - offset += height_width_size; - output_values.y = (float)*(input_ptr + offset); - offset += height_width_size; - output_values.z = (float)*(input_ptr + offset); - offset += height_width_size; - output_values.w = (float)*(input_ptr + offset); - } else if (remain_channel == 3) { - int offset = buffer_offset; - output_values.x = (float)*(input_ptr + offset); - offset += height_width_size; - output_values.y = (float)*(input_ptr + offset); - offset += height_width_size; - output_values.z = (float)*(input_ptr + offset); - } else if (remain_channel == 2) { - int offset = buffer_offset; - output_values.x = (float)*(input_ptr + offset); - offset += height_width_size; - output_values.y = (float)*(input_ptr + offset); - } else if (remain_channel == 1) { - int offset = buffer_offset; - output_values.x = (float)*(input_ptr + offset); +#define GLOBAL_SIZE_3_DIMS __private const int global_size_dim0, __private const int global_size_dim1, __private const int global_size_dim2, +#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) \ + if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { \ + return; \ } - const int out_offset = (((batch_idx * ((channels+3)/4) + channel_4_idx/4) * height + height_idx) * width + width_idx)*4; - vstore4(CONVERT_OUTPUT4(output_values), 0, output+out_offset); -} - - -__kernel void nchw_buffer_to_nchw_buffer(GLOBAL_SIZE_2_DIMS - __global INPUT_TYPE *input_ptr, - __private const int height, __private const int width, __private const int channels, - __private const int input_pad_left, __private const int input_pad_right, - __private const int output_pad_left, __private const int output_pad_right, - __global OUTPUT_TYPE *output) { - int image_width_idx = get_global_id(0); - int image_height_idx = get_global_id(1); - - DEAL_NON_UNIFORM_DIM2(image_width_idx, image_height_idx); - - const int src_width = width + input_pad_left + input_pad_right; - const int dst_width = width + output_pad_left + output_pad_right; - const int batch_idx = image_height_idx / height; - const int height_idx = image_height_idx % height; - const int width_idx = image_width_idx % width; - const int channel_idx = image_width_idx / width; - const int in_offset = ((batch_idx * channels + channel_idx) * height + height_idx) * src_width + width_idx + input_pad_left; - const int out_offset = ((batch_idx * channels + channel_idx) * height + height_idx) * dst_width + width_idx + output_pad_left; - - output[out_offset] = (OUTPUT_TYPE)input_ptr[in_offset]; -} - -// convert data from image(b h, ic/4 w ic4) to buffer(nhwc) -__kernel void nc4hw4_buffer_to_nhwc_buffer(GLOBAL_SIZE_2_DIMS - __global OUTPUT_TYPE *output, - __private const int height, __private const int width, - __private const int channels, - __global INPUT_TYPE *input_ptr) { - int image_width_idx = get_global_id(0); - int image_height_idx = get_global_id(1); - - DEAL_NON_UNIFORM_DIM2(image_width_idx, image_height_idx); +#define MNN_DATA_FORMAT_NCHW 0 +#define MNN_DATA_FORMAT_NHWC 1 +#define MNN_DATA_FORMAT_NC4HW4 2 +#define MNN_DATA_FORMAT_C4NHW4 3 +__kernel void buffer_convert_to_buffer(GLOBAL_SIZE_3_DIMS + __global const INPUT_TYPE *input_ptr, + __private const int4 shape, // N C H W + __global OUTPUT_TYPE *output_ptr +) { - const int batch_idx = image_height_idx / height; - const int height_idx = image_height_idx % height; - const int width_idx = image_width_idx % width; - const int channel_4_idx = (image_width_idx / width) << 2; - const int buffer_offset = ((batch_idx * height + height_idx) * width + width_idx) * channels + channel_4_idx; + int wh = get_global_id(0); + int c = get_global_id(1); + int n = get_global_id(2); - const int in_offset = (((batch_idx * ((channels+3)/4) + channel_4_idx/4) * height + height_idx) * width + width_idx)*4; + DEAL_NON_UNIFORM_DIM3(wh, c, n); + int w = wh % shape.w; + int h = wh / shape.w; - float4 values = convert_float4(vload4(0, input_ptr+in_offset)); - const int remain_channel = channels - channel_4_idx; - if (remain_channel >= 4) { - vstore4(CONVERT_OUTPUT4(values), 0, output + buffer_offset); - } else if (remain_channel == 3) { - int offset = buffer_offset; - output[offset] = (OUTPUT_TYPE)values.x; - offset++; - output[offset] = (OUTPUT_TYPE)values.y; - offset++; - output[offset] = (OUTPUT_TYPE)values.z; - } else if (remain_channel == 2) { - int offset = buffer_offset; - output[offset] = (OUTPUT_TYPE)values.x; - offset++; - output[offset] = (OUTPUT_TYPE)values.y; - } else if (remain_channel == 1) { - int offset = buffer_offset; - output[offset] = (OUTPUT_TYPE)values.x; - } -} +#if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW + int input_offset = ((n * shape.y + c) * shape.z + h) * shape.w + w; +#elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC + int input_offset = ((n * shape.z + h) * shape.w + w) * shape.y + c; +#elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4 + int input_offset = ((((c / 4) * shape.x + n) * shape.z + h) * shape.w + w) * 4 + (c % 4); +#endif -// convert data from buffer(nc4hw4) to buffer(nchw) -__kernel void nc4hw4_buffer_to_nchw_buffer(GLOBAL_SIZE_2_DIMS - __global OUTPUT_TYPE *output, - __private const int height, __private const int width, - __private const int channels, - __global INPUT_TYPE *input_ptr) { - int image_width_idx = get_global_id(0); - int image_height_idx = get_global_id(1); - - DEAL_NON_UNIFORM_DIM2(image_width_idx, image_height_idx); +#if OUTPUT_FORMAT == MNN_DATA_FORMAT_NCHW + int output_offset = ((n * shape.y + c) * shape.z + h) * shape.w + w; +#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NHWC + int output_offset = ((n * shape.z + h) * shape.w + w) * shape.y + c; +#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4 + int output_offset = ((((c / 4) * shape.x + n) * shape.z + h) * shape.w + w) * 4 + (c % 4); +#endif - const int batch_idx = image_height_idx / height; - const int height_idx = image_height_idx % height; - const int width_idx = image_width_idx % width; - int channel_4_idx = (image_width_idx / width) * 4; - int buffer_offset = ((batch_idx * channels + channel_4_idx) * height + height_idx) * width + width_idx; - - const int in_offset = (((batch_idx * ((channels+3)/4) + channel_4_idx/4) * height + height_idx) * width + width_idx)*4; - float4 values = convert_float4(vload4(0, input_ptr+in_offset)); - - const int height_width_size = height * width; - - const int remain_channel = channels - channel_4_idx; - - if (remain_channel >= 4) { - int offset = buffer_offset; - output[offset] = (OUTPUT_TYPE)values.x; - offset += height_width_size; - output[offset] = (OUTPUT_TYPE)values.y; - offset += height_width_size; - output[offset] = (OUTPUT_TYPE)values.z; - offset += height_width_size; - output[offset] = (OUTPUT_TYPE)values.w; - } else if (remain_channel == 3) { - int offset = buffer_offset; - output[offset] = (OUTPUT_TYPE)values.x; - offset += height_width_size; - output[offset] = (OUTPUT_TYPE)values.y; - offset += height_width_size; - output[offset] = (OUTPUT_TYPE)values.z; - } else if (remain_channel == 2) { - int offset = buffer_offset; - output[offset] = (OUTPUT_TYPE)values.x; - offset += height_width_size; - output[offset] = (OUTPUT_TYPE)values.y; - } else if (remain_channel == 1) { - int offset = buffer_offset; - output[offset] = (OUTPUT_TYPE)values.x; - } + output_ptr[output_offset] = input_ptr[input_offset]; } -__kernel void nc4hw4_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS +__kernel void buffer_copy_to_buffer(GLOBAL_SIZE_2_DIMS __global const INPUT_TYPE *input_ptr, - __private const int2 output_shape, - __private const int2 src_stride, - __private const int2 dst_stride, - __global OUTPUT_TYPE *output + __global OUTPUT_TYPE *output_ptr, + __private const int size // N C H W ) { - int image_width_idx = get_global_id(0); - int image_height_idx = get_global_id(1); + const int x = get_global_id(0); + const int y = get_global_id(1); - DEAL_NON_UNIFORM_DIM2(image_width_idx, image_height_idx); - - const int batch_idx = image_height_idx / output_shape.x; - const int height_idx = image_height_idx % output_shape.x; - const int width_idx = image_width_idx % output_shape.y; - const int channel_block_idx = image_width_idx / output_shape.y; - int2 src_bc_offset = src_stride * (int2)(batch_idx, channel_block_idx); - int2 dst_bc_offset = dst_stride * (int2)(batch_idx, channel_block_idx); - int src_buffer_offset = - (((src_bc_offset.x + src_bc_offset.y) * output_shape.x + height_idx) * output_shape.y + width_idx) * 4; - int dst_buffer_offset = - (((dst_bc_offset.x + dst_bc_offset.y) * output_shape.x + height_idx) * output_shape.y + width_idx) * 4; - - vstore4(CONVERT_OUTPUT4(vload4(0, input_ptr + src_buffer_offset)), 0, output+dst_buffer_offset); + DEAL_NON_UNIFORM_DIM2(x, y); + const int offset = x << 2; +#ifdef PACK_LEAVE + if(offset + 3 >= size){ + for(int i = 0; i < size - offset; ++i){ + output_ptr[offset + i] = (OUTPUT_TYPE)input_ptr[offset + i]; + } + } else { +#endif + vstore4(CONVERT_OUTPUT4(vload4(0, input_ptr+offset)), 0, output_ptr+offset); +#ifdef PACK_LEAVE + } +#endif } // convert kernel : from buffer(oihw) to image(oc/4 h w , ic oc4) diff --git a/source/backend/opencl/execution/cl/buffer_convert_quant.cl b/source/backend/opencl/execution/cl/buffer_convert_quant.cl index 5043e1418..8062b1a49 100644 --- a/source/backend/opencl/execution/cl/buffer_convert_quant.cl +++ b/source/backend/opencl/execution/cl/buffer_convert_quant.cl @@ -155,28 +155,29 @@ __kernel void conv2d_1x1_weight_quant_image(GLOBAL_SIZE_2_DIMS __private const int input_channel, __private const int output_channel) { - int x = get_global_id(0); // ic / 16 + int x = get_global_id(0); // ic / 32 int y = get_global_id(1); // oc DEAL_NON_UNIFORM_DIM2(x, y); - const int xin = x << 4; #ifdef USE_LOW_BIT_WEIGHT_INT4 + const int xin = x << 5; #ifdef CHANNEL_LEAVE - uchar8 out = 0; + uchar16 out = 0; uchar *out_ptr = (uchar*)&out; - for(int i = 0; i < 8; ++i){ + for(int i = 0; i < 16; ++i){ int index0 = y * input_channel + xin + i * 2; int index1 = y * input_channel + xin + i * 2 + 1; uchar s0 = input_ptr[index0/2]; uchar s1 = input_ptr[index1/2]; out_ptr[i] = ((index0 % 2) == 0 ? (s0 & 0xf0) : (s0 << 4)) | ((index1 % 2) == 0 ? (s1 >> 4) : (s1 & 0x0f)); } - write_imageui(output, (int2)(y, x), convert_uint4(as_ushort4(out))); + write_imagei(output, (int2)(y, x), as_int4(out)); #else const int inputOffset = (y * input_channel + xin)/2; - write_imageui(output, (int2)(y, x), convert_uint4(as_ushort4(vload8(0, input_ptr + inputOffset)))); + write_imagei(output, (int2)(y, x), as_int4(vload16(0, input_ptr + inputOffset))); #endif #else + const int xin = x << 4; const int inputOffset = y * input_channel + xin; write_imagei(output, (int2)(y, x), as_int4(vload16(0, input_ptr + inputOffset))); #endif @@ -205,7 +206,6 @@ __kernel void conv2d_1x1_ic_oc_weight_quant_buffer(GLOBAL_SIZE_2_DIMS #ifdef USE_LOW_BIT_WEIGHT_INT4 const int inputOffset = (yin * input_channel + xin) / 2; const int outputOffset = ((x * outputChannelC4 + y) * icPack * ocPack) / 2; -#ifdef CHANNEL_LEAVE for(int i = 0; i < icPack; ++i){ for(int j = 0; j < ocPack / 2; ++j){ int index0 = (yin + j * 2) * input_channel + xin + i; @@ -217,18 +217,6 @@ __kernel void conv2d_1x1_ic_oc_weight_quant_buffer(GLOBAL_SIZE_2_DIMS output_ptr[outputOffset + i * (ocPack / 2) + j] = s0 | s1; } } -#else - for(int i = 0; i < icPack/2; ++i){ - for(int j = 0; j < ocPack / 2; ++j){ - char s0 = input_ptr[inputOffset + (j * 2) * (input_channel / 2) + i]; - char s1 = input_ptr[inputOffset + (j * 2 + 1) * (input_channel / 2) + i]; - char d0 = (s0 & 0xf0) | ((s1 & 0xf0) >> 4); - char d1 = ((s0 & 0x0f) << 4) | (s1 & 0x0f); - output_ptr[outputOffset + (i * 2) * (ocPack / 2) + j] = d0; - output_ptr[outputOffset + (i * 2 + 1) * (ocPack / 2) + j] = d1; - } - } -#endif #else const int inputOffset = yin * input_channel + xin; const int outputOffset = (x * outputChannelC4 + y) * icPack * ocPack; diff --git a/source/backend/opencl/execution/cl/cast_buf.cl b/source/backend/opencl/execution/cl/cast_buf.cl index d9c1fb8e7..247071e8b 100644 --- a/source/backend/opencl/execution/cl/cast_buf.cl +++ b/source/backend/opencl/execution/cl/cast_buf.cl @@ -2,36 +2,46 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable #endif -#define GLOBAL_SIZE_3_DIMS \ -__private const int global_size_dim0, __private const int global_size_dim1, __private const int global_size_dim2, +#define GLOBAL_SIZE_2_DIMS \ +__private const int global_size_dim0, __private const int global_size_dim1, -#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) \ - if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { \ +#define DEAL_NON_UNIFORM_DIM2(input1, input2) \ + if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { \ return; \ } -__kernel void cast_buf(GLOBAL_SIZE_3_DIMS +__kernel void cast_buf(GLOBAL_SIZE_2_DIMS __global INPUT_TYPE* input, __global OUTPUT_TYPE* output, - __private const int width, - __private const int height, - __private const int channelBlock + __private const int size ) { - const int width_idx = get_global_id(0); - const int height_idx = get_global_id(1); - const int batch_channel_idx = get_global_id(2); + const int idx = get_global_id(0); + const int idy = get_global_id(1); - DEAL_NON_UNIFORM_DIM3(width_idx, height_idx, batch_channel_idx); - - const int batch_idx = batch_channel_idx / channelBlock; - const int channel_idx = batch_channel_idx % channelBlock; - - const int inp_offset = ((((batch_idx * channelBlock) + channel_idx) * height + height_idx) * width + width_idx)*4; -#ifdef TO_BOOL - int4 value = convert_int4(vload4(0, input + inp_offset)); - value = value == (int4)0 ? (int4)0 : (int4)1; - vstore4(CONVERT_OUTPUT4(value), 0, output + inp_offset); -#else - vstore4(CONVERT_OUTPUT4(vload4(0, input + inp_offset)), 0, output + inp_offset); + DEAL_NON_UNIFORM_DIM2(idx, idy); + const int inp_offset = idx * 4; +#ifdef PACK_LEAVE + if(inp_offset + 3 >= size){ + int remain = size - inp_offset; + for(int i = 0; i < remain; ++i){ + #ifdef TO_BOOL + int value = (int)input[inp_offset + i]; + value = value == 0 ? 0 : 1; + output[inp_offset + i] = (OUTPUT_TYPE)value; + #else + output[inp_offset + i] = (OUTPUT_TYPE)input[inp_offset + i]; + #endif + } + }else { +#endif + #ifdef TO_BOOL + int4 value = convert_int4(vload4(0, input + inp_offset)); + value = value == (int4)0 ? (int4)0 : (int4)1; + vstore4(CONVERT_OUTPUT4(value), 0, output + inp_offset); + #else + vstore4(CONVERT_OUTPUT4(vload4(0, input + inp_offset)), 0, output + inp_offset); + #endif +#ifdef PACK_LEAVE + } #endif } diff --git a/source/backend/opencl/execution/cl/conv_2d_buf.cl b/source/backend/opencl/execution/cl/conv_2d_buf.cl index 9aed2e670..07f8d96fe 100644 --- a/source/backend/opencl/execution/cl/conv_2d_buf.cl +++ b/source/backend/opencl/execution/cl/conv_2d_buf.cl @@ -9,6 +9,77 @@ return; \ } +#ifdef CONV_LOCAL_SIZE +__kernel +void conv_2d_1x1_local(__private const int out_w_blocks, + __global const FLOAT *input, + __global const FLOAT *kernel_ptr, + __global const FLOAT *bias_ptr, + __global FLOAT *output, + __private const int in_c_block, + __private const int batch, + __private const int out_h, + __private const int out_w, + __private const int out_c_block, + __private const int out_c_pack) { + + const int lid = get_local_id(0); + const int out_c_w_idx = get_global_id(1); //c/4 w + const int out_b_h_idx = get_global_id(2); //b h + + COMPUTE_FLOAT4 local sum[CONV_LOCAL_SIZE]; + + const int out_c_idx = out_c_w_idx / out_w_blocks; + const int out_w_idx = out_c_w_idx % out_w_blocks; + const int out_b_idx = out_b_h_idx / out_h; // equal to in_b_idx + const int out_h_idx = out_b_h_idx % out_h; // equal to in_h_idx + + COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias_ptr)); + COMPUTE_FLOAT4 out0 = (COMPUTE_FLOAT4)0; + + int offset = out_c_idx*4; + int inp_offset = (((out_b_idx+in_c_block*batch)*out_h + out_h_idx)* out_w + out_w_idx) << 2; + + const int inp_add = batch*out_h*out_w*4; + for (ushort in_channel_block_idx = lid; in_channel_block_idx < in_c_block; in_channel_block_idx+=CONV_LOCAL_SIZE) { + + int offset = mad24(in_channel_block_idx*4, out_c_pack, out_c_idx*4); + + COMPUTE_FLOAT4 in0 = CONVERT_COMPUTE_FLOAT4(vload4(0, input+inp_offset+in_channel_block_idx*inp_add)); + COMPUTE_FLOAT4 weights0 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset)); + COMPUTE_FLOAT4 weights1 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack)); + COMPUTE_FLOAT4 weights2 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack + out_c_pack)); + COMPUTE_FLOAT4 weights3 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack + out_c_pack + out_c_pack)); + + out0 = mad(in0.x, weights0, out0); + out0 = mad(in0.y, weights1, out0); + out0 = mad(in0.z, weights2, out0); + out0 = mad(in0.w, weights3, out0); + } + + sum[lid] = out0; + barrier(CLK_LOCAL_MEM_FENCE); + for(int i = CONV_LOCAL_SIZE/2; i > 0; i /= 2){ + if (lid < i) + sum[lid] = sum[lid] + sum[lid + i]; + barrier(CLK_LOCAL_MEM_FENCE); + } + out0 = sum[0] + bias0; + if(lid == 0){ +#ifdef RELU + out0 = fmax(out0, (COMPUTE_FLOAT4)0); +#endif + +#ifdef RELU6 + out0 = clamp(out0, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); +#endif + + const int out_offset = (((out_b_idx + out_c_idx*batch)*out_h + out_h_idx)* out_w + out_w_idx)*4; + vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); + } +} +#endif + __kernel void conv_2d_1x1_c4h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __global const FLOAT *input, @@ -18,6 +89,7 @@ void conv_2d_1x1_c4h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __private const int in_c_block, __private const int out_h, __private const int out_w, + __private const int out_b, __private const int out_c_block, __private const int out_c_pack) { @@ -38,15 +110,11 @@ void conv_2d_1x1_c4h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, COMPUTE_FLOAT4 out3 = out0; const int intput_width_idx0 = out_w4_idx; - - + int inp_offset = ((out_b_idx * out_h + out_h_idx)* out_w + intput_width_idx0) << 2; int offset = out_c_idx*4; - int inp_offset = (((out_b_idx*in_c_block)*out_h + out_h_idx)* out_w + intput_width_idx0) << 2; - - const int inp_add = out_h*out_w*4; + const int inp_add = out_b*out_h*out_w*4; for (ushort in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) { - int offset = mad24(in_channel_block_idx*4, out_c_pack, out_c_idx*4); COMPUTE_FLOAT4 in0 = CONVERT_COMPUTE_FLOAT4(vload4(0, input+inp_offset)); COMPUTE_FLOAT4 in1 = CONVERT_COMPUTE_FLOAT4(vload4(1, input+inp_offset)); @@ -95,7 +163,7 @@ void conv_2d_1x1_c4h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, out3 = clamp(out3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((out_b_idx*out_c_block + out_c_idx)*out_h + out_h_idx)* out_w + out_w4_idx)*4; + const int out_offset = (((out_b_idx + out_c_idx * out_b)*out_h + out_h_idx)* out_w + out_w4_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_w - out_w4_idx; if (remain >= 4) { @@ -123,6 +191,7 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __private const int in_c_block, __private const int out_h, __private const int out_w, + __private const int out_b, __private const int out_c_block, __private const int out_c_pack) { @@ -148,12 +217,12 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, COMPUTE_FLOAT4 out7 = out4; const int intput_width_idx0 = out_w4_idx; + int inp_offset = ((out_b_idx * out_h + out_h_idx)* out_w + intput_width_idx0)<<2; + int offset = out_c_idx*8; + const int inp_add = out_b*out_h*out_w*4; for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) { - int offset = mad24(in_channel_block_idx*4, out_c_pack, out_c_idx*8); - const int inp_offset = - (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4; COMPUTE_FLOAT4 in0 = CONVERT_COMPUTE_FLOAT4(vload4(0, input+inp_offset)); COMPUTE_FLOAT4 in1 = CONVERT_COMPUTE_FLOAT4(vload4(1, input+inp_offset)); @@ -208,6 +277,9 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, out7 = mad(in3.y, weights3, out7); out7 = mad(in3.z, weights5, out7); out7 = mad(in3.w, weights7, out7); + + offset += 4 * out_c_pack; + inp_offset += inp_add; } #ifdef RELU @@ -234,10 +306,10 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, out7 = clamp(out7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((out_b_idx*out_c_block + out_c_idx*2)*out_h + out_h_idx)* out_w + out_w4_idx)*4; + const int out_offset = (((out_b_idx + out_c_idx*2*out_b)*out_h + out_h_idx)* out_w + out_w4_idx)*4; __global FLOAT * _tempoutput = output + out_offset; - __global FLOAT * _tempoutput1 = _tempoutput + 4*out_h*out_w; + __global FLOAT * _tempoutput1 = _tempoutput + 4*out_h*out_w*out_b; #ifdef BLOCK_LEAVE const int remain = out_w - out_w4_idx; @@ -287,6 +359,7 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __private const int in_c_block, __private const int out_h, __private const int out_w, + __private const int out_b, __private const int out_c_block, __private const int out_c_pack) { @@ -308,11 +381,10 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, COMPUTE_FLOAT4 out5 = out4; const int intput_width_idx0 = out_w2_idx; + int inp_offset = ((out_b_idx * out_h + out_h_idx)* out_w + intput_width_idx0)<<2; + int offset = out_c_idx*8; + const int inp_add = out_b*out_h*out_w*4; for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) { - - int offset = mad24(in_channel_block_idx*4, out_c_pack, out_c_idx*8); - const int inp_offset = - (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4; COMPUTE_FLOAT4 in0 = CONVERT_COMPUTE_FLOAT4(vload4(0, input+inp_offset)); COMPUTE_FLOAT4 in1 = CONVERT_COMPUTE_FLOAT4(vload4(1, input+inp_offset)); @@ -344,6 +416,9 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, out5 = mad(in1.y, weights3, out5); out5 = mad(in1.z, weights5, out5); out5 = mad(in1.w, weights7, out5); + + offset += 4 * out_c_pack; + inp_offset += inp_add; } #ifdef RELU @@ -362,11 +437,11 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, out5 = clamp(out5, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((out_b_idx*out_c_block + out_c_idx*2)*out_h + out_h_idx)* out_w + out_w2_idx)*4; + const int out_offset = (((out_b_idx + out_c_idx*2*out_b)*out_h + out_h_idx)* out_w + out_w2_idx)*4; __global FLOAT * _tempoutput = output + out_offset; - __global FLOAT * _tempoutput1 = _tempoutput + 4*out_h*out_w; + __global FLOAT * _tempoutput1 = _tempoutput + 4*out_h*out_w*out_b; #ifdef BLOCK_LEAVE const int remain = out_w - out_w2_idx; @@ -405,6 +480,7 @@ void conv_2d_1x1_c4h1w1(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __private const int in_c_block, __private const int out_h, __private const int out_w, + __private const int out_b, __private const int out_c_block, __private const int out_c_pack) { @@ -420,12 +496,12 @@ void conv_2d_1x1_c4h1w1(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, COMPUTE_FLOAT4 out0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias_ptr)); const int intput_width_idx0 = out_w_idx; + int offset = out_c_idx*4; + int inp_offset = ((out_b_idx * out_h + out_h_idx) * out_w + intput_width_idx0)*4; + const int inp_add = out_b*out_h*out_w*4; for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) { - int offset = mad24(in_channel_block_idx*4, out_c_pack, out_c_idx*4); - const int inp_offset = - (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4; COMPUTE_FLOAT4 in0 = CONVERT_COMPUTE_FLOAT4(vload4(0, input+inp_offset)); COMPUTE_FLOAT4 weights0 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset)); @@ -437,6 +513,9 @@ void conv_2d_1x1_c4h1w1(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, out0 = mad(in0.y, weights1, out0); out0 = mad(in0.z, weights2, out0); out0 = mad(in0.w, weights3, out0); + + offset += 4 * out_c_pack; + inp_offset += inp_add; } #ifdef RELU @@ -447,7 +526,7 @@ void conv_2d_1x1_c4h1w1(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, out0 = clamp(out0, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((out_b_idx*out_c_block + out_c_idx)*out_h + out_h_idx)* out_w + out_w_idx)*4; + const int out_offset = (((out_b_idx + out_c_idx*out_b)*out_h + out_h_idx)* out_w + out_w_idx)*4; vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); } @@ -462,6 +541,7 @@ void conv_2d_1x1_c4h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __private const int in_c_block, __private const int out_h, __private const int out_w, + __private const int out_b, __private const int out_c_block, __private const int out_c_pack) { @@ -481,12 +561,11 @@ void conv_2d_1x1_c4h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, COMPUTE_FLOAT4 out1 = out0; const int intput_width_idx0 = out_w2_idx; + int offset = out_c_idx*4; + int inp_offset = ((out_b_idx*out_h + out_h_idx)* out_w + intput_width_idx0)*4; + const int inp_add = out_b*out_h*out_w*4; for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) { - - int offset = mad24(in_channel_block_idx*4, out_c_pack, out_c_idx*4); - const int inp_offset = - (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4; COMPUTE_FLOAT4 in0 = CONVERT_COMPUTE_FLOAT4(vload4(0, input+inp_offset)); COMPUTE_FLOAT4 in1 = CONVERT_COMPUTE_FLOAT4(vload4(1, input+inp_offset)); @@ -505,6 +584,9 @@ void conv_2d_1x1_c4h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, out1 = mad(in1.y, weights1, out1); out1 = mad(in1.z, weights2, out1); out1 = mad(in1.w, weights3, out1); + + offset += 4 * out_c_pack; + inp_offset += inp_add; } #ifdef RELU @@ -517,7 +599,7 @@ void conv_2d_1x1_c4h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, out1 = clamp(out1, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((out_b_idx*out_c_block + out_c_idx)*out_h + out_h_idx)* out_w + out_w2_idx)*4; + const int out_offset = (((out_b_idx + out_c_idx*out_b)*out_h + out_h_idx)* out_w + out_w2_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_w - out_w2_idx; @@ -541,6 +623,7 @@ void conv_2d_c4h1w1(GLOBAL_SIZE_2_DIMS __private const int2 in_hw, __private const int inChannel, __private const int in_c_blocks, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 stride_hw, @@ -580,7 +663,7 @@ void conv_2d_c4h1w1(GLOBAL_SIZE_2_DIMS int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + kw_start) * 4; for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) { for(int ix = in_w_idx_start; ix < in_w_idx_end; ix += dilate_hw.y) { - int inp_offset = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + ix) * 4; + int inp_offset = (((out_b_idx + in_c_idx * batch) * in_hw.x + iy) * in_hw.y + ix) * 4; COMPUTE_FLOAT4 in0 = CONVERT_COMPUTE_FLOAT4(vload4(0, input+inp_offset)); const int filter_w_inc = (ix-in_w_idx_start)/dilate_hw.y; @@ -606,8 +689,7 @@ void conv_2d_c4h1w1(GLOBAL_SIZE_2_DIMS #ifdef RELU6 out0 = clamp(out0, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - - const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + const int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); } @@ -621,6 +703,7 @@ void conv_2d_c4h1w2(GLOBAL_SIZE_2_DIMS __private const int2 in_hw, __private const int inChannel, __private const int in_c_blocks, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 stride_hw, @@ -658,7 +741,7 @@ void conv_2d_c4h1w2(GLOBAL_SIZE_2_DIMS int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4; for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) { - const int inp_offset_base = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + 0) * 4; + const int inp_offset_base = (((out_b_idx + in_c_idx*batch) * in_hw.x + iy) * in_hw.y + 0) * 4; for(int fw = 0; fw < filter_hw.y; fw++) { const int in_w0_idx = fw * dilate_hw.y + in_w0_idx_base; @@ -696,7 +779,7 @@ void conv_2d_c4h1w2(GLOBAL_SIZE_2_DIMS out1 = clamp(out1, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + const int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); if(out_w_idx + 1 >= out_hw.y) return; @@ -715,6 +798,7 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __private const int2 in_hw, __private const int inChannel, __private const int in_c_blocks, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 stride_hw, @@ -756,7 +840,7 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4; for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) { - const int inp_offset_base = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + 0) * 4; + const int inp_offset_base = (((out_b_idx + in_c_idx*batch) * in_hw.x + iy) * in_hw.y + 0) * 4; for(int fw = 0; fw < filter_hw.y; fw++) { const int in_w0_idx = fw * dilate_hw.y + in_w0_idx_base; @@ -812,7 +896,7 @@ void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS out3 = clamp(out3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + const int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.y - out_w_idx; @@ -840,6 +924,7 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __private const int2 in_hw, __private const int inChannel, __private const int in_c_blocks, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 stride_hw, @@ -879,7 +964,7 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) { //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] - const int inp_offset_base = (out_b_idx * in_c_blocks + in_c_idx) * in_hw.x * in_hw.y * 4; + const int inp_offset_base = (out_b_idx + in_c_idx*batch) * in_hw.x * in_hw.y * 4; for(int iy = 0; iy < filter_hw.x; iy++) { int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4; @@ -937,7 +1022,7 @@ void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS out3 = clamp(out3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + const int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.x - out_h_idx; if(remain >= 4){ @@ -972,6 +1057,7 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __private const int2 in_hw, __private const int inChannel, __private const int in_c_blocks, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 stride_hw, @@ -1016,7 +1102,7 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) { //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] - const int inp_offset_base = (out_b_idx * in_c_blocks + in_c_idx) * in_hw.x * in_hw.y * 4; + const int inp_offset_base = (out_b_idx + in_c_idx * batch) * in_hw.x * in_hw.y * 4; for(int iy = 0; iy < filter_hw.x; iy++) { int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4; @@ -1107,7 +1193,7 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS out7 = clamp(out7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.x - out_h_idx; if(remain >= 4){ @@ -1125,12 +1211,12 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS }else if(remain == 1){ vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); } -#ifdef CHANNEL_LEAVE + #ifdef CHANNEL_LEAVE if(out_c_idx + 1 >= out_c_blocks){ return; } -#endif - out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + #endif + out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; if(remain >= 4){ vstore4(CONVERT_FLOAT4(out4), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out5), out_hw.y, output+out_offset); @@ -1151,12 +1237,12 @@ void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS vstore4(CONVERT_FLOAT4(out1), out_hw.y, output+out_offset); vstore4(CONVERT_FLOAT4(out2), 2 * out_hw.y, output+out_offset); vstore4(CONVERT_FLOAT4(out3), 3 * out_hw.y, output+out_offset); -#ifdef CHANNEL_LEAVE + #ifdef CHANNEL_LEAVE if(out_c_idx + 1 >= out_c_blocks){ return; } -#endif - out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + #endif + out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; vstore4(CONVERT_FLOAT4(out4), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out5), out_hw.y, output+out_offset); vstore4(CONVERT_FLOAT4(out6), 2 * out_hw.y, output+out_offset); @@ -1173,6 +1259,7 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS __private const int2 in_hw, __private const int inChannel, __private const int in_c_blocks, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 stride_hw, @@ -1212,7 +1299,7 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS for(ushort in_c_idx = 0; in_c_idx < in_c_blocks; in_c_idx++) { //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] - const int inp_offset_base = (out_b_idx * in_c_blocks + in_c_idx) * in_hw.x * in_hw.y * 4; + const int inp_offset_base = (out_b_idx + in_c_idx*batch) * in_hw.x * in_hw.y * 4; for(int iy = 0; iy < filter_hw.x; iy++) { int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4; @@ -1270,7 +1357,7 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS out3 = clamp(out3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.x - out_h_idx; if(remain >= 2){ @@ -1279,12 +1366,12 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS }else if(remain == 1){ vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); } -#ifdef CHANNEL_LEAVE + #ifdef CHANNEL_LEAVE if(out_c_idx + 1 >= out_c_blocks){ return; } -#endif - out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + #endif + out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; if(remain >= 2){ vstore4(CONVERT_FLOAT4(out2), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out3), out_hw.y, output+out_offset); @@ -1294,12 +1381,12 @@ void conv_2d_c8h2w1(GLOBAL_SIZE_2_DIMS #else vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out1), out_hw.y, output+out_offset); -#ifdef CHANNEL_LEAVE + #ifdef CHANNEL_LEAVE if(out_c_idx + 1 >= out_c_blocks){ return; } -#endif - out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + #endif + out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; vstore4(CONVERT_FLOAT4(out2), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out3), out_hw.y, output+out_offset); #endif @@ -1314,6 +1401,7 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int2 in_hw, __private const int inChannel, __private const int in_c_blocks, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 stride_hw, @@ -1361,7 +1449,7 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4; for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) { - const int inp_offset_base = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + 0) * 4; + const int inp_offset_base = (((out_b_idx + in_c_idx * batch) * in_hw.x + iy) * in_hw.y + 0) * 4; for(int fw = 0; fw < filter_hw.y; fw++) { const int in_w0_idx = fw * dilate_hw.y + in_w0_idx_base; @@ -1450,7 +1538,7 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS out7 = clamp(out7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.y - out_w_idx; if(remain >= 4){ @@ -1463,10 +1551,10 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS }else if(remain == 1){ vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); } -#ifdef CHANNEL_LEAVE + #ifdef CHANNEL_LEAVE if(out_c_idx + 1 >= out_c_blocks)return; -#endif - out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + #endif + out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; if(remain >= 4){ vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4, out5, out6, out7)), 0, output+out_offset); }else if(remain == 3){ @@ -1479,10 +1567,10 @@ void conv_2d_c8h1w4(GLOBAL_SIZE_2_DIMS } #else vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0, out1, out2, out3)), 0, output+out_offset); -#ifdef CHANNEL_LEAVE + #ifdef CHANNEL_LEAVE if(out_c_idx + 1 >= out_c_blocks)return; -#endif - out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + #endif + out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4, out5, out6, out7)), 0, output+out_offset); #endif } diff --git a/source/backend/opencl/execution/cl/conv_2d_c16_subgroup_buf.cl b/source/backend/opencl/execution/cl/conv_2d_c16_subgroup_buf.cl index a64f6d9ab..2167f8014 100644 --- a/source/backend/opencl/execution/cl/conv_2d_c16_subgroup_buf.cl +++ b/source/backend/opencl/execution/cl/conv_2d_c16_subgroup_buf.cl @@ -48,6 +48,7 @@ __kernel void conv_2d_buf_subgroup_c16_c4_b2( __private const int output_width, __private const int output_height, __private const int output_channel, + __private const int batch, __private const int x_blocks, __private const int input_pad_left, __private const int input_pad_right, @@ -82,10 +83,10 @@ __kernel void conv_2d_buf_subgroup_c16_c4_b2( const uint output_x_pitch = 4; const uint output_y_pitch = output_x_pitch * output_width; const uint output_fs_pitch = output_y_pitch * output_height; - const uint output_b_pitch = output_fs_pitch * ((output_channel + 3) / 4); + const uint output_b_pitch = output_fs_pitch * batch; - const uint output_offset = b * output_b_pitch + - (feature_block << 2) * output_fs_pitch + + const uint output_offset = b * output_fs_pitch + + (feature_block << 2) * output_b_pitch + y * output_y_pitch + x * output_x_pitch; @@ -242,13 +243,13 @@ __kernel void conv_2d_buf_subgroup_c16_c4_b2( if ((feature_block+1)*16 >= output_channel) { for (int i = 0; i < 2 && (x + i) < output_width; i++) { if ((feature_block*16 + lid_y * 4 + lid_x < output_channel)) - output[output_offset + lid_y * output_fs_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; + output[output_offset + lid_y * output_b_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; } } else { for (int i = 0; i < 2 && (x + i) < output_width; i++) { - output[output_offset + lid_y * output_fs_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; + output[output_offset + lid_y * output_b_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; } } #if SLM_DIV_FACTOR > 1 @@ -269,6 +270,7 @@ __kernel void conv_2d_buf_subgroup_c16_c4_b4( __private const int output_width, __private const int output_height, __private const int output_channel, + __private const int batch, __private const int x_blocks, __private const int input_pad_left, __private const int input_pad_right, @@ -303,10 +305,10 @@ __kernel void conv_2d_buf_subgroup_c16_c4_b4( const uint output_x_pitch = 4; const uint output_y_pitch = output_x_pitch * output_width; const uint output_fs_pitch = output_y_pitch * output_height; - const uint output_b_pitch = output_fs_pitch * ((output_channel + 3) / 4); + const uint output_b_pitch = output_fs_pitch * batch; - const uint output_offset = b * output_b_pitch + - (feature_block << 2) * output_fs_pitch + + const uint output_offset = b * output_fs_pitch + + (feature_block << 2) * output_b_pitch + y * output_y_pitch + x * output_x_pitch; @@ -463,13 +465,13 @@ __kernel void conv_2d_buf_subgroup_c16_c4_b4( if ((feature_block+1)*16 >= output_channel) { for (int i = 0; i < 4 && (x + i) < output_width; i++) { if ((feature_block*16 + lid_y * 4 + lid_x < output_channel)) - output[output_offset + lid_y * output_fs_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; + output[output_offset + lid_y * output_b_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; } } else { for (int i = 0; i < 4 && (x + i) < output_width; i++) { - output[output_offset + lid_y * output_fs_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; + output[output_offset + lid_y * output_b_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; } } #if SLM_DIV_FACTOR > 1 @@ -490,6 +492,7 @@ __kernel void conv_2d_buf_subgroup_c16_c4_b8( __private const int output_width, __private const int output_height, __private const int output_channel, + __private const int batch, __private const int x_blocks, __private const int input_pad_left, __private const int input_pad_right, @@ -524,10 +527,10 @@ __kernel void conv_2d_buf_subgroup_c16_c4_b8( const uint output_x_pitch = 4; const uint output_y_pitch = output_x_pitch * output_width; const uint output_fs_pitch = output_y_pitch * output_height; - const uint output_b_pitch = output_fs_pitch * ((output_channel + 3) / 4); + const uint output_b_pitch = output_fs_pitch * batch; - const uint output_offset = b * output_b_pitch + - (feature_block << 2) * output_fs_pitch + + const uint output_offset = b * output_fs_pitch + + (feature_block << 2) * output_b_pitch + y * output_y_pitch + x * output_x_pitch; @@ -684,13 +687,13 @@ __kernel void conv_2d_buf_subgroup_c16_c4_b8( if ((feature_block+1)*16 >= output_channel) { for (int i = 0; i < 8 && (x + i) < output_width; i++) { if ((feature_block*16 + lid_y * 4 + lid_x < output_channel)) - output[output_offset + lid_y * output_fs_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; + output[output_offset + lid_y * output_b_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; } } else { for (int i = 0; i < 8 && (x + i) < output_width; i++) { - output[output_offset + lid_y * output_fs_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; + output[output_offset + lid_y * output_b_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; } } #if SLM_DIV_FACTOR > 1 @@ -711,6 +714,7 @@ __kernel void conv_2d_buf_subgroup_c16_c16_b2( __private const int output_width, __private const int output_height, __private const int output_channel, + __private const int batch, __private const int x_blocks, __private const int input_pad_left, __private const int input_pad_right, @@ -944,6 +948,7 @@ __kernel void conv_2d_buf_subgroup_c16_c16_b4( __private const int output_width, __private const int output_height, __private const int output_channel, + __private const int batch, __private const int x_blocks, __private const int input_pad_left, __private const int input_pad_right, @@ -1177,6 +1182,7 @@ __kernel void conv_2d_buf_subgroup_c16_c16_b8( __private const int output_width, __private const int output_height, __private const int output_channel, + __private const int batch, __private const int x_blocks, __private const int input_pad_left, __private const int input_pad_right, diff --git a/source/backend/opencl/execution/cl/conv_2d_c1_subgroup_buf.cl b/source/backend/opencl/execution/cl/conv_2d_c1_subgroup_buf.cl index 6e4f81324..2e40d99e4 100644 --- a/source/backend/opencl/execution/cl/conv_2d_c1_subgroup_buf.cl +++ b/source/backend/opencl/execution/cl/conv_2d_c1_subgroup_buf.cl @@ -47,6 +47,7 @@ __kernel void conv_2d_buf_subgroup_c1_c4_b2( __private const int output_width, __private const int output_height, __private const int output_channel, + __private const int batch, __private const int x_blocks, __private const int input_pad_left, __private const int input_pad_right, @@ -80,11 +81,11 @@ __kernel void conv_2d_buf_subgroup_c1_c4_b2( const uint output_x_pitch = 4; const uint output_y_pitch = output_x_pitch * output_width; const uint output_fs_pitch = output_y_pitch * output_height; - const uint output_b_pitch = output_fs_pitch * output_pack; + const uint output_b_pitch = output_fs_pitch * batch; - const uint output_offset = b * output_b_pitch + - f_block * 4 * output_fs_pitch + + const uint output_offset = b * output_fs_pitch + + f_block * 4 * output_b_pitch + y * output_y_pitch + x * output_x_pitch; @@ -160,13 +161,13 @@ __kernel void conv_2d_buf_subgroup_c1_c4_b2( if ((f_block+1)*16 >= output_channel) { for (int i = 0; i < 2 && (x + i) < output_width; i++) { if ((f_block*16 + lid_y * 4 < output_pack * 4)) - output[output_offset + lid_y * output_fs_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; + output[output_offset + lid_y * output_b_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; } } else { for (int i = 0; i < 2 && (x + i) < output_width; i++) { - output[output_offset + lid_y * output_fs_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; + output[output_offset + lid_y * output_b_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; } } } @@ -184,6 +185,7 @@ __kernel void conv_2d_buf_subgroup_c1_c4_b4( __private const int output_width, __private const int output_height, __private const int output_channel, + __private const int batch, __private const int x_blocks, __private const int input_pad_left, __private const int input_pad_right, @@ -217,11 +219,11 @@ __kernel void conv_2d_buf_subgroup_c1_c4_b4( const uint output_x_pitch = 4; const uint output_y_pitch = output_x_pitch * output_width; const uint output_fs_pitch = output_y_pitch * output_height; - const uint output_b_pitch = output_fs_pitch * output_pack; + const uint output_b_pitch = output_fs_pitch * batch; - const uint output_offset = b * output_b_pitch + - f_block * 4 * output_fs_pitch + + const uint output_offset = b * output_fs_pitch + + f_block * 4 * output_b_pitch + y * output_y_pitch + x * output_x_pitch; @@ -297,13 +299,13 @@ __kernel void conv_2d_buf_subgroup_c1_c4_b4( if ((f_block+1)*16 >= output_channel) { for (int i = 0; i < 4 && (x + i) < output_width; i++) { if ((f_block*16 + lid_y * 4 < output_pack * 4)) - output[output_offset + lid_y * output_fs_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; + output[output_offset + lid_y * output_b_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; } } else { for (int i = 0; i < 4 && (x + i) < output_width; i++) { - output[output_offset + lid_y * output_fs_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; + output[output_offset + lid_y * output_b_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; } } } @@ -321,6 +323,7 @@ __kernel void conv_2d_buf_subgroup_c1_c4_b8( __private const int output_width, __private const int output_height, __private const int output_channel, + __private const int batch, __private const int x_blocks, __private const int input_pad_left, __private const int input_pad_right, @@ -354,11 +357,11 @@ __kernel void conv_2d_buf_subgroup_c1_c4_b8( const uint output_x_pitch = 4; const uint output_y_pitch = output_x_pitch * output_width; const uint output_fs_pitch = output_y_pitch * output_height; - const uint output_b_pitch = output_fs_pitch * output_pack; + const uint output_b_pitch = output_fs_pitch * batch; - const uint output_offset = b * output_b_pitch + - f_block * 4 * output_fs_pitch + + const uint output_offset = b * output_fs_pitch + + f_block * 4 * output_b_pitch + y * output_y_pitch + x * output_x_pitch; @@ -434,13 +437,13 @@ __kernel void conv_2d_buf_subgroup_c1_c4_b8( if ((f_block+1)*16 >= output_channel) { for (int i = 0; i < 8 && (x + i) < output_width; i++) { if ((f_block*16 + lid_y * 4 < output_pack * 4)) - output[output_offset + lid_y * output_fs_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; + output[output_offset + lid_y * output_b_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; } } else { for (int i = 0; i < 8 && (x + i) < output_width; i++) { - output[output_offset + lid_y * output_fs_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; + output[output_offset + lid_y * output_b_pitch + i * output_x_pitch + lid_x] = (FLOAT)dst[i]; } } } @@ -458,6 +461,7 @@ __kernel void conv_2d_buf_subgroup_c1_c16_b2( __private const int output_width, __private const int output_height, __private const int output_channel, + __private const int batch, __private const int x_blocks, __private const int input_pad_left, __private const int input_pad_right, @@ -607,6 +611,7 @@ __kernel void conv_2d_buf_subgroup_c1_c16_b4( __private const int output_width, __private const int output_height, __private const int output_channel, + __private const int batch, __private const int x_blocks, __private const int input_pad_left, __private const int input_pad_right, @@ -756,6 +761,7 @@ __kernel void conv_2d_buf_subgroup_c1_c16_b8( __private const int output_width, __private const int output_height, __private const int output_channel, + __private const int batch, __private const int x_blocks, __private const int input_pad_left, __private const int input_pad_right, @@ -890,4 +896,4 @@ __kernel void conv_2d_buf_subgroup_c1_c16_b8( } } } -} \ No newline at end of file +} diff --git a/source/backend/opencl/execution/cl/conv_2d_int_buf.cl b/source/backend/opencl/execution/cl/conv_2d_int_buf.cl index aeed184ce..e42398c63 100644 --- a/source/backend/opencl/execution/cl/conv_2d_int_buf.cl +++ b/source/backend/opencl/execution/cl/conv_2d_int_buf.cl @@ -34,6 +34,7 @@ void conv_2d_int_c4h1w1(GLOBAL_SIZE_2_DIMS __private const int2 in_hw, __private const int inChannel, __private const int in_c_blocks, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 stride_hw, @@ -77,7 +78,7 @@ void conv_2d_int_c4h1w1(GLOBAL_SIZE_2_DIMS int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + kw_start) * 4; for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) { for(int ix = in_w_idx_start; ix < in_w_idx_end; ix += dilate_hw.y) { - int inp_offset = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + ix) * 4; + int inp_offset = (((out_b_idx + in_c_idx*batch) * in_hw.x + iy) * in_hw.y + ix) * 4; COMPUTE_FLOAT4 in0 = CONVERT_COMPUTE_FLOAT4(vload4(0, input+inp_offset)); const int filter_w_inc = (ix-in_w_idx_start)/dilate_hw.y; @@ -141,7 +142,7 @@ void conv_2d_int_c4h1w1(GLOBAL_SIZE_2_DIMS out0 = clamp(out0, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + const int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); } @@ -160,6 +161,7 @@ void conv_2d_int_c4h1w2(GLOBAL_SIZE_2_DIMS __private const int2 in_hw, __private const int inChannel, __private const int in_c_blocks, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 stride_hw, @@ -203,7 +205,7 @@ void conv_2d_int_c4h1w2(GLOBAL_SIZE_2_DIMS int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4; for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) { - const int inp_offset_base = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + 0) * 4; + const int inp_offset_base = (((out_b_idx + in_c_idx*batch) * in_hw.x + iy) * in_hw.y + 0) * 4; for(int fw = 0; fw < filter_hw.y; fw++) { const int in_w0_idx = fw * dilate_hw.y + in_w0_idx_base; @@ -278,7 +280,7 @@ void conv_2d_int_c4h1w2(GLOBAL_SIZE_2_DIMS out1 = clamp(out1, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + const int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); if(out_w_idx + 1 >= out_hw.y) return; @@ -302,6 +304,7 @@ void conv_2d_int_c4h1w4(GLOBAL_SIZE_2_DIMS __private const int2 in_hw, __private const int inChannel, __private const int in_c_blocks, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 stride_hw, @@ -349,7 +352,7 @@ void conv_2d_int_c4h1w4(GLOBAL_SIZE_2_DIMS int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4; for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) { - const int inp_offset_base = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + 0) * 4; + const int inp_offset_base = (((out_b_idx + in_c_idx*batch) * in_hw.x + iy) * in_hw.y + 0) * 4; for(int fw = 0; fw < filter_hw.y; fw++) { const int in_w0_idx = fw * dilate_hw.y + in_w0_idx_base; @@ -442,7 +445,7 @@ void conv_2d_int_c4h1w4(GLOBAL_SIZE_2_DIMS out3 = clamp(out3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + const int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.y - out_w_idx; @@ -475,6 +478,7 @@ void conv_2d_int_c4h4w1(GLOBAL_SIZE_2_DIMS __private const int2 in_hw, __private const int inChannel, __private const int in_c_blocks, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 stride_hw, @@ -520,7 +524,7 @@ void conv_2d_int_c4h4w1(GLOBAL_SIZE_2_DIMS COMPUTE_FLOAT4 offset = (COMPUTE_FLOAT4)(ScaleOffset.s1, ScaleOffset.s3, ScaleOffset.s5, ScaleOffset.s7); //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] - const int inp_offset_base = (out_b_idx * in_c_blocks + in_c_idx) * in_hw.x * in_hw.y * 4; + const int inp_offset_base = (out_b_idx + in_c_idx*batch) * in_hw.x * in_hw.y * 4; for(int iy = 0; iy < filter_hw.x; iy++) { int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4; @@ -615,7 +619,7 @@ void conv_2d_int_c4h4w1(GLOBAL_SIZE_2_DIMS out3 = clamp(out3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + const int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.x - out_h_idx; if(remain >= 4){ @@ -655,6 +659,7 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS __private const int2 in_hw, __private const int inChannel, __private const int in_c_blocks, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 stride_hw, @@ -709,7 +714,7 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS COMPUTE_FLOAT4 offset1 = (COMPUTE_FLOAT4)(ScaleOffset1.s1, ScaleOffset1.s3, ScaleOffset1.s5, ScaleOffset1.s7); //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] - const int inp_offset_base = (out_b_idx * in_c_blocks + in_c_idx) * in_hw.x * in_hw.y * 4; + const int inp_offset_base = (out_b_idx + in_c_idx*batch) * in_hw.x * in_hw.y * 4; for(int iy = 0; iy < filter_hw.x; iy++) { int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4; @@ -873,7 +878,7 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS out7 = clamp(out7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.x - out_h_idx; if(remain >= 4){ @@ -896,7 +901,7 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS return; } #endif - out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; if(remain >= 4){ vstore4(CONVERT_FLOAT4(out4), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out5), out_hw.y, output+out_offset); @@ -922,7 +927,7 @@ void conv_2d_int_c8h4w1(GLOBAL_SIZE_2_DIMS return; } #endif - out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; vstore4(CONVERT_FLOAT4(out4), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out5), out_hw.y, output+out_offset); vstore4(CONVERT_FLOAT4(out6), 2 * out_hw.y, output+out_offset); @@ -944,6 +949,7 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS __private const int2 in_hw, __private const int inChannel, __private const int in_c_blocks, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 stride_hw, @@ -993,7 +999,7 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS COMPUTE_FLOAT4 offset1 = (COMPUTE_FLOAT4)(ScaleOffset1.s1, ScaleOffset1.s3, ScaleOffset1.s5, ScaleOffset1.s7); //weights NC4HW4 [1, 4*icC4, ocC4*kh*kw, 1] xic4 //index: [0, 4*in_c_idx, out_c_idx*kh*kw + kh_start*kw + kw_start, 0] - const int inp_offset_base = (out_b_idx * in_c_blocks + in_c_idx) * in_hw.x * in_hw.y * 4; + const int inp_offset_base = (out_b_idx + in_c_idx*batch) * in_hw.x * in_hw.y * 4; for(int iy = 0; iy < filter_hw.x; iy++) { int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + iy)*filter_hw.y + kw_start) * 4; @@ -1122,7 +1128,7 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS out3 = clamp(out3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.x - out_h_idx; if(remain >= 2){ @@ -1136,7 +1142,7 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS return; } #endif - out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; if(remain >= 2){ vstore4(CONVERT_FLOAT4(out2), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out3), out_hw.y, output+out_offset); @@ -1151,7 +1157,7 @@ void conv_2d_int_c8h2w1(GLOBAL_SIZE_2_DIMS return; } #endif - out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; vstore4(CONVERT_FLOAT4(out2), 0, output+out_offset); vstore4(CONVERT_FLOAT4(out3), out_hw.y, output+out_offset); #endif @@ -1171,6 +1177,7 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int2 in_hw, __private const int inChannel, __private const int in_c_blocks, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 stride_hw, @@ -1227,7 +1234,7 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS int weight_offset = ((((4*in_c_idx+0)* out_c_blocks + out_c_idx) *filter_hw.x + kh_start)*filter_hw.y + 0) * 4; for(int iy = in_h_idx_start; iy < in_h_idx_end; iy += dilate_hw.x) { - const int inp_offset_base = (((out_b_idx * in_c_blocks + in_c_idx) * in_hw.x + iy) * in_hw.y + 0) * 4; + const int inp_offset_base = (((out_b_idx + in_c_idx*batch) * in_hw.x + iy) * in_hw.y + 0) * 4; for(int fw = 0; fw < filter_hw.y; fw++) { const int in_w0_idx = fw * dilate_hw.y + in_w0_idx_base; @@ -1389,7 +1396,7 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS out7 = clamp(out7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - int out_offset = (((out_b_idx*out_c_blocks + out_c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + int out_offset = (((out_b_idx + out_c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; #ifdef BLOCK_LEAVE const int remain = out_hw.y - out_w_idx; if(remain >= 4){ @@ -1405,7 +1412,7 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS #ifdef CHANNEL_LEAVE if(out_c_idx + 1 >= out_c_blocks)return; #endif - out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; if(remain >= 4){ vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4, out5, out6, out7)), 0, output+out_offset); }else if(remain == 3){ @@ -1421,7 +1428,7 @@ void conv_2d_int_c8h1w4(GLOBAL_SIZE_2_DIMS #ifdef CHANNEL_LEAVE if(out_c_idx + 1 >= out_c_blocks)return; #endif - out_offset = (((out_b_idx*out_c_blocks + out_c_idx + 1)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + out_offset = (((out_b_idx + (out_c_idx + 1)*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4, out5, out6, out7)), 0, output+out_offset); #endif } diff --git a/source/backend/opencl/execution/cl/deconv_2d.cl b/source/backend/opencl/execution/cl/deconv_2d.cl index 195806221..80fdf982b 100644 --- a/source/backend/opencl/execution/cl/deconv_2d.cl +++ b/source/backend/opencl/execution/cl/deconv_2d.cl @@ -17,7 +17,7 @@ __kernel void deconv_2d(GLOBAL_SIZE_3_DIMS #ifdef BIAS __global FLOAT* bias, #endif - __global FLOAT* output, + __global FLOAT* output, __private const int batch, #else __read_only image2d_t input, __read_only image2d_t weights, @@ -82,7 +82,7 @@ __kernel void deconv_2d(GLOBAL_SIZE_3_DIMS weights3 = vload4(kernel_x_3*(out_channel_blocks*kernel_shape.x*kernel_shape.y)+kernel_y, weights); bool outBoundry = (idx_h < 0 || idx_h >= input_shape.x || kernel_start_x < 0 || in_width0 >= input_shape.y); - int inp_offset = (((out_b_idx * in_channel_blocks + ic) * input_shape.x + idx_h) * input_shape.y + in_width0) * 4; + int inp_offset = (((out_b_idx + ic * batch) * input_shape.x + idx_h) * input_shape.y + in_width0) * 4; in0 = outBoundry ? (FLOAT4)0 : vload4(0, input+inp_offset); out0 = mad(in0.x, weights0, out0); @@ -127,7 +127,7 @@ __kernel void deconv_2d(GLOBAL_SIZE_3_DIMS #endif #ifdef USE_BUFFER - const int out_offset = (((out_b_idx*out_channel_blocks + out_channel_blocks_idx)*output_shape.x + out_h_idx)*output_shape.y + out_w_idx)*4; + const int out_offset = (((out_b_idx + out_channel_blocks_idx*batch)*output_shape.x + out_h_idx)*output_shape.y + out_w_idx)*4; vstore4(out0, 0, output+out_offset); #else int out_image_width_idx = mad24(out_channel_blocks_idx, output_shape.y, out_w_idx); diff --git a/source/backend/opencl/execution/cl/depthwise_conv2d_buf.cl b/source/backend/opencl/execution/cl/depthwise_conv2d_buf.cl index 586315962..c32400af9 100644 --- a/source/backend/opencl/execution/cl/depthwise_conv2d_buf.cl +++ b/source/backend/opencl/execution/cl/depthwise_conv2d_buf.cl @@ -23,7 +23,7 @@ void depthwise_conv2d_c4h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, __global const FLOAT *bias, __global FLOAT *output, __private const int2 in_hw, - __private const int channel, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 pad_hw, @@ -58,7 +58,7 @@ void depthwise_conv2d_c4h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, const int in_h_cur = in_h_start + kh * dilate_hw.x; if(in_h_cur < 0 || in_h_cur >= in_hw.x) continue; - int inp_offset = (((b_idx*c_blocks + c_idx)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; + int inp_offset = (((b_idx + c_idx*batch)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; for (int kw = 0; kw < filter_hw.y; kw++) { const int filter_idx = mad24(kh, filter_hw.y, kw); const int kw_dilate = kw * dilate_hw.y; @@ -92,7 +92,7 @@ void depthwise_conv2d_c4h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, outValue3 = clamp(outValue3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w4_idx)*4; + const int out_offset = (((b_idx + c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w4_idx)*4; const int remain = out_hw.y - out_w4_idx; if (remain >= 4) { @@ -119,7 +119,7 @@ void depthwise_conv2d_c4h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, __global const FLOAT *bias, __global FLOAT *output, __private const int2 in_hw, - __private const int channel, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 pad_hw, @@ -150,7 +150,7 @@ void depthwise_conv2d_c4h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, const int in_h_cur = in_h_start + kh * dilate_hw.x; if(in_h_cur < 0 || in_h_cur >= in_hw.x) continue; - int inp_offset = (((b_idx*c_blocks + c_idx)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; + int inp_offset = (((b_idx + c_idx*batch)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; for (int kw = 0; kw < filter_hw.y; kw++) { const int filter_idx = mad24(kh, filter_hw.y, kw); const int kw_dilate = kw * dilate_hw.y; @@ -176,7 +176,7 @@ void depthwise_conv2d_c4h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, outValue1 = clamp(outValue1, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w2_idx)*4; + const int out_offset = (((b_idx + c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w2_idx)*4; const int remain = out_hw.y - out_w2_idx; if (remain >= 2) { @@ -194,7 +194,7 @@ void depthwise_conv2d_c4h1w1(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, __global const FLOAT *bias, __global FLOAT *output, __private const int2 in_hw, - __private const int channel, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 pad_hw, @@ -222,7 +222,7 @@ void depthwise_conv2d_c4h1w1(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, const int in_h_cur = in_h_start + kh * dilate_hw.x; if(in_h_cur < 0 || in_h_cur >= in_hw.x) continue; - int inp_offset = (((b_idx*c_blocks + c_idx)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; + int inp_offset = (((b_idx + c_idx*batch)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; for (int kw = 0; kw < filter_hw.y; kw++) { const int filter_idx = mad24(kh, filter_hw.y, kw); const int kw_dilate = kw * dilate_hw.y; @@ -244,7 +244,7 @@ void depthwise_conv2d_c4h1w1(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, outValue0 = clamp(outValue0, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; + const int out_offset = (((b_idx + c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w_idx)*4; vstore4(CONVERT_FLOAT4(outValue0), 0, output+out_offset); } @@ -255,7 +255,7 @@ void depthwise_conv2d_s1_c8h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, __global const FLOAT *bias, __global FLOAT *output, __private const int2 in_hw, - __private const int channel, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 pad_hw, @@ -294,8 +294,8 @@ void depthwise_conv2d_s1_c8h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, const int in_h_cur = in_h_start + kh; if(in_h_cur < 0 || in_h_cur >= in_hw.x) continue; - int inp_offset_c0 = (((b_idx*c_blocks + c_idx+0)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; - int inp_offset_c1 = (((b_idx*c_blocks + c_idx+1)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; + int inp_offset_c0 = (((b_idx + c_idx*batch)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; + int inp_offset_c1 = (((b_idx + (c_idx+1)*batch)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; for (int kw = 0; kw < filter_hw.y; kw++) { const int filter_idx = mad24(kh, filter_hw.y, kw); COMPUTE_FLOAT4 inValue0 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0, input+inp_offset_c0)); @@ -349,7 +349,7 @@ void depthwise_conv2d_s1_c8h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, outValue7 = clamp(outValue7, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w4_idx)*4; + int out_offset = (((b_idx + c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w4_idx)*4; const int remain = out_hw.y - out_w4_idx; if (remain >= 4) { @@ -370,7 +370,7 @@ void depthwise_conv2d_s1_c8h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, if(c_idx + 1 >= c_blocks) return; - out_offset += out_hw.x * out_hw.y * 4; + out_offset += batch * out_hw.x * out_hw.y * 4; if (remain >= 4) { vstore4(CONVERT_FLOAT4(outValue4), 0, output+out_offset); vstore4(CONVERT_FLOAT4(outValue5), 1, output+out_offset); @@ -395,7 +395,7 @@ void depthwise_conv2d_s1_c8h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, __global const FLOAT *bias, __global FLOAT *output, __private const int2 in_hw, - __private const int channel, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 pad_hw, @@ -428,8 +428,8 @@ void depthwise_conv2d_s1_c8h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, const int in_h_cur = in_h_start + kh; if(in_h_cur < 0 || in_h_cur >= in_hw.x) continue; - int inp_offset_c0 = (((b_idx*c_blocks + c_idx+0)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; - int inp_offset_c1 = (((b_idx*c_blocks + c_idx+1)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; + int inp_offset_c0 = (((b_idx + c_idx*batch)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; + int inp_offset_c1 = (((b_idx + (c_idx+1)*batch)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; for (int kw = 0; kw < filter_hw.y; kw++) { const int filter_idx = mad24(kh, filter_hw.y, kw); COMPUTE_FLOAT4 inValue0 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0, input+inp_offset_c0)); @@ -467,7 +467,7 @@ void depthwise_conv2d_s1_c8h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, outValue5 = clamp(outValue5, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w2_idx)*4; + int out_offset = (((b_idx + c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w2_idx)*4; const int remain = out_hw.y - out_w2_idx; if (remain >= 2) { @@ -479,7 +479,7 @@ void depthwise_conv2d_s1_c8h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, if(c_idx + 1 >= c_blocks) return; - out_offset += out_hw.x * out_hw.y * 4; + out_offset += batch * out_hw.x * out_hw.y * 4; if (remain >= 2) { vstore4(CONVERT_FLOAT4(outValue4), 0, output+out_offset); vstore4(CONVERT_FLOAT4(outValue5), 1, output+out_offset); @@ -494,7 +494,7 @@ void depthwise_conv2d_s1_c4h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, __global const FLOAT *bias, __global FLOAT *output, __private const int2 in_hw, - __private const int channel, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 pad_hw, @@ -530,7 +530,7 @@ void depthwise_conv2d_s1_c4h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, const int in_h_cur = in_h_start + kh; if(in_h_cur < 0 || in_h_cur >= in_hw.x) continue; - int inp_offset = (((b_idx*c_blocks + c_idx)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; + int inp_offset = (((b_idx + c_idx*batch)*in_hw.x + in_h_cur)* in_hw.y + in_w_start_0)*4; for (int kw = 0; kw < filter_hw.y; kw++) { const int filter_idx = mad24(kh, filter_hw.y, kw); inValue0 = (in_w_start_0+kw < 0 || in_w_start_0+kw >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0, input+inp_offset)); @@ -563,7 +563,7 @@ void depthwise_conv2d_s1_c4h1w4(GLOBAL_SIZE_2_DIMS __global const FLOAT *input, outValue3 = clamp(outValue3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w4_idx)*4; + const int out_offset = (((b_idx + c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w4_idx)*4; const int remain = out_hw.y - out_w4_idx; if (remain >= 4) { @@ -590,7 +590,7 @@ void depthwise_conv2d_k3s1p1_c4h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *inp __global const FLOAT *bias, __global FLOAT *output, __private const int2 in_hw, - __private const int channel, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 pad_hw, @@ -617,7 +617,7 @@ void depthwise_conv2d_k3s1p1_c4h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *inp const int in_h_start = out_h_idx - pad_hw.x; COMPUTE_FLOAT4 inValue0, inValue1, inValue2, inValue3; //first line - const int inp_offset = (((b_idx*c_blocks + c_idx)*in_hw.x + in_h_start)* in_hw.y + in_w_start_0)*4; + const int inp_offset = (((b_idx + c_idx*batch)*in_hw.x + in_h_start)* in_hw.y + in_w_start_0)*4; inValue0 = (in_h_start < 0 || in_w_start_0 < 0 ) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, input+inp_offset)); inValue1 = (in_h_start < 0 || in_w_start_0+1 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(1, input+inp_offset)); inValue2 = (in_h_start < 0 || in_w_start_0+2 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(2, input+inp_offset)); @@ -690,7 +690,7 @@ void depthwise_conv2d_k3s1p1_c4h1w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *inp outValue1 = clamp(outValue1, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h_idx)*out_hw.y + out_w2_idx)*4; + const int out_offset = (((b_idx + c_idx*batch)*out_hw.x + out_h_idx)*out_hw.y + out_w2_idx)*4; const int remain = out_hw.y - out_w2_idx; if (remain >= 2) { @@ -708,7 +708,7 @@ void depthwise_conv2d_k3s1p1_c4h2w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *inp __global const FLOAT *bias, __global FLOAT *output, __private const int2 in_hw, - __private const int channel, + __private const int batch, __private const int2 out_hw, __private const int2 filter_hw, __private const int2 pad_hw, @@ -739,7 +739,7 @@ void depthwise_conv2d_k3s1p1_c4h2w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *inp const int in_h_start = out_h2_idx - pad_hw.x; COMPUTE_FLOAT4 inValue0, inValue1, inValue2, inValue3; //first line - const int inp_offset = (((b_idx*c_blocks + c_idx)*in_hw.x + in_h_start)* in_hw.y + in_w_start)*4; + const int inp_offset = (((b_idx + c_idx*batch)*in_hw.x + in_h_start)* in_hw.y + in_w_start)*4; inValue0 = (in_h_start < 0 || in_w_start < 0 ) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0, input+inp_offset)); inValue1 = (in_h_start < 0 || in_w_start+1 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(1, input+inp_offset)); inValue2 = (in_h_start < 0 || in_w_start+2 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(2, input+inp_offset)); @@ -830,7 +830,7 @@ void depthwise_conv2d_k3s1p1_c4h2w2(GLOBAL_SIZE_2_DIMS __global const FLOAT *inp outValue3 = clamp(outValue3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif - const int out_offset = (((b_idx*c_blocks + c_idx)*out_hw.x + out_h2_idx)*out_hw.y + out_w2_idx)*4; + const int out_offset = (((b_idx + c_idx*batch)*out_hw.x + out_h2_idx)*out_hw.y + out_w2_idx)*4; const int remain_w = out_hw.y - out_w2_idx; const int remain_h = out_hw.x - out_h2_idx; diff --git a/source/backend/opencl/execution/cl/depthwise_conv2d_subgroup_buf.cl b/source/backend/opencl/execution/cl/depthwise_conv2d_subgroup_buf.cl index 7d7698059..1bed1618e 100644 --- a/source/backend/opencl/execution/cl/depthwise_conv2d_subgroup_buf.cl +++ b/source/backend/opencl/execution/cl/depthwise_conv2d_subgroup_buf.cl @@ -12,6 +12,7 @@ __kernel void depthwise_conv_2d_buf_c16_c16( __private const int inputHeight, __private const int inputWidth, __private const int Channel, + __private const int Batch, __private const int input_pad_left, __private const int input_pad_right, __private const int outputHeight, @@ -130,6 +131,7 @@ __kernel void depthwise_conv_2d_buf_c16_c4( __private const int inputHeight, __private const int inputWidth, __private const int Channel, + __private const int Batch, __private const int input_pad_left, __private const int input_pad_right, __private const int outputHeight, @@ -167,10 +169,10 @@ __kernel void depthwise_conv_2d_buf_c16_c4( const uint output_x_pitch = 4; const uint output_y_pitch = output_x_pitch * outputWidth; const uint output_fs_pitch = output_y_pitch * outputHeight; - const uint output_b_pitch = output_fs_pitch * ((Channel + 3) / 4); + const uint output_b_pitch = output_fs_pitch * Batch; - const uint output_offset = b * output_b_pitch + - (c << 2) * output_fs_pitch + + const uint output_offset = (c << 2) * output_b_pitch + + b * output_fs_pitch + y * output_y_pitch + x * output_x_pitch; @@ -223,6 +225,6 @@ __kernel void depthwise_conv_2d_buf_c16_c4( const uint lid_x = sglid % 4; const uint lid_y = sglid / 4; for (int i = 0; i < 8 && (x + i) < outputWidth; i++) { - output[output_offset + lid_y * output_fs_pitch + i * output_x_pitch + lid_x] = dst[i]; + output[output_offset + lid_y * output_b_pitch + i * output_x_pitch + lid_x] = dst[i]; } -} \ No newline at end of file +} diff --git a/source/backend/opencl/execution/cl/gather_buf.cl b/source/backend/opencl/execution/cl/gather_buf.cl index 22b23dbe0..8af02b080 100644 --- a/source/backend/opencl/execution/cl/gather_buf.cl +++ b/source/backend/opencl/execution/cl/gather_buf.cl @@ -17,8 +17,6 @@ __kernel void batch_gather_buf(__private int global_dim0, __private int global_d __private const int4 stride_dst, __private const int2 steps, __private const int2 iters, - __private const int4 dst_c4size,// w, h, c, n - __private const int4 src_c4size,// w, h, c, n __private const int inputSize) { int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2)); @@ -28,91 +26,22 @@ __kernel void batch_gather_buf(__private int global_dim0, __private int global_d int y = pos.x / x_size; int2 index = (int2)(pos.z, pos.z); - - #ifdef OFFSET_DST - { - int offset_value = pos.z; - int off_c4_size = (offset_dst_shape.z + 3) >> 2; - #ifdef GATHER_INPUT_NHWC - int off_c = offset_value % offset_dst_shape.z; offset_value /= offset_dst_shape.z; - int off_w = offset_value % offset_dst_shape.x; offset_value /= offset_dst_shape.x; - int off_h = offset_value % offset_dst_shape.y; - int off_b = offset_value / offset_dst_shape.y; - #else - int off_w = offset_value % offset_dst_shape.x; offset_value /= offset_dst_shape.x; - int off_h = offset_value % offset_dst_shape.y; offset_value /= offset_dst_shape.y; - int off_c = offset_value % offset_dst_shape.z; - int off_b = offset_value / offset_dst_shape.z; - #endif - int real_dst_offset = (((off_b * off_c4_size + off_c / 4) * offset_dst_shape.y + off_h) * offset_dst_shape.x + off_w) * 4 + off_c % 4; - index.x = offset_dst_ptr[real_dst_offset]; - } - #endif - - #ifdef OFFSET_SRC - { - int offset_value = pos.z; - int off_c4_size = (offset_src_shape.z + 3) >> 2; - #ifdef GATHER_INPUT_NHWC - int off_c = offset_value % offset_src_shape.z; offset_value /= offset_src_shape.z; - int off_w = offset_value % offset_src_shape.x; offset_value /= offset_src_shape.x; - int off_h = offset_value % offset_src_shape.y; - int off_b = offset_value / offset_src_shape.y; - #else - int off_w = offset_value % offset_src_shape.x; offset_value /= offset_src_shape.x; - int off_h = offset_value % offset_src_shape.y; offset_value /= offset_src_shape.y; - int off_c = offset_value % offset_src_shape.z; - int off_b = offset_value / offset_src_shape.z; - #endif - int real_src_offset = (((off_b * off_c4_size + off_c / 4) * offset_src_shape.y + off_h) * offset_src_shape.x + off_w) * 4 + off_c % 4; - index.y = offset_src_ptr[real_src_offset]; - } - #endif - +#ifdef OFFSET_DST + index.x = offset_dst_ptr[pos.z]; +#endif + +#ifdef OFFSET_SRC + index.y = offset_src_ptr[pos.z]; +#endif int2 offset = index * steps; int src_offset = offset.y + stride_src.w + x * stride_src.x + y * stride_src.y + pos.y * stride_src.z; int dst_offset = offset.x + stride_dst.w + x * stride_dst.x + y * stride_dst.y + pos.y * stride_dst.z; - int src_offsetC4, dst_offsetC4; - { -#ifdef GATHER_INPUT_NHWC - int c = src_offset % src_c4size.z; src_offset /= src_c4size.z; - int w = src_offset % src_c4size.x; src_offset /= src_c4size.x; - int h = src_offset % src_c4size.y; - int b = src_offset / src_c4size.y; - int c4_size = (src_c4size.z + 3) / 4; - src_offsetC4 = (((b * c4_size + (c / 4)) * src_c4size.y + h) * src_c4size.x + w) * 4 + (c % 4); -#else - int w = src_offset % src_c4size.x; src_offset /= src_c4size.x; - int h = src_offset % src_c4size.y; src_offset /= src_c4size.y; - int c = src_offset % src_c4size.z; - int b = src_offset / src_c4size.z; - int c4_size = (src_c4size.z + 3) / 4; - src_offsetC4 = (((b * c4_size + (c / 4)) * src_c4size.y + h) * src_c4size.x + w) * 4 + (c % 4); -#endif - } - { -#ifdef GATHER_OUTPUT_NHWC - int c = dst_offset % dst_c4size.z; dst_offset /= dst_c4size.z; - int w = dst_offset % dst_c4size.x; dst_offset /= dst_c4size.x; - int h = dst_offset % dst_c4size.y; - int b = dst_offset / dst_c4size.y; - int c4_size = (dst_c4size.z + 3) / 4; - dst_offsetC4 = (((b * c4_size + (c / 4)) * dst_c4size.y + h) * dst_c4size.x + w) * 4 + (c % 4); -#else - int w = dst_offset % dst_c4size.x; dst_offset /= dst_c4size.x; - int h = dst_offset % dst_c4size.y; dst_offset /= dst_c4size.y; - int c = dst_offset % dst_c4size.z; - int b = dst_offset / dst_c4size.z; - int c4_size = (dst_c4size.z + 3) / 4; - dst_offsetC4 = (((b * c4_size + (c / 4)) * dst_c4size.y + h) * dst_c4size.x + w) * 4 + (c % 4); -#endif - } if(offset.x >= 0){ if(offset.y >= 0 && offset.y < inputSize){ - output[dst_offsetC4] = (OUTPUT_TYPE)input[src_offsetC4]; + output[dst_offset] = (OUTPUT_TYPE)input[src_offset]; }else{ - output[dst_offsetC4] = (OUTPUT_TYPE)(0); + output[dst_offset] = (OUTPUT_TYPE)(0); } } } diff --git a/source/backend/opencl/execution/cl/gemm_buf.cl b/source/backend/opencl/execution/cl/gemm_buf.cl index 903b62252..0e4fe3d46 100644 --- a/source/backend/opencl/execution/cl/gemm_buf.cl +++ b/source/backend/opencl/execution/cl/gemm_buf.cl @@ -10,118 +10,7 @@ return; \ } -__kernel void gemm_buf(GLOBAL_SIZE_DIM2 - __global const FLOAT* input0, - __global const FLOAT* input1, - __global FLOAT* output, - __private const int width,//UP_DIV(wUnit*hUnit,4) - __private const int height,//dstChannelC4 - __private const int srcChannelC4, - __private const int alpha2) { - int2 pos = (int2)(get_global_id(0), get_global_id(1)); - UNIFORM_BOUNDRY_CHECK(pos.x, pos.y); - - const int pos_x = pos.x % width; - const int pos_y = pos.x / width; - const int pos_z = pos.y; - - COMPUTE_FLOAT16 o = (COMPUTE_FLOAT16)0; - - int kenerlY = mad24(pos_z, height, pos_y); - - for (int k = 0; k < srcChannelC4; ++k) { - //NHWC [1, 1, alpha2*height, srcChannelC4*4] x 4 - //index:[0, 0, pos_z*width+pos_y, index+0] - //int inp1_offset = (((k * (alpha2*height) + kenerlY) * (srcChannelC4*4) + index)*4 + 0)*4; - - COMPUTE_FLOAT16 k_v16 = CONVERT_COMPUTE_FLOAT16(vload16(kenerlY*(srcChannelC4) + k, input1)); - - //NC4HW4 [alpha*alpha, srcChannelC4, width, 4] x 4 - //index: [pos_z, k, pos_x, 0] - - COMPUTE_FLOAT16 s = CONVERT_COMPUTE_FLOAT16(vload16(((pos_z*srcChannelC4 + k) * width + pos_x), input0)); - - o = mad((COMPUTE_FLOAT16)((COMPUTE_FLOAT4)s.s0, (COMPUTE_FLOAT4)s.s4, (COMPUTE_FLOAT4)s.s8, (COMPUTE_FLOAT4)s.sc), (COMPUTE_FLOAT16)(k_v16.s0123, k_v16.s0123, k_v16.s0123, k_v16.s0123), o); - o = mad((COMPUTE_FLOAT16)((COMPUTE_FLOAT4)s.s1, (COMPUTE_FLOAT4)s.s5, (COMPUTE_FLOAT4)s.s9, (COMPUTE_FLOAT4)s.sd), (COMPUTE_FLOAT16)(k_v16.s4567, k_v16.s4567, k_v16.s4567, k_v16.s4567), o); - o = mad((COMPUTE_FLOAT16)((COMPUTE_FLOAT4)s.s2, (COMPUTE_FLOAT4)s.s6, (COMPUTE_FLOAT4)s.sa, (COMPUTE_FLOAT4)s.se), (COMPUTE_FLOAT16)(k_v16.s89ab, k_v16.s89ab, k_v16.s89ab, k_v16.s89ab), o); - o = mad((COMPUTE_FLOAT16)((COMPUTE_FLOAT4)s.s3, (COMPUTE_FLOAT4)s.s7, (COMPUTE_FLOAT4)s.sb, (COMPUTE_FLOAT4)s.sf), (COMPUTE_FLOAT16)(k_v16.scdef, k_v16.scdef, k_v16.scdef, k_v16.scdef), o); - } - - //index: [pos_y, pos_z, 0, pos_x] - int out_offset = (((pos_y * alpha2 + pos_z) * 4 + 0) * width + pos_x) * 4; - - vstore4(CONVERT_FLOAT4(o.s0123), 0, output+out_offset); - vstore4(CONVERT_FLOAT4(o.s4567), 0, output+out_offset+4*width); - vstore4(CONVERT_FLOAT4(o.s89ab), 0, output+out_offset+8*width); - vstore4(CONVERT_FLOAT4(o.scdef), 0, output+out_offset+12*width); -} - - - -__kernel void gemm_buf2(GLOBAL_SIZE_DIM2 - __global const FLOAT* input0, - __global const FLOAT* input1, - __global FLOAT* output, - __private const int width,//UP_DIV(wUnit*hUnit,8) - __private const int height,//dstChannelC4 - __private const int srcChannelC4, - __private const int alpha2) { - int2 pos = (int2)(get_global_id(0), get_global_id(1)); - UNIFORM_BOUNDRY_CHECK(pos.x, pos.y); - - const int width_block = (width+1) >> 1; - const int pos_x = (pos.x % width_block) << 1; - const int pos_y = pos.x / width_block; - const int pos_z = pos.y; - - COMPUTE_FLOAT16 o0 = (COMPUTE_FLOAT16)0; - COMPUTE_FLOAT16 o1 = (COMPUTE_FLOAT16)0; - - const int kenerlY = mad24(pos_z, height, pos_y); - const int kernel_base = mul24(kenerlY, srcChannelC4); - const int inp_base = (pos_z*srcChannelC4 + 0) * width + pos_x; - - for (int k = 0; k < srcChannelC4; ++k) { - //NHWC [1, 1, alpha2*height, srcChannelC4*4] x 4 - //index:[0, 0, pos_z*width+pos_y, index+0] - //int inp1_offset = (((k * (alpha2*height) + kenerlY) * (srcChannelC4*4) + index)*4 + 0)*4; - - COMPUTE_FLOAT16 k_v16 = CONVERT_COMPUTE_FLOAT16(vload16(kernel_base + k, input1)); - - //NC4HW4 [alpha*alpha, srcChannelC4, width, 4] x 4 - //index: [pos_z, k, pos_x, 0] - - const int inp_offset = mad24(k, width, inp_base); - COMPUTE_FLOAT16 s = CONVERT_COMPUTE_FLOAT16(vload16(inp_offset, input0)); - - o0 = mad((COMPUTE_FLOAT16)((COMPUTE_FLOAT4)s.s0, (COMPUTE_FLOAT4)s.s4, (COMPUTE_FLOAT4)s.s8, (COMPUTE_FLOAT4)s.sc), (COMPUTE_FLOAT16)(k_v16.s0123, k_v16.s0123, k_v16.s0123, k_v16.s0123), o0); - o0 = mad((COMPUTE_FLOAT16)((COMPUTE_FLOAT4)s.s1, (COMPUTE_FLOAT4)s.s5, (COMPUTE_FLOAT4)s.s9, (COMPUTE_FLOAT4)s.sd), (COMPUTE_FLOAT16)(k_v16.s4567, k_v16.s4567, k_v16.s4567, k_v16.s4567), o0); - o0 = mad((COMPUTE_FLOAT16)((COMPUTE_FLOAT4)s.s2, (COMPUTE_FLOAT4)s.s6, (COMPUTE_FLOAT4)s.sa, (COMPUTE_FLOAT4)s.se), (COMPUTE_FLOAT16)(k_v16.s89ab, k_v16.s89ab, k_v16.s89ab, k_v16.s89ab), o0); - o0 = mad((COMPUTE_FLOAT16)((COMPUTE_FLOAT4)s.s3, (COMPUTE_FLOAT4)s.s7, (COMPUTE_FLOAT4)s.sb, (COMPUTE_FLOAT4)s.sf), (COMPUTE_FLOAT16)(k_v16.scdef, k_v16.scdef, k_v16.scdef, k_v16.scdef), o0); - - s = CONVERT_COMPUTE_FLOAT16(vload16(inp_offset + 1, input0)); - o1 = mad((COMPUTE_FLOAT16)((COMPUTE_FLOAT4)s.s0, (COMPUTE_FLOAT4)s.s4, (COMPUTE_FLOAT4)s.s8, (COMPUTE_FLOAT4)s.sc), (COMPUTE_FLOAT16)(k_v16.s0123, k_v16.s0123, k_v16.s0123, k_v16.s0123), o1); - o1 = mad((COMPUTE_FLOAT16)((COMPUTE_FLOAT4)s.s1, (COMPUTE_FLOAT4)s.s5, (COMPUTE_FLOAT4)s.s9, (COMPUTE_FLOAT4)s.sd), (COMPUTE_FLOAT16)(k_v16.s4567, k_v16.s4567, k_v16.s4567, k_v16.s4567), o1); - o1 = mad((COMPUTE_FLOAT16)((COMPUTE_FLOAT4)s.s2, (COMPUTE_FLOAT4)s.s6, (COMPUTE_FLOAT4)s.sa, (COMPUTE_FLOAT4)s.se), (COMPUTE_FLOAT16)(k_v16.s89ab, k_v16.s89ab, k_v16.s89ab, k_v16.s89ab), o1); - o1 = mad((COMPUTE_FLOAT16)((COMPUTE_FLOAT4)s.s3, (COMPUTE_FLOAT4)s.s7, (COMPUTE_FLOAT4)s.sb, (COMPUTE_FLOAT4)s.sf), (COMPUTE_FLOAT16)(k_v16.scdef, k_v16.scdef, k_v16.scdef, k_v16.scdef), o1); - } - - //index: [pos_y, pos_z, 0, pos_x] - int out_offset = (((pos_y * alpha2 + pos_z) * 4 + 0) * width + pos_x) * 4; - - vstore4(CONVERT_FLOAT4(o0.s0123), 0, output+out_offset); - vstore4(CONVERT_FLOAT4(o0.s4567), 0, output+out_offset+4*width); - vstore4(CONVERT_FLOAT4(o0.s89ab), 0, output+out_offset+8*width); - vstore4(CONVERT_FLOAT4(o0.scdef), 0, output+out_offset+12*width); - - if(pos_x + 1 >= width) return; - vstore4(CONVERT_FLOAT4(o1.s0123), 1, output+out_offset); - vstore4(CONVERT_FLOAT4(o1.s4567), 1, output+out_offset+4*width); - vstore4(CONVERT_FLOAT4(o1.s89ab), 1, output+out_offset+8*width); - vstore4(CONVERT_FLOAT4(o1.scdef), 1, output+out_offset+12*width); -} - -// [B, K/4, area, 4] -> [alignK, alignM] (M = B * area) +// [K/4, M, 4] -> [alignK, alignM] __kernel void transpose_pad(GLOBAL_SIZE_DIM2 const int alignM, const int alignK, @@ -131,7 +20,6 @@ __kernel void transpose_pad(GLOBAL_SIZE_DIM2 __global const FLOAT* input, __global FLOAT* output ) { -#ifdef AREA_EQUAL_1 const int idx_m4 = get_global_id(0); // idx M const int idx_k4 = get_global_id(1); // idx K UNIFORM_BOUNDRY_CHECK(idx_m4, idx_k4); @@ -139,71 +27,25 @@ __kernel void transpose_pad(GLOBAL_SIZE_DIM2 const int idx_m = idx_m4 << 2; const int idx_k = idx_k4 << 2; const int K_4 = (K + 3) >> 2; - const int in_offset_base = (idx_m * K_4 + idx_k4) * 4; + const int in_offset_base = (idx_k4 * M + idx_m) * 4; const int out_offset_base = idx_k * alignM + idx_m; - FLOAT4 m0k4 = (idx_k4 >= K_4 || idx_m + 0 >= M) ? (FLOAT4)0 : vload4(0, input + in_offset_base); - FLOAT4 m1k4 = (idx_k4 >= K_4 || idx_m + 1 >= M) ? (FLOAT4)0 : vload4(0, input + in_offset_base + (K_4 << 2)); - FLOAT4 m2k4 = (idx_k4 >= K_4 || idx_m + 2 >= M) ? (FLOAT4)0 : vload4(0, input + in_offset_base + (K_4 << 2) * 2); - FLOAT4 m3k4 = (idx_k4 >= K_4 || idx_m + 3 >= M) ? (FLOAT4)0 : vload4(0, input + in_offset_base + (K_4 << 2) * 3); - - vstore4((FLOAT4)(m0k4.x, m1k4.x, m2k4.x, m3k4.x), 0, output + out_offset_base); - vstore4((FLOAT4)(m0k4.y, m1k4.y, m2k4.y, m3k4.y), 0, output + out_offset_base + alignM); - vstore4((FLOAT4)(m0k4.z, m1k4.z, m2k4.z, m3k4.z), 0, output + out_offset_base + alignM + alignM); - vstore4((FLOAT4)(m0k4.w, m1k4.w, m2k4.w, m3k4.w), 0, output + out_offset_base + alignM + alignM + alignM); -#elif defined BATCH_EQUAL_1 - - const int idx_m4 = get_global_id(0); // idx M - const int idx_k4 = get_global_id(1); // idx K - UNIFORM_BOUNDRY_CHECK(idx_m4, idx_k4); - - const int idx_m = idx_m4 << 2; - const int idx_k = idx_k4 << 2; - const int K_4 = (K + 3) >> 2; - const int in_offset_base = (idx_k4 * area + idx_m) * 4; - const int out_offset_base = idx_k * alignM + idx_m; - FLOAT4 m0k4 = (idx_k4 >= K_4 || idx_m + 0 >= M) ? (FLOAT4)0 : vload4(0, input + in_offset_base); FLOAT4 m1k4 = (idx_k4 >= K_4 || idx_m + 1 >= M) ? (FLOAT4)0 : vload4(0, input + in_offset_base + 4); FLOAT4 m2k4 = (idx_k4 >= K_4 || idx_m + 2 >= M) ? (FLOAT4)0 : vload4(0, input + in_offset_base + 8); FLOAT4 m3k4 = (idx_k4 >= K_4 || idx_m + 3 >= M) ? (FLOAT4)0 : vload4(0, input + in_offset_base + 12); - + vstore4((FLOAT4)(m0k4.x, m1k4.x, m2k4.x, m3k4.x), 0, output + out_offset_base); vstore4((FLOAT4)(m0k4.y, m1k4.y, m2k4.y, m3k4.y), 0, output + out_offset_base + alignM); vstore4((FLOAT4)(m0k4.z, m1k4.z, m2k4.z, m3k4.z), 0, output + out_offset_base + alignM + alignM); vstore4((FLOAT4)(m0k4.w, m1k4.w, m2k4.w, m3k4.w), 0, output + out_offset_base + alignM + alignM + alignM); +} -#else - - const int idx_m = get_global_id(0); // idx M - const int idx_k4 = get_global_id(1); // idx K - UNIFORM_BOUNDRY_CHECK(idx_m, idx_k4); - - const int K_4 = (K + 3) >> 2; - const int idx_k = idx_k4 << 2; - const int out_offset_base = idx_k * alignM + idx_m; - - if(idx_k4 >= K_4 || idx_m >= M) { - output[out_offset_base] = (FLOAT)0; - output[out_offset_base + alignM] = (FLOAT)0; - output[out_offset_base + alignM + alignM] = (FLOAT)0; - output[out_offset_base + alignM + alignM + alignM] = (FLOAT)0; - return; - } - const int idx_b = idx_m / area; - const int idx_area = idx_m % area; - - const int in_offset_base = ((idx_b * K_4 + idx_k4) * area + idx_area) * 4; - FLOAT4 data = vload4(0, input + in_offset_base); - - output[out_offset_base] = data.x; - output[out_offset_base + alignM] = data.y; - output[out_offset_base + alignM + alignM] = data.z; - output[out_offset_base + alignM + alignM + alignM] = data.w; +#ifndef M_VEC +#define M_VEC 1 #endif -} -// [alignM, alignN] -> [B, N/4, area, 4] (M = B * area) +// [alignM, alignN] -> [N/4, B, area, N4] (M = B * area) __kernel void transpose_bias(GLOBAL_SIZE_DIM2 const int alignM, const int alignN, @@ -214,133 +56,24 @@ __kernel void transpose_bias(GLOBAL_SIZE_DIM2 __global const FLOAT* input1, __global FLOAT* output ) { -#ifdef AREA_EQUAL_1 - const int idx_m = get_global_id(0); // idx M - const int idx_n_16 = get_global_id(1); // idx N - UNIFORM_BOUNDRY_CHECK(idx_m, idx_n_16); + int idx_m = get_global_id(0); // idx M + int idx_n4 = get_global_id(1); // idx N + UNIFORM_BOUNDRY_CHECK(idx_m, idx_n4); - const int N_4 = (N + 3) >> 2; - const int N_16 = (N + 15) >> 4; - const int N_left = N & 15; - bool canVec16 = (N_left == 0 || (N_left != 0 && idx_n_16 < N_16 - 1)); - if(canVec16) { - FLOAT16 res0 = vload16(0, input0 + idx_m * alignN + (idx_n_16 << 4)); - FLOAT16 res1 = vload16(0, input1 + (idx_n_16 << 4)); - FLOAT16 res = res0 + res1; - #ifdef RELU - res = fmax(res, (FLOAT16)0); - #endif - #ifdef RELU6 - res = clamp(res, (FLOAT16)0, (FLOAT16)6); - #endif - vstore16(res, 0, output + ((idx_m * N_4 + (idx_n_16 << 2)) << 2)); - } else { + const int idx_n = idx_n4 << 2; - FLOAT4 res0 = vload4(0, input0 + idx_m * alignN + (idx_n_16 << 4)); - FLOAT4 res1 = vload4(0, input1 + (idx_n_16 << 4)); + idx_m = idx_m * M_VEC; + FLOAT4 res1 = vload4(0, input1 + idx_n); + #pragma unroll + for(int i = 0; i < M_VEC; i++) { + FLOAT4 res0 = vload4(0, input0 + (idx_m + i) * alignN + idx_n); FLOAT4 res = res0 + res1; #ifdef RELU - res = fmax(res, (FLOAT4)0); - #endif - #ifdef RELU6 - res = clamp(res, (FLOAT4)0, (FLOAT4)6); - #endif - vstore4(res, 0, output + ((idx_m * N_4 + (idx_n_16 << 2)) << 2)); - - if(idx_n_16 * 4 + 1 >= N_4) return; - res0 = vload4(0, input0 + idx_m * alignN + (idx_n_16 << 4) + 4); - res1 = vload4(0, input1 + (idx_n_16 << 4) + 4); - res = res0 + res1; - #ifdef RELU - res = fmax(res, (FLOAT4)0); - #endif - #ifdef RELU6 - res = clamp(res, (FLOAT4)0, (FLOAT4)6); - #endif - vstore4(res, 0, output + ((idx_m * N_4 + (idx_n_16 << 2)) << 2) + 4); - - if(idx_n_16 * 4 + 2 >= N_4) return; - res0 = vload4(0, input0 + idx_m * alignN + (idx_n_16 << 4) + 8); - res1 = vload4(0, input1 + (idx_n_16 << 4) + 8); - res = res0 + res1; - #ifdef RELU - res = fmax(res, (FLOAT4)0); - #endif - #ifdef RELU6 - res = clamp(res, (FLOAT4)0, (FLOAT4)6); - #endif - vstore4(res, 0, output + ((idx_m * N_4 + (idx_n_16 << 2)) << 2) + 8); - - if(idx_n_16 * 4 + 3 >= N_4) return; - res0 = vload4(0, input0 + idx_m * alignN + (idx_n_16 << 4) + 12); - res1 = vload4(0, input1 + (idx_n_16 << 4) + 12); - res = res0 + res1; - #ifdef RELU - res = fmax(res, (FLOAT4)0); + res = fmax(res, (FLOAT4)0); #endif #ifdef RELU6 - res = clamp(res, (FLOAT4)0, (FLOAT4)6); + res = clamp(res, (FLOAT4)0, (FLOAT4)6); #endif - vstore4(res, 0, output + ((idx_m * N_4 + (idx_n_16 << 2)) << 2) + 12); + vstore4(res, 0, output + ((idx_n4 * M + idx_m + i) << 2)); } -#else - const int idx_m = get_global_id(0); // idx M - const int idx_n_16 = get_global_id(1); // idx N - UNIFORM_BOUNDRY_CHECK(idx_m, idx_n_16); - - const int N_4 = (N + 3) >> 2; - - const int idx_b = idx_m / area; - const int idx_area = idx_m % area; - - const int inp_base_offset = idx_m * alignN + (idx_n_16 << 4); - const int out_base_offset = ((idx_b * N_4 + idx_n_16 * 4) * area + idx_area) * 4; - - FLOAT4 res0 = vload4(0, input0 + inp_base_offset); - FLOAT4 res1 = vload4(0, input1 + (idx_n_16 << 4)); - FLOAT4 res = res0 + res1; - #ifdef RELU - res = fmax(res, (FLOAT4)0); - #endif - #ifdef RELU6 - res = clamp(res, (FLOAT4)0, (FLOAT4)6); - #endif - vstore4(res, 0, output + out_base_offset); - - if(idx_n_16 * 4 + 1 >= N_4) return; - res0 = vload4(0, input0 + inp_base_offset + 4); - res1 = vload4(0, input1 + (idx_n_16 << 4) + 4); - res = res0 + res1; - #ifdef RELU - res = fmax(res, (FLOAT4)0); - #endif - #ifdef RELU6 - res = clamp(res, (FLOAT4)0, (FLOAT4)6); - #endif - vstore4(res, 0, output + out_base_offset + area * 4); - - if(idx_n_16 * 4 + 2 >= N_4) return; - res0 = vload4(0, input0 + inp_base_offset + 8); - res1 = vload4(0, input1 + (idx_n_16 << 4) + 8); - res = res0 + res1; - #ifdef RELU - res = fmax(res, (FLOAT4)0); - #endif - #ifdef RELU6 - res = clamp(res, (FLOAT4)0, (FLOAT4)6); - #endif - vstore4(res, 0, output + out_base_offset + area * 8); - - if(idx_n_16 * 4 + 3 >= N_4) return; - res0 = vload4(0, input0 + inp_base_offset + 12); - res1 = vload4(0, input1 + (idx_n_16 << 4) + 12); - res = res0 + res1; - #ifdef RELU - res = fmax(res, (FLOAT4)0); - #endif - #ifdef RELU6 - res = clamp(res, (FLOAT4)0, (FLOAT4)6); - #endif - vstore4(res, 0, output + out_base_offset + area * 12); -#endif } diff --git a/source/backend/opencl/execution/cl/gemm_conv1x1_buf.cl b/source/backend/opencl/execution/cl/gemm_conv1x1_buf.cl new file mode 100644 index 000000000..35304f433 --- /dev/null +++ b/source/backend/opencl/execution/cl/gemm_conv1x1_buf.cl @@ -0,0 +1,760 @@ +#ifdef MNN_SUPPORT_FP16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#endif + +#define GLOBAL_SIZE_DIM2 \ + __private int global_size_dim0, __private int global_size_dim1, + +#define UNIFORM_BOUNDRY_CHECK(index0, index1) \ + if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { \ + return; \ + } + +#define GLOBAL_SIZE_DIM3 \ + __private int global_size_dim0, __private int global_size_dim1, __private int global_size_dim2, + +#define UNIFORM_BOUNDRY_CHECK3(index0, index1, index2) \ + if(index0 >= global_size_dim0 || index1 >= global_size_dim1 || index2 >= global_size_dim2) { \ + return; \ + } + +#define UCHAR16_TO_2CHAR16(a, b, c) \ + a.s0 = (c.s0 >> 4) - 8; a.s1 = (c.s0 & 15) - 8; a.s2 = (c.s1 >> 4) - 8; a.s3 = (c.s1 & 15) - 8; a.s4 = (c.s2 >> 4) - 8; a.s5 = (c.s2 & 15) - 8; a.s6 = (c.s3 >> 4) - 8; a.s7 = (c.s3 & 15) - 8; \ + a.s8 = (c.s4 >> 4) - 8; a.s9 = (c.s4 & 15) - 8; a.sa = (c.s5 >> 4) - 8; a.sb = (c.s5 & 15) - 8; a.sc = (c.s6 >> 4) - 8; a.sd = (c.s6 & 15) - 8; a.se = (c.s7 >> 4) - 8; a.sf = (c.s7 & 15) - 8; \ + b.s0 = (c.s8 >> 4) - 8; b.s1 = (c.s8 & 15) - 8; b.s2 = (c.s9 >> 4) - 8; b.s3 = (c.s9 & 15) - 8; b.s4 = (c.sa >> 4) - 8; b.s5 = (c.sa & 15) - 8; b.s6 = (c.sb >> 4) - 8; b.s7 = (c.sb & 15) - 8; \ + b.s8 = (c.sc >> 4) - 8; b.s9 = (c.sc & 15) - 8; b.sa = (c.sd >> 4) - 8; b.sb = (c.sd & 15) - 8; b.sc = (c.se >> 4) - 8; b.sd = (c.se & 15) - 8; b.se = (c.sf >> 4) - 8; b.sf = (c.sf & 15) - 8; + +#define UCHAR8_TO_CHAR16(a, c) \ + a.s0 = (c.s0 >> 4) - 8; a.s1 = (c.s0 & 15) - 8; a.s2 = (c.s1 >> 4) - 8; a.s3 = (c.s1 & 15) - 8; a.s4 = (c.s2 >> 4) - 8; a.s5 = (c.s2 & 15) - 8; a.s6 = (c.s3 >> 4) - 8; a.s7 = (c.s3 & 15) - 8; \ + a.s8 = (c.s4 >> 4) - 8; a.s9 = (c.s4 & 15) - 8; a.sa = (c.s5 >> 4) - 8; a.sb = (c.s5 & 15) - 8; a.sc = (c.s6 >> 4) - 8; a.sd = (c.s6 & 15) - 8; a.se = (c.s7 >> 4) - 8; a.sf = (c.s7 & 15) - 8; + +#define DOT16X16(a, b, c) \ + c += dot(a.s0123, b.s0123); \ + c += dot(a.s4567, b.s4567); \ + c += dot(a.s89ab, b.s89ab); \ + c += dot(a.scdef, b.scdef); + +#if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE) +#define CHANNEL_PACK 32 +#else +#define CHANNEL_PACK 16 +#endif + +#if (defined USE_LOW_BIT_WEIGHT_INT8) +#define WEIGHT_STRIDE 16 +#elif (defined USE_LOW_BIT_WEIGHT_INT4) +#define WEIGHT_STRIDE 8 +#endif + +__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +#ifdef USE_IMAGE +inline COMPUTE_FLOAT16 readWeight(__read_only image2d_t weight, int ix, int iy, COMPUTE_FLOAT scale, COMPUTE_FLOAT offset){ + return CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(ix, iy)))) * scale + offset; +} +#else + +#if (defined USE_LOW_BIT_WEIGHT_INT8) +inline COMPUTE_FLOAT16 readWeight(__global const char *weight, int ix, int iy, COMPUTE_FLOAT scale, COMPUTE_FLOAT offset){ + return CONVERT_COMPUTE_FLOAT16(vload16(0, weight)) * scale + offset; +} +#elif (defined USE_LOW_BIT_WEIGHT_INT4) +inline COMPUTE_FLOAT16 readWeight(__global const uchar *weight, int ix, int iy, COMPUTE_FLOAT scale, COMPUTE_FLOAT offset){ + uchar16 charWeightsInt40 = vload16(0, weight); + uchar8 charWeightsInt4 = vload8(0, weight); + char16 charWeights = 0; + UCHAR8_TO_CHAR16(charWeights, charWeightsInt4); + return CONVERT_COMPUTE_FLOAT16(charWeights) * scale + offset; +} +#endif +#endif + +__kernel void inverse_quant_weight(GLOBAL_SIZE_DIM2 + #ifdef USE_IMAGE + __read_only image2d_t weight, + #else + #if (defined USE_LOW_BIT_WEIGHT_INT8) + __global const char *weight, + #elif (defined USE_LOW_BIT_WEIGHT_INT4) + __global const uchar *weight, + #endif + #endif + __global const float *dequantScaleOffset, + __global FLOAT* output, + __private const int outputChannelAlign, + __private const int outputChannel4Align, + __private const int blockDim){ + const int x = get_global_id(0); //ic + const int y = get_global_id(1); //oc + + UNIFORM_BOUNDRY_CHECK(x, y); + #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE) + + const int ic = x << 5; + const int oc = y << 2; + const int output_offset = ic * outputChannelAlign + oc; + + int kindex = (ic / blockDim) * outputChannel4Align * 2; + COMPUTE_FLOAT8 ScaleOffset = CONVERT_COMPUTE_FLOAT8(vload8(0, dequantScaleOffset + kindex + oc * 2)); + COMPUTE_FLOAT16 weights00, weights01, weights10, weights11, weights20, weights21, weights30, weights31; + { + uchar16 charWeightsInt40 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(oc, x))); + uchar16 charWeightsInt41 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(oc + 1, x))); + uchar16 charWeightsInt42 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(oc + 2, x))); + uchar16 charWeightsInt43 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(oc + 3, x))); + char16 charWeights0, charWeights1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt40); + weights00 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; + weights01 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s0 + ScaleOffset.s1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt41); + weights10 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s2 + ScaleOffset.s3; + weights11 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt42); + weights20 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s4 + ScaleOffset.s5; + weights21 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s4 + ScaleOffset.s5; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt43); + weights30 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s6 + ScaleOffset.s7; + weights31 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s6 + ScaleOffset.s7; + } + COMPUTE_FLOAT *weights00_ptr = (COMPUTE_FLOAT *)&weights00; + COMPUTE_FLOAT *weights10_ptr = (COMPUTE_FLOAT *)&weights10; + COMPUTE_FLOAT *weights20_ptr = (COMPUTE_FLOAT *)&weights20; + COMPUTE_FLOAT *weights30_ptr = (COMPUTE_FLOAT *)&weights30; + COMPUTE_FLOAT *weights01_ptr = (COMPUTE_FLOAT *)&weights01; + COMPUTE_FLOAT *weights11_ptr = (COMPUTE_FLOAT *)&weights11; + COMPUTE_FLOAT *weights21_ptr = (COMPUTE_FLOAT *)&weights21; + COMPUTE_FLOAT *weights31_ptr = (COMPUTE_FLOAT *)&weights31; + #pragma unroll + for (int i = 0; i < 16; ++i){ + FLOAT4 out = CONVERT_FLOAT4((COMPUTE_FLOAT4)(weights00_ptr[i], weights10_ptr[i], weights20_ptr[i], weights30_ptr[i])); + vstore4(out, 0, output+output_offset+i*outputChannelAlign); + } + #pragma unroll + for (int i = 0; i < 16; ++i){ + FLOAT4 out = CONVERT_FLOAT4((COMPUTE_FLOAT4)(weights01_ptr[i], weights11_ptr[i], weights21_ptr[i], weights31_ptr[i])); + vstore4(out, 0, output+output_offset+(i + 16)*outputChannelAlign); + } + #else + const int ic = x << 4; + const int oc = y << 2; +#ifndef USE_IMAGE + #if (defined USE_LOW_BIT_WEIGHT_INT4) + int weight_offset = oc * 8; + int weight_oc_offset = outputChannel4Align * 8; + int weight_stride = 8; + #else + int weight_offset = oc * 16; + int weight_oc_offset = outputChannel4Align * 16; + int weight_stride = 16; + #endif +#endif + const int output_offset = ic * outputChannelAlign + oc; + + int kindex = (ic / blockDim) * outputChannel4Align * 2; + COMPUTE_FLOAT8 ScaleOffset = CONVERT_COMPUTE_FLOAT8(vload8(0, dequantScaleOffset + kindex + oc * 2)); + #ifdef USE_IMAGE + COMPUTE_FLOAT16 weights0 = readWeight(weight, oc, x, ScaleOffset.s0, ScaleOffset.s1); + COMPUTE_FLOAT16 weights1 = readWeight(weight, oc + 1, x, ScaleOffset.s2, ScaleOffset.s3); + COMPUTE_FLOAT16 weights2 = readWeight(weight, oc + 2, x, ScaleOffset.s4, ScaleOffset.s5); + COMPUTE_FLOAT16 weights3 = readWeight(weight, oc + 3, x, ScaleOffset.s6, ScaleOffset.s7); + #else + COMPUTE_FLOAT16 weights0 = readWeight(weight + weight_offset + x * weight_oc_offset, 0, 0, ScaleOffset.s0, ScaleOffset.s1); + COMPUTE_FLOAT16 weights1 = readWeight(weight + weight_offset + x * weight_oc_offset + weight_stride, 0, 0, ScaleOffset.s2, ScaleOffset.s3); + COMPUTE_FLOAT16 weights2 = readWeight(weight + weight_offset + x * weight_oc_offset + 2 * weight_stride, 0, 0, ScaleOffset.s4, ScaleOffset.s5); + COMPUTE_FLOAT16 weights3 = readWeight(weight + weight_offset + x * weight_oc_offset + 3 * weight_stride, 0, 0, ScaleOffset.s6, ScaleOffset.s7); + #endif + COMPUTE_FLOAT *weights0_ptr = (COMPUTE_FLOAT*)&weights0; + COMPUTE_FLOAT *weights1_ptr = (COMPUTE_FLOAT*)&weights1; + COMPUTE_FLOAT *weights2_ptr = (COMPUTE_FLOAT*)&weights2; + COMPUTE_FLOAT *weights3_ptr = (COMPUTE_FLOAT*)&weights3; + #pragma unroll + for (int i = 0; i < 16; ++i){ + FLOAT4 out = CONVERT_FLOAT4((COMPUTE_FLOAT4)(weights0_ptr[i], weights1_ptr[i], weights2_ptr[i], weights3_ptr[i])); + vstore4(out, 0, output+output_offset+i*outputChannelAlign); + } + #endif +} + +__kernel void reshape_nchw4_nhwc4(GLOBAL_SIZE_DIM2 +__global const FLOAT* input, +__global FLOAT* output, +__private const int bhw, +__private const int channel, +__private const int channelAlign){ + const int x = get_global_id(0); //c + const int y = get_global_id(1); //bhw + + UNIFORM_BOUNDRY_CHECK(x, y); + + const int x4 = x << 2; + const int y4 = y << 2; + const int input_offset = (x * bhw + y4) * 4; + FLOAT4 in0 = vload4(0, input + input_offset); + FLOAT4 in1 = (y4 + 1 < bhw) ? vload4(0, input + input_offset + 4) : (FLOAT4)0; + FLOAT4 in2 = (y4 + 2 < bhw) ? vload4(0, input + input_offset + 8) : (FLOAT4)0; + FLOAT4 in3 = (y4 + 3 < bhw) ? vload4(0, input + input_offset + 12) : (FLOAT4)0; + +#ifdef INPUT_CHANNEL_LEAVE + if(x4 + 3 >= channel){ + FLOAT *in0_ptr = (FLOAT*)&in0; + FLOAT *in1_ptr = (FLOAT*)&in1; + FLOAT *in2_ptr = (FLOAT*)&in2; + FLOAT *in3_ptr = (FLOAT*)&in3; + int remain = x4 + 3 - channel; + for(int i = remain; i >= 0; i--){ + in0_ptr[3 - i] = 0; + in1_ptr[3 - i] = 0; + in2_ptr[3 - i] = 0; + in3_ptr[3 - i] = 0; + } + } +#endif + +#ifdef FORMAT_CNHW + int idx = x / 4; + int idy = x % 4; + const int bhw4 = (bhw + 3) / 4 * 4; + int output_offset = ((idx * bhw4 + y4) * 4 + idy) * 4; // [c/16 b 4 4] + vstore4(in0, 0, output+output_offset); + vstore4(in1, 0, output+output_offset+16); + vstore4(in2, 0, output+output_offset+32); + vstore4(in3, 0, output+output_offset+48); +#else + FLOAT16 out = (FLOAT16)(in0.s0, in1.s0, in2.s0, in3.s0, in0.s1, in1.s1, in2.s1, in3.s1, in0.s2, in1.s2, in2.s2, in3.s2, in0.s3, in1.s3, in2.s3, in3.s3); + const int output_offset = (y * channelAlign + x4) * 4; + vstore16(out, 0, output+output_offset); +#endif +} + +__kernel void reshape_nhwc4_nchw4(GLOBAL_SIZE_DIM2 +__global const FLOAT* input, +__global FLOAT* output, +__private const int bhw, +__private const int channelAlign){ + const int x = get_global_id(0); //c + const int y = get_global_id(1); //bhw + + UNIFORM_BOUNDRY_CHECK(x, y); + + const int x4 = x << 2; + const int y4 = y << 2; + const int output_offset = (x * bhw + y4) * 4; + + + const int input_offset = (y * channelAlign + x4) * 4; + FLOAT16 in = vload16(0, input + input_offset); + + FLOAT4 out0 = (FLOAT4)(in.s0, in.s4, in.s8, in.sc); + FLOAT4 out1 = (FLOAT4)(in.s1, in.s5, in.s9, in.sd); + FLOAT4 out2 = (FLOAT4)(in.s2, in.s6, in.sa, in.se); + FLOAT4 out3 = (FLOAT4)(in.s3, in.s7, in.sb, in.sf); + + vstore4(out0, 0, output+output_offset); + if(y4 + 1 >= bhw) return; + vstore4(out1, 0, output+output_offset+4); + if(y4 + 2 >= bhw) return; + vstore4(out2, 0, output+output_offset+8); + if(y4 + 3 >= bhw) return; + vstore4(out3, 0, output+output_offset+12); +} + + +__kernel void gemm_b4_c4_buf(GLOBAL_SIZE_DIM2 + __global const FLOAT* input, +#ifdef USE_IMAGE + __read_only image2d_t weight, +#else +#if (defined USE_LOW_BIT_WEIGHT_INT8) + __global const char *weight, +#elif (defined USE_LOW_BIT_WEIGHT_INT4) + __global const uchar *weight, +#endif +#endif + __global const float *dequantScaleOffset, + __global const FLOAT *bias, + __global FLOAT* output, + __private const int bhw4, + __private const int dstChannelAlign, + __private const int srcChannelAlign, + __private const int blockNum, + __private const int blockDim) { + const int x = get_global_id(0); //c + const int y = get_global_id(1); //b + + UNIFORM_BOUNDRY_CHECK(x, y); + + const int out_c_idx = x << 2; + const int out_b_idx = y << 2; + + COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(0, bias + out_c_idx)); + COMPUTE_FLOAT4 out = (COMPUTE_FLOAT4)bias0.s0; + COMPUTE_FLOAT4 out1 = (COMPUTE_FLOAT4)bias0.s1, out2 = (COMPUTE_FLOAT4)bias0.s2, out3 = (COMPUTE_FLOAT4)bias0.s3; + +#ifdef FORMAT_CNHW + int input_offset = out_b_idx * 16; +#else + int input_offset = out_b_idx * srcChannelAlign; +#endif + int out_offset = out_b_idx * dstChannelAlign + out_c_idx * 4; + +#ifndef USE_IMAGE + int weight_offset = out_c_idx * WEIGHT_STRIDE; + int weight_oc_offset = dstChannelAlign * WEIGHT_STRIDE; +#endif + + const int loop = (blockDim + CHANNEL_PACK - 1) / CHANNEL_PACK; + + for (int i = 0; i < blockNum; i++){ + int kindex = i * dstChannelAlign * 2; + COMPUTE_FLOAT8 ScaleOffset = CONVERT_COMPUTE_FLOAT8(vload8(0, dequantScaleOffset + kindex + out_c_idx * 2)); + for (int j = 0; j < loop; j++) { + int k = i * loop + j; + #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE) + COMPUTE_FLOAT16 weights00, weights01, weights10, weights11, weights20, weights21, weights30, weights31; + { + uchar16 charWeightsInt40 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(out_c_idx, k))); + uchar16 charWeightsInt41 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(out_c_idx + 1, k))); + uchar16 charWeightsInt42 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(out_c_idx + 2, k))); + uchar16 charWeightsInt43 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(out_c_idx + 3, k))); + char16 charWeights0, charWeights1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt40); + weights00 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; + weights01 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s0 + ScaleOffset.s1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt41); + weights10 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s2 + ScaleOffset.s3; + weights11 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt42); + weights20 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s4 + ScaleOffset.s5; + weights21 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s4 + ScaleOffset.s5; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt43); + weights30 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s6 + ScaleOffset.s7; + weights31 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s6 + ScaleOffset.s7; + } + #ifdef FORMAT_CNHW + int k2 = k << 1; + COMPUTE_FLOAT16 in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k2 * bhw4 * 16)); + DOT16X16(in, weights00, out.s0); + DOT16X16(in, weights10, out1.s0); + DOT16X16(in, weights20, out2.s0); + DOT16X16(in, weights30, out3.s0); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k2 * bhw4 * 16 + 16)); + DOT16X16(in, weights00, out.s1); + DOT16X16(in, weights10, out1.s1); + DOT16X16(in, weights20, out2.s1); + DOT16X16(in, weights30, out3.s1); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k2 * bhw4 * 16 + 32)); + DOT16X16(in, weights00, out.s2); + DOT16X16(in, weights10, out1.s2); + DOT16X16(in, weights20, out2.s2); + DOT16X16(in, weights30, out3.s2); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k2 * bhw4 * 16 + 48)); + DOT16X16(in, weights00, out.s3); + DOT16X16(in, weights10, out1.s3); + DOT16X16(in, weights20, out2.s3); + DOT16X16(in, weights30, out3.s3); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + (k2 + 1) * bhw4 * 16)); + DOT16X16(in, weights01, out.s0); + DOT16X16(in, weights11, out1.s0); + DOT16X16(in, weights21, out2.s0); + DOT16X16(in, weights31, out3.s0); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + (k2 + 1) * bhw4 * 16 + 16)); + DOT16X16(in, weights01, out.s1); + DOT16X16(in, weights11, out1.s1); + DOT16X16(in, weights21, out2.s1); + DOT16X16(in, weights31, out3.s1); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + (k2 + 1) * bhw4 * 16 + 32)); + DOT16X16(in, weights01, out.s2); + DOT16X16(in, weights11, out1.s2); + DOT16X16(in, weights21, out2.s2); + DOT16X16(in, weights31, out3.s2); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + (k2 + 1) * bhw4 * 16 + 48)); + DOT16X16(in, weights01, out.s3); + DOT16X16(in, weights11, out1.s3); + DOT16X16(in, weights21, out2.s3); + DOT16X16(in, weights31, out3.s3); + #else + int k32 = k << 5; + COMPUTE_FLOAT *weights00_ptr = (COMPUTE_FLOAT *)&weights00; + COMPUTE_FLOAT *weights10_ptr = (COMPUTE_FLOAT *)&weights10; + COMPUTE_FLOAT *weights20_ptr = (COMPUTE_FLOAT *)&weights20; + COMPUTE_FLOAT *weights30_ptr = (COMPUTE_FLOAT *)&weights30; + COMPUTE_FLOAT *weights01_ptr = (COMPUTE_FLOAT *)&weights01; + COMPUTE_FLOAT *weights11_ptr = (COMPUTE_FLOAT *)&weights11; + COMPUTE_FLOAT *weights21_ptr = (COMPUTE_FLOAT *)&weights21; + COMPUTE_FLOAT *weights31_ptr = (COMPUTE_FLOAT *)&weights31; + #pragma unroll + for (int i = 0; i < 16; ++i){ + COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k32 + i) * 4)); + out = mad(in, weights00_ptr[i], out); + out1 = mad(in, weights10_ptr[i], out1); + out2 = mad(in, weights20_ptr[i], out2); + out3 = mad(in, weights30_ptr[i], out3); + } + #pragma unroll + for (int i = 0; i < 16; ++i){ + COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k32 + i + 16) * 4)); + out = mad(in, weights01_ptr[i], out); + out1 = mad(in, weights11_ptr[i], out1); + out2 = mad(in, weights21_ptr[i], out2); + out3 = mad(in, weights31_ptr[i], out3); + } + #endif + #else + COMPUTE_FLOAT16 weights0, weights1, weights2, weights3; + #ifdef USE_IMAGE + weights0 = readWeight(weight, out_c_idx, k, ScaleOffset.s0, ScaleOffset.s1); + weights1 = readWeight(weight, out_c_idx + 1, k, ScaleOffset.s2, ScaleOffset.s3); + weights2 = readWeight(weight, out_c_idx + 2, k, ScaleOffset.s4, ScaleOffset.s5); + weights3 = readWeight(weight, out_c_idx + 3, k, ScaleOffset.s6, ScaleOffset.s7); + #else + weights0 = readWeight(weight + weight_offset + k * weight_oc_offset, 0, 0, ScaleOffset.s0, ScaleOffset.s1); + weights1 = readWeight(weight + weight_offset + k * weight_oc_offset + WEIGHT_STRIDE, 0, 0, ScaleOffset.s2, ScaleOffset.s3); + weights2 = readWeight(weight + weight_offset + k * weight_oc_offset + 2 * WEIGHT_STRIDE, 0, 0, ScaleOffset.s4, ScaleOffset.s5); + weights3 = readWeight(weight + weight_offset + k * weight_oc_offset + 3 * WEIGHT_STRIDE, 0, 0, ScaleOffset.s6, ScaleOffset.s7); + #endif + #ifdef FORMAT_CNHW + COMPUTE_FLOAT16 in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k * bhw4 * 16)); + DOT16X16(in, weights0, out.s0); + DOT16X16(in, weights1, out1.s0); + DOT16X16(in, weights2, out2.s0); + DOT16X16(in, weights3, out3.s0); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k * bhw4 * 16 + 16)); + DOT16X16(in, weights0, out.s1); + DOT16X16(in, weights1, out1.s1); + DOT16X16(in, weights2, out2.s1); + DOT16X16(in, weights3, out3.s1); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k * bhw4 * 16 + 32)); + DOT16X16(in, weights0, out.s2); + DOT16X16(in, weights1, out1.s2); + DOT16X16(in, weights2, out2.s2); + DOT16X16(in, weights3, out3.s2); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k * bhw4 * 16 + 48)); + DOT16X16(in, weights0, out.s3); + DOT16X16(in, weights1, out1.s3); + DOT16X16(in, weights2, out2.s3); + DOT16X16(in, weights3, out3.s3); + #else + int k16 = k << 4; + COMPUTE_FLOAT *weights0_ptr = (COMPUTE_FLOAT *)&weights0; + COMPUTE_FLOAT *weights1_ptr = (COMPUTE_FLOAT *)&weights1; + COMPUTE_FLOAT *weights2_ptr = (COMPUTE_FLOAT *)&weights2; + COMPUTE_FLOAT *weights3_ptr = (COMPUTE_FLOAT *)&weights3; + #pragma unroll + for (int i = 0; i < 16; ++i){ + COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k16 + i) * 4)); + out = mad(in, weights0_ptr[i], out); + out1 = mad(in, weights1_ptr[i], out1); + out2 = mad(in, weights2_ptr[i], out2); + out3 = mad(in, weights3_ptr[i], out3); + } + #endif + #endif + } + } +#ifdef RELU + out = fmax(out, (COMPUTE_FLOAT4)0); + out1 = fmax(out1, (COMPUTE_FLOAT4)0); + out2 = fmax(out2, (COMPUTE_FLOAT4)0); + out3 = fmax(out3, (COMPUTE_FLOAT4)0); +#endif + +#ifdef RELU6 + out = clamp(out, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); + out1 = clamp(out1, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); + out2 = clamp(out2, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); + out3 = clamp(out3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); +#endif + + vstore4(CONVERT_FLOAT4(out), 0, output+out_offset); + vstore4(CONVERT_FLOAT4(out1), 0, output+out_offset + 4); + vstore4(CONVERT_FLOAT4(out2), 0, output+out_offset + 8); + vstore4(CONVERT_FLOAT4(out3), 0, output+out_offset + 12); +} + +__kernel void gemm_b4_c2_buf(GLOBAL_SIZE_DIM2 + __global const FLOAT* input, +#ifdef USE_IMAGE + __read_only image2d_t weight, +#else +#if (defined USE_LOW_BIT_WEIGHT_INT8) + __global const char *weight, +#elif (defined USE_LOW_BIT_WEIGHT_INT4) + __global const uchar *weight, +#endif +#endif + __global const float *dequantScaleOffset, + __global const FLOAT *bias, + __global FLOAT* output, + __private const int bhw4, + __private const int dstChannelAlign, + __private const int srcChannelAlign, + __private const int blockNum, + __private const int blockDim) { + const int x = get_global_id(0); //c + const int y = get_global_id(1); //b + + UNIFORM_BOUNDRY_CHECK(x, y); + + const int out_c_idx = x << 1; + const int out_b_idx = y << 2; + + COMPUTE_FLOAT2 bias0 = CONVERT_COMPUTE_FLOAT2(vload2(0, bias + out_c_idx)); + COMPUTE_FLOAT4 out = (COMPUTE_FLOAT4)bias0.s0; + COMPUTE_FLOAT4 out1 = (COMPUTE_FLOAT4)bias0.s1; + +#ifdef FORMAT_CNHW + int input_offset = out_b_idx * 16; +#else + int input_offset = out_b_idx * srcChannelAlign; +#endif + int out_offset = out_b_idx * dstChannelAlign + out_c_idx * 4; + +#ifndef USE_IMAGE + int weight_offset = out_c_idx * WEIGHT_STRIDE; + int weight_oc_offset = dstChannelAlign * WEIGHT_STRIDE; +#endif + + const int loop = (blockDim + CHANNEL_PACK - 1) / CHANNEL_PACK; + + for (int i = 0; i < blockNum; i++){ + int kindex = i * dstChannelAlign * 2; + COMPUTE_FLOAT4 ScaleOffset = CONVERT_COMPUTE_FLOAT4(vload4(0, dequantScaleOffset + kindex + out_c_idx * 2)); + for (int j = 0; j < loop; j++) { + int k = i * loop + j; + #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE) + COMPUTE_FLOAT16 weights00, weights01, weights10, weights11; + { + uchar16 charWeightsInt40 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(out_c_idx, k))); + uchar16 charWeightsInt41 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(out_c_idx + 1, k))); + char16 charWeights0, charWeights1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt40); + weights00 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; + weights01 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s0 + ScaleOffset.s1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt41); + weights10 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s2 + ScaleOffset.s3; + weights11 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; + } + #ifdef FORMAT_CNHW + int k2 = k << 1; + COMPUTE_FLOAT16 in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k2 * bhw4 * 16)); + DOT16X16(in, weights00, out.s0); + DOT16X16(in, weights10, out1.s0); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k2 * bhw4 * 16 + 16)); + DOT16X16(in, weights00, out.s1); + DOT16X16(in, weights10, out1.s1); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k2 * bhw4 * 16 + 32)); + DOT16X16(in, weights00, out.s2); + DOT16X16(in, weights10, out1.s2); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k2 * bhw4 * 16 + 48)); + DOT16X16(in, weights00, out.s3); + DOT16X16(in, weights10, out1.s3); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + (k2 + 1) * bhw4 * 16)); + DOT16X16(in, weights01, out.s0); + DOT16X16(in, weights11, out1.s0); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + (k2 + 1) * bhw4 * 16 + 16)); + DOT16X16(in, weights01, out.s1); + DOT16X16(in, weights11, out1.s1); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + (k2 + 1) * bhw4 * 16 + 32)); + DOT16X16(in, weights01, out.s2); + DOT16X16(in, weights11, out1.s2); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + (k2 + 1) * bhw4 * 16 + 48)); + DOT16X16(in, weights01, out.s3); + DOT16X16(in, weights11, out1.s3); + #else + int k32 = k << 5; + COMPUTE_FLOAT *weights00_ptr = (COMPUTE_FLOAT *)&weights00; + COMPUTE_FLOAT *weights10_ptr = (COMPUTE_FLOAT *)&weights10; + COMPUTE_FLOAT *weights01_ptr = (COMPUTE_FLOAT *)&weights01; + COMPUTE_FLOAT *weights11_ptr = (COMPUTE_FLOAT *)&weights11; + #pragma unroll + for (int i = 0; i < 16; ++i){ + COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k32 + i) * 4)); + out = mad(in, weights00_ptr[i], out); + out1 = mad(in, weights10_ptr[i], out1); + } + #pragma unroll + for (int i = 0; i < 16; ++i){ + COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k32 + i + 16) * 4)); + out = mad(in, weights01_ptr[i], out); + out1 = mad(in, weights11_ptr[i], out1); + } + #endif + #else + COMPUTE_FLOAT16 weights0, weights1; + #ifdef USE_IMAGE + weights0 = readWeight(weight, out_c_idx, k, ScaleOffset.s0, ScaleOffset.s1); + weights1 = readWeight(weight, out_c_idx + 1, k, ScaleOffset.s2, ScaleOffset.s3); + #else + weights0 = readWeight(weight + weight_offset + k * weight_oc_offset, 0, 0, ScaleOffset.s0, ScaleOffset.s1); + weights1 = readWeight(weight + weight_offset + k * weight_oc_offset + WEIGHT_STRIDE, 0, 0, ScaleOffset.s2, ScaleOffset.s3); + #endif + #ifdef FORMAT_CNHW + COMPUTE_FLOAT16 in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k * bhw4 * 16)); + DOT16X16(in, weights0, out.s0); + DOT16X16(in, weights1, out1.s0); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k * bhw4 * 16 + 16)); + DOT16X16(in, weights0, out.s1); + DOT16X16(in, weights1, out1.s1); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k * bhw4 * 16 + 32)); + DOT16X16(in, weights0, out.s2); + DOT16X16(in, weights1, out1.s2); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k * bhw4 * 16 + 48)); + DOT16X16(in, weights0, out.s3); + DOT16X16(in, weights1, out1.s3); + #else + int k16 = k << 4; + COMPUTE_FLOAT *weights0_ptr = (COMPUTE_FLOAT *)&weights0; + COMPUTE_FLOAT *weights1_ptr = (COMPUTE_FLOAT *)&weights1; + #pragma unroll + for (int i = 0; i < 16; ++i){ + COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k16 + i) * 4)); + out = mad(in, weights0_ptr[i], out); + out1 = mad(in, weights1_ptr[i], out1); + } + #endif + #endif + } + } + +#ifdef RELU + out = fmax(out, (COMPUTE_FLOAT4)0); + out1 = fmax(out1, (COMPUTE_FLOAT4)0); +#endif + +#ifdef RELU6 + out = clamp(out, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); + out1 = clamp(out1, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); +#endif + + vstore4(CONVERT_FLOAT4(out), 0, output+out_offset); + vstore4(CONVERT_FLOAT4(out1), 0, output+out_offset+4); +} + +__kernel void gemm_b4_c1_buf(GLOBAL_SIZE_DIM2 + __global const FLOAT* input, +#ifdef USE_IMAGE + __read_only image2d_t weight, +#else +#if (defined USE_LOW_BIT_WEIGHT_INT8) + __global const char *weight, +#elif (defined USE_LOW_BIT_WEIGHT_INT4) + __global const uchar *weight, +#endif +#endif + __global const float *dequantScaleOffset, + __global const FLOAT *bias, + __global FLOAT* output, + __private const int bhw4, + __private const int dstChannelAlign, + __private const int srcChannelAlign, + __private const int blockNum, + __private const int blockDim) { + const int x = get_global_id(0); //c + const int y = get_global_id(1); //b + + UNIFORM_BOUNDRY_CHECK(x, y); + + const int out_c_idx = x; + const int out_b_idx = y << 2; + + COMPUTE_FLOAT bias0 = bias[out_c_idx]; + COMPUTE_FLOAT4 out = (COMPUTE_FLOAT4)bias0; + +#ifdef FORMAT_CNHW + int input_offset = out_b_idx * 16; +#else + int input_offset = out_b_idx * srcChannelAlign; +#endif + int out_offset = out_b_idx * dstChannelAlign + out_c_idx * 4; + +#ifndef USE_IMAGE + int weight_offset = out_c_idx * WEIGHT_STRIDE; + int weight_oc_offset = dstChannelAlign * WEIGHT_STRIDE; +#endif + + const int loop = (blockDim + CHANNEL_PACK - 1) / CHANNEL_PACK; + + for (int i = 0; i < blockNum; i++){ + int kindex = i * dstChannelAlign * 2; + COMPUTE_FLOAT2 ScaleOffset = CONVERT_COMPUTE_FLOAT2(vload2(out_c_idx, dequantScaleOffset + kindex)); + for (int j = 0; j < loop; j++) { + int k = i * loop + j; + #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE) + COMPUTE_FLOAT16 weights00, weights01, weights10, weights11; + { + uchar16 charWeightsInt40 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(out_c_idx, k))); + char16 charWeights0, charWeights1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt40); + weights00 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; + weights01 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s0 + ScaleOffset.s1; + } + #ifdef FORMAT_CNHW + int k2 = k << 1; + COMPUTE_FLOAT16 in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k2 * bhw4 * 16)); + DOT16X16(in, weights00, out.s0); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k2 * bhw4 * 16 + 16)); + DOT16X16(in, weights00, out.s1); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k2 * bhw4 * 16 + 32)); + DOT16X16(in, weights00, out.s2); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k2 * bhw4 * 16 + 48)); + DOT16X16(in, weights00, out.s3); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + (k2 + 1) * bhw4 * 16)); + DOT16X16(in, weights01, out.s0); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + (k2 + 1) * bhw4 * 16 + 16)); + DOT16X16(in, weights01, out.s1); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + (k2 + 1) * bhw4 * 16 + 32)); + DOT16X16(in, weights01, out.s2); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + (k2 + 1) * bhw4 * 16 + 48)); + DOT16X16(in, weights01, out.s3); + #else + int k32 = k << 5; + COMPUTE_FLOAT *weights00_ptr = (COMPUTE_FLOAT *)&weights00; + COMPUTE_FLOAT *weights01_ptr = (COMPUTE_FLOAT *)&weights01; + #pragma unroll + for (int i = 0; i < 16; ++i){ + COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k32 + i) * 4)); + out = mad(in, weights00_ptr[i], out); + } + #pragma unroll + for (int i = 0; i < 16; ++i){ + COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k32 + i + 16) * 4)); + out = mad(in, weights01_ptr[i], out); + } + #endif + #else + COMPUTE_FLOAT16 weights; + #ifdef USE_IMAGE + weights = readWeight(weight, out_c_idx, k, ScaleOffset.s0, ScaleOffset.s1); + #else + weights = readWeight(weight + weight_offset + k * weight_oc_offset, 0, 0, ScaleOffset.s0, ScaleOffset.s1); + #endif + #ifdef FORMAT_CNHW + COMPUTE_FLOAT16 in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k * bhw4 * 16)); + DOT16X16(in, weights, out.s0); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k * bhw4 * 16 + 16)); + DOT16X16(in, weights, out.s1); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k * bhw4 * 16 + 32)); + DOT16X16(in, weights, out.s2); + in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k * bhw4 * 16 + 48)); + DOT16X16(in, weights, out.s3); + #else + int k16 = k << 4; + COMPUTE_FLOAT *weights_ptr = (COMPUTE_FLOAT *)&weights; + #pragma unroll + for (int i = 0; i < 16; ++i){ + COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k16 + i) * 4)); + out = mad(in, weights_ptr[i], out); + } + #endif + #endif + } + } + +#ifdef RELU + out = fmax(out, (COMPUTE_FLOAT4)0); +#endif + +#ifdef RELU6 + out = clamp(out, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); +#endif + vstore4(CONVERT_FLOAT4(out), 0, output+out_offset); +} diff --git a/source/backend/opencl/execution/cl/gemm_quant_batch_buf.cl b/source/backend/opencl/execution/cl/gemm_quant_batch_buf.cl deleted file mode 100644 index 083268503..000000000 --- a/source/backend/opencl/execution/cl/gemm_quant_batch_buf.cl +++ /dev/null @@ -1,821 +0,0 @@ -#ifdef MNN_SUPPORT_FP16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#endif - -#define GLOBAL_SIZE_DIM2 \ - __private int global_size_dim0, __private int global_size_dim1, - -#define UNIFORM_BOUNDRY_CHECK(index0, index1) \ - if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { \ - return; \ - } - -#define GLOBAL_SIZE_DIM3 \ - __private int global_size_dim0, __private int global_size_dim1, __private int global_size_dim2, - -#define UNIFORM_BOUNDRY_CHECK3(index0, index1, index2) \ - if(index0 >= global_size_dim0 || index1 >= global_size_dim1 || index2 >= global_size_dim2) { \ - return; \ - } - -#define UCHAR16_TO_2CHAR16(a, b, c) \ - a.s0 = (c.s0 >> 4) - 8; a.s1 = (c.s0 & 15) - 8; a.s2 = (c.s1 >> 4) - 8; a.s3 = (c.s1 & 15) - 8; a.s4 = (c.s2 >> 4) - 8; a.s5 = (c.s2 & 15) - 8; a.s6 = (c.s3 >> 4) - 8; a.s7 = (c.s3 & 15) - 8; \ - a.s8 = (c.s4 >> 4) - 8; a.s9 = (c.s4 & 15) - 8; a.sa = (c.s5 >> 4) - 8; a.sb = (c.s5 & 15) - 8; a.sc = (c.s6 >> 4) - 8; a.sd = (c.s6 & 15) - 8; a.se = (c.s7 >> 4) - 8; a.sf = (c.s7 & 15) - 8; \ - b.s0 = (c.s8 >> 4) - 8; b.s1 = (c.s8 & 15) - 8; b.s2 = (c.s9 >> 4) - 8; b.s3 = (c.s9 & 15) - 8; b.s4 = (c.sa >> 4) - 8; b.s5 = (c.sa & 15) - 8; b.s6 = (c.sb >> 4) - 8; b.s7 = (c.sb & 15) - 8; \ - b.s8 = (c.sc >> 4) - 8; b.s9 = (c.sc & 15) - 8; b.sa = (c.sd >> 4) - 8; b.sb = (c.sd & 15) - 8; b.sc = (c.se >> 4) - 8; b.sd = (c.se & 15) - 8; b.se = (c.sf >> 4) - 8; b.sf = (c.sf & 15) - 8; - -#define UCHAR8_TO_CHAR16(a, c) \ - a.s0 = (c.s0 >> 4) - 8; a.s1 = (c.s0 & 15) - 8; a.s2 = (c.s1 >> 4) - 8; a.s3 = (c.s1 & 15) - 8; a.s4 = (c.s2 >> 4) - 8; a.s5 = (c.s2 & 15) - 8; a.s6 = (c.s3 >> 4) - 8; a.s7 = (c.s3 & 15) - 8; \ - a.s8 = (c.s4 >> 4) - 8; a.s9 = (c.s4 & 15) - 8; a.sa = (c.s5 >> 4) - 8; a.sb = (c.s5 & 15) - 8; a.sc = (c.s6 >> 4) - 8; a.sd = (c.s6 & 15) - 8; a.se = (c.s7 >> 4) - 8; a.sf = (c.s7 & 15) - 8; - -#define DOT16X16(a, b, c) \ - c += dot(a.s0123, b.s0123); \ - c += dot(a.s4567, b.s4567); \ - c += dot(a.s89ab, b.s89ab); \ - c += dot(a.scdef, b.scdef); - -__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; - -__kernel void reshape_nchw4_nhwc4(GLOBAL_SIZE_DIM3 -__global const FLOAT* input, -__global FLOAT* output, -__private const int width_height, -__private const int batch, -__private const int channel, -__private const int channelC4){ - const int x = get_global_id(0); //c - const int y = get_global_id(1); //b - const int wh = get_global_id(2); // w*h - - UNIFORM_BOUNDRY_CHECK3(x, y, wh); - - const int x4 = x << 2; - const int y4 = y << 2; - const int channel4 = channelC4 * 4; - const int stride = channel4 * width_height; - const int input_offset = (y4 * channel4 + x4) * width_height + wh * 4; - const int output_offset = ((y * width_height + wh) * channel4 + x4) * 4; - FLOAT4 in0 = vload4(0, input + input_offset); - FLOAT4 in1 = (y4 + 1 < batch) ? vload4(0, input + input_offset + stride) : (FLOAT4)0; - FLOAT4 in2 = (y4 + 2 < batch) ? vload4(0, input + input_offset + 2 * stride) : (FLOAT4)0; - FLOAT4 in3 = (y4 + 3 < batch) ? vload4(0, input + input_offset + 3 * stride) : (FLOAT4)0; - -#ifdef INPUT_CHANNEL_LEAVE - if(x4 + 3 >= channel){ - FLOAT *in0_ptr = (FLOAT*)&in0; - FLOAT *in1_ptr = (FLOAT*)&in1; - FLOAT *in2_ptr = (FLOAT*)&in2; - FLOAT *in3_ptr = (FLOAT*)&in3; - int remain = x4 + 3 - channel; - for(int i = remain; i >= 0; i--){ - in0_ptr[3 - remain] = 0; - in1_ptr[3 - remain] = 0; - in2_ptr[3 - remain] = 0; - in3_ptr[3 - remain] = 0; - } - } -#endif - - FLOAT16 out = (FLOAT16)(in0.s0, in1.s0, in2.s0, in3.s0, in0.s1, in1.s1, in2.s1, in3.s1, in0.s2, in1.s2, in2.s2, in3.s2, in0.s3, in1.s3, in2.s3, in3.s3); - - vstore16(out, 0, output+output_offset); -} - -__kernel void reshape_nhwc4_nchw4(GLOBAL_SIZE_DIM3 -__global const FLOAT* input, -__global FLOAT* output, -__private const int width_height, -__private const int batch, -__private const int channelC4){ - const int x = get_global_id(0); //c - const int y = get_global_id(1); //b - const int wh = get_global_id(2); //w*h - - UNIFORM_BOUNDRY_CHECK3(x, y, wh); - - const int x4 = x << 2; - const int y4 = y << 2; - const int channel4 = channelC4 * 4; - const int stride = channel4 * width_height; - const int input_offset = ((y * width_height + wh) * channel4 + x4) * 4; - const int output_offset = (y4 * channel4 + x4) * width_height + wh * 4; - FLOAT16 in = vload16(0, input + input_offset); - - FLOAT4 out0 = (FLOAT4)(in.s0, in.s4, in.s8, in.sc); - FLOAT4 out1 = (FLOAT4)(in.s1, in.s5, in.s9, in.sd); - FLOAT4 out2 = (FLOAT4)(in.s2, in.s6, in.sa, in.se); - FLOAT4 out3 = (FLOAT4)(in.s3, in.s7, in.sb, in.sf); - - vstore4(out0, 0, output+output_offset); - if(y4 + 1 >= batch) return; - vstore4(out1, 0, output+output_offset+stride); - if(y4 + 2 >= batch) return; - vstore4(out2, 0, output+output_offset+2*stride); - if(y4 + 3 >= batch) return; - vstore4(out3, 0, output+output_offset+3*stride); -} - - -__kernel void gemm_b4_c4_buf(GLOBAL_SIZE_DIM2 - __global const FLOAT* input, -#if (defined USE_LOW_BIT_WEIGHT_INT8) - __global const char *weight, -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - __global const uchar *weight, -#endif - __global const float *dequantScaleOffset, - __global const FLOAT *bias, - __global FLOAT* output, - __private const int dstChannelC4, - __private const int srcChannelC4, - __private const int blockNum, - __private const int blockDim) { - const int x = get_global_id(0); //c - const int y = get_global_id(1); //b - - UNIFORM_BOUNDRY_CHECK(x, y); - - const int out_c_idx = x; - const int out_b_idx = y << 2; - - COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias)); - COMPUTE_FLOAT4 out = (COMPUTE_FLOAT4)bias0.s0; - COMPUTE_FLOAT4 out1 = (COMPUTE_FLOAT4)bias0.s1, out2 = (COMPUTE_FLOAT4)bias0.s2, out3 = (COMPUTE_FLOAT4)bias0.s3; - - int input_offset = out_b_idx * srcChannelC4 * 4; - int out_offset = (out_b_idx * dstChannelC4 + out_c_idx * 4) * 4; - -#if (defined USE_LOW_BIT_WEIGHT_INT4) - int weight_offset = out_c_idx * 4 * 8; - int weight_oc_offset = dstChannelC4 * 32; -#else - int weight_offset = out_c_idx * 4 * 16; - int weight_oc_offset = dstChannelC4 * 64; -#endif - - const int loop = (blockDim + 15) / 16; -#ifdef INPUT_CHANNEL_LEAVE - const int loop_end = max(loop - 1, 0); - const int remain = blockDim - loop_end*16; -#else - const int loop_end = loop; -#endif - - for (int i = 0; i < blockNum; i++){ - int kindex = i * dstChannelC4 * 4 * 2; - COMPUTE_FLOAT8 ScaleOffset = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx, dequantScaleOffset + kindex)); - for (int j = 0; j < loop_end; j++) { - int k = i * loop + j; - int k16 = k << 4; - COMPUTE_FLOAT16 weights0, weights1, weights2, weights3; -#if (defined USE_LOW_BIT_WEIGHT_INT8) - weights0 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 16)) * ScaleOffset.s2 + ScaleOffset.s3; - weights2 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 32)) * ScaleOffset.s4 + ScaleOffset.s5; - weights3 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 48)) * ScaleOffset.s6 + ScaleOffset.s7; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - { - uchar16 charWeightsInt40 = vload16(0, weight + weight_offset + k * weight_oc_offset); - uchar16 charWeightsInt41 = vload16(0, weight + weight_offset + k * weight_oc_offset + 16); - { - char16 charWeights0 = 0; - char16 charWeights1 = 0; - UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt40); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; - UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt41); - weights2 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s4 + ScaleOffset.s5; - weights3 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s6 + ScaleOffset.s7; - } - } -#endif - COMPUTE_FLOAT *weights0_ptr = (COMPUTE_FLOAT *)&weights0; - COMPUTE_FLOAT *weights1_ptr = (COMPUTE_FLOAT *)&weights1; - COMPUTE_FLOAT *weights2_ptr = (COMPUTE_FLOAT *)&weights2; - COMPUTE_FLOAT *weights3_ptr = (COMPUTE_FLOAT *)&weights3; - #pragma unroll - for (int i = 0; i < 16; ++i){ - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k16 + i) * 4)); - out = mad(in, weights0_ptr[i], out); - out1 = mad(in, weights1_ptr[i], out1); - out2 = mad(in, weights2_ptr[i], out2); - out3 = mad(in, weights3_ptr[i], out3); - } - } -#ifdef INPUT_CHANNEL_LEAVE - { - int k = i * loop + loop_end; - int k16 = k << 4; - COMPUTE_FLOAT16 weights0, weights1, weights2, weights3; -#if (defined USE_LOW_BIT_WEIGHT_INT8) - weights0 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 16)) * ScaleOffset.s2 + ScaleOffset.s3; - weights2 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 32)) * ScaleOffset.s4 + ScaleOffset.s5; - weights3 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 48)) * ScaleOffset.s6 + ScaleOffset.s7; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - { - uchar16 charWeightsInt40 = vload16(0, weight + weight_offset + k * weight_oc_offset); - uchar16 charWeightsInt41 = vload16(0, weight + weight_offset + k * weight_oc_offset + 16); - { - char16 charWeights0 = 0; - char16 charWeights1 = 0; - UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt40); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; - UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt41); - weights2 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s4 + ScaleOffset.s5; - weights3 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s6 + ScaleOffset.s7; - } - } -#endif - COMPUTE_FLOAT *weights0_ptr = (COMPUTE_FLOAT *)&weights0; - COMPUTE_FLOAT *weights1_ptr = (COMPUTE_FLOAT *)&weights1; - COMPUTE_FLOAT *weights2_ptr = (COMPUTE_FLOAT *)&weights2; - COMPUTE_FLOAT *weights3_ptr = (COMPUTE_FLOAT *)&weights3; - for (int i = 0; i < remain; ++i){ - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k16 + i) * 4)); - out = mad(in, weights0_ptr[i], out); - out1 = mad(in, weights1_ptr[i], out1); - out2 = mad(in, weights2_ptr[i], out2); - out3 = mad(in, weights3_ptr[i], out3); - } - } -#endif - } -#ifdef RELU - out = fmax(out, (COMPUTE_FLOAT4)0); - out1 = fmax(out1, (COMPUTE_FLOAT4)0); - out2 = fmax(out2, (COMPUTE_FLOAT4)0); - out3 = fmax(out3, (COMPUTE_FLOAT4)0); -#endif - -#ifdef RELU6 - out = clamp(out, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); - out1 = clamp(out1, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); - out2 = clamp(out2, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); - out3 = clamp(out3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); -#endif - - vstore4(CONVERT_FLOAT4(out), 0, output+out_offset); - vstore4(CONVERT_FLOAT4(out1), 0, output+out_offset + 4); - vstore4(CONVERT_FLOAT4(out2), 0, output+out_offset + 8); - vstore4(CONVERT_FLOAT4(out3), 0, output+out_offset + 12); -} - -__kernel void gemm_b4_c2_buf(GLOBAL_SIZE_DIM2 - __global const FLOAT* input, -#if (defined USE_LOW_BIT_WEIGHT_INT8) - __global const char *weight, -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - __global const uchar *weight, -#endif - __global const float *dequantScaleOffset, - __global const FLOAT *bias, - __global FLOAT* output, - __private const int dstChannelC4, - __private const int srcChannelC4, - __private const int blockNum, - __private const int blockDim) { - const int x = get_global_id(0); //c - const int y = get_global_id(1); //b - - UNIFORM_BOUNDRY_CHECK(x, y); - - const int out_c_idx = x; - const int out_b_idx = y << 2; - - COMPUTE_FLOAT2 bias0 = CONVERT_COMPUTE_FLOAT2(vload2(out_c_idx, bias)); - COMPUTE_FLOAT4 out = (COMPUTE_FLOAT4)bias0.s0; - COMPUTE_FLOAT4 out1 = (COMPUTE_FLOAT4)bias0.s1; - - int input_offset = out_b_idx * srcChannelC4 * 4; - int out_offset = (out_b_idx * dstChannelC4 + out_c_idx * 2) * 4; - -#if (defined USE_LOW_BIT_WEIGHT_INT4) - int weight_offset = out_c_idx * 2 * 8; - int weight_oc_offset = dstChannelC4 * 32; -#else - int weight_offset = out_c_idx * 2 * 16; - int weight_oc_offset = dstChannelC4 * 64; -#endif - - const int loop = (blockDim + 15) / 16; -#ifdef INPUT_CHANNEL_LEAVE - const int loop_end = max(loop - 1, 0); - const int remain = blockDim - loop_end*16; -#else - const int loop_end = loop; -#endif - - for (int i = 0; i < blockNum; i++){ - int kindex = i * dstChannelC4 * 4 * 2; - COMPUTE_FLOAT4 ScaleOffset = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, dequantScaleOffset + kindex)); - for (int j = 0; j < loop_end; j++) { - int k = i * loop + j; - int k16 = k << 4; - COMPUTE_FLOAT16 weights0, weights1; -#if (defined USE_LOW_BIT_WEIGHT_INT8) - weights0 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 16)) * ScaleOffset.s2 + ScaleOffset.s3; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - { - uchar16 charWeightsInt4 = vload16(0, weight + weight_offset + k * weight_oc_offset); - char16 charWeights0 = 0; - char16 charWeights1 = 0; - UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt4); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; - } -#endif - COMPUTE_FLOAT *weights0_ptr = (COMPUTE_FLOAT *)&weights0; - COMPUTE_FLOAT *weights1_ptr = (COMPUTE_FLOAT *)&weights1; - #pragma unroll - for (int i = 0; i < 16; ++i){ - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k16 + i) * 4)); - out = mad(in, weights0_ptr[i], out); - out1 = mad(in, weights1_ptr[i], out1); - } - } -#ifdef INPUT_CHANNEL_LEAVE - { - int k = i * loop + loop_end; - int k16 = k << 4; - - COMPUTE_FLOAT16 weights0, weights1; -#if (defined USE_LOW_BIT_WEIGHT_INT8) - weights0 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 16)) * ScaleOffset.s2 + ScaleOffset.s3; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - { - uchar16 charWeightsInt4 = vload16(0, weight + weight_offset + k * weight_oc_offset); - char16 charWeights0 = 0; - char16 charWeights1 = 0; - UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt4); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; - } -#endif - COMPUTE_FLOAT *weights0_ptr = (COMPUTE_FLOAT *)&weights0; - COMPUTE_FLOAT *weights1_ptr = (COMPUTE_FLOAT *)&weights1; - for (int i = 0; i < remain; ++i){ - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k16 + i) * 4)); - out = mad(in, weights0_ptr[i], out); - out1 = mad(in, weights1_ptr[i], out1); - } - } -#endif - } - -#ifdef RELU - out = fmax(out, (COMPUTE_FLOAT4)0); - out1 = fmax(out1, (COMPUTE_FLOAT4)0); -#endif - -#ifdef RELU6 - out = clamp(out, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); - out1 = clamp(out1, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); -#endif - - vstore4(CONVERT_FLOAT4(out), 0, output+out_offset); - vstore4(CONVERT_FLOAT4(out1), 0, output+out_offset+4); -} - -__kernel void gemm_b4_c1_buf(GLOBAL_SIZE_DIM2 - __global const FLOAT* input, -#if (defined USE_LOW_BIT_WEIGHT_INT8) - __global const char *weight, -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - __global const uchar *weight, -#endif - __global const float *dequantScaleOffset, - __global const FLOAT *bias, - __global FLOAT* output, - __private const int dstChannelC4, - __private const int srcChannelC4, - __private const int blockNum, - __private const int blockDim) { - const int x = get_global_id(0); //c - const int y = get_global_id(1); //b - - UNIFORM_BOUNDRY_CHECK(x, y); - - const int out_c_idx = x; - const int out_b_idx = y << 2; - - COMPUTE_FLOAT bias0 = bias[out_c_idx]; - COMPUTE_FLOAT4 out = (COMPUTE_FLOAT4)bias0; - - int input_offset = out_b_idx * srcChannelC4 * 4; - int out_offset = (out_b_idx * dstChannelC4 + out_c_idx) * 4; - -#if (defined USE_LOW_BIT_WEIGHT_INT4) - int weight_offset = out_c_idx * 8; - int weight_oc_offset = dstChannelC4 * 32; -#else - int weight_offset = out_c_idx * 16; - int weight_oc_offset = dstChannelC4 * 64; -#endif - - const int loop = (blockDim + 15) / 16; -#ifdef INPUT_CHANNEL_LEAVE - const int loop_end = max(loop - 1, 0); - const int remain = blockDim - loop_end*16; -#else - const int loop_end = loop; -#endif - - for (int i = 0; i < blockNum; i++){ - int kindex = i * dstChannelC4 * 4 * 2; - COMPUTE_FLOAT2 ScaleOffset = CONVERT_COMPUTE_FLOAT2(vload2(out_c_idx, dequantScaleOffset + kindex)); - for (int j = 0; j < loop_end; j++) { - int k = i * loop + j; - int k16 = k << 4; - COMPUTE_FLOAT16 weights; -#if (defined USE_LOW_BIT_WEIGHT_INT8) - weights = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * ScaleOffset.s0 + ScaleOffset.s1; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - { - uchar8 charWeightsInt4 = vload8(0, weight + weight_offset + k * weight_oc_offset); - char16 charWeights = 0; - UCHAR8_TO_CHAR16(charWeights, charWeightsInt4); - weights = CONVERT_COMPUTE_FLOAT16(charWeights) * ScaleOffset.s0 + ScaleOffset.s1; - } -#endif - COMPUTE_FLOAT *weights_ptr = (COMPUTE_FLOAT *)&weights; - #pragma unroll - for (int i = 0; i < 16; ++i){ - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k16 + i) * 4)); - out = mad(in, weights_ptr[i], out); - } - } -#ifdef INPUT_CHANNEL_LEAVE - { - int k = i * loop + loop_end; - int k16 = k << 4; - COMPUTE_FLOAT16 weights; -#if (defined USE_LOW_BIT_WEIGHT_INT8) - weights = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * ScaleOffset.s0 + ScaleOffset.s1; -#elif (defined USE_LOW_BIT_WEIGHT_INT4) - { - uchar8 charWeightsInt4 = vload8(0, weight + weight_offset + k * weight_oc_offset); - char16 charWeights = 0; - UCHAR8_TO_CHAR16(charWeights, charWeightsInt4); - weights = CONVERT_COMPUTE_FLOAT16(charWeights) * ScaleOffset.s0 + ScaleOffset.s1; - } -#endif - COMPUTE_FLOAT *weights_ptr = (COMPUTE_FLOAT *)&weights; - for (int i = 0; i < remain; ++i){ - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k16 + i) * 4)); - out = mad(in, weights_ptr[i], out); - } - } -#endif - } - -#ifdef RELU - out = fmax(out, (COMPUTE_FLOAT4)0); -#endif - -#ifdef RELU6 - out = clamp(out, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); -#endif - vstore4(CONVERT_FLOAT4(out), 0, output+out_offset); -} - -__kernel void gemm_b4_c4_image(GLOBAL_SIZE_DIM2 - __global const FLOAT* input, - __read_only image2d_t weight, - __global const float *dequantScaleOffset, - __global const FLOAT *bias, - __global FLOAT* output, - __private const int dstChannelC4, - __private const int srcChannelC4, - __private const int blockNum, - __private const int blockDim) { - const int x = get_global_id(0); //c - const int y = get_global_id(1); //b - UNIFORM_BOUNDRY_CHECK(x, y); - - const int out_c_idx = x << 2; - const int out_b_idx = y << 2; - - COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(0, bias + out_c_idx)); - COMPUTE_FLOAT4 out = (COMPUTE_FLOAT4)bias0.s0; - COMPUTE_FLOAT4 out1 = (COMPUTE_FLOAT4)bias0.s1; - COMPUTE_FLOAT4 out2 = (COMPUTE_FLOAT4)bias0.s2; - COMPUTE_FLOAT4 out3 = (COMPUTE_FLOAT4)bias0.s3; - - int input_offset = out_b_idx * srcChannelC4 * 4; - int out_offset = (out_b_idx * dstChannelC4 + out_c_idx) * 4; - - const int loop = (blockDim + 15) / 16; - #ifdef INPUT_CHANNEL_LEAVE - const int loop_end = max(loop - 1, 0); - const int remain = blockDim - loop_end*16; - #else - const int loop_end = loop; - #endif - - for (int i = 0; i < blockNum; i++){ - int kindex = i * dstChannelC4 * 4 * 2; - COMPUTE_FLOAT8 ScaleOffset = CONVERT_COMPUTE_FLOAT8(vload8(0, dequantScaleOffset + out_c_idx * 2 + kindex)); - for (int j = 0; j < loop_end; j++) { - int k = i * loop + j; - int k16 = k << 4; - #if (defined USE_LOW_BIT_WEIGHT_INT8) - COMPUTE_FLOAT16 weights0 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx, k)))) * ScaleOffset.s0 + ScaleOffset.s1; - COMPUTE_FLOAT16 weights1 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx + 1, k)))) * ScaleOffset.s2 + ScaleOffset.s3; - COMPUTE_FLOAT16 weights2 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx + 2, k)))) * ScaleOffset.s4 + ScaleOffset.s5; - COMPUTE_FLOAT16 weights3 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx + 3, k)))) * ScaleOffset.s6 + ScaleOffset.s7; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) - COMPUTE_FLOAT16 weights0, weights1, weights2, weights3; - { - uchar8 charWeightsInt40 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx, k)))); - uchar8 charWeightsInt41 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx + 1, k)))); - uchar8 charWeightsInt42 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx + 2, k)))); - uchar8 charWeightsInt43 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx + 3, k)))); - char16 charWeights0 = 0; - char16 charWeights1 = 0; - char16 charWeights2 = 0; - char16 charWeights3 = 0; - UCHAR8_TO_CHAR16(charWeights0, charWeightsInt40); - UCHAR8_TO_CHAR16(charWeights1, charWeightsInt41); - UCHAR8_TO_CHAR16(charWeights2, charWeightsInt42); - UCHAR8_TO_CHAR16(charWeights3, charWeightsInt43); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; - weights2 = CONVERT_COMPUTE_FLOAT16(charWeights2) * ScaleOffset.s4 + ScaleOffset.s5; - weights3 = CONVERT_COMPUTE_FLOAT16(charWeights3) * ScaleOffset.s6 + ScaleOffset.s7; - } - #endif - COMPUTE_FLOAT *weights0_ptr = (COMPUTE_FLOAT *)&weights0; - COMPUTE_FLOAT *weights1_ptr = (COMPUTE_FLOAT *)&weights1; - COMPUTE_FLOAT *weights2_ptr = (COMPUTE_FLOAT *)&weights2; - COMPUTE_FLOAT *weights3_ptr = (COMPUTE_FLOAT *)&weights3; - #pragma unroll - for (int i = 0; i < 16; ++i){ - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k16 + i) * 4)); - out = mad(in, weights0_ptr[i], out); - out1 = mad(in, weights1_ptr[i], out1); - out2 = mad(in, weights2_ptr[i], out2); - out3 = mad(in, weights3_ptr[i], out3); - } - } -#ifdef INPUT_CHANNEL_LEAVE - { - int k = i * loop + loop_end; - int k16 = k << 4; - #if (defined USE_LOW_BIT_WEIGHT_INT8) - COMPUTE_FLOAT16 weights0 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx, k)))) * ScaleOffset.s0 + ScaleOffset.s1; - COMPUTE_FLOAT16 weights1 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx + 1, k)))) * ScaleOffset.s2 + ScaleOffset.s3; - COMPUTE_FLOAT16 weights2 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx + 2, k)))) * ScaleOffset.s4 + ScaleOffset.s5; - COMPUTE_FLOAT16 weights3 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx + 3, k)))) * ScaleOffset.s6 + ScaleOffset.s7; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) - COMPUTE_FLOAT16 weights0, weights1, weights2, weights3; - { - uchar8 charWeightsInt40 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx, k)))); - uchar8 charWeightsInt41 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx + 1, k)))); - uchar8 charWeightsInt42 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx + 2, k)))); - uchar8 charWeightsInt43 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx + 3, k)))); - char16 charWeights0 = 0; - char16 charWeights1 = 0; - char16 charWeights2 = 0; - char16 charWeights3 = 0; - UCHAR8_TO_CHAR16(charWeights0, charWeightsInt40); - UCHAR8_TO_CHAR16(charWeights1, charWeightsInt41); - UCHAR8_TO_CHAR16(charWeights2, charWeightsInt42); - UCHAR8_TO_CHAR16(charWeights3, charWeightsInt43); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; - weights2 = CONVERT_COMPUTE_FLOAT16(charWeights2) * ScaleOffset.s4 + ScaleOffset.s5; - weights3 = CONVERT_COMPUTE_FLOAT16(charWeights3) * ScaleOffset.s6 + ScaleOffset.s7; - } - #endif - COMPUTE_FLOAT *weights0_ptr = (COMPUTE_FLOAT *)&weights0; - COMPUTE_FLOAT *weights1_ptr = (COMPUTE_FLOAT *)&weights1; - COMPUTE_FLOAT *weights2_ptr = (COMPUTE_FLOAT *)&weights2; - COMPUTE_FLOAT *weights3_ptr = (COMPUTE_FLOAT *)&weights3; - #pragma unroll - for (int i = 0; i < remain; ++i){ - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k16 + i) * 4)); - out = mad(in, weights0_ptr[i], out); - out1 = mad(in, weights1_ptr[i], out1); - out2 = mad(in, weights2_ptr[i], out2); - out3 = mad(in, weights3_ptr[i], out3); - } - } -#endif - } - -#ifdef RELU - out = fmax(out, (COMPUTE_FLOAT4)0); - out1 = fmax(out1, (COMPUTE_FLOAT4)0); - out2 = fmax(out2, (COMPUTE_FLOAT4)0); - out3 = fmax(out3, (COMPUTE_FLOAT4)0); -#endif -#ifdef RELU6 - out = clamp(out, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); - out1 = clamp(out1, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); - out2 = clamp(out2, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); - out3 = clamp(out3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); -#endif - vstore4(CONVERT_FLOAT4(out), 0, output + out_offset); - vstore4(CONVERT_FLOAT4(out1), 0, output + out_offset + 4); - vstore4(CONVERT_FLOAT4(out2), 0, output + out_offset + 8); - vstore4(CONVERT_FLOAT4(out3), 0, output + out_offset + 12); -} -__kernel void gemm_b4_c2_image(GLOBAL_SIZE_DIM2 - __global const FLOAT* input, - __read_only image2d_t weight, - __global const float *dequantScaleOffset, - __global const FLOAT *bias, - __global FLOAT* output, - __private const int dstChannelC4, - __private const int srcChannelC4, - __private const int blockNum, - __private const int blockDim) { - const int x = get_global_id(0); //c - const int y = get_global_id(1); //b - UNIFORM_BOUNDRY_CHECK(x, y); - - const int out_c_idx = x << 1; - const int out_b_idx = y << 2; - - COMPUTE_FLOAT2 bias0 = CONVERT_COMPUTE_FLOAT2(vload2(0, bias + out_c_idx)); - COMPUTE_FLOAT4 out = (COMPUTE_FLOAT4)bias0.s0; - COMPUTE_FLOAT4 out1 = (COMPUTE_FLOAT4)bias0.s1; - - int input_offset = out_b_idx * srcChannelC4 * 4; - int out_offset = (out_b_idx * dstChannelC4 + out_c_idx) * 4; - - const int loop = (blockDim + 15) / 16; - #ifdef INPUT_CHANNEL_LEAVE - const int loop_end = max(loop - 1, 0); - const int remain = blockDim - loop_end*16; - #else - const int loop_end = loop; - #endif - - for (int i = 0; i < blockNum; i++){ - int kindex = i * dstChannelC4 * 4 * 2; - COMPUTE_FLOAT4 ScaleOffset = CONVERT_COMPUTE_FLOAT4(vload4(0, dequantScaleOffset + out_c_idx * 2 + kindex)); - for (int j = 0; j < loop_end; j++) { - int k = i * loop + j; - int k16 = k << 4; - #if (defined USE_LOW_BIT_WEIGHT_INT8) - COMPUTE_FLOAT16 weights0 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx, k)))) * ScaleOffset.s0 + ScaleOffset.s1; - COMPUTE_FLOAT16 weights1 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx + 1, k)))) * ScaleOffset.s2 + ScaleOffset.s3; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) - COMPUTE_FLOAT16 weights0, weights1; - { - uchar8 charWeightsInt40 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx, k)))); - uchar8 charWeightsInt41 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx + 1, k)))); - char16 charWeights0 = 0; - char16 charWeights1 = 0; - UCHAR8_TO_CHAR16(charWeights0, charWeightsInt40); - UCHAR8_TO_CHAR16(charWeights1, charWeightsInt41); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; - } - #endif - COMPUTE_FLOAT *weights0_ptr = (COMPUTE_FLOAT *)&weights0; - COMPUTE_FLOAT *weights1_ptr = (COMPUTE_FLOAT *)&weights1; - #pragma unroll - for (int i = 0; i < 16; ++i){ - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k16 + i) * 4)); - out = mad(in, weights0_ptr[i], out); - out1 = mad(in, weights1_ptr[i], out1); - } - } -#ifdef INPUT_CHANNEL_LEAVE - { - int k = i * loop + loop_end; - int k16 = k << 4; - #if (defined USE_LOW_BIT_WEIGHT_INT8) - COMPUTE_FLOAT16 weights0 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx, k)))) * ScaleOffset.s0 + ScaleOffset.s1; - COMPUTE_FLOAT16 weights1 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx + 1, k)))) * ScaleOffset.s2 + ScaleOffset.s3; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) - COMPUTE_FLOAT16 weights0, weights1; - { - uchar8 charWeightsInt40 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx, k)))); - uchar8 charWeightsInt41 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx + 1, k)))); - char16 charWeights0 = 0; - char16 charWeights1 = 0; - UCHAR8_TO_CHAR16(charWeights0, charWeightsInt40); - UCHAR8_TO_CHAR16(charWeights1, charWeightsInt41); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; - } - #endif - COMPUTE_FLOAT *weights0_ptr = (COMPUTE_FLOAT *)&weights0; - COMPUTE_FLOAT *weights1_ptr = (COMPUTE_FLOAT *)&weights1; - #pragma unroll - for (int i = 0; i < remain; ++i){ - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k16 + i) * 4)); - out = mad(in, weights0_ptr[i], out); - out1 = mad(in, weights1_ptr[i], out1); - } - } -#endif - } - -#ifdef RELU - out = fmax(out, (COMPUTE_FLOAT4)0); - out1 = fmax(out1, (COMPUTE_FLOAT4)0); -#endif -#ifdef RELU6 - out = clamp(out, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); - out1 = clamp(out1, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); -#endif - vstore4(CONVERT_FLOAT4(out), 0, output + out_offset); - vstore4(CONVERT_FLOAT4(out1), 0, output + out_offset + 4); -} -__kernel void gemm_b4_c1_image(GLOBAL_SIZE_DIM2 - __global const FLOAT* input, - __read_only image2d_t weight, - __global const float *dequantScaleOffset, - __global const FLOAT *bias, - __global FLOAT* output, - __private const int dstChannelC4, - __private const int srcChannelC4, - __private const int blockNum, - __private const int blockDim) { - const int x = get_global_id(0); //c - const int y = get_global_id(1); //b - UNIFORM_BOUNDRY_CHECK(x, y); - - const int out_c_idx = x; - const int out_b_idx = y << 2; - - COMPUTE_FLOAT bias0 = bias[out_c_idx]; - COMPUTE_FLOAT4 out = (COMPUTE_FLOAT4)bias0; - - int input_offset = out_b_idx * srcChannelC4 * 4; - int out_offset = (out_b_idx * dstChannelC4 + out_c_idx) * 4; - - const int loop = (blockDim + 15) / 16; - #ifdef INPUT_CHANNEL_LEAVE - const int loop_end = max(loop - 1, 0); - const int remain = blockDim - loop_end*16; - #else - const int loop_end = loop; - #endif - - for (int i = 0; i < blockNum; ++i){ - int kindex = i * dstChannelC4 * 4 * 2; - COMPUTE_FLOAT2 ScaleOffset = CONVERT_COMPUTE_FLOAT2(vload2(out_c_idx, dequantScaleOffset + kindex)); - for (int j = 0; j < loop_end; j++) { - int k = i * loop + j; - int k16 = k << 4; - #if (defined USE_LOW_BIT_WEIGHT_INT8) - COMPUTE_FLOAT16 weights0 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx, k)))) * ScaleOffset.s0 + ScaleOffset.s1; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) - COMPUTE_FLOAT16 weights0; - { - uchar8 charWeightsInt4 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx, k)))); - char16 charWeights = 0; - UCHAR8_TO_CHAR16(charWeights, charWeightsInt4); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights) * ScaleOffset.s0 + ScaleOffset.s1; - } - #endif - COMPUTE_FLOAT *weights0_ptr = (COMPUTE_FLOAT *)&weights0; - #pragma unroll - for (int i = 0; i < 16; ++i){ - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k16 + i) * 4)); - out = mad(in, weights0_ptr[i], out); - } - } -#ifdef INPUT_CHANNEL_LEAVE - { - int k = i * loop + loop_end; - int k16 = k << 4; - #if (defined USE_LOW_BIT_WEIGHT_INT8) - COMPUTE_FLOAT16 weights0 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx, k)))) * ScaleOffset.s0 + ScaleOffset.s1; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) - COMPUTE_FLOAT16 weights0; - { - uchar8 charWeightsInt4 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx, k)))); - char16 charWeights = 0; - UCHAR8_TO_CHAR16(charWeights, charWeightsInt4); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights) * ScaleOffset.s0 + ScaleOffset.s1; - } - #endif - COMPUTE_FLOAT *weights0_ptr = (COMPUTE_FLOAT *)&weights0; - #pragma unroll - for (int i = 0; i < remain; ++i){ - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k16 + i) * 4)); - out = mad(in, weights0_ptr[i], out); - } - } -#endif - } - -#ifdef RELU - out = fmax(out, (COMPUTE_FLOAT4)0); -#endif -#ifdef RELU6 - out = clamp(out, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); -#endif - vstore4(CONVERT_FLOAT4(out), 0, output+out_offset); -} - diff --git a/source/backend/opencl/execution/cl/gemv_conv1x1_buf.cl b/source/backend/opencl/execution/cl/gemv_conv1x1_buf.cl index 82b7b02db..df362ce09 100644 --- a/source/backend/opencl/execution/cl/gemv_conv1x1_buf.cl +++ b/source/backend/opencl/execution/cl/gemv_conv1x1_buf.cl @@ -31,21 +31,58 @@ COMPUTE_FLOAT* ptr = (COMPUTE_FLOAT*)&data; \ int remain = k + 15 - channel; \ for(int r = remain; r >= 0; r--){ \ - ptr[15 - remain] = 0; \ + ptr[15 - r] = 0; \ } \ } #else #define PADZEROS(k, channel, data) #endif +#if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE) +#define CHANNEL_PACK 32 +#else +#define CHANNEL_PACK 16 +#endif + +#if (defined USE_LOW_BIT_WEIGHT_INT8) +#define WEIGHT_STRIDE 16 +#elif (defined USE_LOW_BIT_WEIGHT_INT4) +#define WEIGHT_STRIDE 8 +#endif + __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; +#ifdef USE_IMAGE +inline COMPUTE_FLOAT16 readWeight(__read_only image2d_t weight, int ix, int iy, COMPUTE_FLOAT scale, COMPUTE_FLOAT offset){ + return CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(ix, iy)))) * scale + offset; +} +#else + +#if (defined USE_LOW_BIT_WEIGHT_INT8) +inline COMPUTE_FLOAT16 readWeight(__global const char *weight, int ix, int iy, COMPUTE_FLOAT scale, COMPUTE_FLOAT offset){ + return CONVERT_COMPUTE_FLOAT16(vload16(0, weight)) * scale + offset; +} +#elif (defined USE_LOW_BIT_WEIGHT_INT4) +inline COMPUTE_FLOAT16 readWeight(__global const uchar *weight, int ix, int iy, COMPUTE_FLOAT scale, COMPUTE_FLOAT offset){ + uchar16 charWeightsInt40 = vload16(0, weight); + uchar8 charWeightsInt4 = vload8(0, weight); + char16 charWeights = 0; + UCHAR8_TO_CHAR16(charWeights, charWeightsInt4); + return CONVERT_COMPUTE_FLOAT16(charWeights) * scale + offset; +} +#endif +#endif + -__kernel void gemm_conv_c4_buf(GLOBAL_SIZE_DIM2 +__kernel void gemv_conv_c4_buf(GLOBAL_SIZE_DIM2 __global const FLOAT* input, +#ifdef USE_IMAGE + __read_only image2d_t weight, +#else #if (defined USE_LOW_BIT_WEIGHT_INT8) __global const char *weight, #elif (defined USE_LOW_BIT_WEIGHT_INT4) __global const uchar *weight, +#endif #endif __global const float *dequantScaleOffset, __global const FLOAT *bias, @@ -53,49 +90,28 @@ __kernel void gemm_conv_c4_buf(GLOBAL_SIZE_DIM2 __private const int dstChannelC4, __private const int srcChannelC4, __private const int srcChannel, - __private const int batch, - __private const int height, - __private const int width, + __private const int bhw, __private const int blockNum, __private const int blockDim) { - const int out_c_w_idx = get_global_id(0); //c/4 w - const int out_b_h_idx = get_global_id(1); //b h + const int x = get_global_id(0); //c/4 + const int y = get_global_id(1); //b h w - UNIFORM_BOUNDRY_CHECK(out_c_w_idx, out_b_h_idx); - - const int out_c_idx = out_c_w_idx / width; - const int out_w_idx = out_c_w_idx % width; -#ifdef BACTH_BLOCK4 - const int out_b_idx = (out_b_h_idx / height) << 2; -#else - const int out_b_idx = out_b_h_idx / height; -#endif - const int out_h_idx = out_b_h_idx % height; + UNIFORM_BOUNDRY_CHECK(x, y); - COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias)); - COMPUTE_FLOAT4 out = bias0; -#ifdef BACTH_BLOCK4 - COMPUTE_FLOAT4 out1 = bias0, out2 = bias0, out3 = bias0; - int input_offset1 = (((out_b_idx + 1) * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - int input_offset2 = (((out_b_idx + 2) * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - int input_offset3 = (((out_b_idx + 3) * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - bool isValidBatch1 = out_b_idx + 1 < batch; - bool isValidBatch2 = out_b_idx + 2 < batch; - bool isValidBatch3 = out_b_idx + 3 < batch; -#endif + COMPUTE_FLOAT4 bias0 = CONVERT_COMPUTE_FLOAT4(vload4(x, bias)); + COMPUTE_FLOAT4 out0 = bias0; + int idn = x << 2; + int idm = y; - int input_offset = ((out_b_idx * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - int out_offset = (((out_b_idx * dstChannelC4 + out_c_idx) * height + out_h_idx) * width + out_w_idx) * 4; - int wh = width * height * 4; -#if (defined USE_LOW_BIT_WEIGHT_INT4) - int weight_offset = out_c_idx * 4 * 8; - int weight_oc_offset = dstChannelC4 * 32; -#else - int weight_offset = out_c_idx * 4 * 16; - int weight_oc_offset = dstChannelC4 * 64; + int input_offset0 = idm * 4; + + int out_offset = (x * bhw + idm) * 4; +#ifndef USE_IMAGE + int weight_offset = x * 4 * WEIGHT_STRIDE; + int weight_oc_offset = dstChannelC4 * 4 * WEIGHT_STRIDE; #endif - const int loop = (blockDim + 15) / 16; + const int loop = (blockDim + CHANNEL_PACK - 1) / CHANNEL_PACK; #ifdef INPUT_CHANNEL_LEAVE const int loop_end = max(loop - 1, 0); #else @@ -104,122 +120,119 @@ __kernel void gemm_conv_c4_buf(GLOBAL_SIZE_DIM2 for (int i = 0; i < blockNum; ++i){ int kindex = i * dstChannelC4 * 4 * 2; - COMPUTE_FLOAT8 ScaleOffset = CONVERT_COMPUTE_FLOAT8(vload8(out_c_idx, dequantScaleOffset + kindex)); + COMPUTE_FLOAT8 ScaleOffset = CONVERT_COMPUTE_FLOAT8(vload8(x, dequantScaleOffset + kindex)); for (int j = 0; j < loop_end; ++j) { int k = i * loop + j; - #ifndef WIDTH_HEIGHT_1 - int k4 = k << 2; - #endif - COMPUTE_FLOAT16 weights0, weights1, weights2, weights3; - #if (defined USE_LOW_BIT_WEIGHT_INT8) - weights0 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 16)) * ScaleOffset.s2 + ScaleOffset.s3; - weights2 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 32)) * ScaleOffset.s4 + ScaleOffset.s5; - weights3 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 48)) * ScaleOffset.s6 + ScaleOffset.s7; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) + #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE) + int k32 = k << 5; + COMPUTE_FLOAT16 weights00, weights01, weights10, weights11, weights20, weights21, weights30, weights31; { - uchar16 charWeightsInt40 = vload16(0, weight + weight_offset + k * weight_oc_offset); - uchar16 charWeightsInt41 = vload16(0, weight + weight_offset + k * weight_oc_offset + 16); - { - char16 charWeights0 = 0; - char16 charWeights1 = 0; - UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt40); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; - UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt41); - weights2 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s4 + ScaleOffset.s5; - weights3 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s6 + ScaleOffset.s7; - } + uchar16 charWeightsInt40 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(idn, k))); + uchar16 charWeightsInt41 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(idn + 1, k))); + uchar16 charWeightsInt42 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(idn + 2, k))); + uchar16 charWeightsInt43 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(idn + 3, k))); + char16 charWeights0, charWeights1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt40); + weights00 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; + weights01 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s0 + ScaleOffset.s1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt41); + weights10 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s2 + ScaleOffset.s3; + weights11 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt42); + weights20 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s4 + ScaleOffset.s5; + weights21 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s4 + ScaleOffset.s5; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt43); + weights30 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s6 + ScaleOffset.s7; + weights31 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s6 + ScaleOffset.s7; } + { + COMPUTE_FLOAT16 in0 = CONVERT_COMPUTE_FLOAT16(vload16(0, input + k32)); + COMPUTE_FLOAT16 in1 = CONVERT_COMPUTE_FLOAT16(vload16(0, input + k32 + 16)); + DOT16X16(in0, weights00, out0.s0);DOT16X16(in1, weights01, out0.s0); + DOT16X16(in0, weights10, out0.s1);DOT16X16(in1, weights11, out0.s1); + DOT16X16(in0, weights20, out0.s2);DOT16X16(in1, weights21, out0.s2); + DOT16X16(in0, weights30, out0.s3);DOT16X16(in1, weights31, out0.s3); + } + #else + COMPUTE_FLOAT16 weights0, weights1, weights2, weights3; + #ifdef USE_IMAGE + weights0 = readWeight(weight, idn, k, ScaleOffset.s0, ScaleOffset.s1); + weights1 = readWeight(weight, idn + 1, k, ScaleOffset.s2, ScaleOffset.s3); + weights2 = readWeight(weight, idn + 2, k, ScaleOffset.s4, ScaleOffset.s5); + weights3 = readWeight(weight, idn + 3, k, ScaleOffset.s6, ScaleOffset.s7); + #else + weights0 = readWeight(weight + weight_offset + k * weight_oc_offset, 0, 0, ScaleOffset.s0, ScaleOffset.s1); + weights1 = readWeight(weight + weight_offset + k * weight_oc_offset + WEIGHT_STRIDE, 0, 0, ScaleOffset.s2, ScaleOffset.s3); + weights2 = readWeight(weight + weight_offset + k * weight_oc_offset + 2 * WEIGHT_STRIDE, 0, 0, ScaleOffset.s4, ScaleOffset.s5); + weights3 = readWeight(weight + weight_offset + k * weight_oc_offset + 3 * WEIGHT_STRIDE, 0, 0, ScaleOffset.s6, ScaleOffset.s7); #endif { - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(k, input + input_offset)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out.s0); - DOT16X16(in, weights1, out.s1); - DOT16X16(in, weights2, out.s2); - DOT16X16(in, weights3, out.s3); - } - #ifdef BACTH_BLOCK4 - if(isValidBatch1){ - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(k, input + input_offset1)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out1.s0); - DOT16X16(in, weights1, out1.s1); - DOT16X16(in, weights2, out1.s2); - DOT16X16(in, weights3, out1.s3); - } - if(isValidBatch2){ - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(k, input + input_offset2)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out2.s0); - DOT16X16(in, weights1, out2.s1); - DOT16X16(in, weights2, out2.s2); - DOT16X16(in, weights3, out2.s3); - } - if(isValidBatch3){ - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(k, input + input_offset3)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out3.s0); - DOT16X16(in, weights1, out3.s1); - DOT16X16(in, weights2, out3.s2); - DOT16X16(in, weights3, out3.s3); + COMPUTE_FLOAT16 in = CONVERT_COMPUTE_FLOAT16(vload16(k, input)); + DOT16X16(in, weights0, out0.s0); + DOT16X16(in, weights1, out0.s1); + DOT16X16(in, weights2, out0.s2); + DOT16X16(in, weights3, out0.s3); } #endif } #ifdef INPUT_CHANNEL_LEAVE { int k = i * loop + loop_end; - int k4 = k << 2; - COMPUTE_FLOAT16 weights0, weights1, weights2, weights3; - #if (defined USE_LOW_BIT_WEIGHT_INT8) - weights0 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 16)) * ScaleOffset.s2 + ScaleOffset.s3; - weights2 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 32)) * ScaleOffset.s4 + ScaleOffset.s5; - weights3 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 48)) * ScaleOffset.s6 + ScaleOffset.s7; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) + #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE) + int k8 = k << 3; + COMPUTE_FLOAT16 weights00, weights01, weights10, weights11, weights20, weights21, weights30, weights31; { - uchar16 charWeightsInt40 = vload16(0, weight + weight_offset + k * weight_oc_offset); - uchar16 charWeightsInt41 = vload16(0, weight + weight_offset + k * weight_oc_offset + 16); - { - char16 charWeights0 = 0; - char16 charWeights1 = 0; - UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt40); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; - UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt41); - weights2 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s4 + ScaleOffset.s5; - weights3 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s6 + ScaleOffset.s7; - } + uchar16 charWeightsInt40 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(idn, k))); + uchar16 charWeightsInt41 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(idn + 1, k))); + uchar16 charWeightsInt42 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(idn + 2, k))); + uchar16 charWeightsInt43 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(idn + 3, k))); + char16 charWeights0, charWeights1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt40); + weights00 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; + weights01 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s0 + ScaleOffset.s1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt41); + weights10 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s2 + ScaleOffset.s3; + weights11 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt42); + weights20 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s4 + ScaleOffset.s5; + weights21 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s4 + ScaleOffset.s5; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt43); + weights30 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s6 + ScaleOffset.s7; + weights31 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s6 + ScaleOffset.s7; + + PADZEROS(k, srcChannel, weights00);PADZEROS(k + 16, srcChannel, weights01); + PADZEROS(k, srcChannel, weights10);PADZEROS(k + 16, srcChannel, weights11); + PADZEROS(k, srcChannel, weights20);PADZEROS(k + 16, srcChannel, weights21); + PADZEROS(k, srcChannel, weights30);PADZEROS(k + 16, srcChannel, weights31); } + { + COMPUTE_FLOAT16 in0, in1; + in0.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + k8 * 4)); + in0.s4567 = CONVERT_COMPUTE_FLOAT4(k8 + 1 < srcChannelC4 ? vload4(0, input + (k8 + 1) * 4) : (FLOAT4)0); + in0.s89ab = CONVERT_COMPUTE_FLOAT4(k8 + 2 < srcChannelC4 ? vload4(0, input + (k8 + 2) * 4) : (FLOAT4)0); + in0.scdef = CONVERT_COMPUTE_FLOAT4(k8 + 3 < srcChannelC4 ? vload4(0, input + (k8 + 3) * 4) : (FLOAT4)0); + in1.s0123 = CONVERT_COMPUTE_FLOAT4(k8 + 4 < srcChannelC4 ? vload4(0, input + (k8 + 4) * 4) : (FLOAT4)0); + in1.s4567 = CONVERT_COMPUTE_FLOAT4(k8 + 5 < srcChannelC4 ? vload4(0, input + (k8 + 5) * 4) : (FLOAT4)0); + in1.s89ab = CONVERT_COMPUTE_FLOAT4(k8 + 6 < srcChannelC4 ? vload4(0, input + (k8 + 6) * 4) : (FLOAT4)0); + in1.scdef = CONVERT_COMPUTE_FLOAT4(k8 + 7 < srcChannelC4 ? vload4(0, input + (k8 + 7) * 4) : (FLOAT4)0); + DOT16X16(in0, weights00, out0.s0);DOT16X16(in1, weights01, out0.s0); + DOT16X16(in0, weights10, out0.s1);DOT16X16(in1, weights11, out0.s1); + DOT16X16(in0, weights20, out0.s2);DOT16X16(in1, weights21, out0.s2); + DOT16X16(in0, weights30, out0.s3);DOT16X16(in1, weights31, out0.s3); + } + #else + int k4 = k << 2; + COMPUTE_FLOAT16 weights0, weights1, weights2, weights3; + #ifdef USE_IMAGE + weights0 = readWeight(weight, idn, k, ScaleOffset.s0, ScaleOffset.s1); + weights1 = readWeight(weight, idn + 1, k, ScaleOffset.s2, ScaleOffset.s3); + weights2 = readWeight(weight, idn + 2, k, ScaleOffset.s4, ScaleOffset.s5); + weights3 = readWeight(weight, idn + 3, k, ScaleOffset.s6, ScaleOffset.s7); + #else + weights0 = readWeight(weight + weight_offset + k * weight_oc_offset, 0, 0, ScaleOffset.s0, ScaleOffset.s1); + weights1 = readWeight(weight + weight_offset + k * weight_oc_offset + WEIGHT_STRIDE, 0, 0, ScaleOffset.s2, ScaleOffset.s3); + weights2 = readWeight(weight + weight_offset + k * weight_oc_offset + 2 * WEIGHT_STRIDE, 0, 0, ScaleOffset.s4, ScaleOffset.s5); + weights3 = readWeight(weight + weight_offset + k * weight_oc_offset + 3 * WEIGHT_STRIDE, 0, 0, ScaleOffset.s6, ScaleOffset.s7); #endif PADZEROS(k, srcChannel, weights0); PADZEROS(k, srcChannel, weights1); @@ -227,109 +240,40 @@ __kernel void gemm_conv_c4_buf(GLOBAL_SIZE_DIM2 PADZEROS(k, srcChannel, weights3); { COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out.s0); - DOT16X16(in, weights1, out.s1); - DOT16X16(in, weights2, out.s2); - DOT16X16(in, weights3, out.s3); - } - #ifdef BACTH_BLOCK4 - if(isValidBatch1){ - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset1 + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset1 + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset1 + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out1.s0); - DOT16X16(in, weights1, out1.s1); - DOT16X16(in, weights2, out1.s2); - DOT16X16(in, weights3, out1.s3); - } - if(isValidBatch2){ - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset2 + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset2 + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset2 + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out2.s0); - DOT16X16(in, weights1, out2.s1); - DOT16X16(in, weights2, out2.s2); - DOT16X16(in, weights3, out2.s3); - } - if(isValidBatch3){ - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset3 + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset3 + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset3 + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out3.s0); - DOT16X16(in, weights1, out3.s1); - DOT16X16(in, weights2, out3.s2); - DOT16X16(in, weights3, out3.s3); + in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + k4 * 4)); + in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + (k4 + 1) * 4) : (FLOAT4)0); + in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + (k4 + 2) * 4) : (FLOAT4)0); + in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + (k4 + 3) * 4) : (FLOAT4)0); + DOT16X16(in, weights0, out0.s0); + DOT16X16(in, weights1, out0.s1); + DOT16X16(in, weights2, out0.s2); + DOT16X16(in, weights3, out0.s3); } #endif } #endif } - -#ifdef RELU - out = fmax(out, (COMPUTE_FLOAT4)0); -#endif - -#ifdef RELU6 - out = clamp(out, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); -#endif - - vstore4(CONVERT_FLOAT4(out), 0, output+out_offset); -#ifdef BACTH_BLOCK4 - if(isValidBatch1){ - out_offset += dstChannelC4 * height * width * 4; -#ifdef RELU - out1 = fmax(out1, (COMPUTE_FLOAT4)0); -#endif - -#ifdef RELU6 - out1 = clamp(out1, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); -#endif - - vstore4(CONVERT_FLOAT4(out1), 0, output+out_offset); - } - if(isValidBatch2){ - out_offset += dstChannelC4 * height * width * 4; #ifdef RELU - out2 = fmax(out2, (COMPUTE_FLOAT4)0); + out0 = fmax(out0, (COMPUTE_FLOAT4)0); #endif #ifdef RELU6 - out2 = clamp(out2, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); -#endif - - vstore4(CONVERT_FLOAT4(out2), 0, output+out_offset); - } - if(isValidBatch3){ - out_offset += dstChannelC4 * height * width * 4; -#ifdef RELU - out3 = fmax(out3, (COMPUTE_FLOAT4)0); + out0 = clamp(out0, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); #endif -#ifdef RELU6 - out3 = clamp(out3, (COMPUTE_FLOAT4)0, (COMPUTE_FLOAT4)6); -#endif - - vstore4(CONVERT_FLOAT4(out3), 0, output+out_offset); - } -#endif + vstore4(CONVERT_FLOAT4(out0), 0, output+out_offset); } -__kernel void gemm_conv_c2_buf(GLOBAL_SIZE_DIM2 +__kernel void gemv_conv_c2_buf(GLOBAL_SIZE_DIM2 __global const FLOAT* input, +#ifdef USE_IMAGE + __read_only image2d_t weight, +#else #if (defined USE_LOW_BIT_WEIGHT_INT8) __global const char *weight, #elif (defined USE_LOW_BIT_WEIGHT_INT4) __global const uchar *weight, +#endif #endif __global const float *dequantScaleOffset, __global const FLOAT *bias, @@ -337,48 +281,26 @@ __kernel void gemm_conv_c2_buf(GLOBAL_SIZE_DIM2 __private const int dstChannelC4, __private const int srcChannelC4, __private const int srcChannel, - __private const int batch, - __private const int height, - __private const int width, + __private const int bhw, __private const int blockNum, __private const int blockDim) { - const int out_c_w_idx = get_global_id(0); //c/4 w - const int out_b_h_idx = get_global_id(1); //b h - - UNIFORM_BOUNDRY_CHECK(out_c_w_idx, out_b_h_idx); + const int x = get_global_id(0); //c/2 + const int y = get_global_id(1); //b h w - const int out_c_idx = out_c_w_idx / width; - const int out_w_idx = out_c_w_idx % width; -#ifdef BACTH_BLOCK4 - const int out_b_idx = (out_b_h_idx / height) << 2; -#else - const int out_b_idx = out_b_h_idx / height; -#endif - const int out_h_idx = out_b_h_idx % height; - - COMPUTE_FLOAT2 bias0 = CONVERT_COMPUTE_FLOAT2(vload2(out_c_idx, bias)); - COMPUTE_FLOAT2 out = bias0; -#ifdef BACTH_BLOCK4 - COMPUTE_FLOAT2 out1 = bias0, out2 = bias0, out3 = bias0; - int input_offset1 = (((out_b_idx + 1) * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - int input_offset2 = (((out_b_idx + 2) * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - int input_offset3 = (((out_b_idx + 3) * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - bool isValidBatch1 = out_b_idx + 1 < batch; - bool isValidBatch2 = out_b_idx + 2 < batch; - bool isValidBatch3 = out_b_idx + 3 < batch; -#endif - int input_offset = ((out_b_idx * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - int out_offset = (((out_b_idx * dstChannelC4 + (out_c_idx * 2) / 4) * height + out_h_idx) * width + out_w_idx) * 4 + ((out_c_idx * 2)%4); - int wh = width * height * 4; -#if (defined USE_LOW_BIT_WEIGHT_INT4) - int weight_offset = out_c_idx * 2 * 8; - int weight_oc_offset = dstChannelC4 * 32; -#else - int weight_offset = out_c_idx * 2 * 16; - int weight_oc_offset = dstChannelC4 * 64; -#endif - - const int loop = (blockDim + 15) / 16; + UNIFORM_BOUNDRY_CHECK(x, y); + + int idn = x << 1; + int idm = y; + COMPUTE_FLOAT2 bias0 = CONVERT_COMPUTE_FLOAT2(vload2(x, bias)); + COMPUTE_FLOAT2 out0 = bias0; + int input_offset0 = idm * 4; + int out_offset = ((x * 2) / 4 * bhw + idm) * 4 + ((x * 2) % 4); +#ifndef USE_IMAGE + int weight_offset = x * 2 * WEIGHT_STRIDE; + int weight_oc_offset = dstChannelC4 * 4 * WEIGHT_STRIDE; +#endif + + const int loop = (blockDim + CHANNEL_PACK - 1) / CHANNEL_PACK; #ifdef INPUT_CHANNEL_LEAVE const int loop_end = max(loop - 1, 0); #else @@ -387,137 +309,98 @@ __kernel void gemm_conv_c2_buf(GLOBAL_SIZE_DIM2 for (int i = 0; i < blockNum; ++i){ int kindex = i * dstChannelC4 * 4 * 2; - COMPUTE_FLOAT4 ScaleOffset = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, dequantScaleOffset + kindex)); + COMPUTE_FLOAT4 ScaleOffset = CONVERT_COMPUTE_FLOAT4(vload4(x, dequantScaleOffset + kindex)); for (int j = 0; j < loop_end; ++j) { int k = i * loop + j; - #ifndef WIDTH_HEIGHT_1 - int k4 = k << 2; - #endif - COMPUTE_FLOAT16 weights0, weights1; - #if (defined USE_LOW_BIT_WEIGHT_INT8) - weights0 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 16)) * ScaleOffset.s2 + ScaleOffset.s3; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) + #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE) + int k32 = k << 5; + COMPUTE_FLOAT16 weights00, weights01, weights10, weights11; { - uchar16 charWeightsInt4 = vload16(0, weight + weight_offset + k * weight_oc_offset); - char16 charWeights0 = 0; - char16 charWeights1 = 0; - UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt4); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; + uchar16 charWeightsInt40 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(idn, k))); + uchar16 charWeightsInt41 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(idn + 1, k))); + char16 charWeights0, charWeights1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt40); + weights00 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; + weights01 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s0 + ScaleOffset.s1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt41); + weights10 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s2 + ScaleOffset.s3; + weights11 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; } - #endif { - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(k, input + input_offset)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out.s0); - DOT16X16(in, weights1, out.s1); - } - #ifdef BACTH_BLOCK4 - if(isValidBatch1){ - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(k, input + input_offset1)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out1.s0); - DOT16X16(in, weights1, out1.s1); - } - if(isValidBatch2){ - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(k, input + input_offset2)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out2.s0); - DOT16X16(in, weights1, out2.s1); + COMPUTE_FLOAT16 in0 = CONVERT_COMPUTE_FLOAT16(vload16(0, input + k32)); + COMPUTE_FLOAT16 in1 = CONVERT_COMPUTE_FLOAT16(vload16(0, input + k32 + 16)); + DOT16X16(in0, weights00, out0.s0);DOT16X16(in1, weights01, out0.s0); + DOT16X16(in0, weights10, out0.s1);DOT16X16(in1, weights11, out0.s1); } - if(isValidBatch3){ - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(k, input + input_offset3)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out3.s0); - DOT16X16(in, weights1, out3.s1); + #else + COMPUTE_FLOAT16 weights0, weights1; + #ifdef USE_IMAGE + weights0 = readWeight(weight, idn, k, ScaleOffset.s0, ScaleOffset.s1); + weights1 = readWeight(weight, idn + 1, k, ScaleOffset.s2, ScaleOffset.s3); + #else + weights0 = readWeight(weight + weight_offset + k * weight_oc_offset, 0, 0, ScaleOffset.s0, ScaleOffset.s1); + weights1 = readWeight(weight + weight_offset + k * weight_oc_offset + WEIGHT_STRIDE, 0, 0, ScaleOffset.s2, ScaleOffset.s3); + #endif + { + COMPUTE_FLOAT16 in = CONVERT_COMPUTE_FLOAT16(vload16(k, input)); + DOT16X16(in, weights0, out0.s0); + DOT16X16(in, weights1, out0.s1); } #endif } #ifdef INPUT_CHANNEL_LEAVE { int k = i * loop + loop_end; - int k4 = k << 2; - COMPUTE_FLOAT16 weights0, weights1; - #if (defined USE_LOW_BIT_WEIGHT_INT8) - weights0 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset + 16)) * ScaleOffset.s2 + ScaleOffset.s3; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) + #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE) + int k8 = k << 3; + COMPUTE_FLOAT16 weights00, weights01, weights10, weights11; { - uchar16 charWeightsInt4 = vload16(0, weight + weight_offset + k * weight_oc_offset); - char16 charWeights0 = 0; - char16 charWeights1 = 0; - UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt4); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; + uchar16 charWeightsInt40 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(idn, k))); + uchar16 charWeightsInt41 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(idn + 1, k))); + char16 charWeights0, charWeights1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt40); + weights00 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; + weights01 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s0 + ScaleOffset.s1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt41); + weights10 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s2 + ScaleOffset.s3; + weights11 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; + + PADZEROS(k, srcChannel, weights00);PADZEROS(k + 16, srcChannel, weights01); + PADZEROS(k, srcChannel, weights10);PADZEROS(k + 16, srcChannel, weights11); } + { + COMPUTE_FLOAT16 in0, in1; + in0.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + k8 * 4)); + in0.s4567 = CONVERT_COMPUTE_FLOAT4(k8 + 1 < srcChannelC4 ? vload4(0, input + (k8 + 1) * 4) : (FLOAT4)0); + in0.s89ab = CONVERT_COMPUTE_FLOAT4(k8 + 2 < srcChannelC4 ? vload4(0, input + (k8 + 2) * 4) : (FLOAT4)0); + in0.scdef = CONVERT_COMPUTE_FLOAT4(k8 + 3 < srcChannelC4 ? vload4(0, input + (k8 + 3) * 4) : (FLOAT4)0); + in1.s0123 = CONVERT_COMPUTE_FLOAT4(k8 + 4 < srcChannelC4 ? vload4(0, input + (k8 + 4) * 4) : (FLOAT4)0); + in1.s4567 = CONVERT_COMPUTE_FLOAT4(k8 + 5 < srcChannelC4 ? vload4(0, input + (k8 + 5) * 4) : (FLOAT4)0); + in1.s89ab = CONVERT_COMPUTE_FLOAT4(k8 + 6 < srcChannelC4 ? vload4(0, input + (k8 + 6) * 4) : (FLOAT4)0); + in1.scdef = CONVERT_COMPUTE_FLOAT4(k8 + 7 < srcChannelC4 ? vload4(0, input + (k8 + 7) * 4) : (FLOAT4)0); + DOT16X16(in0, weights00, out0.s0);DOT16X16(in1, weights01, out0.s0); + DOT16X16(in0, weights10, out0.s1);DOT16X16(in1, weights11, out0.s1); + } + #else + int k4 = k << 2; + COMPUTE_FLOAT16 weights0, weights1; + #ifdef USE_IMAGE + weights0 = readWeight(weight, idn, k, ScaleOffset.s0, ScaleOffset.s1); + weights1 = readWeight(weight, idn + 1, k, ScaleOffset.s2, ScaleOffset.s3); + #else + weights0 = readWeight(weight + weight_offset + k * weight_oc_offset, 0, 0, ScaleOffset.s0, ScaleOffset.s1); + weights1 = readWeight(weight + weight_offset + k * weight_oc_offset + WEIGHT_STRIDE, 0, 0, ScaleOffset.s2, ScaleOffset.s3); #endif PADZEROS(k, srcChannel, weights0); PADZEROS(k, srcChannel, weights1); { COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out.s0); - DOT16X16(in, weights1, out.s1); - } - #ifdef BACTH_BLOCK4 - if(isValidBatch1){ - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset1 + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset1 + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset1 + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out1.s0); - DOT16X16(in, weights1, out1.s1); - } - if(isValidBatch2){ - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset2 + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset2 + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset2 + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out2.s0); - DOT16X16(in, weights1, out2.s1); - } - if(isValidBatch3){ - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset3 + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset3 + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset3 + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out3.s0); - DOT16X16(in, weights1, out3.s1); + in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + k4 * 4)); + in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + (k4 + 1) * 4) : (FLOAT4)0); + in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + (k4 + 2) * 4) : (FLOAT4)0); + in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + (k4 + 3) * 4) : (FLOAT4)0); + DOT16X16(in, weights0, out0.s0); + DOT16X16(in, weights1, out0.s1); } #endif } @@ -525,60 +408,26 @@ __kernel void gemm_conv_c2_buf(GLOBAL_SIZE_DIM2 } #ifdef RELU - out = fmax(out, (COMPUTE_FLOAT2)0); + out0 = fmax(out0, (COMPUTE_FLOAT2)0); #endif #ifdef RELU6 - out = clamp(out, (COMPUTE_FLOAT2)0, (COMPUTE_FLOAT2)6); + out0 = clamp(out0, (COMPUTE_FLOAT2)0, (COMPUTE_FLOAT2)6); #endif - vstore2(CONVERT_FLOAT2(out), 0, output+out_offset); -#ifdef BACTH_BLOCK4 - if(isValidBatch1){ - out_offset += dstChannelC4 * height * width * 4; -#ifdef RELU - out1 = fmax(out1, (COMPUTE_FLOAT2)0); -#endif - -#ifdef RELU6 - out1 = clamp(out1, (COMPUTE_FLOAT2)0, (COMPUTE_FLOAT2)6); -#endif - - vstore2(CONVERT_FLOAT2(out1), 0, output+out_offset); - } - if(isValidBatch2){ - out_offset += dstChannelC4 * height * width * 4; -#ifdef RELU - out2 = fmax(out2, (COMPUTE_FLOAT2)0); -#endif - -#ifdef RELU6 - out2 = clamp(out2, (COMPUTE_FLOAT2)0, (COMPUTE_FLOAT2)6); -#endif - - vstore2(CONVERT_FLOAT2(out2), 0, output+out_offset); - } - if(isValidBatch3){ - out_offset += dstChannelC4 * height * width * 4; -#ifdef RELU - out3 = fmax(out3, (COMPUTE_FLOAT2)0); -#endif - -#ifdef RELU6 - out3 = clamp(out3, (COMPUTE_FLOAT2)0, (COMPUTE_FLOAT2)6); -#endif - - vstore2(CONVERT_FLOAT2(out3), 0, output+out_offset); - } -#endif + vstore2(CONVERT_FLOAT2(out0), 0, output+out_offset); } -__kernel void gemm_conv_c1_buf(GLOBAL_SIZE_DIM2 +__kernel void gemv_conv_c1_buf(GLOBAL_SIZE_DIM2 __global const FLOAT* input, +#ifdef USE_IMAGE + __read_only image2d_t weight, +#else #if (defined USE_LOW_BIT_WEIGHT_INT8) __global const char *weight, #elif (defined USE_LOW_BIT_WEIGHT_INT4) __global const uchar *weight, +#endif #endif __global const float *dequantScaleOffset, __global const FLOAT *bias, @@ -586,50 +435,28 @@ __kernel void gemm_conv_c1_buf(GLOBAL_SIZE_DIM2 __private const int dstChannelC4, __private const int srcChannelC4, __private const int srcChannel, - __private const int batch, - __private const int height, - __private const int width, + __private const int bhw, __private const int blockNum, __private const int blockDim) { - const int out_c_w_idx = get_global_id(0); //c/4 w - const int out_b_h_idx = get_global_id(1); //b h + const int x = get_global_id(0); //c + const int y = get_global_id(1); //b h w - UNIFORM_BOUNDRY_CHECK(out_c_w_idx, out_b_h_idx); + UNIFORM_BOUNDRY_CHECK(x, y); + int idn = x; + int idm = y; - const int out_c_idx = out_c_w_idx / width; - const int out_w_idx = out_c_w_idx % width; -#ifdef BACTH_BLOCK4 - const int out_b_idx = (out_b_h_idx / height) << 2; -#else - const int out_b_idx = out_b_h_idx / height; -#endif - const int out_h_idx = out_b_h_idx % height; - - COMPUTE_FLOAT bias0 = bias[out_c_idx]; - COMPUTE_FLOAT out = bias0; + COMPUTE_FLOAT bias0 = bias[x]; + COMPUTE_FLOAT out0 = bias0; -#ifdef BACTH_BLOCK4 - COMPUTE_FLOAT out1 = bias0, out2 = bias0, out3 = bias0; - int input_offset1 = (((out_b_idx + 1) * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - int input_offset2 = (((out_b_idx + 2) * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - int input_offset3 = (((out_b_idx + 3) * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - bool isValidBatch1 = out_b_idx + 1 < batch; - bool isValidBatch2 = out_b_idx + 2 < batch; - bool isValidBatch3 = out_b_idx + 3 < batch; -#endif + int input_offset0 = idm * 4; - int input_offset = ((out_b_idx * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - int out_offset = (((out_b_idx * dstChannelC4 + out_c_idx/4) * height + out_h_idx) * width + out_w_idx) * 4 + (out_c_idx%4); - int wh = width * height * 4; -#if (defined USE_LOW_BIT_WEIGHT_INT4) - int weight_offset = out_c_idx * 8; - int weight_oc_offset = dstChannelC4 * 32; -#else - int weight_offset = out_c_idx * 16; - int weight_oc_offset = dstChannelC4 * 64; + int out_offset = ((x / 4) * bhw + idm) * 4 + (x % 4); +#ifndef USE_IMAGE + int weight_offset = x * WEIGHT_STRIDE; + int weight_oc_offset = dstChannelC4 * 4 * WEIGHT_STRIDE; #endif - const int loop = (blockDim + 15) / 16; + const int loop = (blockDim + CHANNEL_PACK - 1) / CHANNEL_PACK; #ifdef INPUT_CHANNEL_LEAVE const int loop_end = max(loop - 1, 0); #else @@ -638,633 +465,92 @@ __kernel void gemm_conv_c1_buf(GLOBAL_SIZE_DIM2 for (int i = 0; i < blockNum; ++i){ int kindex = i * dstChannelC4 * 4 * 2; - COMPUTE_FLOAT2 ScaleOffset = CONVERT_COMPUTE_FLOAT2(vload2(out_c_idx, dequantScaleOffset + kindex)); + COMPUTE_FLOAT2 ScaleOffset = CONVERT_COMPUTE_FLOAT2(vload2(x, dequantScaleOffset + kindex)); for (int j = 0; j < loop_end; ++j) { int k = i * loop + j; - #ifndef WIDTH_HEIGHT_1 - int k4 = k << 2; - #endif - COMPUTE_FLOAT16 weights; - #if (defined USE_LOW_BIT_WEIGHT_INT8) - weights = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * ScaleOffset.s0 + ScaleOffset.s1; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) + #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE) + int k32 = k << 5; + COMPUTE_FLOAT16 weights00, weights01; { - uchar8 charWeightsInt4 = vload8(0, weight + weight_offset + k * weight_oc_offset); - char16 charWeights = 0; - UCHAR8_TO_CHAR16(charWeights, charWeightsInt4); - weights = CONVERT_COMPUTE_FLOAT16(charWeights) * ScaleOffset.s0 + ScaleOffset.s1; + uchar16 charWeightsInt40 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(idn, k))); + char16 charWeights0, charWeights1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt40); + weights00 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; + weights01 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s0 + ScaleOffset.s1; } - #endif { - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(k, input + input_offset)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights, out); - } - #ifdef BACTH_BLOCK4 - if(isValidBatch1){ - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(k, input + input_offset1)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights, out1); - } - if(isValidBatch2){ - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(k, input + input_offset2)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights, out2); + COMPUTE_FLOAT16 in0 = CONVERT_COMPUTE_FLOAT16(vload16(0, input + k32)); + COMPUTE_FLOAT16 in1 = CONVERT_COMPUTE_FLOAT16(vload16(0, input + k32 + 16)); + DOT16X16(in0, weights00, out0);DOT16X16(in1, weights01, out0); } - if(isValidBatch3){ - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(k, input + input_offset3)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights, out3); - } - #endif - } - #ifdef INPUT_CHANNEL_LEAVE - { - int k = i * loop + loop_end; - int k4 = k << 2; + #else COMPUTE_FLOAT16 weights; - #if (defined USE_LOW_BIT_WEIGHT_INT8) - weights = CONVERT_COMPUTE_FLOAT16(vload16(0, weight + weight_offset + k * weight_oc_offset)) * ScaleOffset.s0 + ScaleOffset.s1; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) - { - uchar8 charWeightsInt4 = vload8(0, weight + weight_offset + k * weight_oc_offset); - char16 charWeights = 0; - UCHAR8_TO_CHAR16(charWeights, charWeightsInt4); - weights = CONVERT_COMPUTE_FLOAT16(charWeights) * ScaleOffset.s0 + ScaleOffset.s1; - } - #endif - PADZEROS(k, srcChannel, weights); - { - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights, out); - } - #ifdef BACTH_BLOCK4 - if(isValidBatch1){ - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset1 + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset1 + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset1 + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights, out1); - } - if(isValidBatch2){ - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset2 + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset2 + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset2 + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights, out2); - } - if(isValidBatch3){ - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset3 + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset3 + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset3 + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights, out3); - } - #endif - } - #endif - } - -#ifdef RELU - out = fmax(out, (COMPUTE_FLOAT)0); -#endif - -#ifdef RELU6 - out = clamp(out, (COMPUTE_FLOAT)0, (COMPUTE_FLOAT)6); -#endif - output[out_offset] = out; -#ifdef BACTH_BLOCK4 - if(isValidBatch1){ - out_offset += dstChannelC4 * height * width * 4; -#ifdef RELU - out1 = fmax(out1, (COMPUTE_FLOAT)0); -#endif - -#ifdef RELU6 - out1 = clamp(out1, (COMPUTE_FLOAT)0, (COMPUTE_FLOAT)6); -#endif - - output[out_offset] = out1; - } - if(isValidBatch2){ - out_offset += dstChannelC4 * height * width * 4; -#ifdef RELU - out2 = fmax(out2, (COMPUTE_FLOAT)0); -#endif - -#ifdef RELU6 - out2 = clamp(out2, (COMPUTE_FLOAT)0, (COMPUTE_FLOAT)6); -#endif - - output[out_offset] = out2; - } - if(isValidBatch3){ - out_offset += dstChannelC4 * height * width * 4; -#ifdef RELU - out3 = fmax(out3, (COMPUTE_FLOAT)0); -#endif - -#ifdef RELU6 - out3 = clamp(out3, (COMPUTE_FLOAT)0, (COMPUTE_FLOAT)6); -#endif - - output[out_offset] = out3; - } -#endif -} -__kernel void gemm_conv_c2_image(GLOBAL_SIZE_DIM2 - __global const FLOAT* input, - __read_only image2d_t weight, - __global const float *dequantScaleOffset, - __global const FLOAT *bias, - __global FLOAT* output, - __private const int dstChannelC4, - __private const int srcChannelC4, - __private const int srcChannel, - __private const int batch, - __private const int height, - __private const int width, - __private const int blockNum, - __private const int blockDim) { - const int out_c_w_idx = get_global_id(0); //c/4 w - const int out_b_h_idx = get_global_id(1); //b h - UNIFORM_BOUNDRY_CHECK(out_c_w_idx, out_b_h_idx); - - const int out_c_idx = (out_c_w_idx / width) << 1; - const int out_w_idx = out_c_w_idx % width; -#ifdef BACTH_BLOCK4 - const int out_b_idx = (out_b_h_idx / height) << 2; -#else - const int out_b_idx = out_b_h_idx / height; -#endif - const int out_h_idx = out_b_h_idx % height; - - COMPUTE_FLOAT2 bias0 = CONVERT_COMPUTE_FLOAT2(vload2(0, bias + out_c_idx)); - COMPUTE_FLOAT2 out = bias0; - -#ifdef BACTH_BLOCK4 - COMPUTE_FLOAT2 out1 = bias0, out2 = bias0, out3 = bias0; - int input_offset1 = (((out_b_idx + 1) * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - int input_offset2 = (((out_b_idx + 2) * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - int input_offset3 = (((out_b_idx + 3) * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - bool isValidBatch1 = out_b_idx + 1 < batch; - bool isValidBatch2 = out_b_idx + 2 < batch; - bool isValidBatch3 = out_b_idx + 3 < batch; -#endif - - int input_offset = ((out_b_idx * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - int out_offset = (((out_b_idx * dstChannelC4 + out_c_idx/4) * height + out_h_idx) * width + out_w_idx) * 4 + (out_c_idx % 4); - int wh = width * height * 4; - - const int loop = (blockDim + 15) / 16; - #ifdef INPUT_CHANNEL_LEAVE - const int loop_end = max(loop - 1, 0); - #else - const int loop_end = loop; - #endif - - for (int i = 0; i < blockNum; ++i){ - int kindex = i * dstChannelC4 * 4 * 2; - COMPUTE_FLOAT4 ScaleOffset = CONVERT_COMPUTE_FLOAT4(vload4(0, dequantScaleOffset + out_c_idx * 2 + kindex)); - for (int j = 0; j < loop_end; j++) { - int k = i * loop + j; - #ifndef WIDTH_HEIGHT_1 - int k4 = k << 2; - #endif - #if (defined USE_LOW_BIT_WEIGHT_INT8) - COMPUTE_FLOAT16 weights0 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx, k)))) * ScaleOffset.s0 + ScaleOffset.s1; - COMPUTE_FLOAT16 weights1 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx + 1, k)))) * ScaleOffset.s2 + ScaleOffset.s3; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) - COMPUTE_FLOAT16 weights0, weights1; - { - uchar8 charWeightsInt40 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx, k)))); - uchar8 charWeightsInt41 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx + 1, k)))); - char16 charWeights0 = 0; - char16 charWeights1 = 0; - UCHAR8_TO_CHAR16(charWeights0, charWeightsInt40); - UCHAR8_TO_CHAR16(charWeights1, charWeightsInt41); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; - } + #ifdef USE_IMAGE + weights = readWeight(weight, idn, k, ScaleOffset.s0, ScaleOffset.s1); + #else + weights = readWeight(weight + weight_offset + k * weight_oc_offset, 0, 0, ScaleOffset.s0, ScaleOffset.s1); #endif { - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k * 16)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out.s0); - DOT16X16(in, weights1, out.s1); - } - #ifdef BACTH_BLOCK4 - if(isValidBatch1){ - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset1 + k * 16)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out1.s0); - DOT16X16(in, weights1, out1.s1); - } - if(isValidBatch2){ - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset2 + k * 16)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out2.s0); - DOT16X16(in, weights1, out2.s1); - } - if(isValidBatch3){ - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset3 + k * 16)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out3.s0); - DOT16X16(in, weights1, out3.s1); + COMPUTE_FLOAT16 in = CONVERT_COMPUTE_FLOAT16(vload16(k, input)); + DOT16X16(in, weights, out0); } #endif } #ifdef INPUT_CHANNEL_LEAVE { int k = i * loop + loop_end; - int k4 = k << 2; - #if (defined USE_LOW_BIT_WEIGHT_INT8) - COMPUTE_FLOAT16 weights0 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx, k)))) * ScaleOffset.s0 + ScaleOffset.s1; - COMPUTE_FLOAT16 weights1 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx + 1, k)))) * ScaleOffset.s2 + ScaleOffset.s3; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) - COMPUTE_FLOAT16 weights0, weights1; + #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE) + int k8 = k << 3; + COMPUTE_FLOAT16 weights00, weights01; { - uchar8 charWeightsInt40 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx, k)))); - uchar8 charWeightsInt41 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx + 1, k)))); - char16 charWeights0 = 0; - char16 charWeights1 = 0; - UCHAR8_TO_CHAR16(charWeights0, charWeightsInt40); - UCHAR8_TO_CHAR16(charWeights1, charWeightsInt41); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; - weights1 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s2 + ScaleOffset.s3; - } - #endif - PADZEROS(k, srcChannel, weights0); - PADZEROS(k, srcChannel, weights1); - { - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out.s0); - DOT16X16(in, weights1, out.s1); - } - #ifdef BACTH_BLOCK4 - if(isValidBatch1){ - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset1 + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset1 + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset1 + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out1.s0); - DOT16X16(in, weights1, out1.s1); - } - if(isValidBatch2){ - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset2 + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset2 + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset2 + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out2.s0); - DOT16X16(in, weights1, out2.s1); - } - if(isValidBatch3){ - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset3 + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset3 + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset3 + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out3.s0); - DOT16X16(in, weights1, out3.s1); - } - #endif - } - #endif - } - -#ifdef RELU - out = fmax(out, (COMPUTE_FLOAT2)0); -#endif -#ifdef RELU6 - out = clamp(out, (COMPUTE_FLOAT2)0, (COMPUTE_FLOAT2)6); -#endif - vstore2(CONVERT_FLOAT2(out), 0, output + out_offset); -#ifdef BACTH_BLOCK4 - if(isValidBatch1){ - out_offset += dstChannelC4 * height * width * 4; -#ifdef RELU - out1 = fmax(out1, (COMPUTE_FLOAT2)0); -#endif - -#ifdef RELU6 - out1 = clamp(out1, (COMPUTE_FLOAT2)0, (COMPUTE_FLOAT2)6); -#endif - - vstore2(CONVERT_FLOAT2(out1), 0, output+out_offset); - } - if(isValidBatch2){ - out_offset += dstChannelC4 * height * width * 4; -#ifdef RELU - out2 = fmax(out2, (COMPUTE_FLOAT2)0); -#endif - -#ifdef RELU6 - out2 = clamp(out2, (COMPUTE_FLOAT2)0, (COMPUTE_FLOAT2)6); -#endif - - vstore2(CONVERT_FLOAT2(out2), 0, output+out_offset); - } - if(isValidBatch3){ - out_offset += dstChannelC4 * height * width * 4; -#ifdef RELU - out3 = fmax(out3, (COMPUTE_FLOAT2)0); -#endif - -#ifdef RELU6 - out3 = clamp(out3, (COMPUTE_FLOAT2)0, (COMPUTE_FLOAT2)6); -#endif - - vstore2(CONVERT_FLOAT2(out3), 0, output+out_offset); - } -#endif -} -__kernel void gemm_conv_c1_image(GLOBAL_SIZE_DIM2 - __global const FLOAT* input, - __read_only image2d_t weight, - __global const float *dequantScaleOffset, - __global const FLOAT *bias, - __global FLOAT* output, - __private const int dstChannelC4, - __private const int srcChannelC4, - __private const int srcChannel, - __private const int batch, - __private const int height, - __private const int width, - __private const int blockNum, - __private const int blockDim) { - const int out_c_w_idx = get_global_id(0); //c/4 w - const int out_b_h_idx = get_global_id(1); //b h - UNIFORM_BOUNDRY_CHECK(out_c_w_idx, out_b_h_idx); - - const int out_c_idx = out_c_w_idx / width; - const int out_w_idx = out_c_w_idx % width; -#ifdef BACTH_BLOCK4 - const int out_b_idx = (out_b_h_idx / height) << 2; -#else - const int out_b_idx = out_b_h_idx / height; -#endif - const int out_h_idx = out_b_h_idx % height; - - COMPUTE_FLOAT bias0 = bias[out_c_idx]; - COMPUTE_FLOAT out = bias0; - - int input_offset = ((out_b_idx * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - int out_offset = (((out_b_idx * dstChannelC4 + out_c_idx/4)* height + out_h_idx) * width + out_w_idx) * 4 + (out_c_idx%4); - int wh = width * height * 4; -#ifdef BACTH_BLOCK4 - COMPUTE_FLOAT out1 = bias0, out2 = bias0, out3 = bias0; - int input_offset1 = (((out_b_idx + 1) * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - int input_offset2 = (((out_b_idx + 2) * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - int input_offset3 = (((out_b_idx + 3) * srcChannelC4 * height + out_h_idx) * width + out_w_idx) * 4; - bool isValidBatch1 = out_b_idx + 1 < batch; - bool isValidBatch2 = out_b_idx + 2 < batch; - bool isValidBatch3 = out_b_idx + 3 < batch; -#endif - - const int loop = (blockDim + 15) / 16; - #ifdef INPUT_CHANNEL_LEAVE - const int loop_end = max(loop - 1, 0); - #else - const int loop_end = loop; - #endif - - for (int i = 0; i < blockNum; ++i){ - int kindex = i * dstChannelC4 * 4 * 2; - COMPUTE_FLOAT2 ScaleOffset = CONVERT_COMPUTE_FLOAT2(vload2(out_c_idx, dequantScaleOffset + kindex)); - for (int j = 0; j < loop_end; j++) { - int k = i * loop + j; - #ifndef WIDTH_HEIGHT_1 - int k4 = k << 2; - #endif - #if (defined USE_LOW_BIT_WEIGHT_INT8) - COMPUTE_FLOAT16 weights0 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx, k)))) * ScaleOffset.s0 + ScaleOffset.s1; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) - COMPUTE_FLOAT16 weights0; - { - uchar8 charWeightsInt4 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx, k)))); - char16 charWeights = 0; - UCHAR8_TO_CHAR16(charWeights, charWeightsInt4); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights) * ScaleOffset.s0 + ScaleOffset.s1; + uchar16 charWeightsInt40 = as_uchar16(read_imagei(weight, SAMPLER, (int2)(idn, k))); + char16 charWeights0, charWeights1; + UCHAR16_TO_2CHAR16(charWeights0, charWeights1, charWeightsInt40); + weights00 = CONVERT_COMPUTE_FLOAT16(charWeights0) * ScaleOffset.s0 + ScaleOffset.s1; + weights01 = CONVERT_COMPUTE_FLOAT16(charWeights1) * ScaleOffset.s0 + ScaleOffset.s1; + + PADZEROS(k, srcChannel, weights00);PADZEROS(k + 16, srcChannel, weights01); } - #endif { - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset + k * 16)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out); - } - #ifdef BACTH_BLOCK4 - if(isValidBatch1){ - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset1 + k * 16)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out1); - } - if(isValidBatch2){ - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset2 + k * 16)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out2); - } - if(isValidBatch3){ - COMPUTE_FLOAT16 in; - #ifdef WIDTH_HEIGHT_1 - in = CONVERT_COMPUTE_FLOAT16(vload16(0, input + input_offset3 + k * 16)); - #else - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + (k4 + 1) * wh)); - in.s89ab = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + (k4 + 2) * wh)); - in.scdef = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + (k4 + 3) * wh)); - #endif - DOT16X16(in, weights0, out3); - } - #endif - } - #ifdef INPUT_CHANNEL_LEAVE - { - int k = i * loop + loop_end; + COMPUTE_FLOAT16 in0, in1; + in0.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + k8 * 4)); + in0.s4567 = CONVERT_COMPUTE_FLOAT4(k8 + 1 < srcChannelC4 ? vload4(0, input + (k8 + 1) * 4) : (FLOAT4)0); + in0.s89ab = CONVERT_COMPUTE_FLOAT4(k8 + 2 < srcChannelC4 ? vload4(0, input + (k8 + 2) * 4) : (FLOAT4)0); + in0.scdef = CONVERT_COMPUTE_FLOAT4(k8 + 3 < srcChannelC4 ? vload4(0, input + (k8 + 3) * 4) : (FLOAT4)0); + in1.s0123 = CONVERT_COMPUTE_FLOAT4(k8 + 4 < srcChannelC4 ? vload4(0, input + (k8 + 4) * 4) : (FLOAT4)0); + in1.s4567 = CONVERT_COMPUTE_FLOAT4(k8 + 5 < srcChannelC4 ? vload4(0, input + (k8 + 5) * 4) : (FLOAT4)0); + in1.s89ab = CONVERT_COMPUTE_FLOAT4(k8 + 6 < srcChannelC4 ? vload4(0, input + (k8 + 6) * 4) : (FLOAT4)0); + in1.scdef = CONVERT_COMPUTE_FLOAT4(k8 + 7 < srcChannelC4 ? vload4(0, input + (k8 + 7) * 4) : (FLOAT4)0); + DOT16X16(in0, weights00, out0);DOT16X16(in1, weights01, out0); + } + #else int k4 = k << 2; - #if (defined USE_LOW_BIT_WEIGHT_INT8) - COMPUTE_FLOAT16 weights0 = CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight, SAMPLER, (int2)(out_c_idx, k)))) * ScaleOffset.s0 + ScaleOffset.s1; - #elif (defined USE_LOW_BIT_WEIGHT_INT4) - COMPUTE_FLOAT16 weights0; - { - uchar8 charWeightsInt4 = as_uchar8(convert_ushort4(read_imageui(weight, SAMPLER, (int2)(out_c_idx, k)))); - char16 charWeights = 0; - UCHAR8_TO_CHAR16(charWeights, charWeightsInt4); - weights0 = CONVERT_COMPUTE_FLOAT16(charWeights) * ScaleOffset.s0 + ScaleOffset.s1; - } + COMPUTE_FLOAT16 weights; + #ifdef USE_IMAGE + weights = readWeight(weight, idn, k, ScaleOffset.s0, ScaleOffset.s1); + #else + weights = readWeight(weight + weight_offset + k * weight_oc_offset, 0, 0, ScaleOffset.s0, ScaleOffset.s1); #endif - PADZEROS(k, srcChannel, weights0); + PADZEROS(k, srcChannel, weights); { - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out); - } - #ifdef BACTH_BLOCK4 - if(isValidBatch1){ COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset1 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset1 + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset1 + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset1 + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out1); - } - if(isValidBatch2){ - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset2 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset2 + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset2 + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset2 + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out2); - } - if(isValidBatch3){ - COMPUTE_FLOAT16 in; - in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + input_offset3 + k4 * wh)); - in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + input_offset3 + (k4 + 1) * wh) : (FLOAT4)0); - in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + input_offset3 + (k4 + 2) * wh) : (FLOAT4)0); - in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + input_offset3 + (k4 + 3) * wh) : (FLOAT4)0); - DOT16X16(in, weights0, out3); + in.s0123 = CONVERT_COMPUTE_FLOAT4(vload4(0, input + k4 * 4)); + in.s4567 = CONVERT_COMPUTE_FLOAT4(k4 + 1 < srcChannelC4 ? vload4(0, input + (k4 + 1) * 4) : (FLOAT4)0); + in.s89ab = CONVERT_COMPUTE_FLOAT4(k4 + 2 < srcChannelC4 ? vload4(0, input + (k4 + 2) * 4) : (FLOAT4)0); + in.scdef = CONVERT_COMPUTE_FLOAT4(k4 + 3 < srcChannelC4 ? vload4(0, input + (k4 + 3) * 4) : (FLOAT4)0); + DOT16X16(in, weights, out0); } #endif } #endif } - -#ifdef RELU - out = fmax(out, (COMPUTE_FLOAT)0); -#endif -#ifdef RELU6 - out = clamp(out, (COMPUTE_FLOAT)0, (COMPUTE_FLOAT)6); -#endif - output[out_offset] = out; -#ifdef BACTH_BLOCK4 - if(isValidBatch1){ - out_offset += dstChannelC4 * height * width * 4; -#ifdef RELU - out1 = fmax(out1, (COMPUTE_FLOAT)0); -#endif - -#ifdef RELU6 - out1 = clamp(out1, (COMPUTE_FLOAT)0, (COMPUTE_FLOAT)6); -#endif - - output[out_offset] = out1; - } - if(isValidBatch2){ - out_offset += dstChannelC4 * height * width * 4; -#ifdef RELU - out2 = fmax(out2, (COMPUTE_FLOAT)0); -#endif - -#ifdef RELU6 - out1 = clamp(out2, (COMPUTE_FLOAT)0, (COMPUTE_FLOAT)6); -#endif - - output[out_offset] = out2; - } - if(isValidBatch3){ - out_offset += dstChannelC4 * height * width * 4; + #ifdef RELU - out3 = fmax(out3, (COMPUTE_FLOAT)0); + out0 = fmax(out0, (COMPUTE_FLOAT)0); #endif #ifdef RELU6 - out3 = clamp(out3, (COMPUTE_FLOAT)0, (COMPUTE_FLOAT)6); -#endif - - output[out_offset] = out3; - } + out0 = clamp(out0, (COMPUTE_FLOAT)0, (COMPUTE_FLOAT)6); #endif + output[out_offset] = out0; } - diff --git a/source/backend/opencl/execution/cl/grid_sample_buf.cl b/source/backend/opencl/execution/cl/grid_sample_buf.cl index 391a88163..758cb2295 100644 --- a/source/backend/opencl/execution/cl/grid_sample_buf.cl +++ b/source/backend/opencl/execution/cl/grid_sample_buf.cl @@ -61,7 +61,7 @@ __kernel void nearest_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input, __private const int input_width, __private const int output_height, __private const int output_width, - __private const int channelBlocks, + __private const int batch, __private const enum BorderMode paddingMode, __private const int alignCorners){ @@ -88,19 +88,13 @@ __kernel void nearest_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input, (xn,xn,xn,xn) (y5,y6,y7,y8) --------------------------- */ - const int slice = output_height_idx / 4; - const int slice_blocks = (output_height + 3) / 4; // output_width_block_idx means gird y offset, 2 means grid width - const int grid_offset = ((output_batch_idx * slice_blocks + slice) * output_width + output_width_block_idx) * 2; - COMPUTE_FLOAT4 grid_x = CONVERT_COMPUTE_FLOAT4(vload4(grid_offset, grid)); - COMPUTE_FLOAT4 grid_y = CONVERT_COMPUTE_FLOAT4(vload4(grid_offset + 1, grid)); + const int grid_offset = (output_batch_idx * output_height + output_height_idx) * output_width + output_width_block_idx; + COMPUTE_FLOAT2 grid_xy = CONVERT_COMPUTE_FLOAT2(vload2(grid_offset, grid)); - const float arr[8] = {grid_x.x, grid_y.x, grid_x.y, grid_y.y, grid_x.z, grid_y.z, grid_x.w, grid_y.w}; - // get grid x,y - const int arr_offset = output_height_idx % 4; - const float x = arr[2 * arr_offset]; - const float y = arr[2 * arr_offset + 1]; + const float x = (float)grid_xy.x; + const float y = (float)grid_xy.y; // convert grid x,y to input x,y coordinate range float in_grid_x = getPosition(x, input_width, alignCorners); @@ -110,10 +104,10 @@ __kernel void nearest_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input, int nw = floor(in_grid_x + 0.5f); int nh = floor(in_grid_y + 0.5f); - const int inp_offset_base = (output_batch_idx * channelBlocks + output_channel_block_idx) * input_height; + const int inp_offset_base = (output_batch_idx + output_channel_block_idx * batch) * input_height; COMPUTE_FLOAT4 value = sample(nh, nw, inp_offset_base, input, input_height, input_width, paddingMode); - const int output_offset = ((output_batch_idx * channelBlocks + output_channel_block_idx ) * output_height + output_height_idx) * output_width + output_width_block_idx; + const int output_offset = ((output_batch_idx + output_channel_block_idx * batch) * output_height + output_height_idx) * output_width + output_width_block_idx; vstore4(CONVERT_FLOAT4(value), output_offset, output); } @@ -124,7 +118,7 @@ __kernel void bilinear_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input, __private const int input_width, __private const int output_height, __private const int output_width, - __private const int channelBlocks, + __private const int batch, __private const enum BorderMode paddingMode, __private const int alignCorners){ @@ -137,19 +131,14 @@ __kernel void bilinear_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input, const int output_batch_idx = output_batch_height_block_idx / output_height; const int output_height_idx = output_batch_height_block_idx % output_height; - const int slice = output_height_idx / 4; - const int slice_blocks = (output_height + 3) / 4; // output_width_block_idx means gird y offset, 2 means grid width - const int grid_offset = ((output_batch_idx * slice_blocks + slice) * output_width + output_width_block_idx) * 2; - COMPUTE_FLOAT4 grid_x = CONVERT_COMPUTE_FLOAT4(vload4(grid_offset, grid)); - COMPUTE_FLOAT4 grid_y = CONVERT_COMPUTE_FLOAT4(vload4(grid_offset + 1, grid)); + const int grid_offset = (output_batch_idx * output_height + output_height_idx) * output_width + output_width_block_idx; + COMPUTE_FLOAT2 grid_xy = CONVERT_COMPUTE_FLOAT2(vload2(grid_offset, grid)); - const float arr[8] = {grid_x.x, grid_y.x, grid_x.y, grid_y.y, grid_x.z, grid_y.z, grid_x.w, grid_y.w}; // get grid x,y - const int arr_offset = output_height_idx % 4; - const float x = arr[2 * arr_offset]; - const float y = arr[2 * arr_offset + 1]; + const float x = (float)grid_xy.x; + const float y = (float)grid_xy.y; // convert grid x,y to input x,y coordinate range float in_grid_x = getPosition(x, input_width, alignCorners); @@ -164,7 +153,7 @@ __kernel void bilinear_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input, float y_weight = in_h1 - in_grid_y; // bilinear interpolation - const int inp_offset_base = (output_batch_idx * channelBlocks + output_channel_block_idx) * input_height; + const int inp_offset_base = (output_batch_idx + output_channel_block_idx * batch) * input_height; COMPUTE_FLOAT4 i00 = sample(in_h0, in_w0, inp_offset_base, input, input_height, input_width, paddingMode); COMPUTE_FLOAT4 i01 = sample(in_h0, in_w1, inp_offset_base, input, input_height, input_width, paddingMode); COMPUTE_FLOAT4 i10 = sample(in_h1, in_w0, inp_offset_base, input, input_height, input_width, paddingMode); @@ -173,6 +162,6 @@ __kernel void bilinear_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input, COMPUTE_FLOAT4 value = CONVERT_COMPUTE_FLOAT4(((COMPUTE_FLOAT4)x_weight * CONVERT_COMPUTE_FLOAT4(i00) + (COMPUTE_FLOAT4)(1.0f - x_weight) * CONVERT_COMPUTE_FLOAT4(i01)) * (COMPUTE_FLOAT4)y_weight + ((COMPUTE_FLOAT4)x_weight * CONVERT_COMPUTE_FLOAT4(i10) + (COMPUTE_FLOAT4)(1.0f - x_weight) * CONVERT_COMPUTE_FLOAT4(i11)) * (COMPUTE_FLOAT4)(1.0f- y_weight)); - const int output_offset = ((output_batch_idx * channelBlocks + output_channel_block_idx ) * output_height + output_height_idx) * output_width + output_width_block_idx; + const int output_offset = ((output_batch_idx + output_channel_block_idx * batch) * output_height + output_height_idx) * output_width + output_width_block_idx; vstore4(CONVERT_FLOAT4(value), output_offset, output); } diff --git a/source/backend/opencl/execution/cl/input_transe_buf.cl b/source/backend/opencl/execution/cl/input_transe_buf.cl index 352a7b1ab..1b86b6e7d 100644 --- a/source/backend/opencl/execution/cl/input_transe_buf.cl +++ b/source/backend/opencl/execution/cl/input_transe_buf.cl @@ -12,6 +12,7 @@ __kernel void conv_transe_c4_c1( __private const int input_width, __private const int input_height, __private const int input_channel, + __private const int batch, __private const int channel_blocks, __private const int input_pad_left, __private const int input_pad_right) @@ -29,10 +30,10 @@ __kernel void conv_transe_c4_c1( const uint input_x_pitch = 4; const uint input_y_pitch = input_x_pitch * input_width; const uint input_f_pitch = input_y_pitch * input_height; - const uint input_b_pitch = input_f_pitch * channel_blocks; + const uint input_b_pitch = input_f_pitch * batch; - const uint input_offset = b * input_b_pitch + - c * input_f_pitch + + const uint input_offset = b * input_f_pitch + + c * input_b_pitch + h * input_y_pitch + w * input_x_pitch; @@ -63,6 +64,7 @@ __kernel void conv_transe_c4_c16( int input_width, int input_height, int input_channel, + int batch, int channel_blocks, int input_pad_left, int input_pad_right) @@ -80,10 +82,10 @@ __kernel void conv_transe_c4_c16( const uint input_x_pitch = 4; const uint input_y_pitch = input_x_pitch * input_width; const uint input_f_pitch = input_y_pitch * input_height; - const uint input_b_pitch = input_f_pitch * channel_blocks; + const uint input_b_pitch = input_f_pitch * batch; - const uint input_offset = b * input_b_pitch + - c * input_f_pitch + + const uint input_offset = b * input_f_pitch + + c * input_b_pitch + h * input_y_pitch + w * input_x_pitch; @@ -110,4 +112,4 @@ __kernel void conv_transe_c4_c16( vstore4((FLOAT4)0, 0, output + pad_offset + i * output_x_pitch); } } -} \ No newline at end of file +} diff --git a/source/backend/opencl/execution/cl/interp_buf.cl b/source/backend/opencl/execution/cl/interp_buf.cl index 464997c15..99bcea8db 100644 --- a/source/backend/opencl/execution/cl/interp_buf.cl +++ b/source/backend/opencl/execution/cl/interp_buf.cl @@ -20,7 +20,7 @@ __kernel void nearest_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input, __private const int input_width, __private const int out_height, __private const int out_width, - __private const int channelBlocks) { + __private const int batch) { const int output_channel_block_idx = get_global_id(0); const int output_width_block_idx = get_global_id(1); const int output_batch_height_block_idx = get_global_id(2); @@ -40,10 +40,10 @@ __kernel void nearest_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input, const int in_w_index = min(max(0, (int)floor(in_w_idx)), input_width-1); #endif - const int inp_offset = ((output_batch_idx * channelBlocks + output_channel_block_idx) * input_height + in_h_index) * input_width + in_w_index; + const int inp_offset = ((output_batch_idx + output_channel_block_idx*batch) * input_height + in_h_index) * input_width + in_w_index; FLOAT4 value = vload4(inp_offset, input); - const int out_offset = ((output_batch_idx * channelBlocks + output_channel_block_idx) * out_height + output_height_idx) * out_width + output_width_block_idx; + const int out_offset = ((output_batch_idx + output_channel_block_idx*batch) * out_height + output_height_idx) * out_width + output_width_block_idx; vstore4(value, out_offset, output); } @@ -57,7 +57,7 @@ __kernel void bilinear_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input, __private const int input_width, __private const int out_height, __private const int out_width, - __private const int channelBlocks) { + __private const int batch) { const int output_channel_block_idx = get_global_id(0); const int output_width_block_idx = get_global_id(1); const int output_batch_height_block_idx = get_global_id(2); @@ -77,7 +77,7 @@ __kernel void bilinear_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input, float factor_w = (in_w_idx - (int)floor(in_w_idx)); float factor_h = (in_h_idx - (int)floor(in_h_idx)); - const int inp_offset_base = (output_batch_idx * channelBlocks + output_channel_block_idx) * input_height; + const int inp_offset_base = (output_batch_idx + output_channel_block_idx*batch) * input_height; const int inp_offset_00 = (inp_offset_base + in_h0_index) * input_width + in_w0_index; const int inp_offset_01 = (inp_offset_base + in_h0_index) * input_width + in_w1_index; const int inp_offset_10 = (inp_offset_base + in_h1_index) * input_width + in_w0_index; @@ -90,7 +90,7 @@ __kernel void bilinear_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input, FLOAT4 value = CONVERT_FLOAT4((float4)((1.0-factor_w)*(1.0-factor_h))*convert_float4(value_00) + (float4)(factor_w*(1.0-factor_h))*convert_float4(value_01) + (float4)((1.0-factor_w)*factor_h)*convert_float4(value_10) + (float4)(factor_w*factor_h)*convert_float4(value_11)); - const int out_offset = ((output_batch_idx * channelBlocks + output_channel_block_idx) * out_height + output_height_idx) * out_width + output_width_block_idx; + const int out_offset = ((output_batch_idx + output_channel_block_idx*batch) * out_height + output_height_idx) * out_width + output_width_block_idx; vstore4(value, out_offset, output); } @@ -109,7 +109,7 @@ __kernel void nearest3D_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input, __private const int out_depth, __private const int out_height, __private const int out_width, - __private const int channelBlocks) { + __private const int batch) { const int output_channel_block_idx = get_global_id(0); const int output_height_width_block_idx = get_global_id(1); const int output_batch_depth_block_idx = get_global_id(2); @@ -129,11 +129,11 @@ __kernel void nearest3D_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input, const int in_h_index = min(max(0, (int)floor(in_h_idx)), input_height-1); const int in_w_index = min(max(0, (int)floor(in_w_idx)), input_width-1); - const int inp_offset = (((output_batch_idx * channelBlocks + output_channel_block_idx) + const int inp_offset = (((output_batch_idx + output_channel_block_idx*batch) * input_depth + in_d_index) * input_height + in_h_index) * input_width + in_w_index; - const int out_offset = (((output_batch_idx * channelBlocks + output_channel_block_idx) + const int out_offset = (((output_batch_idx + output_channel_block_idx*batch) * out_depth + output_depth_idx) * out_height + output_height_idx) * out_width + output_width_idx; FLOAT4 value = vload4(inp_offset, input); vstore4(value, out_offset, output); -} \ No newline at end of file +} diff --git a/source/backend/opencl/execution/cl/layernorm_buf.cl b/source/backend/opencl/execution/cl/layernorm_buf.cl index 3ee18e085..eab7caf8f 100644 --- a/source/backend/opencl/execution/cl/layernorm_buf.cl +++ b/source/backend/opencl/execution/cl/layernorm_buf.cl @@ -2,274 +2,48 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable #endif -__kernel void layernorm_w_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2, - __global const FLOAT * input, - __global FLOAT * output, - __private const int width, - __private const int height, - __private const int channel, -#ifdef GAMMA_BETA - __global const FLOAT *gamma, - __global const FLOAT *beta, -#endif - __private float epsilon){ - int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2)); - float4 local sum[LOCAL_SIZE]; - if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) { - const int h = pos.y % height; - const int c = pos.y / height; - const int b = pos.z; - const int lid = get_local_id(0); - const int channel4 = (channel + 3) / 4; - const int offset = ((b * channel4 + c) * height + h) * width * 4; - - float4 in_sum = 0; -#ifdef RMSNORM - float4 mean = 0; -#else - for(int i = lid; i < width; i+=LOCAL_SIZE){ - float4 in = convert_float4(vload4(i, input + offset)); - in_sum += in; - } - sum[lid] = in_sum; - barrier(CLK_LOCAL_MEM_FENCE); - for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ - if (lid < i) - sum[lid] = sum[lid] + sum[lid + i]; - barrier(CLK_LOCAL_MEM_FENCE); - } - - float4 mean = sum[0] / (float4)width; -#endif - in_sum = 0; - for(int i = lid; i < width; i+=LOCAL_SIZE){ - float4 in = convert_float4(vload4(i, input + offset)); - in_sum += (in - mean) * (in - mean); - } - sum[lid] = in_sum; - barrier(CLK_LOCAL_MEM_FENCE); - for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ - if (lid < i) - sum[lid] = sum[lid] + sum[lid + i]; - barrier(CLK_LOCAL_MEM_FENCE); - } - float4 square_sum = sum[0] / (float4)width; - float4 value = (float4)1.0f / (float4)sqrt(square_sum + (float4)epsilon); - for(int i = lid; i < width; i+=LOCAL_SIZE){ - float4 in = convert_float4(vload4(i, input + offset)); -#ifdef GAMMA_BETA - float4 out = (in - mean) * value * (float4)gamma[i] + (float4)beta[i]; -#else - float4 out = (in - mean) * value; -#endif - vstore4(CONVERT_FLOAT4(out), i, output + offset); - } - } -} - - -__kernel void layernorm_hw_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2, - __global const FLOAT * input, - __global FLOAT * output, - __private const int width, - __private const int height, - __private const int channel, -#ifdef GAMMA_BETA - __global const FLOAT *gamma, - __global const FLOAT *beta, -#endif - __private float epsilon){ - int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2)); - float4 local sum[LOCAL_SIZE]; - if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) { - const int c = pos.y; - const int b = pos.z; - const int height_width = height * width; - const int channel4 = (channel + 3) / 4; - const int lid = get_local_id(0); - const int offset = ((b * channel4 + c) * height) * width * 4; - - float4 in_sum = 0; -#ifdef RMSNORM - float4 mean = 0; -#else - for(int i = lid; i < height_width; i+=LOCAL_SIZE){ - float4 in = convert_float4(vload4(i, input + offset)); - in_sum += in; - } - sum[lid] = in_sum; - barrier(CLK_LOCAL_MEM_FENCE); - for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ - if (lid < i) - sum[lid] = sum[lid] + sum[lid + i]; - barrier(CLK_LOCAL_MEM_FENCE); - } - - float4 mean = sum[0] / (float4)height_width; -#endif - in_sum = 0; - for(int i = lid; i < height_width; i+=LOCAL_SIZE){ - float4 in = convert_float4(vload4(i, input + offset)); - in_sum += (in - mean) * (in - mean); - } - sum[lid] = in_sum; - barrier(CLK_LOCAL_MEM_FENCE); - for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ - if (lid < i) - sum[lid] = sum[lid] + sum[lid + i]; - barrier(CLK_LOCAL_MEM_FENCE); - } - float4 square_sum = sum[0] / (float4)height_width; - float4 value = (float4)1.0f / (float4)sqrt(square_sum + (float4)epsilon); - for(int i = lid; i < height_width; i+=LOCAL_SIZE){ - float4 in = convert_float4(vload4(i, input + offset)); -#ifdef GAMMA_BETA - float4 out = (in - mean) * value * (float4)gamma[i] + (float4)beta[i]; -#else - float4 out = (in - mean) * value; -#endif - vstore4(CONVERT_FLOAT4(out), i, output + offset); - } - } -} - -__kernel void layernorm_chw_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2, - __global const FLOAT * input, - __global FLOAT * output, - __private const int width, - __private const int height, - __private const int channel, -#ifdef GAMMA_BETA - __global const FLOAT *gamma, - __global const FLOAT *beta, -#endif - __private float epsilon){ - int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2)); - float local sum[LOCAL_SIZE]; - if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) { - const int b = pos.z; - const int sum_size = width * height * channel; - const int reduce_size = width * height; - const int lid = get_local_id(0); - const int channel4 = (channel + 3) / 4; - const int channel_remain = channel - (channel4 - 1) * 4; - const int offset = ((b * channel4) * height) * width * 4; - const int wh_offset = height * width * 4; - - float4 in_sum = 0; - float4 in_sum_left = 0; - float *in_sum_left_ptr = (float*)(&in_sum_left); -#ifdef RMSNORM - float4 mean = 0; -#else - for(int c = 0; c < channel4 - 1; ++c){ - for(int i = lid; i < reduce_size; i+=LOCAL_SIZE){ - float4 in = convert_float4(vload4(i, input + offset + c * wh_offset)); - in_sum += in; - } - } - for(int i = lid; i < reduce_size; i+=LOCAL_SIZE){ - float4 in = convert_float4(vload4(i, input + offset + (channel4 - 1) * wh_offset)); - in_sum_left += in; - } - in_sum.x = in_sum.x + in_sum.y + in_sum.z + in_sum.w; - for(int i = 1; i < channel_remain; ++i){ - in_sum_left_ptr[0] += in_sum_left_ptr[i]; - } - sum[lid] = in_sum.x + in_sum_left.x; - barrier(CLK_LOCAL_MEM_FENCE); - for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ - if (lid < i) - sum[lid] = sum[lid] + sum[lid + i]; - barrier(CLK_LOCAL_MEM_FENCE); - } - - float4 mean = sum[0] / (float4)sum_size; -#endif - in_sum = 0; - in_sum_left = 0; - for(int c = 0; c < channel4 - 1; ++c){ - for(int i = lid; i < reduce_size; i+=LOCAL_SIZE){ - float4 in = convert_float4(vload4(i, input + offset + c * wh_offset)); - in_sum += (in - mean) * (in - mean); - } - } - - for(int i = lid; i < reduce_size; i+=LOCAL_SIZE){ - float4 in = convert_float4(vload4(i, input + offset + (channel4 - 1) * wh_offset)); - in_sum_left += (in - mean) * (in - mean); - } - - in_sum.x = in_sum.x + in_sum.y + in_sum.z + in_sum.w; - for(int i = 1; i < channel_remain; ++i){ - in_sum_left_ptr[0] += in_sum_left_ptr[i]; - } - - sum[lid] = in_sum.x + in_sum_left.x; - barrier(CLK_LOCAL_MEM_FENCE); - for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ - if (lid < i) - sum[lid] = sum[lid] + sum[lid + i]; - barrier(CLK_LOCAL_MEM_FENCE); - } - float4 square_sum = sum[0] / (float4)sum_size; - float4 value = (float4)1.0f / (float4)sqrt(square_sum + (float4)epsilon); - for(int c = 0; c < channel4; ++c){ - for(int i = lid; i < reduce_size; i+=LOCAL_SIZE){ - float4 in = convert_float4(vload4(i, input + offset + c * wh_offset)); -#ifdef GAMMA_BETA - float4 out = (in - mean) * value * (float4)gamma[c * reduce_size + i] + (float4)beta[c * reduce_size + i]; -#else - float4 out = (in - mean) * value; -#endif - vstore4(CONVERT_FLOAT4(out), i, output + offset + c * wh_offset); - } - } - } -} - - -__kernel void layernorm_plain_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2, +__kernel void layernorm_buf(__private int global_dim0, __private int global_dim1, __global const FLOAT * input, __global FLOAT * output, __private const int inside, - __private const int outside, #ifdef GAMMA_BETA __global const FLOAT *gamma, __global const FLOAT *beta, #endif __private float epsilon){ - int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2)); - COMPUTE_FLOAT local sum[LOCAL_SIZE]; - if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) { - const int idx_out = pos.z; + int2 pos = (int2)(get_global_id(0), get_global_id(1)); +#if LOCAL_SIZE > 1 + float local sum[LOCAL_SIZE]; + if (pos.x < global_dim0 && pos.y < global_dim1) { const int lid = get_local_id(0); - const int offset = idx_out * inside; + const int offset = pos.y * inside; const int inside_v4 = (inside + 3) >> 2; + #ifdef PACK_LEAVE + const int loop = inside_v4 - 1; const int inside_remain = inside - ((inside_v4-1) << 2); - - COMPUTE_FLOAT4 in_sum = 0; + #else + const int loop = inside_v4; + #endif + + float4 in_sum = 0; int index = lid; - for(; index < inside_v4 - 1; index+=LOCAL_SIZE){ - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(index, input + offset)); + #ifdef RMSNORM + float4 mean = (float4)0; + #else + for(; index < loop; index+=LOCAL_SIZE){ + float4 in = convert_float4(vload4(index, input + offset)); in_sum += in; } sum[lid] = in_sum.x + in_sum.y + in_sum.z+ in_sum.w; - COMPUTE_FLOAT4 in_left = 0; + #ifdef PACK_LEAVE if(index == inside_v4 - 1) { - in_left = CONVERT_COMPUTE_FLOAT4(vload4(inside_v4 - 1, input + offset)); - sum[lid] = sum[lid] + in_left.x; - if(inside_remain > 1) { - sum[lid] = sum[lid] + in_left.y; - } - if(inside_remain > 2) { - sum[lid] = sum[lid] + in_left.z; - } - if(inside_remain > 3) { - sum[lid] = sum[lid] + in_left.w; + for(int i = 0; i < inside_remain; ++i) + float in = input[offset + index * 4 + i]; + sum[lid] = sum[lid] + in; } } + #endif barrier(CLK_LOCAL_MEM_FENCE); for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ @@ -278,47 +52,87 @@ __kernel void layernorm_plain_buf(__private int global_dim0, __private int globa barrier(CLK_LOCAL_MEM_FENCE); } - COMPUTE_FLOAT4 mean = sum[0] / (COMPUTE_FLOAT4)inside; + float4 mean = sum[0] / (float4)inside; + #endif in_sum = 0; index = lid; - for(; index < inside_v4 - 1; index+=LOCAL_SIZE){ - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(index, input + offset)); + for(; index < loop; index+=LOCAL_SIZE){ + float4 in = convert_float4(vload4(index, input + offset)); in_sum += (in - mean) * (in - mean); } sum[lid] = in_sum.x + in_sum.y + in_sum.z + in_sum.w; - + #ifdef PACK_LEAVE if(index == inside_v4 - 1) { - COMPUTE_FLOAT4 in_left = CONVERT_COMPUTE_FLOAT4(vload4(inside_v4 - 1, input + offset)); - in_sum = (in_left - mean) * (in_left - mean); - sum[lid] = sum[lid] + in_sum.x; - if(inside_remain > 1) { - sum[lid] = sum[lid] + in_sum.y; - } - if(inside_remain > 2) { - sum[lid] = sum[lid] + in_sum.z; - } - if(inside_remain > 3) { - sum[lid] = sum[lid] + in_sum.w; + for(int i = 0; i < inside_remain; ++i) + float in = input[offset + index * 4 + i]; + in = (in - mean) * (in - mean); + sum[lid] = sum[lid] + in; } } + #endif barrier(CLK_LOCAL_MEM_FENCE); for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ if (lid < i) sum[lid] = sum[lid] + sum[lid + i]; barrier(CLK_LOCAL_MEM_FENCE); } - COMPUTE_FLOAT4 square_sum = sum[0] / (COMPUTE_FLOAT4)inside; - COMPUTE_FLOAT4 value = (COMPUTE_FLOAT4)1.0f / (COMPUTE_FLOAT4)sqrt(square_sum + (COMPUTE_FLOAT4)epsilon); - - for(int i = lid; i < inside_v4; i+=LOCAL_SIZE){ - COMPUTE_FLOAT4 in = CONVERT_COMPUTE_FLOAT4(vload4(i, input + offset)); -#ifdef GAMMA_BETA - COMPUTE_FLOAT4 out = (in - mean) * value * CONVERT_COMPUTE_FLOAT4(vload4(i, gamma)) + CONVERT_COMPUTE_FLOAT4(vload4(i, beta)); + float4 square_sum = sum[0] / (float4)inside; + float4 value = (float4)1.0f / (float4)sqrt(square_sum + (float4)epsilon); + index = lid; + for(; index < loop; index+=LOCAL_SIZE){ + float4 in = convert_float4(vload4(index, input + offset)); + #ifdef GAMMA_BETA + float4 out = (in - mean) * value * convert_float4(vload4(index, gamma)) + convert_float4(vload4(index, beta)); + #else + float4 out = (in - mean) * value; + #endif + vstore4(CONVERT_FLOAT4(out), index, output + offset); + } + #ifdef PACK_LEAVE + if(index == inside_v4 - 1) { + for(int i = 0; i < inside_remain; ++i){ + float in = input[offset + index * 4 + i]; + #ifdef GAMMA_BETA + float out = (in - mean.x) * value.x * (float)gamma[index * 4 + i] + (float)beta[index * 4 + i]; + #else + float out = (in - mean.x) * value.x; + #endif + output[offset + index * 4 + i] = out; + } + } + #endif + } #else - COMPUTE_FLOAT4 out = (in - mean) * value; -#endif - vstore4(CONVERT_FLOAT4(out), i, output + offset); + if (pos.x < global_dim0 && pos.y < global_dim1) { + const int offset = pos.y * inside; + #ifdef RMSNORM + float mean = 0; + #else + float in_sum = 0; + for(int index = 0; index < inside; index++){ + in_sum += (float)input[offset + index]; + } + float mean = in_sum / inside; + #endif + + in_sum = 0; + for(int index = 0; index < inside; index++){ + float in = (float)input[offset + index]; + in_sum += (in - mean) * (in - mean); + } + float square_sum = in_sum / inside; + float value = 1.0f / sqrt(square_sum + epsilon); + for(int i = 0; i < inside; ++i){ + float in = input[offset + i]; + #ifdef GAMMA_BETA + float out = (in - mean) * value * (float)gamma[i] + (float)beta[i]; + #else + float out = (in - mean) * value; + #endif + output[offset + i] = out; } } + +#endif } diff --git a/source/backend/opencl/execution/cl/loop_buf.cl b/source/backend/opencl/execution/cl/loop_buf.cl index de7cb5725..c1a0521dc 100644 --- a/source/backend/opencl/execution/cl/loop_buf.cl +++ b/source/backend/opencl/execution/cl/loop_buf.cl @@ -21,7 +21,7 @@ #define TSH 8 // thread handle size H dimension #endif -// [N C4 H 1 4] -> [N H C 1] +// [C4 N H 1 4] -> [N H C 1] __kernel void tile_trans_3d_buf(__global INPUT_TYPE* input, __global OUTPUT_TYPE* output, __private const int widthPad, @@ -39,7 +39,6 @@ __kernel void tile_trans_3d_buf(__global INPUT_TYPE* input, // group id const int c = get_group_id(0) * WGSC; const int h = get_group_id(1) * WGSH; - const int channel_4 = (channel + 3) >> 2; int jc = lidc; int ih = lidh; @@ -53,7 +52,7 @@ __kernel void tile_trans_3d_buf(__global INPUT_TYPE* input, int offset_h = i * WGSH / TSH + ih; int offset_c = j * WGSC / TSC + jc ; // [TSH, WGSH / TSH] [TSC / 4, WGSC / TSC, 4] - localData[offset_h][offset_c] = (h + offset_h >= height || c + 4 * offset_c >= channel) ? (INPUT_TYPE4)0 : vload4(0, input + ((b * channel_4 + (c/4+offset_c)) * height + (h+offset_h)) * 4); + localData[offset_h][offset_c] = (h + offset_h >= height || c + 4 * offset_c >= channel) ? (INPUT_TYPE4)0 : vload4(0, input + ((b + (c/4+offset_c)*batch) * height + (h+offset_h)) * 4); } } @@ -78,7 +77,7 @@ __kernel void tile_trans_3d_buf(__global INPUT_TYPE* input, } } } -// [N C4 H W 4] -> [N C W H] +// [C4 N H W 4] -> [N C W H] __kernel void tile_trans_4d_buf(__global INPUT_TYPE* input, __global OUTPUT_TYPE* output, __private const int widthPad, @@ -99,7 +98,6 @@ __kernel void tile_trans_4d_buf(__global INPUT_TYPE* input, // group id const int w = get_group_id(0) * WGSW; const int h = get_group_id(1) * WGSH; - const int channel_4 = (channel + 3) >> 2; int jw = lidw; int ih = lidh; @@ -112,7 +110,7 @@ __kernel void tile_trans_4d_buf(__global INPUT_TYPE* input, for(int j = 0; j < TSW; j++) { int offset_h = h + ih + i * WGSH/TSH; int offset_w = w + jw + j * WGSW/TSW; - localData[ih + i * WGSH / TSH][jw + j * WGSW/TSW] = (offset_h >= height || offset_w >= width) ? (INPUT_TYPE4)0 : vload4(0, input + (((b * channel_4 + c4) * height + offset_h) * width + offset_w) * 4); + localData[ih + i * WGSH / TSH][jw + j * WGSW/TSW] = (offset_h >= height || offset_w >= width) ? (INPUT_TYPE4)0 : vload4(0, input + (((b + c4 * batch) * height + offset_h) * width + offset_w) * 4); } } @@ -234,8 +232,8 @@ __kernel void tile_buf(__private int global_dim0, __private int global_dim1, __p const int c = c_4 << 2; const int x_src_pitch = 4; const int y_src_pitch = x_src_pitch * width; - const int c_src_pitch = y_src_pitch * height; - const int b_src_pitch = c_src_pitch * ((channel + 3) / 4); + const int b_src_pitch = y_src_pitch * height; + const int c_src_pitch = b_src_pitch * batch; bool outBound = (w >= width || h >= height || c >= channel); #ifdef MNN_NHWC @@ -390,156 +388,32 @@ __kernel void pack_buf(__private int global_dim0, __private int global_dim1, __p } #ifdef LOOP_BINARY_OPERATOR -__kernel void broadcast_binary_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2, +__kernel void loop_binary_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2, __global OUTPUT_TYPE* output, __global INPUT_TYPE* input0, __global INPUT_TYPE* input1, - __private const int8 src0_size, //(batch, channel, height, width) - __private const int4 src0C4_size, // nc4hw4 - __private const int8 src1_size, - __private const int4 src1C4_size, - __private const int8 dst_size, - __private const int dst_width, - __private const int dst_height, - __private const int dst_channel, - __private const int channel_block) { - int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2)); - - if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) { - - const int wo = pos.x; - const int ho = pos.y; - const int co = pos.z % channel_block; - const int no = pos.z / channel_block; - const int output_offset = ((((no * channel_block) + co) * dst_height + ho) * dst_width + wo) * 4; - int co4 = co << 2; - int4 covec = (int4)(co4 % dst_channel, (co4 + 1) % dst_channel, (co4 + 2) % dst_channel, (co4 + 3) % dst_channel); - int4 out_offset = ((no * dst_channel + covec) * dst_height + ho) * dst_width + wo; - int4 w = out_offset % (dst_size.s3 * dst_size.s4); out_offset /= (dst_size.s3 * dst_size.s4); - int4 h = out_offset % dst_size.s2; out_offset /= dst_size.s2; - int4 c = out_offset % dst_size.s1; out_offset /= dst_size.s1; - int4 n = out_offset % dst_size.s0; - float4 in0, in1; - -#ifdef BROADCAST_INPUT1 - in0 = convert_float4(vload4(0, input0 + output_offset)); - const int src1_channel_block = (src1C4_size.y + 3) / 4; - float* in1_ptr = (float*)&in1; - { - int4 w0 = w % (src1_size.s3 * src1_size.s4); - int4 h0 = h % src1_size.s2; - int4 c0 = c % src1_size.s1; - int4 n0 = n % src1_size.s0; - int* w0_ptr = (int*)&w0; - int* h0_ptr = (int*)&h0; - int* c0_ptr = (int*)&c0; - int* n0_ptr = (int*)&n0; - for(int i = 0; i < 4; ++i){ - int c4offset = ((n0_ptr[i] * src1_size.s1 + c0_ptr[i]) * src1_size.s2 + h0_ptr[i]) * src1_size.s3 * src1_size.s4 + w0_ptr[i]; - int wc4 = c4offset % src1C4_size.w; c4offset /= src1C4_size.w; - int hc4 = c4offset % src1C4_size.z; c4offset /= src1C4_size.z; - int cc4 = c4offset % src1C4_size.y; c4offset /= src1C4_size.y; - int nc4 = c4offset % src1C4_size.x; - int cc4_offset = cc4 / 4; - int cc4_remain = cc4 % 4; - in1_ptr[i] = (float)input1[((((nc4 * src1_channel_block) + cc4_offset) * src1C4_size.z + hc4) * src1C4_size.w + wc4) * 4 + cc4_remain]; - } - } -#else - const int src0_channel_block = (src0C4_size.y + 3) / 4; - float* in0_ptr = (float*)&in0; - { - int4 w0 = w % (src0_size.s3 * src0_size.s4); - int4 h0 = h % src0_size.s2; - int4 c0 = c % src0_size.s1; - int4 n0 = n % src0_size.s0; - int* w0_ptr = (int*)&w0; - int* h0_ptr = (int*)&h0; - int* c0_ptr = (int*)&c0; - int* n0_ptr = (int*)&n0; - for(int i = 0; i < 4; ++i){ - int c4offset = ((n0_ptr[i] * src0_size.s1 + c0_ptr[i]) * src0_size.s2 + h0_ptr[i]) * src0_size.s3 * src0_size.s4 + w0_ptr[i]; - int wc4 = c4offset % src0C4_size.w; c4offset /= src0C4_size.w; - int hc4 = c4offset % src0C4_size.z; c4offset /= src0C4_size.z; - int cc4 = c4offset % src0C4_size.y; c4offset /= src0C4_size.y; - int nc4 = c4offset % src0C4_size.x; - int cc4_offset = cc4 / 4; - int cc4_remain = cc4 % 4; - in0_ptr[i] = (float)input0[((((nc4 * src0_channel_block) + cc4_offset) * src0C4_size.z + hc4) * src0C4_size.w + wc4) * 4 + cc4_remain]; - } - } - in1 = convert_float4(vload4(0, input1 + output_offset)); -#endif - float4 out = LOOP_BINARY_OPERATOR; - vstore4(CONVERT_OUTPUT4(out), 0, output + output_offset); - } -} - -__kernel void broadcast_binary_channel_equall_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2, - __global OUTPUT_TYPE* output, __global INPUT_TYPE* input0, __global INPUT_TYPE* input1, - __private const int8 src0_size, //(batch, channel, height, width) - __private const int4 src0C4_size, // nc4hw4 - __private const int8 src1_size, - __private const int4 src1C4_size, - __private const int8 dst_size, - __private const int dst_width, - __private const int dst_height, - __private const int dst_channel, - __private const int channel_block) { - int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2)); - - if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) { - const int wo = pos.x; - const int ho = pos.y; - const int co = pos.z % channel_block; - const int no = pos.z / channel_block; - const int output_offset = ((((no * channel_block) + co) * dst_height + ho) * dst_width + wo) * 4; -#ifdef BROADCAST_INPUT1 - const int src1_channel_block = (src1C4_size.y + 3) / 4; - const int input_offset = (((((no % src1_size.s0) * src1_channel_block) + co) * src1C4_size.z + (ho % src1_size.s2)) * src1C4_size.w + (wo % (src1_size.s3 * src1_size.s4))) * 4; - float4 in0 = convert_float4(vload4(0, input0 + output_offset)); - float4 in1 = convert_float4(vload4(0, input1 + input_offset)); -#else - const int src0_channel_block = (src0C4_size.y + 3) / 4; - const int input_offset = (((((no % src0_size.s0) * src0_channel_block) + co) * src0C4_size.z + (ho % src0_size.s2)) * src0C4_size.w + (wo % (src0_size.s3 * src0_size.s4))) * 4; - float4 in0 = convert_float4(vload4(0, input0 + input_offset)); - float4 in1 = convert_float4(vload4(0, input1 + output_offset)); -#endif - float4 out = LOOP_BINARY_OPERATOR; - vstore4(CONVERT_OUTPUT4(out), 0, output + output_offset); - } -} - -//channel = 1 and dimmision = 1 -__kernel void broadcast_binary_dimmision1_channel1_buf(__private int global_dim0, __private int global_dim1, __private int global_dim2, - __global OUTPUT_TYPE* output, __global INPUT_TYPE* input0, __global INPUT_TYPE* input1, - __private const int8 src0_size, //(batch, channel, height, width) - __private const int4 src0C4_size, // nc4hw4 - __private const int8 src1_size, - __private const int4 src1C4_size, - __private const int8 dst_size, - __private const int dst_width, - __private const int dst_height, - __private const int dst_channel, - __private const int channel_block) { - int3 pos = (int3)(get_global_id(0), get_global_id(1), get_global_id(2)); + __private const int input0Stride0, + __private const int input0Stride1, + __private const int input0Stride2, + __private const int input1Stride0, + __private const int input1Stride1, + __private const int input1Stride2, + __private const int outputStride0, + __private const int outputStride1, + __private const int outputStride2 + ) { + + const int x = get_global_id(0); + const int y = get_global_id(1); + const int z = get_global_id(2); - if (pos.x < global_dim0 && pos.y < global_dim1 && pos.z < global_dim2) { - const int wo = pos.x; - const int ho = pos.y; - const int co = pos.z % channel_block; - const int no = pos.z / channel_block; + if (x < global_dim0 && y < global_dim1 && z < global_dim2) { - const int output_offset = ((((no * channel_block) + co) * dst_height + ho) * dst_width + wo) * 4; -#ifdef BROADCAST_INPUT1 - const int input_offset = ((no % src1_size.s0) * src1_size.s2 + (ho % src1_size.s2)) * src1_size.s3 * src1_size.s4 + (wo % (src1_size.s3 * src1_size.s4)); - float4 in0 = convert_float4(vload4(0, input0 + output_offset)); - float4 in1 = (float4)(input1[input_offset]); -#else - const int input_offset = ((no % src0_size.s0) * src0_size.s2 + (ho % src0_size.s2)) * src0_size.s3 * src0_size.s4 + (wo % (src0_size.s3 * src0_size.s4)); - float4 in0 = (float4)(input0[input_offset]); - float4 in1 = convert_float4(vload4(0, input1 + output_offset)); -#endif - float4 out = LOOP_BINARY_OPERATOR; - vstore4(CONVERT_OUTPUT4(out), 0, output + output_offset); + int inputIndex0 = z * input0Stride0 + y * input0Stride1 + x * input0Stride2; + int inputIndex1 = z * input1Stride0 + y * input1Stride1 + x * input1Stride2; + int outputIndex = z * outputStride0 + y * outputStride1 + x * outputStride2; + float in0 = (float)input0[inputIndex0]; + float in1 = (float)input1[inputIndex1]; + float out = LOOP_BINARY_OPERATOR; + output[outputIndex] = (OUTPUT_TYPE)out; } } #endif diff --git a/source/backend/opencl/execution/cl/matmul_buf.cl b/source/backend/opencl/execution/cl/matmul_buf.cl index 4d2b65756..c4ddd12d8 100644 --- a/source/backend/opencl/execution/cl/matmul_buf.cl +++ b/source/backend/opencl/execution/cl/matmul_buf.cl @@ -16,426 +16,170 @@ __kernel void matmul_buf(GLOBAL_SIZE_2_DIMS __global const FLOAT* input_a, __global const FLOAT* input_c, #endif __global FLOAT* output_c, - __private const int channels, - __private const int channel_blocks, - __private const int width_blocks, - __private const int width) { - const int width_blocks_idx = get_global_id(0);// output W - const int height_idx = get_global_id(1);// output H - - DEAL_NON_UNIFORM_DIM2(width_blocks_idx, height_idx); - COMPUTE_FLOAT4 a; - COMPUTE_FLOAT4 b0 = 0, b1 = 0, b2 = 0, b3 = 0; - COMPUTE_FLOAT4 v_zero = (COMPUTE_FLOAT4)((COMPUTE_FLOAT)0.0); + __private const int M, + __private const int N, + __private const int K) { + int2 pos = (int2)(get_global_id(0), get_global_id(1)); // N M + + DEAL_NON_UNIFORM_DIM2(pos.x, pos.y); + const int idn = pos.x << 2; + const int idm = pos.y << 2; + + COMPUTE_FLOAT4 out[4]; #ifdef BIAS - COMPUTE_FLOAT4 temp = CONVERT_COMPUTE_FLOAT4(vload4(width_blocks_idx, input_c)); - - COMPUTE_FLOAT result0 = temp.x; - COMPUTE_FLOAT result1 = temp.y; - COMPUTE_FLOAT result2 = temp.z; - COMPUTE_FLOAT result3 = temp.w; + COMPUTE_FLOAT4 bias = CONVERT_COMPUTE_FLOAT4(vload4(0, input_c + idn)); + #pragma unroll + for(int i = 0; i < 4; ++i){ + out[i] = bias; + } #else - COMPUTE_FLOAT result0 = 0; - COMPUTE_FLOAT result1 = 0; - COMPUTE_FLOAT result2 = 0; - COMPUTE_FLOAT result3 = 0; + #pragma unroll + for(int i = 0; i < 4; ++i){ + out[i] = (COMPUTE_FLOAT4)0; + } #endif - const int remain = channel_blocks*4 - channels; - for (short pos = 0; pos < channel_blocks - 1; pos += 1) { - const int inpa_offset = height_idx * channel_blocks + pos; - a = CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset, input_a)); - - const int inpb_offset = (pos*4) * width_blocks + width_blocks_idx; - - b0 = CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset, input_b)); - b1 = CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + width_blocks, input_b)); - b2 = CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + width_blocks*2, input_b)); - b3 = CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + width_blocks*3, input_b)); - - COMPUTE_FLOAT4 btmp0 = (COMPUTE_FLOAT4)(b0.s0, b1.s0, b2.s0, b3.s0); - COMPUTE_FLOAT4 btmp1 = (COMPUTE_FLOAT4)(b0.s1, b1.s1, b2.s1, b3.s1); - COMPUTE_FLOAT4 btmp2 = (COMPUTE_FLOAT4)(b0.s2, b1.s2, b2.s2, b3.s2); - COMPUTE_FLOAT4 btmp3 = (COMPUTE_FLOAT4)(b0.s3, b1.s3, b2.s3, b3.s3); - - result0 += dot(a, btmp0); - result1 += dot(a, btmp1); - result2 += dot(a, btmp2); - result3 += dot(a, btmp3); - } + const int K4 = (K + 3)/4; + #ifdef K_LEAVE + const int loop_end = max(K4 - 1, 0); + const int remain = K - loop_end*4; + #else + const int loop_end = K4; + #endif - { - const int inpa_offset = height_idx * channel_blocks + channel_blocks - 1; - a = CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset, input_a)); - - const int inpb_offset = ((channel_blocks - 1)*4) * width_blocks + width_blocks_idx; - - b0 = CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset, input_b)); - b1 = (remain >= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + width_blocks, input_b)); - b2 = (remain >= 2) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + width_blocks*2, input_b)); - b3 = (remain >= 1) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + width_blocks*3, input_b)); - if (remain == 3) { - a.y = 0; - a.z = 0; - a.w = 0; - } else if (remain == 2) { - a.z = 0; - a.w = 0; - } else if (remain == 1) { - a.w = 0;; - } - - COMPUTE_FLOAT4 btmp0 = (COMPUTE_FLOAT4)(b0.s0, b1.s0, b2.s0, b3.s0); - COMPUTE_FLOAT4 btmp1 = (COMPUTE_FLOAT4)(b0.s1, b1.s1, b2.s1, b3.s1); - COMPUTE_FLOAT4 btmp2 = (COMPUTE_FLOAT4)(b0.s2, b1.s2, b2.s2, b3.s2); - COMPUTE_FLOAT4 btmp3 = (COMPUTE_FLOAT4)(b0.s3, b1.s3, b2.s3, b3.s3); - - result0 += dot(a, btmp0); - result1 += dot(a, btmp1); - result2 += dot(a, btmp2); - result3 += dot(a, btmp3); - } - - const int out_offset = height_idx * width_blocks + width_blocks_idx; - vstore4(CONVERT_FLOAT4((COMPUTE_FLOAT4)(result0, result1, result2, result3)), out_offset, output_c); -} - -__kernel void matmul_transB_buf(GLOBAL_SIZE_2_DIMS __global const FLOAT* input_a, - __global const FLOAT* input_b, - #ifdef BIAS - __global const FLOAT* input_c, - #endif - __global FLOAT* output_c, - __private const int channels, - __private const int channel_blocks, - __private const int width_blocks, - __private const int width) { - const int width_blocks_idx = get_global_id(0); - const int height_idx = get_global_id(1); - - DEAL_NON_UNIFORM_DIM2(width_blocks_idx, height_idx); - COMPUTE_FLOAT4 a; - COMPUTE_FLOAT4 b0 = 0, b1 = 0, b2 = 0, b3 = 0; - COMPUTE_FLOAT4 v_zero = (COMPUTE_FLOAT4)((COMPUTE_FLOAT)0.0); - - #ifdef BIAS - COMPUTE_FLOAT4 temp = CONVERT_COMPUTE_FLOAT4(vload4(width_blocks_idx, input_c)); - COMPUTE_FLOAT result0 = temp.x; - COMPUTE_FLOAT result1 = temp.y; - COMPUTE_FLOAT result2 = temp.z; - COMPUTE_FLOAT result3 = temp.w; + #ifdef TRANSPOSE_A + __global const FLOAT* input_a_offset = input_a + idm; // K x M #else - COMPUTE_FLOAT result0 = 0; - COMPUTE_FLOAT result1 = 0; - COMPUTE_FLOAT result2 = 0; - COMPUTE_FLOAT result3 = 0; + __global const FLOAT* input_a_offset = input_a + idm * K; // M x K #endif - - const int remaina = channel_blocks*4 - channels; - const int remainb = (width_blocks_idx+1)*4 - width; - for (short pos = 0; pos < channel_blocks - 1; pos += 1) { - const int inpa_offset = height_idx * channel_blocks + pos; - a = CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset, input_a)); - - const int inpb_offset = (width_blocks_idx*4) * channel_blocks + pos; - - b0 = CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset, input_b)); - b1 = (remainb >= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + channel_blocks, input_b)); - b2 = (remainb >= 2) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + channel_blocks*2, input_b)); - b3 = (remainb >= 1) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + channel_blocks*3, input_b)); - - result0 += dot(a, b0); - result1 += dot(a, b1); - result2 += dot(a, b2); - result3 += dot(a, b3); - } - { - const int inpa_offset = height_idx * channel_blocks + channel_blocks - 1; - a = CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset, input_a)); - - const int inpb_offset = (width_blocks_idx*4) * channel_blocks + channel_blocks - 1; - - b0 = CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset, input_b)); - b1 = (remainb >= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + channel_blocks, input_b)); - b2 = (remainb >= 2) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + channel_blocks*2, input_b)); - b3 = (remainb >= 1) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + channel_blocks*3, input_b)); - - if (remaina == 3) { - a.y = 0; - a.z = 0; - a.w = 0; - } else if (remaina == 2) { - a.z = 0; - a.w = 0; - } else if (remaina == 1) { - a.w = 0; - } - - result0 += dot(a, b0); - result1 += dot(a, b1); - result2 += dot(a, b2); - result3 += dot(a, b3); - } - const int out_offset = height_idx * width_blocks + width_blocks_idx; - vstore4(CONVERT_FLOAT4((COMPUTE_FLOAT4)(result0, result1, result2, result3)), out_offset, output_c); -} - - -__kernel void matmul_transA_buf(GLOBAL_SIZE_2_DIMS __global const FLOAT* input_a, - __global const FLOAT* input_b, - #ifdef BIAS - __global const FLOAT* input_c, - #endif - __global FLOAT* output_c, - __private const int channels, - __private const int channel_blocks, - __private const int height, - __private const int height_blocks, - __private const int width_blocks, - __private const int width) { - const int width_blocks_idx = get_global_id(0); - const int height_blocks_idx = get_global_id(1); - - DEAL_NON_UNIFORM_DIM2(width_blocks_idx, height_blocks_idx); - - COMPUTE_FLOAT4 v_zero = (COMPUTE_FLOAT4)((COMPUTE_FLOAT)0.0); - #ifdef BIAS - COMPUTE_FLOAT4 result0 = CONVERT_COMPUTE_FLOAT4(vload4(width_blocks_idx, input_c)); - COMPUTE_FLOAT4 result1 = result0; - COMPUTE_FLOAT4 result2 = result0; - COMPUTE_FLOAT4 result3 = result0; + #ifdef TRANSPOSE_B + __global const FLOAT* input_b_offset = input_b + idn * K; // N x K #else - COMPUTE_FLOAT4 result0 = 0; - COMPUTE_FLOAT4 result1 = 0; - COMPUTE_FLOAT4 result2 = 0; - COMPUTE_FLOAT4 result3 = 0; + __global const FLOAT* input_b_offset = input_b + idn; // K x N #endif - const int remain = channel_blocks*4 - channels; - for (short pos = 0; pos < channel_blocks - 1; pos += 1) { - - const int inpa_offset = (4*pos) * height_blocks + height_blocks_idx; - COMPUTE_FLOAT4 a0 = CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset, input_a)); - COMPUTE_FLOAT4 a1 = CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset + height_blocks, input_a)); - COMPUTE_FLOAT4 a2 = CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset + height_blocks*2, input_a)); - COMPUTE_FLOAT4 a3 = CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset + height_blocks*3, input_a)); - - const int inpb_offset = (4*pos) * width_blocks + width_blocks_idx; - COMPUTE_FLOAT4 b0 = CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset, input_b)); - COMPUTE_FLOAT4 b1 = CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + width_blocks, input_b)); - COMPUTE_FLOAT4 b2 = CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + width_blocks*2, input_b)); - COMPUTE_FLOAT4 b3 = CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + width_blocks*3, input_b)); - - COMPUTE_FLOAT4 a0_trans = (COMPUTE_FLOAT4)(a0.x, a1.x, a2.x, a3.x); - COMPUTE_FLOAT4 a1_trans = (COMPUTE_FLOAT4)(a0.y, a1.y, a2.y, a3.y); - COMPUTE_FLOAT4 a2_trans = (COMPUTE_FLOAT4)(a0.z, a1.z, a2.z, a3.z); - COMPUTE_FLOAT4 a3_trans = (COMPUTE_FLOAT4)(a0.w, a1.w, a2.w, a3.w); - - COMPUTE_FLOAT4 b0_trans = (COMPUTE_FLOAT4)(b0.x, b1.x, b2.x, b3.x); - COMPUTE_FLOAT4 b1_trans = (COMPUTE_FLOAT4)(b0.y, b1.y, b2.y, b3.y); - COMPUTE_FLOAT4 b2_trans = (COMPUTE_FLOAT4)(b0.z, b1.z, b2.z, b3.z); - COMPUTE_FLOAT4 b3_trans = (COMPUTE_FLOAT4)(b0.w, b1.w, b2.w, b3.w); - - //matmul - result0.x += dot(a0_trans, b0_trans); - result0.y += dot(a0_trans, b1_trans); - result0.z += dot(a0_trans, b2_trans); - result0.w += dot(a0_trans, b3_trans); - - result1.x += dot(a1_trans, b0_trans); - result1.y += dot(a1_trans, b1_trans); - result1.z += dot(a1_trans, b2_trans); - result1.w += dot(a1_trans, b3_trans); + for (int k = 0; k < loop_end; ++k) { + int kindex = k << 2; + COMPUTE_FLOAT4 A[4]; // m4 x k4 + COMPUTE_FLOAT4 B[4]; // k4 x n4 + #ifdef TRANSPOSE_A + { + COMPUTE_FLOAT4 tmp0 = CONVERT_COMPUTE_FLOAT4(vload4(0, input_a_offset + kindex * M)); + COMPUTE_FLOAT4 tmp1 = CONVERT_COMPUTE_FLOAT4(vload4(0, input_a_offset + (kindex + 1) * M)); + COMPUTE_FLOAT4 tmp2 = CONVERT_COMPUTE_FLOAT4(vload4(0, input_a_offset + (kindex + 2) * M)); + COMPUTE_FLOAT4 tmp3 = CONVERT_COMPUTE_FLOAT4(vload4(0, input_a_offset + (kindex + 3) * M)); + + A[0] = (COMPUTE_FLOAT4)(tmp0.x, tmp1.x, tmp2.x, tmp3.x); + A[1] = (COMPUTE_FLOAT4)(tmp0.y, tmp1.y, tmp2.y, tmp3.y); + A[2] = (COMPUTE_FLOAT4)(tmp0.z, tmp1.z, tmp2.z, tmp3.z); + A[3] = (COMPUTE_FLOAT4)(tmp0.w, tmp1.w, tmp2.w, tmp3.w); + } + #else + A[0] = CONVERT_COMPUTE_FLOAT4(vload4(0, input_a_offset + kindex)); + A[1] = CONVERT_COMPUTE_FLOAT4(vload4(0, input_a_offset + kindex + K)); + A[2] = CONVERT_COMPUTE_FLOAT4(vload4(0, input_a_offset + kindex + 2 * K)); + A[3] = CONVERT_COMPUTE_FLOAT4(vload4(0, input_a_offset + kindex + 3 * K)); + #endif - result2.x += dot(a2_trans, b0_trans); - result2.y += dot(a2_trans, b1_trans); - result2.z += dot(a2_trans, b2_trans); - result2.w += dot(a2_trans, b3_trans); + #ifdef TRANSPOSE_B + { + COMPUTE_FLOAT4 tmp0 = CONVERT_COMPUTE_FLOAT4(vload4(0, input_b_offset + kindex)); + COMPUTE_FLOAT4 tmp1 = CONVERT_COMPUTE_FLOAT4(vload4(0, input_b_offset + kindex + K)); + COMPUTE_FLOAT4 tmp2 = CONVERT_COMPUTE_FLOAT4(vload4(0, input_b_offset + kindex + 2 * K)); + COMPUTE_FLOAT4 tmp3 = CONVERT_COMPUTE_FLOAT4(vload4(0, input_b_offset + kindex + 3 * K)); + + B[0] = (COMPUTE_FLOAT4)(tmp0.x, tmp1.x, tmp2.x, tmp3.x); + B[1] = (COMPUTE_FLOAT4)(tmp0.y, tmp1.y, tmp2.y, tmp3.y); + B[2] = (COMPUTE_FLOAT4)(tmp0.z, tmp1.z, tmp2.z, tmp3.z); + B[3] = (COMPUTE_FLOAT4)(tmp0.w, tmp1.w, tmp2.w, tmp3.w); + } + #else + B[0] = CONVERT_COMPUTE_FLOAT4(vload4(0, input_b_offset + kindex * N)); + B[1] = CONVERT_COMPUTE_FLOAT4(vload4(0, input_b_offset + (kindex + 1) * N)); + B[2] = CONVERT_COMPUTE_FLOAT4(vload4(0, input_b_offset + (kindex + 2) * N)); + B[3] = CONVERT_COMPUTE_FLOAT4(vload4(0, input_b_offset + (kindex + 3) * N)); + #endif - result3.x += dot(a3_trans, b0_trans); - result3.y += dot(a3_trans, b1_trans); - result3.z += dot(a3_trans, b2_trans); - result3.w += dot(a3_trans, b3_trans); + #pragma unroll + for (int vec_m = 0; vec_m < 4; ++vec_m){ + out[vec_m] = mad((COMPUTE_FLOAT4)A[vec_m].x, B[0], out[vec_m]); + out[vec_m] = mad((COMPUTE_FLOAT4)A[vec_m].y, B[1], out[vec_m]); + out[vec_m] = mad((COMPUTE_FLOAT4)A[vec_m].z, B[2], out[vec_m]); + out[vec_m] = mad((COMPUTE_FLOAT4)A[vec_m].w, B[3], out[vec_m]); + } } - - { - const int inpa_offset = (4*(channel_blocks - 1)) * height_blocks + height_blocks_idx; - COMPUTE_FLOAT4 a0 = CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset, input_a)); - COMPUTE_FLOAT4 a1 = ((remain >= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset + height_blocks, input_a))); - COMPUTE_FLOAT4 a2 = ((remain >= 2) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset + height_blocks*2, input_a))); - COMPUTE_FLOAT4 a3 = ((remain >= 1) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset + height_blocks*3, input_a))); - - const int inpb_offset = (4*(channel_blocks - 1)) * width_blocks + width_blocks_idx; - COMPUTE_FLOAT4 b0 = CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset, input_b)); - COMPUTE_FLOAT4 b1 = ((remain >= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + width_blocks, input_b))); - COMPUTE_FLOAT4 b2 = ((remain >= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + width_blocks*2, input_b))); - COMPUTE_FLOAT4 b3 = ((remain >= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + width_blocks*3, input_b))); - - COMPUTE_FLOAT4 a0_trans = (COMPUTE_FLOAT4)(a0.x, a1.x, a2.x, a3.x); - COMPUTE_FLOAT4 a1_trans = (COMPUTE_FLOAT4)(a0.y, a1.y, a2.y, a3.y); - COMPUTE_FLOAT4 a2_trans = (COMPUTE_FLOAT4)(a0.z, a1.z, a2.z, a3.z); - COMPUTE_FLOAT4 a3_trans = (COMPUTE_FLOAT4)(a0.w, a1.w, a2.w, a3.w); + #ifdef K_LEAVE + for (int k = loop_end << 2; k < K; ++k){ + COMPUTE_FLOAT4 A; // m4 + COMPUTE_FLOAT4 B; // n4 + #ifdef TRANSPOSE_A + A = CONVERT_COMPUTE_FLOAT4(vload4(0, input_a_offset + k * M)); + #else + A.x = (COMPUTE_FLOAT)input_a_offset[k]; + A.y = (COMPUTE_FLOAT)input_a_offset[k + K]; + A.z = (COMPUTE_FLOAT)input_a_offset[k + 2 * K]; + A.w = (COMPUTE_FLOAT)input_a_offset[k + 3 * K]; + #endif - COMPUTE_FLOAT4 b0_trans = (COMPUTE_FLOAT4)(b0.x, b1.x, b2.x, b3.x); - COMPUTE_FLOAT4 b1_trans = (COMPUTE_FLOAT4)(b0.y, b1.y, b2.y, b3.y); - COMPUTE_FLOAT4 b2_trans = (COMPUTE_FLOAT4)(b0.z, b1.z, b2.z, b3.z); - COMPUTE_FLOAT4 b3_trans = (COMPUTE_FLOAT4)(b0.w, b1.w, b2.w, b3.w); - - //matmul - result0.x += dot(a0_trans, b0_trans); - result0.y += dot(a0_trans, b1_trans); - result0.z += dot(a0_trans, b2_trans); - result0.w += dot(a0_trans, b3_trans); - - result1.x += dot(a1_trans, b0_trans); - result1.y += dot(a1_trans, b1_trans); - result1.z += dot(a1_trans, b2_trans); - result1.w += dot(a1_trans, b3_trans); - - result2.x += dot(a2_trans, b0_trans); - result2.y += dot(a2_trans, b1_trans); - result2.z += dot(a2_trans, b2_trans); - result2.w += dot(a2_trans, b3_trans); - - result3.x += dot(a3_trans, b0_trans); - result3.y += dot(a3_trans, b1_trans); - result3.z += dot(a3_trans, b2_trans); - result3.w += dot(a3_trans, b3_trans); + #ifdef TRANSPOSE_B + B.x = (COMPUTE_FLOAT)input_b_offset[k]; + B.y = (COMPUTE_FLOAT)input_b_offset[k + K]; + B.z = (COMPUTE_FLOAT)input_b_offset[k + 2 * K]; + B.w = (COMPUTE_FLOAT)input_b_offset[k + 3 * K]; + #else + B = CONVERT_COMPUTE_FLOAT4(vload4(0, input_b_offset + k * N)); + #endif + out[0] = mad((COMPUTE_FLOAT4)A.x, B, out[0]); + out[1] = mad((COMPUTE_FLOAT4)A.y, B, out[1]); + out[2] = mad((COMPUTE_FLOAT4)A.z, B, out[2]); + out[3] = mad((COMPUTE_FLOAT4)A.w, B, out[3]); } - - const int out_offset = (4*height_blocks_idx) * width_blocks + width_blocks_idx; - - vstore4(CONVERT_FLOAT4(result0), out_offset, output_c); - if(4*height_blocks_idx+1 >= height) return; - vstore4(CONVERT_FLOAT4(result1), out_offset + width_blocks, output_c); - if(4*height_blocks_idx+2 >= height) return; - vstore4(CONVERT_FLOAT4(result2), out_offset + width_blocks*2, output_c); - if(4*height_blocks_idx+3 >= height) return; - vstore4(CONVERT_FLOAT4(result3), out_offset + width_blocks*3, output_c); -} - -__kernel void matmul_transA_transB_buf(GLOBAL_SIZE_2_DIMS __global const FLOAT* input_a, - __global const FLOAT* input_b, - #ifdef BIAS - __global const FLOAT* input_c, - #endif - __global FLOAT* output_c, - __private const int channels, - __private const int channel_blocks, - __private const int height, - __private const int height_blocks, - __private const int width_blocks, - __private const int width) { - const int width_blocks_idx = get_global_id(0); - const int height_blocks_idx = get_global_id(1); - - DEAL_NON_UNIFORM_DIM2(width_blocks_idx, height_blocks_idx); - - COMPUTE_FLOAT4 v_zero = (COMPUTE_FLOAT4)((COMPUTE_FLOAT)0.0); - #ifdef BIAS - COMPUTE_FLOAT4 result0 = CONVERT_COMPUTE_FLOAT4(vload4(width_blocks_idx, input_c)); - - COMPUTE_FLOAT4 result1 = result0; - COMPUTE_FLOAT4 result2 = result0; - COMPUTE_FLOAT4 result3 = result0; - #else - COMPUTE_FLOAT4 result0 = 0; - COMPUTE_FLOAT4 result1 = 0; - COMPUTE_FLOAT4 result2 = 0; - COMPUTE_FLOAT4 result3 = 0; #endif - const int remaina = channel_blocks * 4 - channels; - const int remainb = (width_blocks_idx + 1) * 4 - width; - for (short pos = 0; pos < channel_blocks - 1; pos += 1) { - const int inpa_offset = (4*pos) * height_blocks + height_blocks_idx; - COMPUTE_FLOAT4 a0 = CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset, input_a)); - COMPUTE_FLOAT4 a1 = CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset + height_blocks, input_a)); - COMPUTE_FLOAT4 a2 = CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset + height_blocks*2, input_a)); - COMPUTE_FLOAT4 a3 = CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset + height_blocks*3, input_a)); - - const int inpb_offset = (4*width_blocks_idx) * channel_blocks + pos; - COMPUTE_FLOAT4 b0 = CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset, input_b)); - COMPUTE_FLOAT4 b1 = ((remainb >= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + channel_blocks, input_b))); - COMPUTE_FLOAT4 b2 = ((remainb >= 2) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + channel_blocks*2, input_b))); - COMPUTE_FLOAT4 b3 = ((remainb >= 1) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + channel_blocks*3, input_b))); - - COMPUTE_FLOAT4 a0_trans = (COMPUTE_FLOAT4)(a0.x, a1.x, a2.x, a3.x); - COMPUTE_FLOAT4 a1_trans = (COMPUTE_FLOAT4)(a0.y, a1.y, a2.y, a3.y); - COMPUTE_FLOAT4 a2_trans = (COMPUTE_FLOAT4)(a0.z, a1.z, a2.z, a3.z); - COMPUTE_FLOAT4 a3_trans = (COMPUTE_FLOAT4)(a0.w, a1.w, a2.w, a3.w); - - //matmul - result0.x += dot(a0_trans, b0); - result0.y += dot(a0_trans, b1); - result0.z += dot(a0_trans, b2); - result0.w += dot(a0_trans, b3); - - result1.x += dot(a1_trans, b0); - result1.y += dot(a1_trans, b1); - result1.z += dot(a1_trans, b2); - result1.w += dot(a1_trans, b3); - - result2.x += dot(a2_trans, b0); - result2.y += dot(a2_trans, b1); - result2.z += dot(a2_trans, b2); - result2.w += dot(a2_trans, b3); - - result3.x += dot(a3_trans, b0); - result3.y += dot(a3_trans, b1); - result3.z += dot(a3_trans, b2); - result3.w += dot(a3_trans, b3); - } - { - const int inpa_offset = (4*(channel_blocks-1)) * height_blocks + height_blocks_idx; - COMPUTE_FLOAT4 a0 = CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset, input_a)); - COMPUTE_FLOAT4 a1 = ((remaina >= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset + height_blocks, input_a))); - COMPUTE_FLOAT4 a2 = ((remaina >= 2) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset + height_blocks*2, input_a))); - COMPUTE_FLOAT4 a3 = ((remaina >= 1) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset + height_blocks*3, input_a))); - - const int inpb_offset = (4*width_blocks_idx) * channel_blocks + channel_blocks-1; - COMPUTE_FLOAT4 b0 = CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset, input_b)); - COMPUTE_FLOAT4 b1 = ((remainb >= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + channel_blocks, input_b))); - COMPUTE_FLOAT4 b2 = ((remainb >= 2) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + channel_blocks*2, input_b))); - COMPUTE_FLOAT4 b3 = ((remainb >= 1) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset + channel_blocks*3, input_b))); - - COMPUTE_FLOAT4 a0_trans = (COMPUTE_FLOAT4)(a0.x, a1.x, a2.x, a3.x); - COMPUTE_FLOAT4 a1_trans = (COMPUTE_FLOAT4)(a0.y, a1.y, a2.y, a3.y); - COMPUTE_FLOAT4 a2_trans = (COMPUTE_FLOAT4)(a0.z, a1.z, a2.z, a3.z); - COMPUTE_FLOAT4 a3_trans = (COMPUTE_FLOAT4)(a0.w, a1.w, a2.w, a3.w); - - //matmul - result0.x += dot(a0_trans, b0); - result0.y += dot(a0_trans, b1); - result0.z += dot(a0_trans, b2); - result0.w += dot(a0_trans, b3); - - result1.x += dot(a1_trans, b0); - result1.y += dot(a1_trans, b1); - result1.z += dot(a1_trans, b2); - result1.w += dot(a1_trans, b3); - - result2.x += dot(a2_trans, b0); - result2.y += dot(a2_trans, b1); - result2.z += dot(a2_trans, b2); - result2.w += dot(a2_trans, b3); - - result3.x += dot(a3_trans, b0); - result3.y += dot(a3_trans, b1); - result3.z += dot(a3_trans, b2); - result3.w += dot(a3_trans, b3); + const int out_offset = idm * N + idn; + #ifdef M_LEAVE + if(idm + 3 >= M){ + #ifdef N_LEAVE + if(idn + 3 >= N){ + for (int vec_m = 0; vec_m < M - idm; ++vec_m){ + COMPUTE_FLOAT *out_ptr = (COMPUTE_FLOAT*)&out[vec_m]; + for(int vec_n = 0; vec_n < N - idn; ++vec_n){ + output_c[out_offset + vec_m * N + vec_n] = out_ptr[vec_n]; + } + } + } else { + #endif + for (int vec_m = 0; vec_m < M - idm; ++vec_m){ + vstore4(CONVERT_FLOAT4(out[vec_m]), 0, output_c + out_offset + vec_m * N); + } + + #ifdef N_LEAVE + } + #endif + } else{ + #endif + #ifdef N_LEAVE + if(idn + 3 >= N){ + #pragma unroll + for (int vec_m = 0; vec_m < 4; ++vec_m){ + COMPUTE_FLOAT *out_ptr = (COMPUTE_FLOAT*)&out[vec_m]; + for(int vec_n = 0; vec_n < N - idn; ++vec_n){ + output_c[out_offset + vec_m * N + vec_n] = out_ptr[vec_n]; + } + } + } else { + #endif + #pragma unroll + for (int vec_m = 0; vec_m < 4; ++vec_m){ + vstore4(CONVERT_FLOAT4(out[vec_m]), 0, output_c + out_offset + vec_m * N); + } + #ifdef N_LEAVE + } + #endif + #ifdef M_LEAVE } - - const int out_offset = (4*height_blocks_idx) * width_blocks + width_blocks_idx; - - vstore4(CONVERT_FLOAT4(result0), out_offset, output_c); - if(4*height_blocks_idx+1 >= height) return; - vstore4(CONVERT_FLOAT4(result1), out_offset + width_blocks, output_c); - if(4*height_blocks_idx+2 >= height) return; - vstore4(CONVERT_FLOAT4(result2), out_offset + width_blocks*2, output_c); - if(4*height_blocks_idx+3 >= height) return; - vstore4(CONVERT_FLOAT4(result3), out_offset + width_blocks*3, output_c); + #endif } diff --git a/source/backend/opencl/execution/cl/matmul_params_buf.cl b/source/backend/opencl/execution/cl/matmul_params_buf.cl index c4520fc8e..a96f2caa7 100644 --- a/source/backend/opencl/execution/cl/matmul_params_buf.cl +++ b/source/backend/opencl/execution/cl/matmul_params_buf.cl @@ -83,6 +83,8 @@ // 2 -> with bias (eltwise_add) [M, N] // 3 -> with bias (eltwise_sub) [M, N] // 4 -> with bias (eltwise_sub and get negative) [M, N] +// 5 -> with bias (mask 0 for invalid) [M, N] + #ifndef BIAS_TYPE #define BIAS_TYPE 0 #endif @@ -95,6 +97,8 @@ #define DEAL_BIAS(x, a) x = x - a #elif BIAS_TYPE == 4 #define DEAL_BIAS(x, a) x = a - x +#elif BIAS_TYPE == 5 +#define DEAL_BIAS(x, a) x = (a == 0 ? (FLOAT)(-FLT_MAX) : x) #endif // By default the workgroup size requirement is enabled. For Qualcomm devices the workgroup size @@ -103,7 +107,32 @@ #define RELAX_WORKGROUP_SIZE 0 #endif -#define ZERO (FLOAT)0.0f +typedef float real_arg; +#define GetRealArg(x) (FLOAT)x +typedef FLOAT real; + +#ifndef PRECISION_COMPUTE +#define PRECISION_COMPUTE COMPUTE_FLOAT +#define CONVERT_PRECISION_COMPUTE(x) CONVERT_COMPUTE_FLOAT(x) +#endif +#ifndef PRECISION_COMPUTE2 +#define PRECISION_COMPUTE2 COMPUTE_FLOAT2 +#define CONVERT_PRECISION_COMPUTE2(x) CONVERT_COMPUTE_FLOAT2(x) +#endif +#ifndef PRECISION_COMPUTE4 +#define PRECISION_COMPUTE4 COMPUTE_FLOAT4 +#define CONVERT_PRECISION_COMPUTE4(x) CONVERT_COMPUTE_FLOAT4(x) +#endif +#ifndef PRECISION_COMPUTE8 +#define PRECISION_COMPUTE8 COMPUTE_FLOAT8 +#define CONVERT_PRECISION_COMPUTE8(x) CONVERT_COMPUTE_FLOAT8(x) +#endif +#ifndef PRECISION_COMPUTE16 +#define PRECISION_COMPUTE16 COMPUTE_FLOAT16 +#define CONVERT_PRECISION_COMPUTE16(x) CONVERT_COMPUTE_FLOAT16(x) +#endif + +#define ZERO (PRECISION_COMPUTE)0.0f // Sets a variable to zero #define SetToZero(a) a = ZERO #define IsZero(a) (a == ZERO) @@ -129,43 +158,72 @@ INLINE_FUNC int GetGroupID0() { return get_group_id(0); } // ================================================================================================= -// End of the C++11 raw string literal - -typedef float real_arg; -#define GetRealArg(x) (FLOAT)x -typedef FLOAT real; - // Data-widths in dimension M #if VWM == 1 typedef FLOAT realM; + #define COMPUTE_FLOATM PRECISION_COMPUTE + #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE(x) + #define CONVERT_FLOATM(x) CONVERT_FLOAT(x) #elif VWM == 2 typedef FLOAT2 realM; + #define COMPUTE_FLOATM PRECISION_COMPUTE2 + #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE2(x) + #define CONVERT_FLOATM(x) CONVERT_FLOAT2(x) #elif VWM == 4 typedef FLOAT4 realM; + #define COMPUTE_FLOATM PRECISION_COMPUTE4 + #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE4(x) + #define CONVERT_FLOATM(x) CONVERT_FLOAT4(x) #elif VWM == 8 typedef FLOAT8 realM; + #define COMPUTE_FLOATM PRECISION_COMPUTE8 + #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE8(x) + #define CONVERT_FLOATM(x) CONVERT_FLOAT8(x) #elif VWM == 16 typedef FLOAT16 realM; + #define COMPUTE_FLOATM PRECISION_COMPUTE16 + #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE16(x) + #define CONVERT_FLOATM(x) CONVERT_FLOAT16(x) #endif // Data-widths in dimension N #if VWN == 1 typedef FLOAT realN; + typedef int intN; + #define COMPUTE_FLOATN PRECISION_COMPUTE + #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE(x) + #define CONVERT_FLOATN(x) CONVERT_FLOAT(x) #elif VWN == 2 typedef FLOAT2 realN; + typedef int2 intN; + #define COMPUTE_FLOATN PRECISION_COMPUTE2 + #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE2(x) + #define CONVERT_FLOATN(x) CONVERT_FLOAT2(x) #elif VWN == 4 typedef FLOAT4 realN; + typedef int4 intN; + #define COMPUTE_FLOATN PRECISION_COMPUTE4 + #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE4(x) + #define CONVERT_FLOATN(x) CONVERT_FLOAT4(x) #elif VWN == 8 typedef FLOAT8 realN; + typedef int8 intN; + #define COMPUTE_FLOATN PRECISION_COMPUTE8 + #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE8(x) + #define CONVERT_FLOATN(x) CONVERT_FLOAT8(x) #elif VWN == 16 typedef FLOAT16 realN; + typedef int16 intN; + #define COMPUTE_FLOATN PRECISION_COMPUTE16 + #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE16(x) + #define CONVERT_FLOATN(x) CONVERT_FLOAT16(x) #endif // ================================================================================================= // Initializes the accumulation registers to zero -INLINE_FUNC realM InitAccRegisters() { - realM result; +INLINE_FUNC COMPUTE_FLOATM InitAccRegisters() { + COMPUTE_FLOATM result; #if VWM == 1 SetToZero(result); #elif VWM == 2 @@ -206,8 +264,8 @@ INLINE_FUNC realM InitAccRegisters() { return result; } -INLINE_FUNC realN InitAccRegistersN() { - realN result; +INLINE_FUNC COMPUTE_FLOATN InitAccRegistersN() { + COMPUTE_FLOATN result; #if VWN == 1 SetToZero(result); #elif VWN == 2 @@ -443,10 +501,10 @@ INLINE_FUNC realN LocalToPrivateB(LOCAL_PTR realN* blm, const int _ni, const int #endif // The vectorised multiply-add function -INLINE_FUNC realM MultiplyAddVector(realM cvec, const realM avec, const real bval) { +INLINE_FUNC COMPUTE_FLOATM MultiplyAddVector(COMPUTE_FLOATM cvec, COMPUTE_FLOATM avec, PRECISION_COMPUTE bval) { #if USE_VECTOR_MAD == 1 #if USE_CL_MAD == 1 - cvec = mad(avec, (realM)bval, cvec); + cvec = mad(avec, (COMPUTE_FLOATM)bval, cvec); #else cvec += avec * bval; #endif @@ -493,10 +551,10 @@ INLINE_FUNC realM MultiplyAddVector(realM cvec, const realM avec, const real bva } // The vectorised multiply-add function -INLINE_FUNC realN MultiplyAddVectorN(realN cvec, const real avec, const realN bval) { +INLINE_FUNC COMPUTE_FLOATN MultiplyAddVectorN(COMPUTE_FLOATN cvec, PRECISION_COMPUTE avec, COMPUTE_FLOATN bval) { #if USE_VECTOR_MAD == 1 #if USE_CL_MAD == 1 - cvec = mad((realN)avec, bval, cvec); + cvec = mad((COMPUTE_FLOATN)avec, bval, cvec); #else cvec += avec * bval; #endif @@ -571,8 +629,8 @@ INLINE_FUNC INT2 StoreIndexM() { } // layout : [N, M] -INLINE_FUNC void StoreResultsM(__global realM* cgm, realM c_value, const INT2 baseOffset, const int _mi, const int _ni, - const int kSizeM, const real alpha, const real beta) { +INLINE_FUNC void StoreResultsM(__global realM* cgm, COMPUTE_FLOATM c_value, const INT2 baseOffset, const int _mi, const int _ni, + const int kSizeM, const PRECISION_COMPUTE alpha, const PRECISION_COMPUTE beta) { #if STRM == 0 int idm = _mi + baseOffset.index[0]; #elif STRM == 1 @@ -586,11 +644,11 @@ INLINE_FUNC void StoreResultsM(__global realM* cgm, realM c_value, const INT2 ba int index = idn*(kSizeM/VWM) + idm; - realM result = c_value; + COMPUTE_FLOATM result = c_value; // The final multiplication with alpha (in case beta == 0) #ifdef ONLY_HAVE_ALPHA - realM xval = c_value; + COMPUTE_FLOATM xval = c_value; #if VWM == 1 Multiply(result, alpha, xval); #elif VWM == 2 @@ -632,8 +690,8 @@ INLINE_FUNC void StoreResultsM(__global realM* cgm, realM c_value, const INT2 ba // The final multiplication with alpha and the addition with beta*C #ifdef HAVE_ALPHA_BETA - realM xval = c_value; - realM yval = cgm[index]; + COMPUTE_FLOATM xval = c_value; + COMPUTE_FLOATM yval = CONVERT_COMPUTE_FLOATM(cgm[index]); #if VWM == 1 AXPBY(result, alpha, xval, beta, yval); #elif VWM == 2 @@ -672,7 +730,7 @@ INLINE_FUNC void StoreResultsM(__global realM* cgm, realM c_value, const INT2 ba AXPBY(result.sF, alpha, xval.sF, beta, yval.sF); #endif #endif - cgm[index] = result; + cgm[index] = CONVERT_FLOATM(result); } INLINE_FUNC INT2 StoreIndexN() { @@ -695,7 +753,7 @@ INLINE_FUNC INT2 StoreIndexN() { return res; } // layout : [M, N] -INLINE_FUNC void StoreResultsN(__global realN* cgn, realN c_value, +INLINE_FUNC void StoreResultsN(__global realN* cgn, COMPUTE_FLOATN c_value, const INT2 baseOffset, #if BIAS_TYPE > 0 #if BIAS_TYPE > 1 @@ -705,7 +763,7 @@ INLINE_FUNC void StoreResultsN(__global realN* cgn, realN c_value, #endif #endif const int _mi, const int _ni, - const int cstride/*kSizeN*/, const int dstride/*kSizeN*/, const real alpha, const real beta) { + const int cstride/*kSizeN*/, const int dstride/*kSizeN*/, const PRECISION_COMPUTE alpha, const PRECISION_COMPUTE beta) { #if STRM == 0 int idm = _mi + baseOffset.index[0]; @@ -720,11 +778,11 @@ INLINE_FUNC void StoreResultsN(__global realN* cgn, realN c_value, int index = idm * (cstride/VWN) + idn; - realN result = c_value; + COMPUTE_FLOATN result = c_value; // The final multiplication with alpha (in case beta == 0) #ifdef ONLY_HAVE_ALPHA - realN xval = c_value; + COMPUTE_FLOATN xval = c_value; #if VWN == 1 Multiply(result, alpha, xval); #elif VWN == 2 @@ -766,8 +824,8 @@ INLINE_FUNC void StoreResultsN(__global realN* cgn, realN c_value, // The final multiplication with alpha and the addition with beta*C #ifdef HAVE_ALPHA_BETA - realN xval = c_value; - realN yval = cgn[index]; + COMPUTE_FLOATN xval = c_value; + COMPUTE_FLOATN yval = CONVERT_COMPUTE_FLOATN(cgn[index]); #if VWN == 1 AXPBY(result, alpha, xval, beta, yval); #elif VWN == 2 @@ -810,29 +868,31 @@ INLINE_FUNC void StoreResultsN(__global realN* cgn, realN c_value, #if BIAS_TYPE > 0 #if BIAS_TYPE == 1 - realN eval = epm[_ni]; + COMPUTE_FLOATN eval = CONVERT_COMPUTE_FLOATN(epm[_ni]); + #elif BIAS_TYPE == 5 + int index_bias = idm * (dstride/VWN) + idn; + intN eval = ((__global intN*)egm)[index_bias]; #else - int index_bias = idm * (dstride/VWN) + idn; - realN eval = egm[index_bias]; + COMPUTE_FLOATN eval = CONVERT_COMPUTE_FLOATN(egm[index_bias]); #endif #if VWN == 1 DEAL_BIAS(result, eval); #ifdef RELU - result = fmax(result, (FLOAT)0); + result = fmax(result, (COMPUTE_FLOATN)0); #endif #ifdef RELU6 - result = clamp(result, (FLOAT)0, (FLOAT)6); + result = clamp(result, (COMPUTE_FLOATN)0, (COMPUTE_FLOATN)6); #endif #elif VWN == 2 DEAL_BIAS(result.x, eval.x); DEAL_BIAS(result.y, eval.y); #ifdef RELU - result = fmax(result, (FLOAT2)0); + result = fmax(result, (COMPUTE_FLOATN)0); #endif #ifdef RELU6 - result = clamp(result, (FLOAT2)0, (FLOAT2)6); + result = clamp(result, (COMPUTE_FLOATN)0, (COMPUTE_FLOATN)6); #endif #elif VWN == 4 DEAL_BIAS(result.x, eval.x); @@ -840,10 +900,10 @@ INLINE_FUNC void StoreResultsN(__global realN* cgn, realN c_value, DEAL_BIAS(result.z, eval.z); DEAL_BIAS(result.w, eval.w); #ifdef RELU - result = fmax(result, (FLOAT4)0); + result = fmax(result, (COMPUTE_FLOATN)0); #endif #ifdef RELU6 - result = clamp(result, (FLOAT4)0, (FLOAT4)6); + result = clamp(result, (COMPUTE_FLOATN)0, (COMPUTE_FLOATN)6); #endif #elif VWN == 8 DEAL_BIAS(result.s0, eval.s0); @@ -855,10 +915,10 @@ INLINE_FUNC void StoreResultsN(__global realN* cgn, realN c_value, DEAL_BIAS(result.s6, eval.s6); DEAL_BIAS(result.s7, eval.s7); #ifdef RELU - result = fmax(result, (FLOAT8)0); + result = fmax(result, (COMPUTE_FLOATN)0); #endif #ifdef RELU6 - result = clamp(result, (FLOAT8)0, (FLOAT8)6); + result = clamp(result, (COMPUTE_FLOATN)0, (COMPUTE_FLOATN)6); #endif #elif VWN == 16 DEAL_BIAS(result.s0, eval.s0); @@ -878,15 +938,15 @@ INLINE_FUNC void StoreResultsN(__global realN* cgn, realN c_value, DEAL_BIAS(result.sE, eval.sE); DEAL_BIAS(result.sF, eval.sF); #ifdef RELU - result = fmax(result, (FLOAT16)0); + result = fmax(result, (COMPUTE_FLOATN)0); #endif #ifdef RELU6 - result = clamp(result, (FLOAT16)0, (FLOAT16)6); + result = clamp(result, (COMPUTE_FLOATN)0, (COMPUTE_FLOATN)6); #endif #endif #endif - cgn[index] = result; + cgn[index] = CONVERT_FLOATN(result); } @@ -896,7 +956,7 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, #if BIAS_TYPE > 0 __global realN* restrict egm, #endif - __global realM* cgm, const real alpha, const real beta + __global realM* cgm, const real_arg alpha, const real_arg beta #if SA == 1 && SB == 1 , LOCAL_PTR realM* alm, LOCAL_PTR realN* blm #elif SA == 1 @@ -907,10 +967,10 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, ) { #ifdef OUTPUTMN #pragma promote_to_registers - realN cpn[MWI*(NWI/VWN)]; // MWI * NWI + COMPUTE_FLOATN cpn[MWI*(NWI/VWN)]; // MWI * NWI #else #pragma promote_to_registers - realM cpm[NWI*(MWI/VWM)]; // NWI * MWI + COMPUTE_FLOATM cpm[NWI*(MWI/VWM)]; // NWI * MWI #endif // Combined thread identifier (volatile to disable caching) @@ -941,9 +1001,9 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, #if SA == 1 || SB == 1 // Allocates workitem-private memory (registers) #pragma promote_to_registers - realM apm[MWI/VWM]; // MWI * 1 + COMPUTE_FLOATM apm[MWI/VWM]; // MWI * 1 #pragma promote_to_registers - realN bpm[NWI/VWN]; // 1 * NWI + COMPUTE_FLOATN bpm[NWI/VWN]; // 1 * NWI for (int kwg = 0; kwg < kSizeK; kwg += KWG) { // Loads data: off-chip --> local (matrix A) @@ -970,10 +1030,10 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { // Loads data: local --> private (matrix A) #if SA == 1 - apm[_mi] = LocalToPrivateA(alm, _mi, kg); + apm[_mi] = CONVERT_COMPUTE_FLOATM(LocalToPrivateA(alm, _mi, kg)); // Loads data: off-chip --> private (matrix A) #elif SA == 0 - apm[_mi] = GlobalToPrivateA(agm, _mi, kSizeM, idk); + apm[_mi] = CONVERT_COMPUTE_FLOATM(GlobalToPrivateA(agm, _mi, kSizeM, idk)); #endif } @@ -983,10 +1043,10 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { // Loads data: local --> private (matrix B) #if SB == 1 - bpm[_ni] = LocalToPrivateB(blm, _ni, kg); + bpm[_ni] = CONVERT_COMPUTE_FLOATN(LocalToPrivateB(blm, _ni, kg)); // Loads data: off-chip --> private (matrix B) #else - bpm[_ni] = GlobalToPrivateB(bgm, _ni, kSizeN, idk); + bpm[_ni] = CONVERT_COMPUTE_FLOATN(GlobalToPrivateB(bgm, _ni, kSizeN, idk)); #endif } @@ -997,7 +1057,7 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { #pragma unroll for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { - const realM aval = apm[_mi]; + const COMPUTE_FLOATM aval = apm[_mi]; #if VWM == 1 // [MWI/VWM, VWM, NWI/VWN, VWN] cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni], aval, bpm[_ni]); @@ -1043,7 +1103,7 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { #pragma unroll for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { - const realM aval = apm[_mi]; + const COMPUTE_FLOATM aval = apm[_mi]; #if VWN == 1 cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]); #elif VWN == 2 @@ -1098,7 +1158,7 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, for (int _kj = 0; _kj < kSizeK; _kj += 4) { #ifdef OUTPUTMN #pragma promote_to_registers - realN bpm[NWI/VWN]; // 1 * NWI + COMPUTE_FLOATN bpm[NWI/VWN]; // 1 * NWI #pragma unroll for(int _ki = 0; _ki < 4; _ki += 1) { @@ -1106,12 +1166,12 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, #pragma unroll for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { // Loads data: off-chip --> private (matrix B) - bpm[_ni] = GlobalToPrivateOptB(bgm, baseIndexB, _ni, stride.s1/*kSizeN*/, idk); + bpm[_ni] = CONVERT_COMPUTE_FLOATN(GlobalToPrivateOptB(bgm, baseIndexB, _ni, stride.s1/*kSizeN*/, idk)); } #pragma unroll for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { - const realM aval = GlobalToPrivateOptA(agm, baseIndexA, _mi, stride.s0/*kSizeM*/, idk); + const COMPUTE_FLOATM aval = CONVERT_COMPUTE_FLOATM(GlobalToPrivateOptA(agm, baseIndexA, _mi, stride.s0/*kSizeM*/, idk)); #pragma unroll for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { #if VWM == 1 @@ -1158,22 +1218,22 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, #else #pragma promote_to_registers - realM apm[MWI/VWM]; // MWI * 1 + COMPUTE_FLOATM apm[MWI/VWM]; // MWI * 1 #pragma unroll for(int _ki = 0; _ki < 4; _ki += 1) { int idk = _kj + _ki; #pragma unroll for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { // Loads data: off-chip --> private (matrix B) - apm[_mi] = GlobalToPrivateOptA(agm, baseIndexA, _mi, stride.s0/*kSizeM*/, idk); + apm[_mi] = CONVERT_COMPUTE_FLOATM(GlobalToPrivateOptA(agm, baseIndexA, _mi, stride.s0/*kSizeM*/, idk)); } #pragma unroll for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { - const realN bval = GlobalToPrivateOptB(bgm, baseIndexB, _ni, stride.s1/*kSizeN*/, idk); + const COMPUTE_FLOATN bval = CONVERT_COMPUTE_FLOATN(GlobalToPrivateOptB(bgm, baseIndexB, _ni, stride.s1/*kSizeN*/, idk)); #pragma unroll for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { - const realM aval = apm[_mi]; + const COMPUTE_FLOATM aval = apm[_mi]; #if VWN == 1 cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bval); #elif VWN == 2 @@ -1288,8 +1348,6 @@ void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK, __private const int4 offset, __private const int4 stride ) { - const real alpha = GetRealArg(arg_alpha); - const real beta = GetRealArg(arg_beta); // Adds the offsets (in case of use of a single temporary buffer for A, B, and C) agm = (const __global realM*)((const __global real*)agm + offset.s0); @@ -1313,25 +1371,25 @@ void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK, #if BIAS_TYPE > 0 egm, #endif - cgm, alpha, beta, alm, blm); + cgm, arg_alpha, arg_beta, alm, blm); #elif SA == 1 XgemmBody(kSizeM, kSizeN, kSizeK, stride, agm, bgm, #if BIAS_TYPE > 0 egm, #endif - cgm, alpha, beta, alm); + cgm, arg_alpha, arg_beta, alm); #elif SB == 1 XgemmBody(kSizeM, kSizeN, kSizeK, stride, agm, bgm, #if BIAS_TYPE > 0 egm, #endif - cgm, alpha, beta, blm); + cgm, arg_alpha, arg_beta, blm); #else XgemmBody(kSizeM, kSizeN, kSizeK, stride, agm, bgm, #if BIAS_TYPE > 0 egm, #endif - cgm, alpha, beta); + cgm, arg_alpha, arg_beta); #endif } @@ -1346,29 +1404,32 @@ void XgemmBatched(const int kSizeM, const real_arg arg_alpha, const real_arg arg_beta, const __global realM* restrict agm, - const int batch_offset_a, const __global realN* restrict bgm, - const int batch_offset_b, #if BIAS_TYPE > 0 __global realN* restrict egm, - const int batch_offset_e, #endif __global realM* cgm, - const int batch_offset_c) { + const int4 batch_offset, // [batch_offset_a, batch_offset_b, batch_offset_c, batch_offset_e] + const int4 stride, // [stride_a, stride_b, stride_c, stride_e] + /* + total_batch -> [loop_y, loop_x] + with group batch -> [loop_y, loop_x/group_num] + group_size == loop_x/group_num + */ + const int4 group // [group_num_a, group_num_b, group_num_e, loop_x] +) { const int batch = get_group_id(2); - const real alpha = GetRealArg(arg_alpha); - const real beta = GetRealArg(arg_beta); // Sets the offsets - const int a_offset = batch * batch_offset_a; - const int b_offset = batch * batch_offset_b; - const int c_offset = batch * batch_offset_c; + const int a_offset = ((batch / group.w) * group.x + (batch % group.w) / group.x) * batch_offset.x; + const int b_offset = ((batch / group.w) * group.y + (batch % group.w) / group.y) * batch_offset.y; + const int c_offset = batch * batch_offset.z; const __global realM* restrict agm_ = &agm[a_offset / VWM]; const __global realN* restrict bgm_ = &bgm[b_offset / VWN]; __global realM* restrict cgm_ = &cgm[c_offset / VWM]; #if BIAS_TYPE > 0 - const int e_offset = batch * batch_offset_e; + const int e_offset = ((batch / group.w) * group.z + (batch % group.w) / group.z) * batch_offset.w; __global realN* restrict egm_ = &egm[e_offset / VWN]; #endif @@ -1379,40 +1440,32 @@ void XgemmBatched(const int kSizeM, #if SB == 1 __local realN blm[KWG * NWG/VWN]; #endif - int4 stride; - stride.s0 = kSizeM; - stride.s1 = kSizeN; - #ifdef OUTPUTMN - stride.s2 = kSizeN; - #else - stride.s2 = kSizeM; - #endif - stride.s3 = kSizeN; + // Computes the matrix-multiplication and stores the result in global memory #if SA == 1 && SB == 1 XgemmBody(kSizeM, kSizeN, kSizeK, stride, agm_, bgm_, #if BIAS_TYPE > 0 egm_, #endif - cgm_, alpha, beta, alm, blm); + cgm_, arg_alpha, arg_beta, alm, blm); #elif SA == 1 XgemmBody(kSizeM, kSizeN, kSizeK, stride, agm_, bgm_, #if BIAS_TYPE > 0 egm_, #endif - cgm_, alpha, beta, alm); + cgm_, arg_alpha, arg_beta, alm); #elif SB == 1 XgemmBody(kSizeM, kSizeN, kSizeK, stride, agm_, bgm_, #if BIAS_TYPE > 0 egm_, #endif - cgm_, alpha, beta, blm); + cgm_, arg_alpha, arg_beta, blm); #else XgemmBody(kSizeM, kSizeN, kSizeK, stride, agm_, bgm_, #if BIAS_TYPE > 0 egm_, #endif - cgm_, alpha, beta); + cgm_, arg_alpha, arg_beta); #endif } diff --git a/source/backend/opencl/execution/cl/opencl_codegen.py b/source/backend/opencl/execution/cl/opencl_codegen.py index 56d497aee..b18dd59dc 100644 --- a/source/backend/opencl/execution/cl/opencl_codegen.py +++ b/source/backend/opencl/execution/cl/opencl_codegen.py @@ -41,7 +41,7 @@ def opencl_codegen(): for file_name_all in os.listdir(cl_kernel_dir): file_path = os.path.join(cl_kernel_dir, file_name_all) if file_path[-3:] == ".cl": - with open(file_path, "r") as f: + with open(file_path, "r", encoding = 'utf-8') as f: file_name = file_name_all[:-3] if file_name[-4:] == "_buf": opencl_source_map += "#ifndef MNN_OPENCL_BUFFER_CLOSED\n" diff --git a/source/backend/opencl/execution/cl/opencl_program.cc b/source/backend/opencl/execution/cl/opencl_program.cc index 7eebdd550..a4d2cb4f4 100644 --- a/source/backend/opencl/execution/cl/opencl_program.cc +++ b/source/backend/opencl/execution/cl/opencl_program.cc @@ -1346,7 +1346,7 @@ const char* deconv_2d = " #ifdef BIAS\n" " __global FLOAT* bias,\n" " #endif\n" -" __global FLOAT* output,\n" +" __global FLOAT* output,__private const int batch,\n" " #else\n" " __read_only image2d_t input,\n" " __read_only image2d_t weights,\n" @@ -1406,7 +1406,7 @@ const char* deconv_2d = " weights2=vload4(kernel_x_2*(out_channel_blocks*kernel_shape.x*kernel_shape.y)+kernel_y,weights);\n" " weights3=vload4(kernel_x_3*(out_channel_blocks*kernel_shape.x*kernel_shape.y)+kernel_y,weights);\n" " bool outBoundry=(idx_h<0 || idx_h >= input_shape.x || kernel_start_x<0 || in_width0 >= input_shape.y);\n" -" int inp_offset=(((out_b_idx*in_channel_blocks+ic)*input_shape.x+idx_h)*input_shape.y+in_width0)*4;\n" +" int inp_offset=(((out_b_idx+ic*batch)*input_shape.x+idx_h)*input_shape.y+in_width0)*4;\n" " in0=outBoundry ? (FLOAT4)0 : vload4(0,input+inp_offset);\n" " out0=mad(in0.x,weights0,out0);\n" " out0=mad(in0.y,weights1,out0);\n" @@ -1443,7 +1443,7 @@ const char* deconv_2d = " out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n" "#endif\n" "#ifdef USE_BUFFER\n" -" const int out_offset=(((out_b_idx*out_channel_blocks+out_channel_blocks_idx)*output_shape.x+out_h_idx)*output_shape.y+out_w_idx)*4;\n" +" const int out_offset=(((out_b_idx+out_channel_blocks_idx*batch)*output_shape.x+out_h_idx)*output_shape.y+out_w_idx)*4;\n" " vstore4(out0,0,output+out_offset);\n" "#else\n" " int out_image_width_idx=mad24(out_channel_blocks_idx,output_shape.y,out_w_idx);\n" @@ -1542,7 +1542,7 @@ const char* grid_sample_buf = " __private const int input_width,\n" " __private const int output_height,\n" " __private const int output_width,\n" -" __private const int channelBlocks,\n" +" __private const int batch,\n" " __private const enum BorderMode paddingMode,\n" " __private const int alignCorners){\n" " \n" @@ -1567,27 +1567,21 @@ const char* grid_sample_buf = " (xn,xn,xn,xn) (y5,y6,y7,y8)\n" " ---------------------------\n" " */\n" -" const int slice=output_height_idx/4;\n" -" const int slice_blocks=(output_height+3)/4;\n" " // output_width_block_idx means gird y offset,2 means grid width\n" -" const int grid_offset=((output_batch_idx*slice_blocks+slice)*output_width+output_width_block_idx)*2;\n" -" COMPUTE_FLOAT4 grid_x=CONVERT_COMPUTE_FLOAT4(vload4(grid_offset,grid));\n" -" COMPUTE_FLOAT4 grid_y=CONVERT_COMPUTE_FLOAT4(vload4(grid_offset+1,grid));\n" -" const float arr[8]={grid_x.x,grid_y.x,grid_x.y,grid_y.y,grid_x.z,grid_y.z,grid_x.w,grid_y.w};\n" -" \n" +" const int grid_offset=(output_batch_idx*output_height+output_height_idx)*output_width+output_width_block_idx;\n" +" COMPUTE_FLOAT2 grid_xy=CONVERT_COMPUTE_FLOAT2(vload2(grid_offset,grid));\n" " // get grid x,y\n" -" const int arr_offset=output_height_idx % 4;\n" -" const float x=arr[2*arr_offset];\n" -" const float y=arr[2*arr_offset+1];\n" +" const float x=(float)grid_xy.x;\n" +" const float y=(float)grid_xy.y;\n" " // convert grid x,y to input x,y coordinate range\n" " float in_grid_x=getPosition(x,input_width,alignCorners);\n" " float in_grid_y=getPosition(y,input_height,alignCorners);\n" " // get nearest point\n" " int nw=floor(in_grid_x+0.5f);\n" " int nh=floor(in_grid_y+0.5f);\n" -" const int inp_offset_base=(output_batch_idx*channelBlocks+output_channel_block_idx)*input_height;\n" +" const int inp_offset_base=(output_batch_idx+output_channel_block_idx*batch)*input_height;\n" " COMPUTE_FLOAT4 value=sample(nh,nw,inp_offset_base,input,input_height,input_width,paddingMode);\n" -" const int output_offset=((output_batch_idx*channelBlocks+output_channel_block_idx )*output_height+output_height_idx)*output_width+output_width_block_idx;\n" +" const int output_offset=((output_batch_idx+output_channel_block_idx*batch)*output_height+output_height_idx)*output_width+output_width_block_idx;\n" " vstore4(CONVERT_FLOAT4(value),output_offset,output);\n" "}\n" "__kernel void bilinear_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input,\n" @@ -1597,7 +1591,7 @@ const char* grid_sample_buf = " __private const int input_width,\n" " __private const int output_height,\n" " __private const int output_width,\n" -" __private const int channelBlocks,\n" +" __private const int batch,\n" " __private const enum BorderMode paddingMode,\n" " __private const int alignCorners){\n" " const int output_channel_block_idx=get_global_id(0);\n" @@ -1606,18 +1600,13 @@ const char* grid_sample_buf = " DEAL_NON_UNIFORM_DIM3(output_channel_block_idx,output_width_block_idx,output_batch_height_block_idx);\n" " const int output_batch_idx=output_batch_height_block_idx/output_height;\n" " const int output_height_idx=output_batch_height_block_idx % output_height;\n" -" const int slice=output_height_idx/4;\n" -" const int slice_blocks=(output_height+3)/4;\n" " // output_width_block_idx means gird y offset,2 means grid width\n" -" const int grid_offset=((output_batch_idx*slice_blocks+slice)*output_width+output_width_block_idx)*2;\n" -" COMPUTE_FLOAT4 grid_x=CONVERT_COMPUTE_FLOAT4(vload4(grid_offset,grid));\n" -" COMPUTE_FLOAT4 grid_y=CONVERT_COMPUTE_FLOAT4(vload4(grid_offset+1,grid));\n" -" const float arr[8]={grid_x.x,grid_y.x,grid_x.y,grid_y.y,grid_x.z,grid_y.z,grid_x.w,grid_y.w};\n" +" const int grid_offset=(output_batch_idx*output_height+output_height_idx)*output_width+output_width_block_idx;\n" +" COMPUTE_FLOAT2 grid_xy=CONVERT_COMPUTE_FLOAT2(vload2(grid_offset,grid));\n" " \n" " // get grid x,y\n" -" const int arr_offset=output_height_idx % 4;\n" -" const float x=arr[2*arr_offset];\n" -" const float y=arr[2*arr_offset+1];\n" +" const float x=(float)grid_xy.x;\n" +" const float y=(float)grid_xy.y;\n" " // convert grid x,y to input x,y coordinate range\n" " float in_grid_x=getPosition(x,input_width,alignCorners);\n" " float in_grid_y=getPosition(y,input_height,alignCorners);\n" @@ -1628,7 +1617,7 @@ const char* grid_sample_buf = " float x_weight=in_w1-in_grid_x;\n" " float y_weight=in_h1-in_grid_y;\n" " // bilinear interpolation\n" -" const int inp_offset_base=(output_batch_idx*channelBlocks+output_channel_block_idx)*input_height;\n" +" const int inp_offset_base=(output_batch_idx+output_channel_block_idx*batch)*input_height;\n" " COMPUTE_FLOAT4 i00=sample(in_h0,in_w0,inp_offset_base,input,input_height,input_width,paddingMode);\n" " COMPUTE_FLOAT4 i01=sample(in_h0,in_w1,inp_offset_base,input,input_height,input_width,paddingMode);\n" " COMPUTE_FLOAT4 i10=sample(in_h1,in_w0,inp_offset_base,input,input_height,input_width,paddingMode);\n" @@ -1636,7 +1625,7 @@ const char* grid_sample_buf = " COMPUTE_FLOAT4 value=CONVERT_COMPUTE_FLOAT4(((COMPUTE_FLOAT4)x_weight*CONVERT_COMPUTE_FLOAT4(i00)+(COMPUTE_FLOAT4)(1.0f-x_weight)*CONVERT_COMPUTE_FLOAT4(i01))*(COMPUTE_FLOAT4)y_weight +\n" " ((COMPUTE_FLOAT4)x_weight*CONVERT_COMPUTE_FLOAT4(i10)+(COMPUTE_FLOAT4)(1.0f-x_weight)*CONVERT_COMPUTE_FLOAT4(i11))*(COMPUTE_FLOAT4)(1.0f- y_weight));\n" " \n" -" const int output_offset=((output_batch_idx*channelBlocks+output_channel_block_idx )*output_height+output_height_idx)*output_width+output_width_block_idx;\n" +" const int output_offset=((output_batch_idx+output_channel_block_idx*batch)*output_height+output_height_idx)*output_width+output_width_block_idx;\n" " vstore4(CONVERT_FLOAT4(value),output_offset,output);\n" "}\n" ; @@ -1730,34 +1719,35 @@ const char* range_buf = "#ifdef MNN_SUPPORT_FP16\n" "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" "#endif\n" -"#define GLOBAL_SIZE_3_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n" -"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" -"__kernel void range_buf(GLOBAL_SIZE_3_DIMS\n" +"#define GLOBAL_SIZE_2_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,\n" +"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n" +"__kernel void range_buf(GLOBAL_SIZE_2_DIMS\n" " __global const INPUT_TYPE* input0,\n" " __global const INPUT_TYPE* input2,\n" " __global OUTPUT_TYPE* output,\n" -" __private const int width,\n" -" __private const int height,\n" -" __private const int channel,\n" -" __private const int channelBlock\n" +" __private const int size\n" " ) {\n" -" const int width_idx=get_global_id(0);\n" -" const int height_idx=get_global_id(1);\n" -" const int batch_channel_idx=get_global_id(2);\n" -" DEAL_NON_UNIFORM_DIM3(width_idx,height_idx,batch_channel_idx);\n" -" \n" -" const int batch_idx=batch_channel_idx/channelBlock;\n" -" const int channel_idx=batch_channel_idx % channelBlock;\n" +" const int x=get_global_id(0);\n" +" const int y=get_global_id(1);\n" +" DEAL_NON_UNIFORM_DIM2(x,y);\n" " \n" -" const int offset=((((batch_idx*channelBlock)+channel_idx)*height+height_idx)*width+width_idx)*4;\n" -" const int channel4=channel_idx << 2;\n" -" int index=(((batch_idx*channel)+channel4)*height+height_idx)*width+width_idx;\n" -" int size=height*width;\n" -" int4 index4=(int4)(index,index+size,index+size*2,index+size*3);\n" +" int index=x << 2;\n" +" int4 index4=(int4)(index,index+1,index+2,index+3);\n" " INPUT_TYPE start=input0[0];\n" " INPUT_TYPE step=input2[0];\n" " OUTPUT_TYPE4 value=(OUTPUT_TYPE4)start+CONVERT_OUTPUT4(index4)*(OUTPUT_TYPE4)step;\n" -" vstore4(value,0,output+offset);\n" +"#ifdef PACK_LEAVE\n" +" if(index+3 >= size){\n" +" OUTPUT_TYPE* value_ptr=(OUTPUT_TYPE*)&value;\n" +" for(int i=0; iinside_len){\n" +" for(int i=lid+inside_len; i [N Y X]\n" -"__kernel void trans_3d_buf(__global const FLOAT* input,\n" +"__kernel void trans_3d_buf(GLOBAL_SIZE_3_DIMS\n" +" __global const FLOAT* input,\n" " __global FLOAT* output,\n" " __private const int batch,\n" " __private const int width,\n" " __private const int height\n" ") {\n" " int b=get_global_id(2);\n" -" \n" -" const int w=get_global_id(0) << 3;\n" -" const int h=get_global_id(1) << 3;\n" +" int w=get_global_id(0);\n" +" int h=get_global_id(1);\n" +" DEAL_NON_UNIFORM_DIM3(w,h,b);\n" +" w=w << 3;\n" +" h=h << 3;\n" " \n" " const int inp_offset=(b*width+w)*height+h;\n" " const int out_offset=(b*height+h)*width+w;\n" @@ -2005,6 +2007,7 @@ const char* self_attention_buf = " __private const int seq_len_piece,\n" " __private const int head_num,\n" " __private const int head_dim,\n" +" __private const int batch,\n" " __private const int seq_index\n" ") {\n" " \n" @@ -2026,8 +2029,7 @@ const char* self_attention_buf = " \n" " const int offset_inp=((b*head_num+hn)*head_dim_pack+4*hd)*seq_len_pack+4*sl;\n" " \n" -" const int offset_out=(((b*seq_len_4+seq_index*seq_len_piece/4+sl)*head_num+hn)*head_dim+4*hd)*4;\n" -" \n" +" const int offset_out=((((seq_index*seq_len_piece/4+sl)*batch+b)*head_num+hn)*head_dim+4*hd)*4;\n" " // Q\n" " FLOAT4 temp_0=vload4(0,input+offset_inp);\n" " FLOAT4 temp_1=vload4(0,input+offset_inp+seq_len_pack);\n" @@ -2308,63 +2310,75 @@ const char* gemv_conv1x1_buf = "#define UCHAR8_TO_CHAR16(a, c) "" a.s0 = (c.s0 >> 4) - 8; a.s1 = (c.s0 & 15) - 8; a.s2 = (c.s1 >> 4) - 8; a.s3 = (c.s1 & 15) - 8; a.s4 = (c.s2 >> 4) - 8; a.s5 = (c.s2 & 15) - 8; a.s6 = (c.s3 >> 4) - 8; a.s7 = (c.s3 & 15) - 8; "" a.s8=(c.s4 >> 4)-8; a.s9=(c.s4 & 15)-8; a.sa=(c.s5 >> 4)-8; a.sb=(c.s5 & 15)-8; a.sc=(c.s6 >> 4)-8; a.sd=(c.s6 & 15)-8; a.se=(c.s7 >> 4)-8; a.sf=(c.s7 & 15)-8;\n" "#define DOT16X16(a, b, c) "" c += dot(a.s0123, b.s0123); "" c += dot(a.s4567, b.s4567); "" c += dot(a.s89ab, b.s89ab); "" c += dot(a.scdef,b.scdef);\n" "#ifdef INPUT_CHANNEL_LEAVE\n" -" #define PADZEROS(k, channel, data) {"" COMPUTE_FLOAT* ptr = (COMPUTE_FLOAT*)&data; "" int remain = k + 15 - channel; "" for(int r = remain; r >= 0; r--){ "" ptr[15 - remain] = 0; "" } "" }\n" +" #define PADZEROS(k, channel, data) {"" COMPUTE_FLOAT* ptr = (COMPUTE_FLOAT*)&data; "" int remain = k + 15 - channel; "" for(int r = remain; r >= 0; r--){ "" ptr[15 - r] = 0; "" } "" }\n" "#else\n" " #define PADZEROS(k,channel,data)\n" "#endif\n" +"#if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE)\n" +"#define CHANNEL_PACK 32\n" +"#else\n" +"#define CHANNEL_PACK 16\n" +"#endif\n" +"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +"#define WEIGHT_STRIDE 16\n" +"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" +"#define WEIGHT_STRIDE 8\n" +"#endif\n" "__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" -"__kernel void gemm_conv_c4_buf(GLOBAL_SIZE_DIM2\n" +"#ifdef USE_IMAGE\n" +"inline COMPUTE_FLOAT16 readWeight(__read_only image2d_t weight,int ix,int iy,COMPUTE_FLOAT scale,COMPUTE_FLOAT offset){\n" +" return CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight,SAMPLER,(int2)(ix,iy))))*scale+offset;\n" +"}\n" +"#else\n" +"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +"inline COMPUTE_FLOAT16 readWeight(__global const char *weight,int ix,int iy,COMPUTE_FLOAT scale,COMPUTE_FLOAT offset){\n" +" return CONVERT_COMPUTE_FLOAT16(vload16(0,weight))*scale+offset;\n" +"}\n" +"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" +"inline COMPUTE_FLOAT16 readWeight(__global const uchar *weight,int ix,int iy,COMPUTE_FLOAT scale,COMPUTE_FLOAT offset){\n" +" uchar16 charWeightsInt40=vload16(0,weight);\n" +" uchar8 charWeightsInt4=vload8(0,weight);\n" +" char16 charWeights=0;\n" +" UCHAR8_TO_CHAR16(charWeights,charWeightsInt4);\n" +" return CONVERT_COMPUTE_FLOAT16(charWeights)*scale+offset;\n" +"}\n" +"#endif\n" +"#endif\n" +"__kernel void gemv_conv_c4_buf(GLOBAL_SIZE_DIM2\n" " __global const FLOAT* input,\n" +"#ifdef USE_IMAGE\n" +" __read_only image2d_t weight,\n" +"#else\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" " __global const char *weight,\n" "#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" " __global const uchar *weight,\n" "#endif\n" +"#endif\n" " __global const float *dequantScaleOffset,\n" " __global const FLOAT *bias,\n" " __global FLOAT* output,\n" " __private const int dstChannelC4,\n" " __private const int srcChannelC4,\n" " __private const int srcChannel,\n" -" __private const int batch,\n" -" __private const int height,\n" -" __private const int width,\n" +" __private const int bhw,\n" " __private const int blockNum,\n" " __private const int blockDim) {\n" -" const int out_c_w_idx=get_global_id(0); //c/4 w\n" -" const int out_b_h_idx=get_global_id(1); //b h\n" -" UNIFORM_BOUNDRY_CHECK(out_c_w_idx,out_b_h_idx);\n" -" const int out_c_idx=out_c_w_idx/width;\n" -" const int out_w_idx=out_c_w_idx % width;\n" -"#ifdef BACTH_BLOCK4\n" -" const int out_b_idx=(out_b_h_idx/height) << 2;\n" -"#else\n" -" const int out_b_idx=out_b_h_idx/height;\n" -"#endif\n" -" const int out_h_idx=out_b_h_idx % height;\n" -" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n" -" COMPUTE_FLOAT4 out=bias0;\n" -"#ifdef BACTH_BLOCK4\n" -" COMPUTE_FLOAT4 out1=bias0,out2=bias0,out3=bias0;\n" -" int input_offset1=(((out_b_idx+1)*srcChannelC4*height+out_h_idx)*width+out_w_idx)*4;\n" -" int input_offset2=(((out_b_idx+2)*srcChannelC4*height+out_h_idx)*width+out_w_idx)*4;\n" -" int input_offset3=(((out_b_idx+3)*srcChannelC4*height+out_h_idx)*width+out_w_idx)*4;\n" -" bool isValidBatch1=out_b_idx+1= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n" +"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" +"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n" +"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" +"__kernel void buffer_set_zero(\n" +" GLOBAL_SIZE_2_DIMS\n" +" __global OUTPUT_TYPE *output\n" +" ) {\n" +" const int x=get_global_id(0);\n" +" const int y=get_global_id(1);\n" " \n" -" vstore2(CONVERT_FLOAT2(out1),0,output+out_offset);\n" -" }\n" -" if(isValidBatch2){\n" -" out_offset += dstChannelC4*height*width*4;\n" -"#ifdef RELU\n" -" out2=fmax(out2,(COMPUTE_FLOAT2)0);\n" -"#endif\n" -"#ifdef RELU6\n" -" out2=clamp(out2,(COMPUTE_FLOAT2)0,(COMPUTE_FLOAT2)6);\n" -"#endif\n" +" DEAL_NON_UNIFORM_DIM2(x,y);\n" " \n" -" vstore2(CONVERT_FLOAT2(out2),0,output+out_offset);\n" -" }\n" -" if(isValidBatch3){\n" -" out_offset += dstChannelC4*height*width*4;\n" -"#ifdef RELU\n" -" out3=fmax(out3,(COMPUTE_FLOAT2)0);\n" -"#endif\n" -"#ifdef RELU6\n" -" out3=clamp(out3,(COMPUTE_FLOAT2)0,(COMPUTE_FLOAT2)6);\n" -"#endif\n" +" output[y*global_size_dim0+x]=(OUTPUT_TYPE)(0);\n" +"}\n" +"__kernel void image_set_zero(\n" +" GLOBAL_SIZE_2_DIMS\n" +" __write_only image2d_t output\n" +" ) {\n" +" const int x=get_global_id(0);\n" +" const int y=get_global_id(1);\n" " \n" -" vstore2(CONVERT_FLOAT2(out3),0,output+out_offset);\n" -" }\n" -"#endif\n" +" DEAL_NON_UNIFORM_DIM2(x,y);\n" +" WI_DATA(output,(int2)(x,y),(OUTPUT_TYPE_I4)(0));\n" "}\n" -"__kernel void gemm_conv_c1_image(GLOBAL_SIZE_DIM2\n" -" __global const FLOAT* input,\n" -" __read_only image2d_t weight,\n" -" __global const float *dequantScaleOffset,\n" -" __global const FLOAT *bias,\n" -" __global FLOAT* output,\n" -" __private const int dstChannelC4,\n" -" __private const int srcChannelC4,\n" -" __private const int srcChannel,\n" -" __private const int batch,\n" -" __private const int height,\n" -" __private const int width,\n" -" __private const int blockNum,\n" -" __private const int blockDim) {\n" -" const int out_c_w_idx=get_global_id(0); //c/4 w\n" -" const int out_b_h_idx=get_global_id(1); //b h\n" -" UNIFORM_BOUNDRY_CHECK(out_c_w_idx,out_b_h_idx);\n" -" const int out_c_idx=out_c_w_idx/width;\n" -" const int out_w_idx=out_c_w_idx % width;\n" -"#ifdef BACTH_BLOCK4\n" -" const int out_b_idx=(out_b_h_idx/height) << 2;\n" -"#else\n" -" const int out_b_idx=out_b_h_idx/height;\n" -"#endif\n" -" const int out_h_idx=out_b_h_idx % height;\n" +"__kernel void raster_buffer_direct(\n" +" GLOBAL_SIZE_3_DIMS\n" +" __read_only image2d_t input,\n" +" __private const int inputOffset,\n" +" __private const int combineSrcOffset,\n" +" __private const int inputStride0,\n" +" __private const int inputStride1,\n" +" __private const int inputStride2,\n" +" __private const int src_width,\n" +" __private const int src_height,\n" +" __private const int src_channel,\n" +" __global OUTPUT_TYPE *output,\n" +" __private const int outputOffset,\n" +" __private const int combineDstOffset,\n" +" __private const int outputStride0,\n" +" __private const int outputStride1,\n" +" __private const int outputStride2,\n" +" __private const int global_size0\n" +" ) {\n" +" const int idx=get_global_id(0);\n" +" const int y=get_global_id(1);\n" +" const int z=get_global_id(2);\n" " \n" -" COMPUTE_FLOAT bias0=bias[out_c_idx];\n" -" COMPUTE_FLOAT out=bias0;\n" -" \n" -" int input_offset=((out_b_idx*srcChannelC4*height+out_h_idx)*width+out_w_idx)*4;\n" -" int out_offset=(((out_b_idx*dstChannelC4+out_c_idx/4)* height+out_h_idx)*width+out_w_idx)*4+(out_c_idx%4);\n" -" int wh=width*height*4;\n" -"#ifdef BACTH_BLOCK4\n" -" COMPUTE_FLOAT out1=bias0,out2=bias0,out3=bias0;\n" -" int input_offset1=(((out_b_idx+1)*srcChannelC4*height+out_h_idx)*width+out_w_idx)*4;\n" -" int input_offset2=(((out_b_idx+2)*srcChannelC4*height+out_h_idx)*width+out_w_idx)*4;\n" -" int input_offset3=(((out_b_idx+3)*srcChannelC4*height+out_h_idx)*width+out_w_idx)*4;\n" -" bool isValidBatch1=out_b_idx+1= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n" -"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" -"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n" -"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" -"__kernel void buffer_set_zero(\n" -" GLOBAL_SIZE_2_DIMS\n" -" __global OUTPUT_TYPE *output\n" -" ) {\n" -" const int x=get_global_id(0);\n" -" const int y=get_global_id(1);\n" -" \n" -" DEAL_NON_UNIFORM_DIM2(x,y);\n" -" \n" -" output[y*global_size_dim0+x]=(OUTPUT_TYPE)(0);\n" -"}\n" -"__kernel void image_set_zero(\n" -" GLOBAL_SIZE_2_DIMS\n" -" __write_only image2d_t output\n" -" ) {\n" -" const int x=get_global_id(0);\n" -" const int y=get_global_id(1);\n" -" \n" -" DEAL_NON_UNIFORM_DIM2(x,y);\n" -" WI_DATA(output,(int2)(x,y),(OUTPUT_TYPE_I4)(0));\n" -"}\n" -"__kernel void raster_buffer_direct(\n" -" GLOBAL_SIZE_3_DIMS\n" -" __read_only image2d_t input,\n" -" __private const int inputOffset,\n" -" __private const int combineSrcOffset,\n" -" __private const int inputStride0,\n" -" __private const int inputStride1,\n" -" __private const int inputStride2,\n" -" __private const int src_width,\n" -" __private const int src_height,\n" -" __private const int src_channel,\n" -" __global OUTPUT_TYPE *output,\n" -" __private const int outputOffset,\n" -" __private const int combineDstOffset,\n" -" __private const int outputStride0,\n" -" __private const int outputStride1,\n" -" __private const int outputStride2,\n" -" __private const int global_size0\n" -" ) {\n" -" const int idx=get_global_id(0);\n" -" const int y=get_global_id(1);\n" -" const int z=get_global_id(2);\n" -" \n" -" DEAL_NON_UNIFORM_DIM3(idx,y,z);\n" -" const int x=idx % global_size0;\n" -" const int id=idx/global_size0;\n" +" DEAL_NON_UNIFORM_DIM3(idx,y,z);\n" +" const int x=idx % global_size0;\n" +" const int id=idx/global_size0;\n" " \n" " int inputIndex=inputOffset+id*combineSrcOffset+z*inputStride0+y*inputStride1+x*inputStride2;\n" " int outputIndex=outputOffset+id*combineDstOffset+z*outputStride0+y*outputStride1+x*outputStride2;\n" @@ -3673,6 +2983,7 @@ const char* conv_2d_c1_subgroup_buf = " __private const int output_width,\n" " __private const int output_height,\n" " __private const int output_channel,\n" +" __private const int batch,\n" " __private const int x_blocks,\n" " __private const int input_pad_left,\n" " __private const int input_pad_right,\n" @@ -3699,11 +3010,11 @@ const char* conv_2d_c1_subgroup_buf = " const uint output_x_pitch=4;\n" " const uint output_y_pitch=output_x_pitch*output_width;\n" " const uint output_fs_pitch=output_y_pitch*output_height;\n" -" const uint output_b_pitch=output_fs_pitch*output_pack;\n" +" const uint output_b_pitch=output_fs_pitch*batch;\n" " \n" " \n" -" const uint output_offset=b*output_b_pitch +\n" -" f_block*4*output_fs_pitch +\n" +" const uint output_offset=b*output_fs_pitch +\n" +" f_block*4*output_b_pitch +\n" " y*output_y_pitch +\n" " x*output_x_pitch;\n" " const uint filter_isv_pitch=16;\n" @@ -3771,13 +3082,13 @@ const char* conv_2d_c1_subgroup_buf = " if ((f_block+1)*16 >= output_channel) {\n" " for (int i=0; i<2 && (x+i)= out_hw.y) return;\n" @@ -5086,6 +4404,7 @@ const char* conv_2d_int_buf = " __private const int2 in_hw,\n" " __private const int inChannel,\n" " __private const int in_c_blocks,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 stride_hw,\n" @@ -5127,7 +4446,7 @@ const char* conv_2d_int_buf = " //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n" " int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+kh_start)*filter_hw.y+0)*4;\n" " for(int iy=in_h_idx_start; iy= 4) {\n" @@ -5245,6 +4564,7 @@ const char* conv_2d_int_buf = " __private const int2 in_hw,\n" " __private const int inChannel,\n" " __private const int in_c_blocks,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 stride_hw,\n" @@ -5286,7 +4606,7 @@ const char* conv_2d_int_buf = " COMPUTE_FLOAT4 offset=(COMPUTE_FLOAT4)(ScaleOffset.s1,ScaleOffset.s3,ScaleOffset.s5,ScaleOffset.s7);\n" " //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n" " //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n" -" const int inp_offset_base=(out_b_idx*in_c_blocks+in_c_idx)*in_hw.x*in_hw.y*4;\n" +" const int inp_offset_base=(out_b_idx+in_c_idx*batch)*in_hw.x*in_hw.y*4;\n" " for(int iy=0; iy= 4){\n" @@ -5414,6 +4734,7 @@ const char* conv_2d_int_buf = " __private const int2 in_hw,\n" " __private const int inChannel,\n" " __private const int in_c_blocks,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 stride_hw,\n" @@ -5464,7 +4785,7 @@ const char* conv_2d_int_buf = " COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n" " //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n" " //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n" -" const int inp_offset_base=(out_b_idx*in_c_blocks+in_c_idx)*in_hw.x*in_hw.y*4;\n" +" const int inp_offset_base=(out_b_idx+in_c_idx*batch)*in_hw.x*in_hw.y*4;\n" " for(int iy=0; iy= 4){\n" @@ -5642,7 +4963,7 @@ const char* conv_2d_int_buf = " return;\n" " }\n" "#endif\n" -" out_offset=(((out_b_idx*out_c_blocks+out_c_idx+1)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " if(remain >= 4){\n" " vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n" @@ -5668,7 +4989,7 @@ const char* conv_2d_int_buf = " return;\n" " }\n" "#endif\n" -" out_offset=(((out_b_idx*out_c_blocks+out_c_idx+1)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out6),2*out_hw.y,output+out_offset);\n" @@ -5689,6 +5010,7 @@ const char* conv_2d_int_buf = " __private const int2 in_hw,\n" " __private const int inChannel,\n" " __private const int in_c_blocks,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 stride_hw,\n" @@ -5733,7 +5055,7 @@ const char* conv_2d_int_buf = " COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n" " //weights NC4HW4 [1,4*icC4,ocC4*kh*kw,1] xic4\n" " //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n" -" const int inp_offset_base=(out_b_idx*in_c_blocks+in_c_idx)*in_hw.x*in_hw.y*4;\n" +" const int inp_offset_base=(out_b_idx+in_c_idx*batch)*in_hw.x*in_hw.y*4;\n" " for(int iy=0; iy= 2){\n" @@ -5871,7 +5193,7 @@ const char* conv_2d_int_buf = " return;\n" " }\n" "#endif\n" -" out_offset=(((out_b_idx*out_c_blocks+out_c_idx+1)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " if(remain >= 2){\n" " vstore4(CONVERT_FLOAT4(out2),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out3),out_hw.y,output+out_offset);\n" @@ -5886,7 +5208,7 @@ const char* conv_2d_int_buf = " return;\n" " }\n" "#endif\n" -" out_offset=(((out_b_idx*out_c_blocks+out_c_idx+1)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " vstore4(CONVERT_FLOAT4(out2),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out3),out_hw.y,output+out_offset);\n" "#endif\n" @@ -5905,6 +5227,7 @@ const char* conv_2d_int_buf = " __private const int2 in_hw,\n" " __private const int inChannel,\n" " __private const int in_c_blocks,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 stride_hw,\n" @@ -5956,7 +5279,7 @@ const char* conv_2d_int_buf = " //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n" " int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+kh_start)*filter_hw.y+0)*4;\n" " for(int iy=in_h_idx_start; iy= 4){\n" @@ -6127,7 +5450,7 @@ const char* conv_2d_int_buf = "#ifdef CHANNEL_LEAVE\n" " if(out_c_idx+1 >= out_c_blocks)return;\n" "#endif\n" -" out_offset=(((out_b_idx*out_c_blocks+out_c_idx+1)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " if(remain >= 4){\n" " vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4,out5,out6,out7)),0,output+out_offset);\n" " }else if(remain == 3){\n" @@ -6143,7 +5466,7 @@ const char* conv_2d_int_buf = "#ifdef CHANNEL_LEAVE\n" " if(out_c_idx+1 >= out_c_blocks)return;\n" "#endif\n" -" out_offset=(((out_b_idx*out_c_blocks+out_c_idx+1)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4,out5,out6,out7)),0,output+out_offset);\n" "#endif\n" "}\n" @@ -6166,7 +5489,7 @@ const char* interp_buf = " __private const int input_width,\n" " __private const int out_height,\n" " __private const int out_width,\n" -" __private const int channelBlocks) {\n" +" __private const int batch) {\n" " const int output_channel_block_idx=get_global_id(0);\n" " const int output_width_block_idx=get_global_id(1);\n" " const int output_batch_height_block_idx=get_global_id(2);\n" @@ -6182,9 +5505,9 @@ const char* interp_buf = " const int in_h_index=min(max(0,(int)floor(in_h_idx)),input_height-1);\n" " const int in_w_index=min(max(0,(int)floor(in_w_idx)),input_width-1);\n" "#endif\n" -" const int inp_offset=((output_batch_idx*channelBlocks+output_channel_block_idx)*input_height+in_h_index)*input_width+in_w_index;\n" +" const int inp_offset=((output_batch_idx+output_channel_block_idx*batch)*input_height+in_h_index)*input_width+in_w_index;\n" " FLOAT4 value=vload4(inp_offset,input);\n" -" const int out_offset=((output_batch_idx*channelBlocks+output_channel_block_idx)*out_height+output_height_idx)*out_width+output_width_block_idx;\n" +" const int out_offset=((output_batch_idx+output_channel_block_idx*batch)*out_height+output_height_idx)*out_width+output_width_block_idx;\n" " vstore4(value,out_offset,output);\n" "}\n" "__kernel void bilinear_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT* input,\n" @@ -6197,7 +5520,7 @@ const char* interp_buf = " __private const int input_width,\n" " __private const int out_height,\n" " __private const int out_width,\n" -" __private const int channelBlocks) {\n" +" __private const int batch) {\n" " const int output_channel_block_idx=get_global_id(0);\n" " const int output_width_block_idx=get_global_id(1);\n" " const int output_batch_height_block_idx=get_global_id(2);\n" @@ -6215,7 +5538,7 @@ const char* interp_buf = " float factor_w=(in_w_idx-(int)floor(in_w_idx));\n" " float factor_h=(in_h_idx-(int)floor(in_h_idx));\n" " \n" -" const int inp_offset_base=(output_batch_idx*channelBlocks+output_channel_block_idx)*input_height;\n" +" const int inp_offset_base=(output_batch_idx+output_channel_block_idx*batch)*input_height;\n" " const int inp_offset_00=(inp_offset_base+in_h0_index)*input_width+in_w0_index;\n" " const int inp_offset_01=(inp_offset_base+in_h0_index)*input_width+in_w1_index;\n" " const int inp_offset_10=(inp_offset_base+in_h1_index)*input_width+in_w0_index;\n" @@ -6226,7 +5549,7 @@ const char* interp_buf = " FLOAT4 value_11=vload4(inp_offset_11,input);\n" " FLOAT4 value=CONVERT_FLOAT4((float4)((1.0-factor_w)*(1.0-factor_h))*convert_float4(value_00)+(float4)(factor_w*(1.0-factor_h))*convert_float4(value_01)+(float4)((1.0-factor_w)*factor_h)*convert_float4(value_10)+(float4)(factor_w*factor_h)*convert_float4(value_11));\n" " \n" -" const int out_offset=((output_batch_idx*channelBlocks+output_channel_block_idx)*out_height+output_height_idx)*out_width+output_width_block_idx;\n" +" const int out_offset=((output_batch_idx+output_channel_block_idx*batch)*out_height+output_height_idx)*out_width+output_width_block_idx;\n" " \n" " vstore4(value,out_offset,output);\n" "}\n" @@ -6244,7 +5567,7 @@ const char* interp_buf = " __private const int out_depth,\n" " __private const int out_height,\n" " __private const int out_width,\n" -" __private const int channelBlocks) {\n" +" __private const int batch) {\n" " const int output_channel_block_idx=get_global_id(0);\n" " const int output_height_width_block_idx=get_global_id(1);\n" " const int output_batch_depth_block_idx=get_global_id(2);\n" @@ -6259,9 +5582,9 @@ const char* interp_buf = " const int in_d_index=min(max(0,(int)floor(in_d_idx)),input_depth-1);\n" " const int in_h_index=min(max(0,(int)floor(in_h_idx)),input_height-1);\n" " const int in_w_index=min(max(0,(int)floor(in_w_idx)),input_width-1);\n" -" const int inp_offset=(((output_batch_idx*channelBlocks+output_channel_block_idx)\n" +" const int inp_offset=(((output_batch_idx+output_channel_block_idx*batch)\n" "*input_depth+in_d_index)*input_height+in_h_index)*input_width+in_w_index;\n" -" const int out_offset=(((output_batch_idx*channelBlocks+output_channel_block_idx)\n" +" const int out_offset=(((output_batch_idx+output_channel_block_idx*batch)\n" "*out_depth+output_depth_idx)*out_height+output_height_idx)*out_width+output_width_idx;\n" " FLOAT4 value=vload4(inp_offset,input);\n" " vstore4(value,out_offset,output);\n" @@ -6554,862 +5877,97 @@ const char* softmax = " \n" " /*Compute Result */\n" " for (int i=0; i= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n" -"#define GLOBAL_SIZE_DIM3 "" __private int global_size_dim0,__private int global_size_dim1,__private int global_size_dim2,\n" -"#define UNIFORM_BOUNDRY_CHECK3(index0, index1, index2) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1 || index2 >= global_size_dim2) { "" return; "" }\n" -"#define UCHAR16_TO_2CHAR16(a, b, c) "" a.s0 = (c.s0 >> 4) - 8; a.s1 = (c.s0 & 15) - 8; a.s2 = (c.s1 >> 4) - 8; a.s3 = (c.s1 & 15) - 8; a.s4 = (c.s2 >> 4) - 8; a.s5 = (c.s2 & 15) - 8; a.s6 = (c.s3 >> 4) - 8; a.s7 = (c.s3 & 15) - 8; "" a.s8 = (c.s4 >> 4) - 8; a.s9 = (c.s4 & 15) - 8; a.sa = (c.s5 >> 4) - 8; a.sb = (c.s5 & 15) - 8; a.sc = (c.s6 >> 4) - 8; a.sd = (c.s6 & 15) - 8; a.se = (c.s7 >> 4) - 8; a.sf = (c.s7 & 15) - 8; "" b.s0 = (c.s8 >> 4) - 8; b.s1 = (c.s8 & 15) - 8; b.s2 = (c.s9 >> 4) - 8; b.s3 = (c.s9 & 15) - 8; b.s4 = (c.sa >> 4) - 8; b.s5 = (c.sa & 15) - 8; b.s6 = (c.sb >> 4) - 8; b.s7 = (c.sb & 15) - 8; "" b.s8=(c.sc >> 4)-8; b.s9=(c.sc & 15)-8; b.sa=(c.sd >> 4)-8; b.sb=(c.sd & 15)-8; b.sc=(c.se >> 4)-8; b.sd=(c.se & 15)-8; b.se=(c.sf >> 4)-8; b.sf=(c.sf & 15)-8;\n" -"#define UCHAR8_TO_CHAR16(a, c) "" a.s0 = (c.s0 >> 4) - 8; a.s1 = (c.s0 & 15) - 8; a.s2 = (c.s1 >> 4) - 8; a.s3 = (c.s1 & 15) - 8; a.s4 = (c.s2 >> 4) - 8; a.s5 = (c.s2 & 15) - 8; a.s6 = (c.s3 >> 4) - 8; a.s7 = (c.s3 & 15) - 8; "" a.s8=(c.s4 >> 4)-8; a.s9=(c.s4 & 15)-8; a.sa=(c.s5 >> 4)-8; a.sb=(c.s5 & 15)-8; a.sc=(c.s6 >> 4)-8; a.sd=(c.s6 & 15)-8; a.se=(c.s7 >> 4)-8; a.sf=(c.s7 & 15)-8;\n" -"#define DOT16X16(a, b, c) "" c += dot(a.s0123, b.s0123); "" c += dot(a.s4567, b.s4567); "" c += dot(a.s89ab, b.s89ab); "" c += dot(a.scdef,b.scdef);\n" -"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" -"__kernel void reshape_nchw4_nhwc4(GLOBAL_SIZE_DIM3\n" -"__global const FLOAT* input,\n" -"__global FLOAT* output,\n" -"__private const int width_height,\n" -"__private const int batch,\n" -"__private const int channel,\n" -"__private const int channelC4){\n" -" const int x=get_global_id(0); //c\n" -" const int y=get_global_id(1); //b\n" -" const int wh=get_global_id(2); // w*h\n" -" UNIFORM_BOUNDRY_CHECK3(x,y,wh);\n" -" \n" -" const int x4=x << 2;\n" -" const int y4=y << 2;\n" -" const int channel4=channelC4*4;\n" -" const int stride=channel4*width_height;\n" -" const int input_offset=(y4*channel4+x4)*width_height+wh*4;\n" -" const int output_offset=((y*width_height+wh)*channel4+x4)*4;\n" -" FLOAT4 in0=vload4(0,input+input_offset);\n" -" FLOAT4 in1=(y4+1= channel){\n" -" FLOAT *in0_ptr=(FLOAT*)&in0;\n" -" FLOAT *in1_ptr=(FLOAT*)&in1;\n" -" FLOAT *in2_ptr=(FLOAT*)&in2;\n" -" FLOAT *in3_ptr=(FLOAT*)&in3;\n" -" int remain=x4+3-channel;\n" -" for(int i=remain; i >= 0; i--){\n" -" in0_ptr[3-remain]=0;\n" -" in1_ptr[3-remain]=0;\n" -" in2_ptr[3-remain]=0;\n" -" in3_ptr[3-remain]=0;\n" -" }\n" -" }\n" -"#endif\n" -" \n" -" FLOAT16 out=(FLOAT16)(in0.s0,in1.s0,in2.s0,in3.s0,in0.s1,in1.s1,in2.s1,in3.s1,in0.s2,in1.s2,in2.s2,in3.s2,in0.s3,in1.s3,in2.s3,in3.s3);\n" -" \n" -" vstore16(out,0,output+output_offset);\n" -"}\n" -"__kernel void reshape_nhwc4_nchw4(GLOBAL_SIZE_DIM3\n" -"__global const FLOAT* input,\n" -"__global FLOAT* output,\n" -"__private const int width_height,\n" -"__private const int batch,\n" -"__private const int channelC4){\n" -" const int x=get_global_id(0); //c\n" -" const int y=get_global_id(1); //b\n" -" const int wh=get_global_id(2); //w*h\n" -" UNIFORM_BOUNDRY_CHECK3(x,y,wh);\n" -" \n" -" const int x4=x << 2;\n" -" const int y4=y << 2;\n" -" const int channel4=channelC4*4;\n" -" const int stride=channel4*width_height;\n" -" const int input_offset=((y*width_height+wh)*channel4+x4)*4;\n" -" const int output_offset=(y4*channel4+x4)*width_height+wh*4;\n" -" FLOAT16 in=vload16(0,input+input_offset);\n" -" \n" -" FLOAT4 out0=(FLOAT4)(in.s0,in.s4,in.s8,in.sc);\n" -" FLOAT4 out1=(FLOAT4)(in.s1,in.s5,in.s9,in.sd);\n" -" FLOAT4 out2=(FLOAT4)(in.s2,in.s6,in.sa,in.se);\n" -" FLOAT4 out3=(FLOAT4)(in.s3,in.s7,in.sb,in.sf);\n" -" \n" -" vstore4(out0,0,output+output_offset);\n" -" if(y4+1 >= batch) return;\n" -" vstore4(out1,0,output+output_offset+stride);\n" -" if(y4+2 >= batch) return;\n" -" vstore4(out2,0,output+output_offset+2*stride);\n" -" if(y4+3 >= batch) return;\n" -" vstore4(out3,0,output+output_offset+3*stride);\n" -"}\n" -"__kernel void gemm_b4_c4_buf(GLOBAL_SIZE_DIM2\n" -" __global const FLOAT* input,\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" -" __global const char *weight,\n" -"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" __global const uchar *weight,\n" -"#endif\n" -" __global const float *dequantScaleOffset,\n" -" __global const FLOAT *bias,\n" -" __global FLOAT* output,\n" -" __private const int dstChannelC4,\n" -" __private const int srcChannelC4,\n" -" __private const int blockNum,\n" -" __private const int blockDim) {\n" -" const int x=get_global_id(0); //c\n" -" const int y=get_global_id(1); //b\n" -" UNIFORM_BOUNDRY_CHECK(x,y);\n" -" const int out_c_idx=x;\n" -" const int out_b_idx=y << 2;\n" -" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n" -" COMPUTE_FLOAT4 out=(COMPUTE_FLOAT4)bias0.s0;\n" -" COMPUTE_FLOAT4 out1=(COMPUTE_FLOAT4)bias0.s1,out2=(COMPUTE_FLOAT4)bias0.s2,out3=(COMPUTE_FLOAT4)bias0.s3;\n" -" \n" -" int input_offset=out_b_idx*srcChannelC4*4;\n" -" int out_offset=(out_b_idx*dstChannelC4+out_c_idx*4)*4;\n" -"#if (defined USE_LOW_BIT_WEIGHT_INT4)\n" -" int weight_offset=out_c_idx*4*8;\n" -" int weight_oc_offset=dstChannelC4*32;\n" -"#else\n" -" int weight_offset=out_c_idx*4*16;\n" -" int weight_oc_offset=dstChannelC4*64;\n" -"#endif\n" -" const int loop=(blockDim+15)/16;\n" -"#ifdef INPUT_CHANNEL_LEAVE\n" -" const int loop_end=max(loop-1,0);\n" -" const int remain=blockDim-loop_end*16;\n" -"#else\n" -" const int loop_end=loop;\n" -"#endif\n" -" \n" -" for (int i=0; i= size){\n" +" int remain=size-offset;\n" +" float4 in0,in1;\n" +" float* in0_ptr=(float*)&in0;\n" +" float* in1_ptr=(float*)&in1;\n" " \n" -" const int loop=(blockDim+15)/16;\n" -" #ifdef INPUT_CHANNEL_LEAVE\n" -" const int loop_end=max(loop-1,0);\n" -" const int remain=blockDim-loop_end*16;\n" +" for(int i=0; i> 2;\n" -" const int offset=(((batch_idx*channel4+channel_idx)*shape.y+h_idx)*shape.z+w_idx)*4;\n" +" const int offset=(((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n" " const int dst_offset=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width+w_idx+output_pad_left)*16+(channel_idx % 4)*4;\n" " \n" " float4 in0=convert_float4(vload4(0,input0+offset*isFull.x));\n" @@ -7657,7 +6189,7 @@ const char* binary_subgroup_buf = " const int channel_idx=get_global_id(1);\n" " const int src_width=shape.z+input1_pad_left+input1_pad_right;\n" " const int channe_out_idx=channel_idx >> 2;\n" -" const int offset0=(((batch_idx*channel4+channel_idx)*shape.y+h_idx)*shape.z+w_idx)*4;\n" +" const int offset0=(((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n" " const int offset1=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*src_width+w_idx+input1_pad_left)*16+(channel_idx % 4)*4;\n" " float4 in0=convert_float4(vload4(0,input0+offset0*isFull.x));\n" " float4 in1=convert_float4(vload4(0,input1+offset1*isFull.y));\n" @@ -7691,7 +6223,7 @@ const char* binary_subgroup_buf = " const int channel_idx=get_global_id(1);\n" " const int src_width=shape.z+input0_pad_left+input0_pad_right;\n" " const int channe_out_idx=channel_idx >> 2;\n" -" const int offset1=(((batch_idx*channel4+channel_idx)*shape.y+h_idx)*shape.z+w_idx)*4;\n" +" const int offset1=(((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n" " const int offset0=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*src_width+w_idx+input0_pad_left)*16+(channel_idx % 4)*4;\n" " \n" " float4 in0=convert_float4(vload4(0,input0+offset0*isFull.x));\n" @@ -7728,7 +6260,7 @@ const char* binary_subgroup_buf = " const int src_width=shape.z+input1_pad_left+input1_pad_right;\n" " const int dst_width=shape.z+output_pad_left+output_pad_right;\n" " const int channe_out_idx=channel_idx >> 2;\n" -" const int offset0=(((batch_idx*channel4+channel_idx)*shape.y+h_idx)*shape.z+w_idx)*4;\n" +" const int offset0=(((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n" " const int offset1=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*src_width+w_idx+input1_pad_left)*16+(channel_idx % 4)*4;\n" " const int dst_offset=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width+w_idx+output_pad_left)*16+(channel_idx % 4)*4;\n" " \n" @@ -7776,7 +6308,7 @@ const char* binary_subgroup_buf = " const int src_width=shape.z+input0_pad_left+input0_pad_right;\n" " const int dst_width=shape.z+output_pad_left+output_pad_right;\n" " const int channe_out_idx=channel_idx >> 2;\n" -" const int offset1=(((batch_idx*channel4+channel_idx)*shape.y+h_idx)*shape.z+w_idx)*4;\n" +" const int offset1=(((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n" " const int offset0=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*src_width+w_idx+input0_pad_left)*16+(channel_idx % 4)*4;\n" " const int dst_offset=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width+w_idx+output_pad_left)*16+(channel_idx % 4)*4;\n" " \n" @@ -7819,7 +6351,7 @@ const char* binary_subgroup_buf = " const int batch_idx=get_global_id(2);\n" " const int channel_idx=get_global_id(1);\n" " \n" -" const int offset0=(((batch_idx*channel4+channel_idx)*shape.y+h_idx)*shape.z+w_idx)*4;\n" +" const int offset0=(((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n" " const int offset1=channel_idx*4;\n" " \n" " float4 in0=convert_float4(vload4(0,input0+offset0));\n" @@ -7844,7 +6376,7 @@ const char* binary_subgroup_buf = " const int dst_width=shape.z+output_pad_left+output_pad_right;\n" " const int channe_out_idx=channel_idx >> 2;\n" " \n" -" const int offset0=(((batch_idx*channel4+channel_idx)*shape.y+h_idx)*shape.z+w_idx)*4;\n" +" const int offset0=(((batch_idx+channel_idx*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n" " const int offset1=channel_idx*4;\n" " const int offset=(((batch_idx*channel16+channe_out_idx)*shape.y+h_idx)*dst_width+w_idx+output_pad_left)*16+(channel_idx % 4)*4;\n" " float4 in0=convert_float4(vload4(0,input0+offset0));\n" @@ -7920,10 +6452,10 @@ const char* binary_subgroup_buf = " const int channel_idx=get_group_id(1);\n" " const int sglid=get_sub_group_local_id();\n" " const int src_width=shape.z+input0_pad_left+input0_pad_right;\n" -" const int width_height=shape.z*shape.y*4;\n" +" const int batch_width_height=shape.x*shape.z*shape.y*4;\n" " const int offset0=(((batch_idx*channel16+channel_idx)*shape.y+h_idx)*src_width+w_idx+input0_pad_left)*16;\n" " const int offset1=channel_idx*16;\n" -" const int offset=(((batch_idx*channel4+(channel_idx<<2))*shape.y+h_idx)*shape.z+w_idx)*4;\n" +" const int offset=(((batch_idx+(channel_idx<<2)*shape.x)*shape.y+h_idx)*shape.z+w_idx)*4;\n" " float4 in0=convert_float4(AS_INPUT_DATA4(INTEL_SUB_GROUP_READ4((__global INTEL_DATA*)(input0+offset0))));\n" " float4 in1=(float4)(AS_INPUT_DATA(INTEL_SUB_GROUP_READ((__global INTEL_DATA*)(input1+offset1))));\n" " \n" @@ -7932,7 +6464,7 @@ const char* binary_subgroup_buf = " const int lid_y=sglid/4;\n" " int block_size=w_idx+4>shape.z ? (shape.z % 4) : 4;\n" " for (int i=0; ishape.z ? (shape.z % 4) : 4;\n" " for (int i=0; i= input_shape.x) {\n" @@ -8376,10 +6911,10 @@ const char* pooling_subgroup_buf = " }\n" " #endif\n" " \n" -" const int out_offset=(((b_idx*in_channel_block+c_idx)*output_shape.x+oh_idx)* output_shape.y+ow_idx+output_pad_left)*4;\n" +" const int out_offset=(((b_idx+c_idx*batch)*output_shape.x+oh_idx)* output_shape.y+ow_idx+output_pad_left)*4;\n" " vstore4(CONVERT_FLOAT4(result),0,output+out_offset);\n" " #if RETURN_REDICE\n" -" vstore4(CONVERT_FLOAT4(redice),0,rediceOutput+(((b_idx*in_channel_block+c_idx)*output_shape.x+oh_idx)* output_shape.y+ow_idx)*4);\n" +" vstore4(CONVERT_FLOAT4(redice),0,rediceOutput+(((b_idx+c_idx*batch)*output_shape.x+oh_idx)* output_shape.y+ow_idx)*4);\n" " #endif\n" "}\n" "__kernel void pooling_c4_c16(GLOBAL_SIZE_3_DIMS __global const FLOAT *input,\n" @@ -8389,6 +6924,7 @@ const char* pooling_subgroup_buf = " __global FLOAT *output,\n" " __global FLOAT *rediceOutput,\n" " __private const int channel,\n" +" __private const int batch,\n" " __private const int in_channel_block,\n" " __private const int out_channel_block,\n" " __private const int input_pad_left,\n" @@ -8409,7 +6945,7 @@ const char* pooling_subgroup_buf = " \n" " #ifdef POOL_AVG\n" " COMPUTE_FLOAT4 result=(COMPUTE_FLOAT4)(0);\n" -" const int inp_offset=(((b_idx*in_channel_block+c_idx)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4;\n" +" const int inp_offset=(((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4;\n" " #ifdef COUNT_INCLUDE_PADDING\n" " int total_count=(min(ih_start+KERNEL_Y,input_shape.x+pad_shape.x)-ih_start)*(min(iw_start+KERNEL_X,input_shape.y+pad_shape.y)-iw_start);\n" "#else\n" @@ -8438,7 +6974,7 @@ const char* pooling_subgroup_buf = " #if RETURN_REDICE\n" " int4 redice=(int4)0;\n" " #endif\n" -" const int inp_offset=(((b_idx*in_channel_block+c_idx)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4;\n" +" const int inp_offset=(((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4;\n" " for(int kh=0; kh= input_shape.x) {\n" @@ -8482,6 +7018,7 @@ const char* pooling_subgroup_buf = " __global FLOAT *output,\n" " __global FLOAT *rediceOutput,\n" " __private const int channel,\n" +" __private const int batch,\n" " __private const int in_channel_block,\n" " __private const int out_channel_block,\n" " __private const int input_pad_left,\n" @@ -8624,6 +7161,7 @@ const char* pooling_subgroup_buf = " __global FLOAT *output,\n" " __global FLOAT *rediceOutput,\n" " __private const int channel,\n" +" __private const int batch,\n" " __private const int in_channel_block,\n" " __private const int out_channel_block,\n" " __private const int input_pad_left,\n" @@ -8704,18 +7242,18 @@ const char* pooling_subgroup_buf = " const uint lid_x=sglid % 4;\n" " const uint lid_y=sglid/4;\n" " \n" -" const int out_offset=(((b_idx*out_channel_block+c_idx*4)*output_shape.x+oh_idx)* output_shape.y+ow_idx+output_pad_left)*4;\n" -" const int width_height=output_shape.y*output_shape.x*4;\n" +" const int out_offset=(((b_idx+c_idx*4*batch)*output_shape.x+oh_idx)* output_shape.y+ow_idx+output_pad_left)*4;\n" +" const int batch_width_height=batch*output_shape.y*output_shape.x*4;\n" "#if RETURN_REDICE\n" -" const int redice_offset=(((b_idx*out_channel_block+c_idx*4)*output_shape.x+oh_idx)* output_shape.y+ow_idx)*4;\n" +" const int redice_offset=(((b_idx+c_idx*4*batch)*output_shape.x+oh_idx)* output_shape.y+ow_idx)*4;\n" "#endif\n" "#if OUTPUT_LEFTOVERS\n" " if ((c_idx+1)*16 >= channel) {\n" " for (int i=0; i<8; i++) {\n" " if ((c_idx*16+lid_y*4+lid_x= input_shape.x) {\n" @@ -8811,7 +7349,7 @@ const char* pooling_buf = " }\n" " #endif\n" " \n" -" const int out_offset=(((b_idx*channel_block+c_idx)*output_shape.x+oh_idx)* output_shape.y+ow_idx)*4;\n" +" const int out_offset=(((b_idx+c_idx*batch)*output_shape.x+oh_idx)* output_shape.y+ow_idx)*4;\n" " vstore4(CONVERT_FLOAT4(result),0,output+out_offset);\n" " #if RETURN_REDICE\n" " vstore4(CONVERT_FLOAT4(redice),0,rediceOutput+out_offset);\n" @@ -8826,7 +7364,7 @@ const char* pooling_buf = " __private const int2 kernel_shape,\n" " __global FLOAT *output,\n" " __global FLOAT *rediceOutput,\n" -" __private const int channel_block) {\n" +" __private const int batch) {\n" " const int local_id=get_local_id(0);\n" " const int output_channel_idx=get_global_id(1);\n" " const int output_batch_idx=get_global_id(2);\n" @@ -8840,7 +7378,7 @@ const char* pooling_buf = "#endif\n" "#endif\n" " COMPUTE_FLOAT4 local sum[LOCAL_SIZE];\n" -" const int inp_offset=((output_batch_idx*channel_block+output_channel_idx)*input_shape.x)*input_shape.y*4;\n" +" const int inp_offset=((output_batch_idx+output_channel_idx*batch)*input_shape.x)*input_shape.y*4;\n" " const int size=input_shape.x*input_shape.y;\n" " for(int i=local_id; i= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" +"#define GLOBAL_SIZE_2_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,\n" +"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n" "inline float4 gelu(float4 in){\n" " float4 value=0.79788458f*(0.044715f*in*in*in+in);\n" " float4 x2=value*value;\n" @@ -9323,20 +7861,35 @@ const char* unary_buf = " (value*(135135.0f+x2*(17325.0f+x2*(378.0f+x2))))/(135135.0f+x2*(62370.0f+x2*(3150.0f+x2*28.0f))));\n" " return (1.0f+dst)*in*0.5f;\n" "}\n" -"__kernel void unary_buf(GLOBAL_SIZE_3_DIMS\n" +"__kernel void unary_buf(GLOBAL_SIZE_2_DIMS\n" " __global const INPUT_TYPE *input,\n" " __global OUTPUT_TYPE *output,\n" -" __private const int height) {\n" -" const int channel_block_idx=get_global_id(0);\n" -" const int w=get_global_id(1);\n" -" const int hb=get_global_id(2);\n" -" DEAL_NON_UNIFORM_DIM3(channel_block_idx,w,hb);\n" -" const int batch_idx=hb/height;\n" -" const int height_idx=hb % height;\n" -" const int offset=(((batch_idx*global_size_dim0+channel_block_idx)*height+height_idx)*global_size_dim1+w)*4;\n" +" __private const int size) {\n" +" const int x=get_global_id(0);\n" +" const int y=get_global_id(1);\n" +" DEAL_NON_UNIFORM_DIM2(x,y);\n" +" const int offset=x << 2;\n" +"#ifdef PACK_LEAVE\n" +" if(offset+3 >= size){\n" +" int remain=size-offset;\n" +" float4 in;\n" +" float* in_ptr=(float*)∈\n" +" for(int i=0; i= in_hw.x) continue;\n" " \n" -" int inp_offset=(((b_idx*c_blocks+c_idx)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" +" int inp_offset=(((b_idx+c_idx*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" " for (int kw=0; kw= 4) {\n" " vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n" @@ -9438,7 +7991,7 @@ const char* depthwise_conv2d_buf = " __global const FLOAT *bias,\n" " __global FLOAT *output,\n" " __private const int2 in_hw,\n" -" __private const int channel,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 pad_hw,\n" @@ -9465,7 +8018,7 @@ const char* depthwise_conv2d_buf = " const int in_h_cur=in_h_start+kh*dilate_hw.x;\n" " if(in_h_cur<0 || in_h_cur >= in_hw.x) continue;\n" " \n" -" int inp_offset=(((b_idx*c_blocks+c_idx)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" +" int inp_offset=(((b_idx+c_idx*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" " for (int kw=0; kw= 2) {\n" " vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n" @@ -9502,7 +8055,7 @@ const char* depthwise_conv2d_buf = " __global const FLOAT *bias,\n" " __global FLOAT *output,\n" " __private const int2 in_hw,\n" -" __private const int channel,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 pad_hw,\n" @@ -9526,7 +8079,7 @@ const char* depthwise_conv2d_buf = " const int in_h_cur=in_h_start+kh*dilate_hw.x;\n" " if(in_h_cur<0 || in_h_cur >= in_hw.x) continue;\n" " \n" -" int inp_offset=(((b_idx*c_blocks+c_idx)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" +" int inp_offset=(((b_idx+c_idx*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" " for (int kw=0; kw= in_hw.x) continue;\n" " \n" -" int inp_offset_c0=(((b_idx*c_blocks+c_idx+0)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" -" int inp_offset_c1=(((b_idx*c_blocks+c_idx+1)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" +" int inp_offset_c0=(((b_idx+c_idx*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" +" int inp_offset_c1=(((b_idx+(c_idx+1)*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" " for (int kw=0; kw= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0,input+inp_offset_c0));\n" @@ -9636,7 +8189,7 @@ const char* depthwise_conv2d_buf = " outValue6=clamp(outValue6,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" " outValue7=clamp(outValue7,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" "#endif\n" -" int out_offset=(((b_idx*c_blocks+c_idx)*out_hw.x+out_h_idx)*out_hw.y+out_w4_idx)*4;\n" +" int out_offset=(((b_idx+c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w4_idx)*4;\n" " const int remain=out_hw.y-out_w4_idx;\n" " if (remain >= 4) {\n" " vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n" @@ -9656,7 +8209,7 @@ const char* depthwise_conv2d_buf = " \n" " if(c_idx+1 >= c_blocks) return;\n" " \n" -" out_offset += out_hw.x*out_hw.y*4;\n" +" out_offset += batch*out_hw.x*out_hw.y*4;\n" " if (remain >= 4) {\n" " vstore4(CONVERT_FLOAT4(outValue4),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(outValue5),1,output+out_offset);\n" @@ -9679,7 +8232,7 @@ const char* depthwise_conv2d_buf = " __global const FLOAT *bias,\n" " __global FLOAT *output,\n" " __private const int2 in_hw,\n" -" __private const int channel,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 pad_hw,\n" @@ -9707,8 +8260,8 @@ const char* depthwise_conv2d_buf = " const int in_h_cur=in_h_start+kh;\n" " if(in_h_cur<0 || in_h_cur >= in_hw.x) continue;\n" " \n" -" int inp_offset_c0=(((b_idx*c_blocks+c_idx+0)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" -" int inp_offset_c1=(((b_idx*c_blocks+c_idx+1)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" +" int inp_offset_c0=(((b_idx+c_idx*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" +" int inp_offset_c1=(((b_idx+(c_idx+1)*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" " for (int kw=0; kw= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0,input+inp_offset_c0));\n" @@ -9740,7 +8293,7 @@ const char* depthwise_conv2d_buf = " outValue4=clamp(outValue4,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" " outValue5=clamp(outValue5,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" "#endif\n" -" int out_offset=(((b_idx*c_blocks+c_idx)*out_hw.x+out_h_idx)*out_hw.y+out_w2_idx)*4;\n" +" int out_offset=(((b_idx+c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w2_idx)*4;\n" " const int remain=out_hw.y-out_w2_idx;\n" " if (remain >= 2) {\n" " vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n" @@ -9751,7 +8304,7 @@ const char* depthwise_conv2d_buf = " \n" " if(c_idx+1 >= c_blocks) return;\n" " \n" -" out_offset += out_hw.x*out_hw.y*4;\n" +" out_offset += batch*out_hw.x*out_hw.y*4;\n" " if (remain >= 2) {\n" " vstore4(CONVERT_FLOAT4(outValue4),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(outValue5),1,output+out_offset);\n" @@ -9765,7 +8318,7 @@ const char* depthwise_conv2d_buf = " __global const FLOAT *bias,\n" " __global FLOAT *output,\n" " __private const int2 in_hw,\n" -" __private const int channel,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 pad_hw,\n" @@ -9796,7 +8349,7 @@ const char* depthwise_conv2d_buf = " const int in_h_cur=in_h_start+kh;\n" " if(in_h_cur<0 || in_h_cur >= in_hw.x) continue;\n" " \n" -" int inp_offset=(((b_idx*c_blocks+c_idx)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" +" int inp_offset=(((b_idx+c_idx*batch)*in_hw.x+in_h_cur)* in_hw.y+in_w_start_0)*4;\n" " for (int kw=0; kw= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(kw+0,input+inp_offset));\n" @@ -9824,7 +8377,7 @@ const char* depthwise_conv2d_buf = " outValue2=clamp(outValue2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" " outValue3=clamp(outValue3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" "#endif\n" -" const int out_offset=(((b_idx*c_blocks+c_idx)*out_hw.x+out_h_idx)*out_hw.y+out_w4_idx)*4;\n" +" const int out_offset=(((b_idx+c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w4_idx)*4;\n" " const int remain=out_hw.y-out_w4_idx;\n" " if (remain >= 4) {\n" " vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n" @@ -9848,7 +8401,7 @@ const char* depthwise_conv2d_buf = " __global const FLOAT *bias,\n" " __global FLOAT *output,\n" " __private const int2 in_hw,\n" -" __private const int channel,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 pad_hw,\n" @@ -9870,7 +8423,7 @@ const char* depthwise_conv2d_buf = " const int in_h_start=out_h_idx-pad_hw.x;\n" " COMPUTE_FLOAT4 inValue0,inValue1,inValue2,inValue3;\n" " //first line\n" -" const int inp_offset=(((b_idx*c_blocks+c_idx)*in_hw.x+in_h_start)* in_hw.y+in_w_start_0)*4;\n" +" const int inp_offset=(((b_idx+c_idx*batch)*in_hw.x+in_h_start)* in_hw.y+in_w_start_0)*4;\n" " inValue0=(in_h_start<0 || in_w_start_0<0 ) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset));\n" " inValue1=(in_h_start<0 || in_w_start_0+1 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(1,input+inp_offset));\n" " inValue2=(in_h_start<0 || in_w_start_0+2 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(2,input+inp_offset));\n" @@ -9935,7 +8488,7 @@ const char* depthwise_conv2d_buf = " outValue0=clamp(outValue0,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" " outValue1=clamp(outValue1,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" "#endif\n" -" const int out_offset=(((b_idx*c_blocks+c_idx)*out_hw.x+out_h_idx)*out_hw.y+out_w2_idx)*4;\n" +" const int out_offset=(((b_idx+c_idx*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w2_idx)*4;\n" " const int remain=out_hw.y-out_w2_idx;\n" " if (remain >= 2) {\n" " vstore4(CONVERT_FLOAT4(outValue0),0,output+out_offset);\n" @@ -9950,7 +8503,7 @@ const char* depthwise_conv2d_buf = " __global const FLOAT *bias,\n" " __global FLOAT *output,\n" " __private const int2 in_hw,\n" -" __private const int channel,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 pad_hw,\n" @@ -9976,7 +8529,7 @@ const char* depthwise_conv2d_buf = " const int in_h_start=out_h2_idx-pad_hw.x;\n" " COMPUTE_FLOAT4 inValue0,inValue1,inValue2,inValue3;\n" " //first line\n" -" const int inp_offset=(((b_idx*c_blocks+c_idx)*in_hw.x+in_h_start)* in_hw.y+in_w_start)*4;\n" +" const int inp_offset=(((b_idx+c_idx*batch)*in_hw.x+in_h_start)* in_hw.y+in_w_start)*4;\n" " inValue0=(in_h_start<0 || in_w_start<0 ) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(0,input+inp_offset));\n" " inValue1=(in_h_start<0 || in_w_start+1 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(1,input+inp_offset));\n" " inValue2=(in_h_start<0 || in_w_start+2 >= in_hw.y) ? (COMPUTE_FLOAT4)0 : CONVERT_COMPUTE_FLOAT4(vload4(2,input+inp_offset));\n" @@ -10059,7 +8612,7 @@ const char* depthwise_conv2d_buf = " outValue2=clamp(outValue2,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" " outValue3=clamp(outValue3,(COMPUTE_FLOAT4)0,(COMPUTE_FLOAT4)6);\n" "#endif\n" -" const int out_offset=(((b_idx*c_blocks+c_idx)*out_hw.x+out_h2_idx)*out_hw.y+out_w2_idx)*4;\n" +" const int out_offset=(((b_idx+c_idx*batch)*out_hw.x+out_h2_idx)*out_hw.y+out_w2_idx)*4;\n" " const int remain_w=out_hw.y-out_w2_idx;\n" " const int remain_h=out_hw.x-out_h2_idx;\n" " if(remain_w >= 2 && remain_h >= 2) {\n" @@ -10171,6 +8724,7 @@ const char* winogradTransform_buf = " __private const int srcWidth,// 6\n" " __private const int srcHeight,__private const int srcChannelC4,\n" " __private const int dstHeightPad,__private const int srcChannelPad,\n" +" __private const int batch,\n" " __private const int batchOffset) {\n" " int2 pos=(int2)(get_global_id(0),get_global_id(1)); \n" " UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n" @@ -10206,7 +8760,7 @@ const char* winogradTransform_buf = " FLOAT4 S23;\n" " FLOAT4 S33;\n" " \n" -" int inp_offset=(((batchIndex*srcChannelC4+srcZ)*srcHeight+syStart)*srcWidth+sxStart)*4;\n" +" int inp_offset=(((batchIndex+srcZ*batch)*srcHeight+syStart)*srcWidth+sxStart)*4;\n" " {\n" " int sx=0+sxStart;\n" " int sy=0+syStart;\n" @@ -10451,6 +9005,7 @@ const char* winogradTransform_buf = " __private const int dstChannelC4,\n" " __private const int srcWidthPad,\n" " __private const int dstChannelPad,\n" +" __private const int batch,\n" " __private const int batchOffset) {\n" " int2 pos=(int2)(get_global_id(0),get_global_id(1));\n" " UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n" @@ -10498,7 +9053,7 @@ const char* winogradTransform_buf = " \n" " //NC4HW4 [batch,dstChannelC4,dstHeight,dstWidth]\n" " //index: [batchIndex,oz,oyStart,oxStart]\n" -" int out_offset=(((batchIndex*dstChannelC4+ oz)*dstHeight+oyStart)*dstWidth+oxStart)*4;\n" +" int out_offset=(((batchIndex+oz*batch)*dstHeight+oyStart)*dstWidth+oxStart)*4;\n" " {\n" " int ox=oxStart+0;\n" " int oy=oyStart+0;\n" @@ -10578,6 +9133,7 @@ const char* winogradTransform_subgroup_buf = " __private const int srcWidth,// 6\n" " __private const int srcHeight,__private const int srcChannelC4,__private const int srcChannelC16,__private const int dstHeight,\n" " __private const int batchOffset,\n" +" __private const int batch,\n" " __private const int input_pad_left,__private const int input_pad_right) {\n" " int2 pos=(int2)(get_global_id(0),get_global_id(1)); \n" " UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n" @@ -10657,6 +9213,7 @@ const char* winogradTransform_subgroup_buf = " __private const int dstHeight,\n" " __private const int dstChannelC4,__private const int dstChannelC16,__private const int srcWidth,\n" " __private const int batchOffset,\n" +" __private const int batch,\n" " __private const int output_pad_left,__private const int output_pad_right) {\n" " int2 pos=(int2)(get_global_id(0),get_global_id(1));\n" " UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n" @@ -10773,6 +9330,7 @@ const char* winogradTransform_subgroup_buf = " __private const int srcWidth,// 6\n" " __private const int srcHeight,__private const int srcChannelC4,__private const int srcChannelC16,__private const int dstHeight,\n" " __private const int batchOffset,\n" +" __private const int batch,\n" " __private const int input_pad_left,__private const int input_pad_right) {\n" " int2 pos=(int2)(get_global_id(0),get_global_id(1)); \n" " UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n" @@ -10800,7 +9358,7 @@ const char* winogradTransform_subgroup_buf = " FLOAT4 S23;\n" " FLOAT4 S33;\n" " \n" -" int inp_offset=(((batchOffset*srcChannelC4+pos.y)*srcHeight+syStart)*srcWidth+sxStart)*4;\n" +" int inp_offset=(((batchOffset+pos.y*batch)*srcHeight+syStart)*srcWidth+sxStart)*4;\n" " {\n" " int sx=0+sxStart;\n" " int sy=0+syStart;\n" @@ -10949,6 +9507,7 @@ const char* winogradTransform_subgroup_buf = " __private const int dstHeight,\n" " __private const int dstChannelC4,__private const int dstChannelC16,__private const int srcWidth,\n" " __private const int batchOffset,\n" +" __private const int batch,\n" " __private const int output_pad_left,__private const int output_pad_right) {\n" " int2 pos=(int2)(get_global_id(0),get_global_id(1));\n" " UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n" @@ -10992,7 +9551,7 @@ const char* winogradTransform_subgroup_buf = " \n" " //NC4HW4 [batch,dstChannelC4,dstHeight,dstWidth]\n" " //index: [batchOffset,pos.y,oyStart,oxStart]\n" -" int out_offset=(((batchOffset*dstChannelC4+ pos.y)*dstHeight+oyStart)*dstWidth+oxStart)*4;\n" +" int out_offset=(((batchOffset+ pos.y*batch)*dstHeight+oyStart)*dstWidth+oxStart)*4;\n" " {\n" " int ox=oxStart+0;\n" " int oy=oyStart+0;\n" @@ -11126,7 +9685,7 @@ const char* splitgelu_buf = "#ifdef MNN_SUPPORT_FP16\n" "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" "#endif\n" -"__kernel void splitgelu_buf(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n" +"__kernel void splitgelu_buf(__private int global_dim0,__private int global_dim1,\n" " __global const FLOAT*input,\n" " #ifdef DOUBLE_INPUTS\n" " __global const FLOAT*input1,\n" @@ -11134,46 +9693,55 @@ const char* splitgelu_buf = " __global FLOAT*output,\n" " __private const int4 shape\n" "){\n" -" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n" -" if (pos.x> 2;\n" -" const int area_4=(shape.z+3) >> 2;\n" -" const int in_offset=((b*channel_4+c_4)*area_4*2+hw_4)*16;\n" -" const int out_offset=((b*channel_4+c_4)*area_4+hw_4)*16;\n" +" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n" +" if (pos.x> 2;\n" -" const int in_offset=((b*channel_4+c_4)*shape.z*2+hw)*4;\n" -" const int out_offset=((b*channel_4+c_4)*shape.z+hw)*4;\n" -" \n" +"// The product of W and H is a multiple of 4\n" +"#elif defined (WH_4)\n" +" const int in_offset=bc*shape.z*2+h*4;\n" +" const int out_offset=bc*shape.z+h*4;\n" " float4 valueL=convert_float4(vload4(0,input+in_offset));\n" -" float4 valueR=convert_float4(vload4(shape.z,input+in_offset));\n" +" float4 valueR=convert_float4(vload4(0,input+in_offset+shape.z));\n" " #ifdef DOUBLE_INPUTS\n" -" float valueConstL=input1[hw];\n" -" float valueConstR=input1[shape.z+hw];\n" -" valueL += (float4)valueConstL;\n" -" valueR += (float4)valueConstR;\n" +" float4 valueConstL=convert_float4(vload4(h,input1));\n" +" float4 valueConstR=convert_float4(vload4(h,input1+shape.z));\n" +" valueL += valueConstL;\n" +" valueR += valueConstR;\n" " #endif\n" " float4 out=(erf(valueR*(float4)0.7071067932881648)+(float4)1.0)*valueR*(float4)0.5;\n" " out *= valueL;\n" " vstore4(CONVERT_FLOAT4(out),0,output+out_offset);\n" +"#else\n" +" const int in_offset=bc*shape.z*2+h;\n" +" const int out_offset=bc*shape.z+h;\n" +" \n" +" float valueL=(float)input[in_offset];\n" +" float valueR=(float)input[in_offset+shape.z];\n" +" #ifdef DOUBLE_INPUTS\n" +" float valueConstL=input1[h];\n" +" float valueConstR=input1[shape.z+h];\n" +" valueL += valueConstL;\n" +" valueR += valueConstR;\n" +" #endif\n" +" float out=(erf(valueR*0.7071067932881648)+1.0)*valueR*0.5;\n" +" out *= valueL;\n" +" output[out_offset]=out;\n" "#endif\n" " }\n" "}\n" @@ -11492,27 +10060,28 @@ const char* buffer_convert_quant = " __write_only image2d_t output,\n" " __private const int input_channel,\n" " __private const int output_channel) {\n" -" int x=get_global_id(0); // ic/16\n" +" int x=get_global_id(0); // ic/32\n" " int y=get_global_id(1); // oc\n" " DEAL_NON_UNIFORM_DIM2(x,y);\n" -" const int xin=x << 4;\n" "#ifdef USE_LOW_BIT_WEIGHT_INT4\n" +" const int xin=x << 5;\n" "#ifdef CHANNEL_LEAVE\n" -" uchar8 out=0;\n" +" uchar16 out=0;\n" " uchar *out_ptr=(uchar*)&out;\n" -" for(int i=0; i<8; ++i){\n" +" for(int i=0; i<16; ++i){\n" " int index0=y*input_channel+xin+i*2;\n" " int index1=y*input_channel+xin+i*2+1;\n" " uchar s0=input_ptr[index0/2];\n" " uchar s1=input_ptr[index1/2];\n" " out_ptr[i]=((index0 % 2) == 0 ? (s0 & 0xf0) : (s0 << 4)) | ((index1 % 2) == 0 ? (s1 >> 4) : (s1 & 0x0f));\n" " }\n" -" write_imageui(output,(int2)(y,x),convert_uint4(as_ushort4(out)));\n" +" write_imagei(output,(int2)(y,x),as_int4(out));\n" "#else\n" " const int inputOffset=(y*input_channel+xin)/2;\n" -" write_imageui(output,(int2)(y,x),convert_uint4(as_ushort4(vload8(0,input_ptr+inputOffset))));\n" +" write_imagei(output,(int2)(y,x),as_int4(vload16(0,input_ptr+inputOffset)));\n" "#endif\n" "#else\n" +" const int xin=x << 4;\n" " const int inputOffset=y*input_channel+xin;\n" " write_imagei(output,(int2)(y,x),as_int4(vload16(0,input_ptr+inputOffset)));\n" "#endif\n" @@ -11539,7 +10108,6 @@ const char* buffer_convert_quant = "#ifdef USE_LOW_BIT_WEIGHT_INT4\n" " const int inputOffset=(yin*input_channel+xin)/2;\n" " const int outputOffset=((x*outputChannelC4+y)*icPack*ocPack)/2;\n" -"#ifdef CHANNEL_LEAVE\n" " for(int i=0; i> 4) : (s1 & 0x0f);\n" -" output_ptr[outputOffset+i*(ocPack/2)+j]=s0 | s1;\n" -" }\n" -" }\n" -"#else\n" -" for(int i=0; i> 4);\n" -" char d1=((s0 & 0x0f) << 4) | (s1 & 0x0f);\n" -" output_ptr[outputOffset+(i*2)*(ocPack/2)+j]=d0;\n" -" output_ptr[outputOffset+(i*2+1)*(ocPack/2)+j]=d1;\n" +" output_ptr[outputOffset+i*(ocPack/2)+j]=s0 | s1;\n" " }\n" " }\n" -"#endif\n" "#else\n" " const int inputOffset=yin*input_channel+xin;\n" " const int outputOffset=(x*outputChannelC4+y)*icPack*ocPack;\n" @@ -11581,103 +10137,7 @@ const char* gemm_buf = "#endif\n" "#define GLOBAL_SIZE_DIM2 "" __private int global_size_dim0,__private int global_size_dim1,\n" "#define UNIFORM_BOUNDRY_CHECK(index0, index1) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n" -"__kernel void gemm_buf(GLOBAL_SIZE_DIM2\n" -" __global const FLOAT* input0,\n" -" __global const FLOAT* input1,\n" -" __global FLOAT* output,\n" -" __private const int width,//UP_DIV(wUnit*hUnit,4)\n" -" __private const int height,//dstChannelC4\n" -" __private const int srcChannelC4,\n" -" __private const int alpha2) {\n" -" int2 pos=(int2)(get_global_id(0),get_global_id(1));\n" -" UNIFORM_BOUNDRY_CHECK(pos.x,pos.y);\n" -" const int pos_x=pos.x % width;\n" -" const int pos_y=pos.x/width;\n" -" const int pos_z=pos.y;\n" -" COMPUTE_FLOAT16 o=(COMPUTE_FLOAT16)0;\n" -" \n" -" int kenerlY=mad24(pos_z,height,pos_y);\n" -" for (int k=0; k> 1;\n" -" const int pos_x=(pos.x % width_block) << 1;\n" -" const int pos_y=pos.x/width_block;\n" -" const int pos_z=pos.y;\n" -" COMPUTE_FLOAT16 o0=(COMPUTE_FLOAT16)0;\n" -" COMPUTE_FLOAT16 o1=(COMPUTE_FLOAT16)0;\n" -" const int kenerlY=mad24(pos_z,height,pos_y);\n" -" const int kernel_base=mul24(kenerlY,srcChannelC4);\n" -" const int inp_base=(pos_z*srcChannelC4+0)*width+pos_x;\n" -" \n" -" for (int k=0; k= width) return;\n" -" vstore4(CONVERT_FLOAT4(o1.s0123),1,output+out_offset);\n" -" vstore4(CONVERT_FLOAT4(o1.s4567),1,output+out_offset+4*width);\n" -" vstore4(CONVERT_FLOAT4(o1.s89ab),1,output+out_offset+8*width);\n" -" vstore4(CONVERT_FLOAT4(o1.scdef),1,output+out_offset+12*width);\n" -"}\n" -"// [B,K/4,area,4] -> [alignK,alignM] (M=B*area)\n" +"// [K/4,M,4] -> [alignK,alignM]\n" "__kernel void transpose_pad(GLOBAL_SIZE_DIM2\n" " const int alignM,\n" " const int alignK,\n" @@ -11687,71 +10147,29 @@ const char* gemm_buf = " __global const FLOAT* input,\n" " __global FLOAT* output\n" " ) {\n" -"#ifdef AREA_EQUAL_1\n" " const int idx_m4=get_global_id(0); // idx M\n" " const int idx_k4=get_global_id(1); // idx K\n" " UNIFORM_BOUNDRY_CHECK(idx_m4,idx_k4);\n" " const int idx_m=idx_m4 << 2;\n" " const int idx_k=idx_k4 << 2;\n" " const int K_4=(K+3) >> 2;\n" -" const int in_offset_base=(idx_m*K_4+idx_k4)*4;\n" +" const int in_offset_base=(idx_k4*M+idx_m)*4;\n" " const int out_offset_base=idx_k*alignM+idx_m;\n" " \n" " FLOAT4 m0k4=(idx_k4 >= K_4 || idx_m+0 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base);\n" -" FLOAT4 m1k4=(idx_k4 >= K_4 || idx_m+1 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base+(K_4 << 2));\n" -" FLOAT4 m2k4=(idx_k4 >= K_4 || idx_m+2 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base+(K_4 << 2)*2);\n" -" FLOAT4 m3k4=(idx_k4 >= K_4 || idx_m+3 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base+(K_4 << 2)*3);\n" -" \n" -" vstore4((FLOAT4)(m0k4.x,m1k4.x,m2k4.x,m3k4.x),0,output+out_offset_base);\n" -" vstore4((FLOAT4)(m0k4.y,m1k4.y,m2k4.y,m3k4.y),0,output+out_offset_base+alignM);\n" -" vstore4((FLOAT4)(m0k4.z,m1k4.z,m2k4.z,m3k4.z),0,output+out_offset_base+alignM+alignM);\n" -" vstore4((FLOAT4)(m0k4.w,m1k4.w,m2k4.w,m3k4.w),0,output+out_offset_base+alignM+alignM+alignM);\n" -"#elif defined BATCH_EQUAL_1\n" -" const int idx_m4=get_global_id(0); // idx M\n" -" const int idx_k4=get_global_id(1); // idx K\n" -" UNIFORM_BOUNDRY_CHECK(idx_m4,idx_k4);\n" -" const int idx_m=idx_m4 << 2;\n" -" const int idx_k=idx_k4 << 2;\n" -" const int K_4=(K+3) >> 2;\n" -" const int in_offset_base=(idx_k4*area+idx_m)*4;\n" -" const int out_offset_base=idx_k*alignM+idx_m;\n" -" FLOAT4 m0k4=(idx_k4 >= K_4 || idx_m+0 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base);\n" " FLOAT4 m1k4=(idx_k4 >= K_4 || idx_m+1 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base+4);\n" " FLOAT4 m2k4=(idx_k4 >= K_4 || idx_m+2 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base+8);\n" " FLOAT4 m3k4=(idx_k4 >= K_4 || idx_m+3 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base+12);\n" +" \n" " vstore4((FLOAT4)(m0k4.x,m1k4.x,m2k4.x,m3k4.x),0,output+out_offset_base);\n" " vstore4((FLOAT4)(m0k4.y,m1k4.y,m2k4.y,m3k4.y),0,output+out_offset_base+alignM);\n" " vstore4((FLOAT4)(m0k4.z,m1k4.z,m2k4.z,m3k4.z),0,output+out_offset_base+alignM+alignM);\n" " vstore4((FLOAT4)(m0k4.w,m1k4.w,m2k4.w,m3k4.w),0,output+out_offset_base+alignM+alignM+alignM);\n" -"#else\n" -" const int idx_m=get_global_id(0); // idx M\n" -" const int idx_k4=get_global_id(1); // idx K\n" -" UNIFORM_BOUNDRY_CHECK(idx_m,idx_k4);\n" -" \n" -" const int K_4=(K+3) >> 2;\n" -" const int idx_k=idx_k4 << 2;\n" -" const int out_offset_base=idx_k*alignM+idx_m;\n" -" \n" -" if(idx_k4 >= K_4 || idx_m >= M) {\n" -" output[out_offset_base]=(FLOAT)0;\n" -" output[out_offset_base+alignM]=(FLOAT)0;\n" -" output[out_offset_base+alignM+alignM]=(FLOAT)0;\n" -" output[out_offset_base+alignM+alignM+alignM]=(FLOAT)0;\n" -" return;\n" -" }\n" -" const int idx_b=idx_m/area;\n" -" const int idx_area=idx_m % area;\n" -" \n" -" const int in_offset_base=((idx_b*K_4+idx_k4)*area+idx_area)*4;\n" -" FLOAT4 data=vload4(0,input+in_offset_base);\n" -" \n" -" output[out_offset_base]=data.x;\n" -" output[out_offset_base+alignM]=data.y;\n" -" output[out_offset_base+alignM+alignM]=data.z;\n" -" output[out_offset_base+alignM+alignM+alignM]=data.w;\n" -"#endif\n" "}\n" -"// [alignM,alignN] -> [B,N/4,area,4] (M=B*area)\n" +"#ifndef M_VEC\n" +"#define M_VEC 1\n" +"#endif\n" +"// [alignM,alignN] -> [N/4,B,area,N4] (M=B*area)\n" "__kernel void transpose_bias(GLOBAL_SIZE_DIM2\n" " const int alignM,\n" " const int alignN,\n" @@ -11762,28 +10180,15 @@ const char* gemm_buf = " __global const FLOAT* input1,\n" " __global FLOAT* output\n" " ) {\n" -"#ifdef AREA_EQUAL_1\n" -" const int idx_m=get_global_id(0); // idx M\n" -" const int idx_n_16=get_global_id(1); // idx N\n" -" UNIFORM_BOUNDRY_CHECK(idx_m,idx_n_16);\n" -" const int N_4=(N+3) >> 2;\n" -" const int N_16=(N+15) >> 4;\n" -" const int N_left=N & 15;\n" -" bool canVec16=(N_left == 0 || (N_left != 0 && idx_n_16= N_4) return;\n" -" res0=vload4(0,input0+idx_m*alignN+(idx_n_16 << 4)+4);\n" -" res1=vload4(0,input1+(idx_n_16 << 4)+4);\n" -" res=res0+res1;\n" -" #ifdef RELU\n" -" res=fmax(res,(FLOAT4)0);\n" -" #endif\n" -" #ifdef RELU6\n" -" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n" -" #endif\n" -" vstore4(res,0,output+((idx_m*N_4+(idx_n_16 << 2)) << 2)+4);\n" -" \n" -" if(idx_n_16*4+2 >= N_4) return;\n" -" res0=vload4(0,input0+idx_m*alignN+(idx_n_16 << 4)+8);\n" -" res1=vload4(0,input1+(idx_n_16 << 4)+8);\n" -" res=res0+res1;\n" -" #ifdef RELU\n" -" res=fmax(res,(FLOAT4)0);\n" -" #endif\n" -" #ifdef RELU6\n" -" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n" -" #endif\n" -" vstore4(res,0,output+((idx_m*N_4+(idx_n_16 << 2)) << 2)+8);\n" -" \n" -" if(idx_n_16*4+3 >= N_4) return;\n" -" res0=vload4(0,input0+idx_m*alignN+(idx_n_16 << 4)+12);\n" -" res1=vload4(0,input1+(idx_n_16 << 4)+12);\n" -" res=res0+res1;\n" -" #ifdef RELU\n" -" res=fmax(res,(FLOAT4)0);\n" -" #endif\n" -" #ifdef RELU6\n" -" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n" -" #endif\n" -" vstore4(res,0,output+((idx_m*N_4+(idx_n_16 << 2)) << 2)+12);\n" +" vstore4(res,0,output+((idx_n4*M+idx_m+i) << 2));\n" " }\n" -"#else\n" -" const int idx_m=get_global_id(0); // idx M\n" -" const int idx_n_16=get_global_id(1); // idx N\n" -" UNIFORM_BOUNDRY_CHECK(idx_m,idx_n_16);\n" -" \n" -" const int N_4=(N+3) >> 2;\n" -" const int idx_b=idx_m/area;\n" -" const int idx_area=idx_m % area;\n" -" \n" -" const int inp_base_offset=idx_m*alignN+(idx_n_16 << 4);\n" -" const int out_base_offset=((idx_b*N_4+idx_n_16*4)*area+idx_area)*4;\n" -" \n" -" FLOAT4 res0=vload4(0,input0+inp_base_offset);\n" -" FLOAT4 res1=vload4(0,input1+(idx_n_16 << 4));\n" -" FLOAT4 res=res0+res1;\n" -" #ifdef RELU\n" -" res=fmax(res,(FLOAT4)0);\n" -" #endif\n" -" #ifdef RELU6\n" -" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n" -" #endif\n" -" vstore4(res,0,output+out_base_offset);\n" -" \n" -" if(idx_n_16*4+1 >= N_4) return;\n" -" res0=vload4(0,input0+inp_base_offset+4);\n" -" res1=vload4(0,input1+(idx_n_16 << 4)+4);\n" -" res=res0+res1;\n" -" #ifdef RELU\n" -" res=fmax(res,(FLOAT4)0);\n" -" #endif\n" -" #ifdef RELU6\n" -" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n" -" #endif\n" -" vstore4(res,0,output+out_base_offset+area*4);\n" -" \n" -" if(idx_n_16*4+2 >= N_4) return;\n" -" res0=vload4(0,input0+inp_base_offset+8);\n" -" res1=vload4(0,input1+(idx_n_16 << 4)+8);\n" -" res=res0+res1;\n" -" #ifdef RELU\n" -" res=fmax(res,(FLOAT4)0);\n" -" #endif\n" -" #ifdef RELU6\n" -" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n" -" #endif\n" -" vstore4(res,0,output+out_base_offset+area*8);\n" -" \n" -" if(idx_n_16*4+3 >= N_4) return;\n" -" res0=vload4(0,input0+inp_base_offset+12);\n" -" res1=vload4(0,input1+(idx_n_16 << 4)+12);\n" -" res=res0+res1;\n" -" #ifdef RELU\n" -" res=fmax(res,(FLOAT4)0);\n" -" #endif\n" -" #ifdef RELU6\n" -" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n" -" #endif\n" -" vstore4(res,0,output+out_base_offset+area*12);\n" -"#endif\n" "}\n" ; #endif @@ -12287,235 +10597,58 @@ const char* loop = " int nc4=c4offset % src1C4_size.w;\n" " int cc4_offset=cc4/4;\n" " int cc4_remain=cc4 % 4;\n" -" float4 tmp=convert_float4(RI_DATA(input1,SAMPLER,(int2)(cc4_offset*src1C4_size.x+wc4,nc4*src1C4_size.y+hc4)));\n" -" float *tmp_ptr=(float*)&tmp;\n" -" in1_ptr[i]=tmp_ptr[cc4_remain];\n" -" }\n" -" }\n" -" \n" -" float4 out=LOOP_BINARY_OPERATOR;\n" -" WI_DATA(output,(int2)(co*dst_width+wo,no*dst_height+ho),CONVERT_OUTPUT_I4(out));\n" -" }\n" -"}\n" -"#endif\n" -; -#ifndef MNN_OPENCL_BUFFER_CLOSED -const char* argmax_buf = -"#ifdef MNN_SUPPORT_FP16\n" -"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" -"#endif\n" -"#define GLOBAL_SIZE_3_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n" -"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" -"#define ARGMAX_SELECT(A, B, C, D) "" if(A.x < B.x){ A.x = B.x; C.x = D; } "" if(A.y < B.y){ A.y = B.y; C.y = D; } "" if(A.z < B.z){ A.z = B.z; C.z = D; } "" if(A.w B.x){ A.x = B.x; C.x = D; } "" if(A.y > B.y){ A.y = B.y; C.y = D; } "" if(A.z > B.z){ A.z = B.z; C.z = D; } "" if(A.w>B.w){ A.w=B.w; C.w=D; } \n" -"__kernel void argmax_width_buf(GLOBAL_SIZE_3_DIMS\n" -" __global const FLOAT* input,\n" -" __global int* output,\n" -" __private const int inputWidth,\n" -" __private const int inputHeight,\n" -" __private const int inputChannel,\n" -" __private const int inputBatch,\n" -" __private const int inputChannelBlock,\n" -" __private const int oututWidth,\n" -" __private const int outputHeight,\n" -" __private const int outputChannel,\n" -" __private const int outputChannelBlock\n" -" ) {\n" -" const int x=get_global_id(0);\n" -" const int height_idx=get_global_id(1);\n" -" const int batch_channel_idx=get_global_id(2);\n" -" DEAL_NON_UNIFORM_DIM3(x,height_idx,batch_channel_idx);\n" -" \n" -" const int batch_idx=batch_channel_idx/outputChannelBlock;\n" -" const int channel_idx=batch_channel_idx % outputChannelBlock;\n" -" \n" -" const int offset=((((batch_idx*inputChannelBlock)+channel_idx)*inputHeight+height_idx)*inputWidth+0)*4;\n" -" const int outputOffset=((((batch_idx*outputChannelBlock)+channel_idx)*outputHeight+height_idx)*oututWidth+0)*4;\n" -" int4 index=0;\n" -"#ifdef ARGMAX\n" -" FLOAT4 maxValue=(FLOAT4)-FLT_MAX;\n" -"#else\n" -" FLOAT4 maxValue=(FLOAT4)FLT_MAX;\n" -"#endif\n" -"#if ARGMAX_LOCAL_SIZE >= 4\n" -" int lid=get_local_id(0);\n" -" FLOAT4 local reduce[ARGMAX_LOCAL_SIZE];\n" -" int4 local index_reduce[ARGMAX_LOCAL_SIZE];\n" -" \n" -" for (int i=lid; i0; i /= 2){\n" -" if (lidreduce[lid+i].x){reduce[lid].x=reduce[lid+i].x; index_reduce[lid].x=index_reduce[lid+i].x;}\n" -" if(reduce[lid].y>reduce[lid+i].y){reduce[lid].y=reduce[lid+i].y; index_reduce[lid].y=index_reduce[lid+i].y;}\n" -" if(reduce[lid].z>reduce[lid+i].z){reduce[lid].z=reduce[lid+i].z; index_reduce[lid].z=index_reduce[lid+i].z;}\n" -" if(reduce[lid].w>reduce[lid+i].w){reduce[lid].w=reduce[lid+i].w; index_reduce[lid].w=index_reduce[lid+i].w;}\n" -"#endif\n" -" }\n" -" barrier(CLK_LOCAL_MEM_FENCE);\n" -" }\n" -" if(lid == 0){\n" -" vstore4(index_reduce[0],0,output+outputOffset);\n" -" }\n" -"#else\n" -" for(int i=0; i= 4\n" -" int lid=get_local_id(0);\n" -" FLOAT4 local reduce[ARGMAX_LOCAL_SIZE];\n" -" int4 local index_reduce[ARGMAX_LOCAL_SIZE];\n" -" \n" -" for (int i=lid; i0; i /= 2){\n" -" if (lidreduce[lid+i].x){reduce[lid].x=reduce[lid+i].x; index_reduce[lid].x=index_reduce[lid+i].x;}\n" -" if(reduce[lid].y>reduce[lid+i].y){reduce[lid].y=reduce[lid+i].y; index_reduce[lid].y=index_reduce[lid+i].y;}\n" -" if(reduce[lid].z>reduce[lid+i].z){reduce[lid].z=reduce[lid+i].z; index_reduce[lid].z=index_reduce[lid+i].z;}\n" -" if(reduce[lid].w>reduce[lid+i].w){reduce[lid].w=reduce[lid+i].w; index_reduce[lid].w=index_reduce[lid+i].w;}\n" -"#endif\n" +" float4 tmp=convert_float4(RI_DATA(input1,SAMPLER,(int2)(cc4_offset*src1C4_size.x+wc4,nc4*src1C4_size.y+hc4)));\n" +" float *tmp_ptr=(float*)&tmp;\n" +" in1_ptr[i]=tmp_ptr[cc4_remain];\n" " }\n" -" barrier(CLK_LOCAL_MEM_FENCE);\n" " }\n" -" if(lid == 0){\n" -" vstore4(index_reduce[0],0,output+outputOffset);\n" +" \n" +" float4 out=LOOP_BINARY_OPERATOR;\n" +" WI_DATA(output,(int2)(co*dst_width+wo,no*dst_height+ho),CONVERT_OUTPUT_I4(out));\n" " }\n" -"#else\n" -" for(int i=0; i= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" +"#define ARGMAX_SELECT(A, B, C, D) "" if(A.x < B.x){ A.x = B.x; C.x = D; } "" if(A.y < B.y){ A.y = B.y; C.y = D; } "" if(A.z < B.z){ A.z = B.z; C.z = D; } "" if(A.w B.x){ A.x = B.x; C.x = D; } "" if(A.y > B.y){ A.y = B.y; C.y = D; } "" if(A.z > B.z){ A.z = B.z; C.z = D; } "" if(A.w>B.w){ A.w=B.w; C.w=D; } \n" +"__kernel void argmax_buf(GLOBAL_SIZE_3_DIMS\n" " __global const FLOAT* input,\n" " __global int* output,\n" -" __private const int inputWidth,\n" -" __private const int inputHeight,\n" -" __private const int inputChannel,\n" -" __private const int inputBatch,\n" -" __private const int inputChannelBlock,\n" -" __private const int oututWidth,\n" -" __private const int outputHeight,\n" -" __private const int outputChannel,\n" -" __private const int outputChannelBlock\n" -" ) {\n" +" __private const int inside,\n" +" __private const int outside,\n" +" __private const int dim){\n" " const int x=get_global_id(0);\n" -" const int wh=get_global_id(1);\n" -" const int batch_idx=get_global_id(2);\n" -" DEAL_NON_UNIFORM_DIM3(x,wh,batch_idx);\n" +" const int y=get_global_id(1); // inside\n" +" const int z=get_global_id(2); // outside\n" " \n" -" const int width_idx=wh % oututWidth;\n" -" const int height_idx=wh/oututWidth;\n" -" const int offset=((((batch_idx*inputChannelBlock)+0)*inputHeight+height_idx)*inputWidth+width_idx)*4;\n" -"#ifdef ARGMAX_CHANNEL_DIM1\n" -" const int outputOffset=((batch_idx*outputHeight+height_idx)*oututWidth+width_idx);\n" -"#else\n" -" const int outputOffset=((((batch_idx*outputChannelBlock)+0)*outputHeight+height_idx)*oututWidth+width_idx)*4;\n" -"#endif\n" -" int remain=inputChannel-(inputChannelBlock-1)*4;\n" +" DEAL_NON_UNIFORM_DIM3(x,y,z);\n" +" int index=0;\n" "#ifdef ARGMAX\n" " FLOAT maxValue=(FLOAT)-FLT_MAX;\n" "#else\n" -" FLOAT maxValue=(FLOAT)FLT_MAX;\n" +"FLOAT maxValue=(FLOAT)FLT_MAX;\n" "#endif\n" -" int index=0;\n" -" FLOAT4 value;\n" -" FLOAT *valuePtr=(FLOAT*)&value;\n" +" const int offset=z*dim*inside+y;\n" "#if ARGMAX_LOCAL_SIZE >= 4\n" " int lid=get_local_id(0);\n" " FLOAT local reduce[ARGMAX_LOCAL_SIZE];\n" " int local index_reduce[ARGMAX_LOCAL_SIZE];\n" " \n" -" for (int i=lid; ivaluePtr[j]){\n" -" index=i*4+j;\n" -" maxValue=valuePtr[j];\n" -" }\n" +" if(maxValue>value){ maxValue=value; index=i; }\n" "#endif\n" " }\n" -" }\n" " reduce[lid]=maxValue;\n" " index_reduce[lid]=index;\n" " barrier(CLK_LOCAL_MEM_FENCE);\n" @@ -12530,94 +10663,45 @@ const char* argmax_buf = " barrier(CLK_LOCAL_MEM_FENCE);\n" " }\n" " if(lid == 0){\n" -" maxValue=reduce[lid];\n" -" index=index_reduce[lid];\n" -" value=vload4((inputChannelBlock-1)*inputWidth*inputHeight,input+offset);\n" -" for(int j=0; jvaluePtr[j]){\n" -" index=(inputChannelBlock-1)*4+j;\n" -" maxValue=valuePtr[j];\n" -" }\n" -"#endif\n" -" }\n" -" output[outputOffset]=index;\n" +" output[z*inside+y]=index_reduce[0];\n" " }\n" "#else\n" -" for(int i=0; ivaluePtr[j]){\n" -" index=i*4+j;\n" -" maxValue=valuePtr[j];\n" -" }\n" -"#endif\n" -" }\n" -" }\n" -" value=vload4((inputChannelBlock-1)*inputWidth*inputHeight,input+offset);\n" -" for(int j=0; jvaluePtr[j]){\n" -" index=(inputChannelBlock-1)*4+j;\n" -" maxValue=valuePtr[j];\n" -" }\n" +" if(maxValue>value){ maxValue=value; index=i; }\n" "#endif\n" " }\n" -" output[outputOffset]=index;\n" +" output[z*inside+y]=index;\n" "#endif\n" "}\n" -"__kernel void argmax_batch_buf(GLOBAL_SIZE_3_DIMS\n" +"__kernel void argmax_v4_buf(GLOBAL_SIZE_3_DIMS\n" " __global const FLOAT* input,\n" " __global int* output,\n" -" __private const int inputWidth,\n" -" __private const int inputHeight,\n" -" __private const int inputChannel,\n" -" __private const int inputBatch,\n" -" __private const int inputChannelBlock,\n" -" __private const int oututWidth,\n" -" __private const int outputHeight,\n" -" __private const int outputChannel,\n" -" __private const int outputChannelBlock\n" -" ) {\n" +" __private const int inside,\n" +" __private const int outside,\n" +" __private const int dim){\n" " const int x=get_global_id(0);\n" -" const int wh=get_global_id(1);\n" -" const int channel_idx=get_global_id(2);\n" -" DEAL_NON_UNIFORM_DIM3(x,wh,channel_idx);\n" +" const int y=get_global_id(1) << 2; // inside\n" +" const int z=get_global_id(2); // outside\n" " \n" -" const int width_idx=wh % oututWidth;\n" -" const int height_idx=wh/oututWidth;\n" -" const int offset=((((0*inputChannelBlock)+channel_idx)*inputHeight+height_idx)*inputWidth+width_idx)*4;\n" -" const int outputOffset=((((0*outputChannelBlock)+channel_idx)*outputHeight+height_idx)*oututWidth+width_idx)*4;\n" +" DEAL_NON_UNIFORM_DIM3(x,y,z);\n" " int4 index=0;\n" -" int batchOffset=inputChannelBlock*inputHeight*inputWidth;\n" "#ifdef ARGMAX\n" " FLOAT4 maxValue=(FLOAT4)-FLT_MAX;\n" "#else\n" " FLOAT4 maxValue=(FLOAT4)FLT_MAX;\n" "#endif\n" +" const int offset=z*dim*inside+y;\n" "#if ARGMAX_LOCAL_SIZE >= 4\n" " int lid=get_local_id(0);\n" " FLOAT4 local reduce[ARGMAX_LOCAL_SIZE];\n" " int4 local index_reduce[ARGMAX_LOCAL_SIZE];\n" " \n" -" for (int i=lid; i= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" +"#define DEAL_OUTER_SEQLEN_NOT_ALIGN(length) "" if(4 * sl + 3 >= length) {"" temp_3 = (FLOAT4)0;"" }"" if(4 * sl + 2 >= length) {"" temp_2 = (FLOAT4)0;"" }"" if(4 * sl + 1 >= length) {"" temp_1 = (FLOAT4)0;"" }\n" +"#define DEAL_INNER_HEADDIM_NOT_ALIGN(length) "" if(hd * 4 + 3 >= length) {"" temp_0.w = (FLOAT)0;"" temp_1.w = (FLOAT)0;"" temp_2.w = (FLOAT)0;"" temp_3.w = (FLOAT)0;"" }"" if(hd * 4 + 2 >= length) {"" temp_0.z = (FLOAT)0;"" temp_1.z = (FLOAT)0;"" temp_2.z = (FLOAT)0;"" temp_3.z = (FLOAT)0;"" }"" if(hd * 4 + 1 >= length) {"" temp_0.y = (FLOAT)0;"" temp_1.y = (FLOAT)0;"" temp_2.y = (FLOAT)0;"" temp_3.y = (FLOAT)0;"" }\n" +"__kernel void rearrange_qkv(GLOBAL_SIZE_3_DIMS\n" +" __global const FLOAT *input_q,//[batch,seqLenQ/4,headNum,headDim,seqLenQ_4]\n" +" __global const FLOAT *input_k,// [batch,seqLenKV/4,headNum/group,headDim,seqLenKV_4]\n" +" __global const FLOAT *input_v,// [batch,seqLenKV/4,headNum/group,headDim,seqLenKV_4]\n" +" __global FLOAT *output_q,// [batch*headNum,ROUND_UP(headDim,mTileHDK),ROUND_UP(seqLenQ,mTileQ)]\n" +" __global FLOAT *output_k,// [batch*headNum/group,ROUND_UP(headDim,mTileHDK),ROUND_UP(seqLenKV,mTileKV)]\n" +" __global FLOAT *output_v,// [batch*headNum/group,ROUND_UP(seqLenKV,mTileKV),ROUND_UP(headDim,mTileHDN)]\n" +" __global FLOAT *past_k,// [batch,seqLenKV/4,headNum/group,headDim,seqLenKV_4]\n" +" __global FLOAT *past_v,// [batch,seqLenKV/4,headNum/group,headDim,seqLenKV_4]\n" +" __private const int4 tile,// [mTileQ,mTileKV,mTileHDK,mTileHDN]\n" +" __private const int4 shape,// [seqLenQ,seqLenKV,headNum,headDim]\n" +" __private const int4 param // [group,batch]\n" +") {\n" +" const int sl=get_global_id(0); // seqLen/4 : max(seqLenPackQ/4,seqLenPackKV/4)\n" +" const int hd=get_global_id(1); // headDim/4 : max(headDimPackQK/4,headDimPackV/4)\n" +" const int z=get_global_id(2); // batch*headNum\n" +" DEAL_NON_UNIFORM_DIM3(sl,hd,z);\n" +" \n" +" const int seqLenQ=shape.x;\n" +" const int seqLenKV=shape.y;\n" +" const int headNum=shape.z;\n" +" const int headDim=shape.w;\n" +" const int group=param.x;\n" +" const int batch=param.y;\n" +" const int b=z % batch;\n" +" const int hn=z/batch;\n" +" \n" +" const int seqLenQ_4=(seqLenQ+3)/4;\n" +" //const int in_offset_q=(((b*seqLenQ_4+sl)*headNum+hn)*headDim+4*hd)*4;\n" +" const int in_offset_q=(((b*seqLenQ+sl*4)*headNum+hn)*headDim+4*hd);\n" +" const int seqLenPackQ=((seqLenQ+tile.x-1)/tile.x)*tile.x;\n" +" const int headDimPackQK=((headDim+tile.z-1)/tile.z)*tile.z;\n" +" const int out_offset_q=(((b*headNum+hn)*headDimPackQK+hd*4)*seqLenPackQ+sl*4);\n" +" \n" +" if(sl*4= seqLenQ || hd*4 >= headDim) {\n" +" vstore4((FLOAT4)0,0,output_q+out_offset_q);\n" +" vstore4((FLOAT4)0,0,output_q+out_offset_q+seqLenPackQ);\n" +" vstore4((FLOAT4)0,0,output_q+out_offset_q+2*seqLenPackQ);\n" +" vstore4((FLOAT4)0,0,output_q+out_offset_q+3*seqLenPackQ);\n" +" } else {\n" +" FLOAT4 temp_0=vload4(0,input_q+in_offset_q);\n" +" FLOAT4 temp_1=(sl*4+1 >= seqLenQ) ? (FLOAT4)0 : vload4(0,input_q+in_offset_q+headNum*headDim);\n" +" FLOAT4 temp_2=(sl*4+2 >= seqLenQ) ? (FLOAT4)0 : vload4(0,input_q+in_offset_q+2*headNum*headDim);\n" +" FLOAT4 temp_3=(sl*4+3 >= seqLenQ) ? (FLOAT4)0 : vload4(0,input_q+in_offset_q+3*headNum*headDim);\n" +" #ifdef HEADDIM_LEAVE\n" +" DEAL_INNER_HEADDIM_NOT_ALIGN(headDim)\n" +" #endif\n" +" #ifdef SEQLEN_LEAVE\n" +" DEAL_OUTER_SEQLEN_NOT_ALIGN(seqLenQ)\n" +" #endif\n" +" vstore4((FLOAT4)(temp_0.s0,temp_1.s0,temp_2.s0,temp_3.s0),0,output_q+out_offset_q);\n" +" vstore4((FLOAT4)(temp_0.s1,temp_1.s1,temp_2.s1,temp_3.s1),0,output_q+out_offset_q+seqLenPackQ);\n" +" vstore4((FLOAT4)(temp_0.s2,temp_1.s2,temp_2.s2,temp_3.s2),0,output_q+out_offset_q+2*seqLenPackQ);\n" +" vstore4((FLOAT4)(temp_0.s3,temp_1.s3,temp_2.s3,temp_3.s3),0,output_q+out_offset_q+3*seqLenPackQ);\n" +" }\n" +" }\n" +" \n" +" if(hn >= headNum/group) {\n" +" return;\n" +" }\n" +" \n" +" const int seqLenPackKV=((seqLenKV+tile.y-1)/tile.y)*tile.y;\n" +" const int headDimPackV=((headDim+tile.w-1)/tile.w)*tile.w;\n" +" const int seqLenKV_4=(seqLenKV+3)/4;\n" +" const int in_offset_kv=(((b*seqLenKV+sl*4)*headNum/group+hn)*headDim+4*hd);\n" +" \n" +" if(sl*4= seqLenKV || hd*4 >= headDim) {\n" +" vstore4((FLOAT4)0,0,output_k+out_offset_k);\n" +" vstore4((FLOAT4)0,0,output_k+out_offset_k+seqLenPackKV);\n" +" vstore4((FLOAT4)0,0,output_k+out_offset_k+2*seqLenPackKV);\n" +" vstore4((FLOAT4)0,0,output_k+out_offset_k+3*seqLenPackKV);\n" +" } else {\n" +" FLOAT4 temp_0=vload4(0,input_k+in_offset_kv);\n" +" FLOAT4 temp_1=(sl*4+1 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_k+in_offset_kv+headNum*headDim/group);\n" +" FLOAT4 temp_2=(sl*4+2 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_k+in_offset_kv+2*headNum*headDim/group);\n" +" FLOAT4 temp_3=(sl*4+3 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_k+in_offset_kv+3*headNum*headDim/group);\n" +" #ifdef HEADDIM_LEAVE\n" +" DEAL_INNER_HEADDIM_NOT_ALIGN(headDim)\n" +" #endif\n" +" #ifdef SEQLEN_LEAVE\n" +" DEAL_OUTER_SEQLEN_NOT_ALIGN(seqLenKV)\n" +" #endif\n" +" vstore4((FLOAT4)(temp_0.s0,temp_1.s0,temp_2.s0,temp_3.s0),0,output_k+out_offset_k);\n" +" vstore4((FLOAT4)(temp_0.s1,temp_1.s1,temp_2.s1,temp_3.s1),0,output_k+out_offset_k+seqLenPackKV);\n" +" vstore4((FLOAT4)(temp_0.s2,temp_1.s2,temp_2.s2,temp_3.s2),0,output_k+out_offset_k+2*seqLenPackKV);\n" +" vstore4((FLOAT4)(temp_0.s3,temp_1.s3,temp_2.s3,temp_3.s3),0,output_k+out_offset_k+3*seqLenPackKV);\n" +" \n" +" // pastK\n" +" vstore4(temp_0,0,past_k+in_offset_kv);\n" +" if(sl*4+1= seqLenKV || hd*4 >= headDim) {\n" +" vstore4((FLOAT4)0,0,output_v+out_offset_v);\n" +" vstore4((FLOAT4)0,0,output_v+out_offset_v+headDimPackV);\n" +" vstore4((FLOAT4)0,0,output_v+out_offset_v+2*headDimPackV);\n" +" vstore4((FLOAT4)0,0,output_v+out_offset_v+3*headDimPackV);\n" +" } else {\n" +" FLOAT4 temp_0=vload4(0,input_v+in_offset_kv);\n" +" FLOAT4 temp_1=(sl*4+1 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_v+in_offset_kv+headNum*headDim/group);\n" +" FLOAT4 temp_2=(sl*4+2 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_v+in_offset_kv+2*headNum*headDim/group);\n" +" FLOAT4 temp_3=(sl*4+3 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_v+in_offset_kv+3*headNum*headDim/group);\n" +" #ifdef HEADDIM_LEAVE\n" +" DEAL_INNER_HEADDIM_NOT_ALIGN(headDim)\n" +" #endif\n" +" #ifdef SEQLEN_LEAVE\n" +" DEAL_OUTER_SEQLEN_NOT_ALIGN(seqLenKV)\n" +" #endif\n" +" vstore4(temp_0,0,output_v+out_offset_v);\n" +" vstore4(temp_1,0,output_v+out_offset_v+headDimPackV);\n" +" vstore4(temp_2,0,output_v+out_offset_v+2*headDimPackV);\n" +" vstore4(temp_3,0,output_v+out_offset_v+3*headDimPackV);\n" +" \n" +" // pastV\n" +" vstore4(temp_0,0,past_v+in_offset_kv);\n" +" if(sl*4+1= shape.x || sl_kv*4 >= shape.y) {\n" +" vstore4((MASK_DTYPE4)0,0,output_mask+out_offset);\n" +" vstore4((MASK_DTYPE4)0,0,output_mask+out_offset+seq_len_kv_pack);\n" +" vstore4((MASK_DTYPE4)0,0,output_mask+out_offset+seq_len_kv_pack*2);\n" +" vstore4((MASK_DTYPE4)0,0,output_mask+out_offset+seq_len_kv_pack*3);\n" +" } else {\n" +" int y_down_align4=(shape.y/4*4);\n" +" MASK_DTYPE4 temp_0,temp_1,temp_2,temp_3;\n" +" \n" +" if(sl_kv*4= shape.x) ? (MASK_DTYPE4)0 : vload4(0,input_mask+in_offset+shape.y);\n" +" temp_2=(sl*4+2 >= shape.x) ? (MASK_DTYPE4)0 : vload4(0,input_mask+in_offset+shape.y*2);\n" +" temp_3=(sl*4+3 >= shape.x) ? (MASK_DTYPE4)0 : vload4(0,input_mask+in_offset+shape.y*3);\n" +" } else if(sl_kv*4+1 == shape.y){\n" +" temp_0=(MASK_DTYPE4)(input_mask[in_offset],0,0,0);\n" +" temp_1=(sl*4+1 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y],0,0,0);//vload4(0,input_mask+in_offset+shape.y);\n" +" temp_2=(sl*4+2 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*2],0,0,0);//vload4(0,input_mask+in_offset+shape.y*2);\n" +" temp_3=(sl*4+3 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*3],0,0,0);//vload4(0,input_mask+in_offset+shape.y*3);\n" +" } else if(sl_kv*4+2 == shape.y){\n" +" temp_0=(MASK_DTYPE4)(input_mask[in_offset],input_mask[in_offset+1],0,0);\n" +" temp_1=(sl*4+1 >= shape.x) ? (MASK_DTYPE4)0 : (FLOAT4)(input_mask[in_offset+shape.y],input_mask[in_offset+shape.y+1],0,0);//vload4(0,input_mask+in_offset+shape.y);\n" +" temp_2=(sl*4+2 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*2],input_mask[in_offset+shape.y*2+1],0,0);//vload4(0,input_mask+in_offset+shape.y*2);\n" +" temp_3=(sl*4+3 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*3],input_mask[in_offset+shape.y*3+1],0,0);//vload4(0,input_mask+in_offset+shape.y*3);\n" +" } else if(sl_kv*4+3 == shape.y){\n" +" temp_0=(MASK_DTYPE4)(input_mask[in_offset],input_mask[in_offset+1],input_mask[in_offset+2],0);\n" +" temp_1=(sl*4+1 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y],input_mask[in_offset+shape.y+1],input_mask[in_offset+shape.y+2],0);//vload4(0,input_mask+in_offset+shape.y);\n" +" temp_2=(sl*4+2 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*2],input_mask[in_offset+shape.y*2+1],input_mask[in_offset+shape.y*2+2],0);//vload4(0,input_mask+in_offset+shape.y*2);\n" +" temp_3=(sl*4+3 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*3],input_mask[in_offset+shape.y*3+1],input_mask[in_offset+shape.y*3+2],0);//vload4(0,input_mask+in_offset+shape.y*3);\n" +" }\n" +" vstore4(temp_0,0,output_mask+out_offset);\n" +" vstore4(temp_1,0,output_mask+out_offset+seq_len_kv_pack);\n" +" vstore4(temp_2,0,output_mask+out_offset+2*seq_len_kv_pack);\n" +" vstore4(temp_3,0,output_mask+out_offset+3*seq_len_kv_pack);\n" +" }\n" +"}\n" +"__kernel void qkv_transpose_output(GLOBAL_SIZE_3_DIMS\n" +" __global const FLOAT *input,// [Batch*mNumHead,ROUND_UP(mHeadDim,mTileHDN),ROUND_UP(seqLen,mTileQ)]\n" +" __global FLOAT *output,// [Batch,seqLen/4,mNumHead, mHeadDim,4]\n" +" __private const int tile_q,\n" +" __private const int tile_hdn,\n" +" __private const int seq_len,\n" +" __private const int head_num,\n" +" __private const int head_dim\n" +") {\n" +" \n" +" const int sl=get_global_id(0); // seqLen_4\n" +" const int hd=get_global_id(1); // mHeadDim_4\n" +" const int z=get_global_id(2); // Batch*mNumHead\n" +" DEAL_NON_UNIFORM_DIM3(sl,hd,z);\n" +" \n" +" const int b=z/head_num;\n" +" const int hn=z % head_num;\n" +" \n" +" const int seq_len_pack=((seq_len+tile_q-1)/tile_q)*tile_q;\n" +" const int head_dim_pack=((head_dim+tile_hdn-1)/tile_hdn)*tile_hdn;\n" +" \n" +" const int offset_inp=((b*head_num+hn)*head_dim_pack+4*hd)*seq_len_pack+4*sl;\n" +" \n" +" const int offset_out=(((b*seq_len+sl*4)*head_num+hn)*head_dim+4*hd);\n" +" \n" +" // Q\n" +" FLOAT4 temp_0=vload4(0,input+offset_inp);\n" +" FLOAT4 temp_1=vload4(0,input+offset_inp+seq_len_pack);\n" +" FLOAT4 temp_2=vload4(0,input+offset_inp+2*seq_len_pack);\n" +" FLOAT4 temp_3=vload4(0,input+offset_inp+3*seq_len_pack);\n" +" \n" +" vstore4((FLOAT4)(temp_0.s0,temp_1.s0,temp_2.s0,temp_3.s0),0,output+offset_out);\n" +" if(4*sl+1 >= seq_len) return;\n" +" vstore4((FLOAT4)(temp_0.s1,temp_1.s1,temp_2.s1,temp_3.s1),0,output+offset_out+head_num*head_dim);\n" +" if(4*sl+2 >= seq_len) return;\n" +" vstore4((FLOAT4)(temp_0.s2,temp_1.s2,temp_2.s2,temp_3.s2),0,output+offset_out+2*head_num*head_dim);\n" +" if(4*sl+3 >= seq_len) return;\n" +" vstore4((FLOAT4)(temp_0.s3,temp_1.s3,temp_2.s3,temp_3.s3),0,output+offset_out+3*head_num*head_dim);\n" +"}\n" +"#ifndef NUMHEAD_GROUP_SIZE\n" +"#define NUMHEAD_GROUP_SIZE 1\n" +"#endif\n" "__kernel void matmul_qk_div_mask(GLOBAL_SIZE_3_DIMS\n" -" __global const FLOAT *input0,// query [1 query_seq_len/4 head_num head_dim 4]\n" -" __global const FLOAT *input1,// key [1 key_seq_len/4 head_num head_dim 4]\n" -" __global FLOAT *output,// prefill [1 head_num query_seq_len/4 key_seq_len 4] decode[1 head_num key_seq_len/4 4]\n" -" __global FLOAT *past_key,// [1 head_num max_length/4 head_dim 4]\n" -"#ifdef ADD_MASK\n" +" __global const FLOAT *input0,// query [1 query_seq_len head_num head_dim]\n" +" __global const FLOAT *input1,// key [1 key_seq_len head_num head_dim]\n" +" __global FLOAT *output,// prefill [1 head_num query_seq_len key_seq_len] decode[1 head_num key_seq_len/4 4]\n" +" __global FLOAT *past_key,// [1 max_length head_num head_dim]\n" +" #ifdef ADD_MASK\n" " __global const FLOAT* mask,\n" -"#else\n" -" __global const int* mask,// [1 1 query_seq_len key_seq_len 4]\n" -"#endif\n" +" #else\n" +" __global const int* mask,// [1 1 query_seq_len key_seq_len]\n" +" #endif\n" " __private const float scale,\n" " __private const int query_seq_len,\n" " __private const int key_seq_len,\n" " __private const int head_num,\n" " __private const int kv_head_num,\n" " __private const int head_dim) {\n" -" const int x=get_global_id(0); // query_seq_len/4 for prefill 1 for decode\n" -" const int y=get_global_id(1); // head_num\n" -" const int z=get_global_id(2); // key_seq_len/4\n" +" \n" +" const int x=get_global_id(0); // key_seq_len\n" +" const int y=get_global_id(1); // query_seq_len for prefill 1 for decode\n" +" const int z=get_global_id(2); // head_num\n" " DEAL_NON_UNIFORM_DIM3(x,y,z);\n" " \n" -" int yin=y/NUMHEAD_GROUP_SIZE;\n" -" const int offset=head_num*head_dim*4;\n" -" const int offset_head=y*head_dim*4;\n" -" __global const FLOAT *A_offset=input0+x*offset+offset_head;\n" -" __global FLOAT *Pastkey_offset=past_key+(z*kv_head_num+yin)*head_dim*4;\n" -" const int z4=z << 2;\n" -" float4 Vscale=(float4)scale;\n" +" int x4=x << 2;\n" +" int y4=y << 2;\n" +" int zin=z/NUMHEAD_GROUP_SIZE;\n" +" __global const FLOAT *A_offset=input0+(y4*head_num+z)*head_dim;\n" +" __global FLOAT *Pastkey_offset=past_key+(x4*kv_head_num+zin)*head_dim;\n" +" int strideA=head_num*head_dim;\n" +" int strideB=kv_head_num*head_dim;\n" "#ifdef OPENCL_PREFILL_ATTENTION\n" -" __global const FLOAT *B_offset=input1+(z*kv_head_num+yin)*head_dim*4;\n" -" const int x4=x << 2;\n" -" const int query_seq_len4=(query_seq_len+3)/4;\n" -" const int output_offset=y*query_seq_len4*key_seq_len*4;\n" +" __global const FLOAT *B_offset=input1+(x4*kv_head_num+zin)*head_dim;\n" +" int output_offset=(z*query_seq_len+y4)*key_seq_len+x4;\n" " float4 out0=0;\n" " float4 out1=0;\n" " float4 out2=0;\n" " float4 out3=0;\n" " \n" -" const int head_dim4=(head_dim+3)/4;\n" -"#ifdef HEADDIM_LEAVE\n" -" for(int i=0; i= key_seq_len) return;\n" -" vstore4(CONVERT_FLOAT4(out1),0,output+output_offset+x*key_seq_len*4+(z4+1)*4);\n" -" if(z4+2 >= key_seq_len) return;\n" -" vstore4(CONVERT_FLOAT4(out2),0,output+output_offset+x*key_seq_len*4+(z4+2)*4);\n" -" if(z4+3 >= key_seq_len) return;\n" -" vstore4(CONVERT_FLOAT4(out3),0,output+output_offset+x*key_seq_len*4+(z4+3)*4);\n" +" #endif\n" +" if(B3_enable){\n" +" vstore4(CONVERT_FLOAT4(out0),0,output+output_offset);\n" +" if(!A1_enable) return;\n" +" output_offset += key_seq_len;\n" +" vstore4(CONVERT_FLOAT4(out1),0,output+output_offset);\n" +" if(!A2_enable) return;\n" +" output_offset += key_seq_len;\n" +" vstore4(CONVERT_FLOAT4(out2),0,output+output_offset);\n" +" if(!A3_enable) return;\n" +" output_offset += key_seq_len;\n" +" vstore4(CONVERT_FLOAT4(out3),0,output+output_offset);\n" +" } else if(B2_enable){\n" +" vstore3(CONVERT_FLOAT3((float3)(out0.x,out0.y,out0.z)),0,output+output_offset);\n" +" if(!A1_enable) return;\n" +" output_offset += key_seq_len;\n" +" vstore3(CONVERT_FLOAT3((float3)(out1.x,out1.y,out1.z)),0,output+output_offset);\n" +" if(!A2_enable) return;\n" +" output_offset += key_seq_len;\n" +" vstore3(CONVERT_FLOAT3((float3)(out2.x,out2.y,out2.z)),0,output+output_offset);\n" +" if(!A3_enable) return;\n" +" output_offset += key_seq_len;\n" +" vstore3(CONVERT_FLOAT3((float3)(out3.x,out3.y,out3.z)),0,output+output_offset);\n" +" } else if(B1_enable){\n" +" vstore2(CONVERT_FLOAT2((float2)(out0.x,out0.y)),0,output+output_offset);\n" +" if(!A1_enable) return;\n" +" output_offset += key_seq_len;\n" +" vstore2(CONVERT_FLOAT2((float2)(out1.x,out1.y)),0,output+output_offset);\n" +" if(!A2_enable) return;\n" +" output_offset += key_seq_len;\n" +" vstore2(CONVERT_FLOAT2((float2)(out2.x,out2.y)),0,output+output_offset);\n" +" if(!A3_enable) return;\n" +" output_offset += key_seq_len;\n" +" vstore2(CONVERT_FLOAT2((float2)(out3.x,out3.y)),0,output+output_offset);\n" +" } else {\n" +" output[output_offset]=out0.x;\n" +" if(!A1_enable) return;\n" +" output[output_offset+key_seq_len]=out1.x;\n" +" if(!A2_enable) return;\n" +" output[output_offset+key_seq_len+key_seq_len]=out2.x;\n" +" if(!A3_enable) return;\n" +" output[output_offset+key_seq_len+key_seq_len+key_seq_len]=out3.x;\n" +" }\n" "#else\n" -" __global const FLOAT *B_offset=input1+yin*head_dim*4;\n" -" const int key_seq_len4=(key_seq_len+3)/4;\n" " float4 out=0;\n" " const int head_dim4=(head_dim+3)/4;\n" -" \n" -"#ifdef HEADDIM_LEAVE\n" +" int key_seq_len4=(key_seq_len+3)/4;\n" +" #ifdef HEADDIM_LEAVE\n" " for(int i=0; i= 4){\n" +" vstore4(CONVERT_FLOAT4(out),0,output+z*key_seq_len+x4);\n" +" } else if (remain >= 3){\n" +" vstore3(CONVERT_FLOAT3((float3)(out.x,out.y,out.z)),0,output+z*key_seq_len+x4);\n" +" } else if (remain >= 2){\n" +" vstore2(CONVERT_FLOAT2((float2)(out.x,out.y)),0,output+z*key_seq_len+x4);\n" +" } else {\n" +" output[z*key_seq_len+x4]=out.x;\n" " }\n" -" out *= Vscale;\n" -" vstore4(CONVERT_FLOAT4(out),0,output+y*key_seq_len4*4+z4);\n" "#endif\n" "}\n" "__kernel void matmul_qkv(GLOBAL_SIZE_3_DIMS\n" -" __global const FLOAT *input0,// qk prefill [1 head_num qk_seq_len/4 value_seq_len 4] decode[1 head_num value_seq_len/4 4]\n" -" __global const FLOAT *input1,// [1 value_seq_len/4 head_num head_dim 4]\n" -" __global FLOAT *output,// [1 qk_seq_len head_num*head_dim 1 4]\n" -" __global FLOAT *past_value,// [1 value_seq_len/4 head_num head_dim 4]\n" +" __global const FLOAT *input0,// qk prefill [1 head_num qk_seq_len value_seq_len] decode[1 head_num value_seq_len]\n" +" __global const FLOAT *input1,// [1 value_seq_len head_num head_dim]\n" +" __global FLOAT *output,// [1 qk_seq_len head_num head_dim]\n" +" __global FLOAT *past_value,// [1 value_seq_len head_num head_dim]\n" " __private const int qk_seq_len,\n" " __private const int value_seq_len,\n" " __private const int head_num,\n" " __private const int kv_head_num,\n" " __private const int head_dim) {\n" -" const int x=get_global_id(0); // prefill qk_seq_len/4 decode 1\n" +" \n" +" const int x=get_global_id(0); // head_dim << 2\n" " const int y=get_global_id(1); // head_num\n" -" const int z=get_global_id(2); // head_dim << 2\n" -" const int z4=z << 2;\n" +" const int z=get_global_id(2); // prefill qk_seq_len decode 1\n" +" \n" +" const int x4=x << 2;\n" " DEAL_NON_UNIFORM_DIM3(x,y,z);\n" " \n" " const int yin=y/NUMHEAD_GROUP_SIZE;\n" "#ifdef OPENCL_PREFILL_ATTENTION\n" -" const int offset=head_num*head_dim*4;\n" -" const int stride=kv_head_num*head_dim*4;\n" -" const int offset_head=y*head_dim*4+z4*4;\n" -" const int value_seq_len4=(value_seq_len+3)/4;\n" -" const int qk_seq_len4=(qk_seq_len+3)/4;\n" -" __global const FLOAT *A_offset=input0+(y*qk_seq_len4+x)*value_seq_len*4;\n" -" __global const FLOAT *B_offset=input1+yin*head_dim*4+z4*4;\n" -" __global FLOAT *Pastvalue_offset=past_value+yin*head_dim*4+z4*4;\n" +" int z4=z << 2;\n" +" int value_seq_len4=(value_seq_len+3)/4;\n" +" int loop_end=max(value_seq_len4-1,0);\n" +" const int stride=kv_head_num*head_dim;\n" +" __global const FLOAT *A_offset=input0+(y*qk_seq_len+z4)*value_seq_len;\n" +" __global const FLOAT *B_offset=input1+yin*head_dim+x4;\n" +" __global FLOAT *Pastvalue_offset=past_value+yin*head_dim+x4;\n" " COMPUTE_FLOAT4 out0=0;\n" " COMPUTE_FLOAT4 out1=0;\n" " COMPUTE_FLOAT4 out2=0;\n" " COMPUTE_FLOAT4 out3=0;\n" " \n" -" for(int i=0; i= head_dim) return;\n" -" vstore4(CONVERT_FLOAT4(out1),0,output+x*offset+(y*head_dim+z4+1)*4);\n" -" vstore4(CONVERT_FLOAT4(B.s4567),1,Pastvalue_offset+(value_seq_len4-1)*stride);\n" -" if(z4+2 >= head_dim) return;\n" -" vstore4(CONVERT_FLOAT4(out2),0,output+x*offset+(y*head_dim+z4+2)*4);\n" -" vstore4(CONVERT_FLOAT4(B.s89ab),2,Pastvalue_offset+(value_seq_len4-1)*stride);\n" -" if(z4+3 >= head_dim) return;\n" -" vstore4(CONVERT_FLOAT4(out3),0,output+x*offset+(y*head_dim+z4+3)*4);\n" -" vstore4(CONVERT_FLOAT4(B.scdef),3,Pastvalue_offset+(value_seq_len4-1)*stride);\n" -"#else\n" -" COMPUTE_FLOAT16 B=CONVERT_COMPUTE_FLOAT16(vload16(0,B_offset+(value_seq_len4-1)*stride));\n" -" vstore16(CONVERT_FLOAT16(B),0,Pastvalue_offset+(value_seq_len4-1)*stride);\n" -" COMPUTE_FLOAT *B_ptr=(COMPUTE_FLOAT*)&B;\n" -" for(int i=(value_seq_len4-1)*4,j=0; i= 4){\n" +" vstore4(CONVERT_FLOAT4(out0),0,output+output_offset);\n" +" } else if(remain == 3){\n" +" vstore3(CONVERT_FLOAT3((COMPUTE_FLOAT3)(out0.x,out0.y,out0.z)),0,output+output_offset);\n" +" } else if(remain == 2){\n" +" vstore2(CONVERT_FLOAT2((COMPUTE_FLOAT3)(out0.x,out0.y)),0,output+output_offset);\n" +" } else{\n" +" output[output_offset]=out0.x;\n" +" }\n" +" if(z4+1 >= qk_seq_len) return;\n" +" output_offset += head_num*head_dim;\n" +" if(remain >= 4){\n" +" vstore4(CONVERT_FLOAT4(out1),0,output+output_offset);\n" +" } else if(remain == 3){\n" +" vstore3(CONVERT_FLOAT3((COMPUTE_FLOAT3)(out1.x,out1.y,out1.z)),0,output+output_offset);\n" +" } else if(remain == 2){\n" +" vstore2(CONVERT_FLOAT2((COMPUTE_FLOAT3)(out1.x,out1.y)),0,output+output_offset);\n" +" } else{\n" +" output[output_offset]=out1.x;\n" +" }\n" +" if(z4+2 >= qk_seq_len) return;\n" +" output_offset += head_num*head_dim;\n" +" if(remain >= 4){\n" +" vstore4(CONVERT_FLOAT4(out2),0,output+output_offset);\n" +" } else if(remain == 3){\n" +" vstore3(CONVERT_FLOAT3((COMPUTE_FLOAT3)(out2.x,out2.y,out2.z)),0,output+output_offset);\n" +" } else if(remain == 2){\n" +" vstore2(CONVERT_FLOAT2((COMPUTE_FLOAT3)(out2.x,out2.y)),0,output+output_offset);\n" +" } else{\n" +" output[output_offset]=out2.x;\n" +" }\n" +" if(z4+3 >= qk_seq_len) return;\n" +" output_offset += head_num*head_dim;\n" +" if(remain >= 4){\n" +" vstore4(CONVERT_FLOAT4(out3),0,output+output_offset);\n" +" } else if(remain == 3){\n" +" vstore3(CONVERT_FLOAT3((COMPUTE_FLOAT3)(out3.x,out3.y,out3.z)),0,output+output_offset);\n" +" } else if(remain == 2){\n" +" vstore2(CONVERT_FLOAT2((COMPUTE_FLOAT3)(out3.x,out3.y)),0,output+output_offset);\n" +" } else{\n" +" output[(x*head_num+y)*head_dim+z4]=out3.x;\n" +" }\n" +" #else\n" +" int output_offset=(z4*head_num+y)*head_dim+x4;\n" +" vstore4(CONVERT_FLOAT4(out0),0,output+output_offset);\n" +" if(z4+1 >= qk_seq_len) return;\n" +" output_offset += head_num*head_dim;\n" +" vstore4(CONVERT_FLOAT4(out1),0,output+output_offset);\n" +" if(z4+2 >= qk_seq_len) return;\n" +" output_offset += head_num*head_dim;\n" +" vstore4(CONVERT_FLOAT4(out2),0,output+output_offset);\n" +" if(z4+3 >= qk_seq_len) return;\n" +" output_offset += head_num*head_dim;\n" +" vstore4(CONVERT_FLOAT4(out3),0,output+output_offset);\n" +" #endif\n" "#else\n" -" const int value_seq_len4=(value_seq_len+3)/4;\n" -" const int stride=kv_head_num*head_dim*4;\n" -" const int offset=head_num*head_dim*4;\n" -" const int offset_head=y*head_dim*4+z4*4;\n" -" const int loop=(value_seq_len+2)/4;\n" -" __global const FLOAT *A_offset=input0+y*value_seq_len4*4;\n" -" __global const FLOAT *B_offset=input1+yin*head_dim*4+z4*4;\n" -" __global FLOAT *Pastvalue_offset=past_value+yin*head_dim*4+z4*4;\n" +" int value_seq_len4=(value_seq_len-1+3)/4;\n" +" int loop_end=max(value_seq_len4-1,0);\n" +" const int stride=kv_head_num*head_dim;\n" +" __global const FLOAT *A_offset=input0+y*value_seq_len;\n" +" __global const FLOAT *B_offset=input1+yin*head_dim+x4;\n" +" __global FLOAT *Pastvalue_offset=past_value+yin*head_dim+x4;\n" " COMPUTE_FLOAT4 out=0;\n" " \n" -" for(int i=0; i> 2)*stride+((value_seq_len-1) % 4);\n" -" \n" -"#ifdef HEADDIM_LEAVE\n" -" Pastvalue_offset[index]=B0;\n" -" output[(y*head_dim+z4)*4]=out.s0;\n" -" if(z4+1 >= head_dim) return;\n" -" Pastvalue_offset[index+4]=B1;\n" -" output[(y*head_dim+z4+1)*4]=out.s1;\n" -" if(z4+2 >= head_dim) return;\n" -" Pastvalue_offset[index+8]=B2;\n" -" output[(y*head_dim+z4+2)*4]=out.s2;\n" -" if(z4+3 >= head_dim) return;\n" -" Pastvalue_offset[index+12]=B3;\n" -" output[(y*head_dim+z4+3)*4]=out.s3;\n" -"#else\n" -" Pastvalue_offset[index]=B0;\n" -" Pastvalue_offset[index+4]=B1;\n" -" Pastvalue_offset[index+8]=B2;\n" -" Pastvalue_offset[index+12]=B3;\n" +" COMPUTE_FLOAT4 B=CONVERT_COMPUTE_FLOAT4(vload4(0,B_offset));\n" +" out=mad(B,(COMPUTE_FLOAT4)A,out);\n" " \n" -" output[(y*head_dim+z4)*4]=out.s0;\n" -" output[(y*head_dim+z4+1)*4]=out.s1;\n" -" output[(y*head_dim+z4+2)*4]=out.s2;\n" -" output[(y*head_dim+z4+3)*4]=out.s3;\n" -"#endif\n" +" #ifdef HEADDIM_LEAVE\n" +" int remain=head_dim-x4;\n" +" if(remain >= 4){\n" +" vstore4(CONVERT_FLOAT4(out),0,output+y*head_dim+x4);\n" +" vstore4(CONVERT_FLOAT4(B),0,Pastvalue_offset+(value_seq_len-1)*stride);\n" +" } else if(remain == 3){\n" +" vstore3(CONVERT_FLOAT3((COMPUTE_FLOAT3)(out.x,out.y,out.z)),0,output+y*head_dim+x4);\n" +" vstore3(CONVERT_FLOAT4((COMPUTE_FLOAT3)(B.x,B.y,B.z)),0,Pastvalue_offset+(value_seq_len-1)*stride);\n" +" } else if(remain == 2){\n" +" vstore2(CONVERT_FLOAT2((COMPUTE_FLOAT3)(out.x,out.y)),0,output+y*head_dim+x4);\n" +" vstore2(CONVERT_FLOAT4((COMPUTE_FLOAT3)(B.x,B.y)),0,Pastvalue_offset+(value_seq_len-1)*stride);\n" +" } else{\n" +" output[(x*head_num+y)*head_dim+x4]=out.x;\n" +" Pastvalue_offset[(value_seq_len-1)*stride]=B.x;\n" +" }\n" +" #else\n" +" vstore4(CONVERT_FLOAT4(B),0,Pastvalue_offset+(value_seq_len-1)*stride);\n" +" vstore4(CONVERT_FLOAT4(out),0,output+y*head_dim+x4);\n" +" #endif\n" " \n" "#endif\n" "}\n" @@ -13482,6 +11953,7 @@ const char* unary_subgroup_buf = " __private const int width,\n" " __private const int height,\n" " __private const int channel,\n" +" __private const int batch,\n" " __private const int input_pad_left,__private const int input_pad_right,\n" " __private const int output_pad_left,__private const int output_pad_right) {\n" " const int channel_block_idx=get_global_id(0);\n" @@ -13490,8 +11962,7 @@ const char* unary_subgroup_buf = " DEAL_NON_UNIFORM_DIM3(channel_block_idx,w,hb);\n" " const int batch_idx=hb/height;\n" " const int height_idx=hb % height;\n" -" const int channel4=(channel+3)/4;\n" -" const int offset=(((batch_idx*channel4+channel_block_idx)*height+height_idx)*width+w)*4;\n" +" const int offset=(((batch_idx+channel_block_idx*batch)*height+height_idx)*width+w)*4;\n" " float4 in=convert_float4(vload4(0,input+offset));\n" " float4 out=OPERATOR;\n" " vstore4(CONVERT_OUTPUT4(out),0,output+offset);\n" @@ -13502,6 +11973,7 @@ const char* unary_subgroup_buf = " __private const int width,\n" " __private const int height,\n" " __private const int channel,\n" +" __private const int batch,\n" " __private const int input_pad_left,__private const int input_pad_right,\n" " __private const int output_pad_left,__private const int output_pad_right) {\n" " const int channel_block_idx=get_global_id(0);\n" @@ -13511,10 +11983,9 @@ const char* unary_subgroup_buf = " const int batch_idx=hb/height;\n" " const int height_idx=hb % height;\n" " const int dst_width=output_pad_left+width+output_pad_right;\n" -" const int channel4=(channel+3)/4;\n" " const int channel16=(channel+15)/16;\n" " const int channe_out_idx=channel_block_idx >> 2;\n" -" const int offset=(((batch_idx*channel4+channel_block_idx)*height+height_idx)*width+w)*4;\n" +" const int offset=(((batch_idx+channel_block_idx*batch)*height+height_idx)*width+w)*4;\n" " const int dst_offset=(((batch_idx*channel16+channe_out_idx)*height+height_idx)*dst_width+w+output_pad_left)*16+(channel_block_idx % 4)*4;\n" " float4 in=convert_float4(vload4(0,input+offset));\n" " float4 out=OPERATOR;\n" @@ -13537,6 +12008,7 @@ const char* unary_subgroup_buf = " __private const int width,\n" " __private const int height,\n" " __private const int channel,\n" +" __private const int batch,\n" " __private const int input_pad_left,__private const int input_pad_right,\n" " __private const int output_pad_left,__private const int output_pad_right) {\n" " const int channel_idx=get_group_id(0);\n" @@ -13578,6 +12050,7 @@ const char* unary_subgroup_buf = " __private const int width,\n" " __private const int height,\n" " __private const int channel,\n" +" __private const int batch,\n" " __private const int input_pad_left,__private const int input_pad_right,\n" " __private const int output_pad_left,__private const int output_pad_right) {\n" " const int channel_idx=get_group_id(0);\n" @@ -13587,10 +12060,9 @@ const char* unary_subgroup_buf = " const int batch_idx=hb/height;\n" " const int height_idx=hb % height;\n" " const int src_width=width+input_pad_left+input_pad_right;\n" -" const int channel4=(channel+3)/4;\n" " const int channel16=(channel+15)/16;\n" " const int src_offset=(((batch_idx*channel16+channel_idx)*height+height_idx)*src_width+w+input_pad_left)*16;\n" -" const int dst_offset=(((batch_idx*channel4+(channel_idx<<2))*height+height_idx)*width+w)*4;\n" +" const int dst_offset=(((batch_idx+(channel_idx<<2)*batch)*height+height_idx)*width+w)*4;\n" " const int height_width=height*width*4;\n" " \n" " float4 in=convert_float4(AS_INPUT_DATA4(INTEL_SUB_GROUP_READ4((__global INTEL_DATA*)(input+src_offset))));\n" @@ -14155,25 +12627,24 @@ const char* scale_buf = " __global const FLOAT* bias,\n" "#endif\n" " __global FLOAT* output,\n" -" __private const int4 shape) {//N,H,W,C4\n" -" const int out_w_c_idx=get_global_id(0);\n" -" const int out_h_b_idx=get_global_id(1);\n" -" \n" -" DEAL_NON_UNIFORM_DIM2(out_w_c_idx,out_h_b_idx);\n" -" const int out_b_idx=out_h_b_idx/shape.y;\n" -" const int out_h_idx=out_h_b_idx % shape.y;\n" -" const int out_c_idx=out_w_c_idx/shape.z;\n" -" const int out_w_idx=out_w_c_idx % shape.z;\n" +" __private const int channelBlock,\n" +" __private const int batch,\n" +" __private const int inside) {\n" +" const int x=get_global_id(0); // inside(width*height)\n" +" const int y=get_global_id(1); // channelBlock*batch\n" " \n" -" const int offset=(((out_b_idx*shape.w+out_c_idx)*shape.y+out_h_idx)*shape.z+out_w_idx)*4;\n" +" DEAL_NON_UNIFORM_DIM2(x,y);\n" +" const int out_c_idx=y % channelBlock;\n" +" const int out_b_idx=y/channelBlock;\n" +" const int offset=((out_b_idx+out_c_idx*batch)*inside+x)*4;\n" " COMPUTE_FLOAT4 in_value=CONVERT_COMPUTE_FLOAT4(vload4(0,input+offset));\n" " COMPUTE_FLOAT4 scale_value=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,scale));\n" -"#ifdef BIAS\n" +" #ifdef BIAS\n" " COMPUTE_FLOAT4 bias_value=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias));\n" " COMPUTE_FLOAT4 out_value=in_value*scale_value+bias_value;\n" -"#else\n" +" #else\n" " COMPUTE_FLOAT4 out_value=in_value*scale_value;\n" -"#endif\n" +" #endif\n" " vstore4(CONVERT_FLOAT4(out_value),0,output+offset);\n" "}\n" ; @@ -14191,380 +12662,169 @@ const char* matmul_buf = " __global const FLOAT* input_c,\n" " #endif\n" " __global FLOAT* output_c,\n" -" __private const int channels,\n" -" __private const int channel_blocks,\n" -" __private const int width_blocks,\n" -" __private const int width) {\n" -" const int width_blocks_idx=get_global_id(0);// output W\n" -" const int height_idx=get_global_id(1);// output H\n" -" DEAL_NON_UNIFORM_DIM2(width_blocks_idx,height_idx);\n" -" COMPUTE_FLOAT4 a;\n" -" COMPUTE_FLOAT4 b0=0,b1=0,b2=0,b3=0;\n" -" COMPUTE_FLOAT4 v_zero=(COMPUTE_FLOAT4)((COMPUTE_FLOAT)0.0);\n" +" __private const int M,\n" +" __private const int N,\n" +" __private const int K) {\n" +" int2 pos=(int2)(get_global_id(0),get_global_id(1)); // N M\n" +" DEAL_NON_UNIFORM_DIM2(pos.x,pos.y);\n" +" const int idn=pos.x << 2;\n" +" const int idm=pos.y << 2;\n" +" \n" +" COMPUTE_FLOAT4 out[4];\n" " #ifdef BIAS\n" -" COMPUTE_FLOAT4 temp=CONVERT_COMPUTE_FLOAT4(vload4(width_blocks_idx,input_c));\n" -" COMPUTE_FLOAT result0=temp.x;\n" -" COMPUTE_FLOAT result1=temp.y;\n" -" COMPUTE_FLOAT result2=temp.z;\n" -" COMPUTE_FLOAT result3=temp.w;\n" -" #else\n" -" COMPUTE_FLOAT result0=0;\n" -" COMPUTE_FLOAT result1=0;\n" -" COMPUTE_FLOAT result2=0;\n" -" COMPUTE_FLOAT result3=0;\n" -" #endif\n" -" const int remain=channel_blocks*4-channels;\n" -" for (short pos=0; pos= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+width_blocks,input_b));\n" -" b2=(remain >= 2) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+width_blocks*2,input_b));\n" -" b3=(remain >= 1) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+width_blocks*3,input_b));\n" -" if (remain == 3) {\n" -" a.y=0;\n" -" a.z=0;\n" -" a.w=0;\n" -" } else if (remain == 2) {\n" -" a.z=0;\n" -" a.w=0;\n" -" } else if (remain == 1) {\n" -" a.w=0;;\n" +" COMPUTE_FLOAT4 bias=CONVERT_COMPUTE_FLOAT4(vload4(0,input_c+idn));\n" +" #pragma unroll\n" +" for(int i=0; i<4; ++i){\n" +" out[i]=bias;\n" " }\n" -" COMPUTE_FLOAT4 btmp0=(COMPUTE_FLOAT4)(b0.s0,b1.s0,b2.s0,b3.s0);\n" -" COMPUTE_FLOAT4 btmp1=(COMPUTE_FLOAT4)(b0.s1,b1.s1,b2.s1,b3.s1);\n" -" COMPUTE_FLOAT4 btmp2=(COMPUTE_FLOAT4)(b0.s2,b1.s2,b2.s2,b3.s2);\n" -" COMPUTE_FLOAT4 btmp3=(COMPUTE_FLOAT4)(b0.s3,b1.s3,b2.s3,b3.s3);\n" -" result0 += dot(a,btmp0);\n" -" result1 += dot(a,btmp1);\n" -" result2 += dot(a,btmp2);\n" -" result3 += dot(a,btmp3);\n" +" #else\n" +" #pragma unroll\n" +" for(int i=0; i<4; ++i){\n" +" out[i]=(COMPUTE_FLOAT4)0;\n" " }\n" -" const int out_offset=height_idx*width_blocks+width_blocks_idx;\n" -" vstore4(CONVERT_FLOAT4((COMPUTE_FLOAT4)(result0,result1,result2,result3)),out_offset,output_c);\n" -"}\n" -"__kernel void matmul_transB_buf(GLOBAL_SIZE_2_DIMS __global const FLOAT* input_a,\n" -" __global const FLOAT* input_b,\n" -" #ifdef BIAS\n" -" __global const FLOAT* input_c,\n" " #endif\n" -" __global FLOAT* output_c,\n" -" __private const int channels,\n" -" __private const int channel_blocks,\n" -" __private const int width_blocks,\n" -" __private const int width) {\n" -" const int width_blocks_idx=get_global_id(0);\n" -" const int height_idx=get_global_id(1);\n" -" DEAL_NON_UNIFORM_DIM2(width_blocks_idx,height_idx);\n" -" COMPUTE_FLOAT4 a;\n" -" COMPUTE_FLOAT4 b0=0,b1=0,b2=0,b3=0;\n" -" COMPUTE_FLOAT4 v_zero=(COMPUTE_FLOAT4)((COMPUTE_FLOAT)0.0);\n" -" #ifdef BIAS\n" -" COMPUTE_FLOAT4 temp=CONVERT_COMPUTE_FLOAT4(vload4(width_blocks_idx,input_c));\n" -" COMPUTE_FLOAT result0=temp.x;\n" -" COMPUTE_FLOAT result1=temp.y;\n" -" COMPUTE_FLOAT result2=temp.z;\n" -" COMPUTE_FLOAT result3=temp.w;\n" +" const int K4=(K+3)/4;\n" +" #ifdef K_LEAVE\n" +" const int loop_end=max(K4-1,0);\n" +" const int remain=K-loop_end*4;\n" " #else\n" -" COMPUTE_FLOAT result0=0;\n" -" COMPUTE_FLOAT result1=0;\n" -" COMPUTE_FLOAT result2=0;\n" -" COMPUTE_FLOAT result3=0;\n" -" #endif\n" -" const int remaina=channel_blocks*4-channels;\n" -" const int remainb=(width_blocks_idx+1)*4-width;\n" -" for (short pos=0; pos= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+channel_blocks,input_b));\n" -" b2=(remainb >= 2) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+channel_blocks*2,input_b));\n" -" b3=(remainb >= 1) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+channel_blocks*3,input_b));\n" -" result0 += dot(a,b0);\n" -" result1 += dot(a,b1);\n" -" result2 += dot(a,b2);\n" -" result3 += dot(a,b3);\n" -" }\n" -" \n" -" {\n" -" const int inpa_offset=height_idx*channel_blocks+channel_blocks-1;\n" -" a=CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset,input_a));\n" -" const int inpb_offset=(width_blocks_idx*4)*channel_blocks+channel_blocks-1;\n" -" b0=CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset,input_b));\n" -" b1=(remainb >= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+channel_blocks,input_b));\n" -" b2=(remainb >= 2) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+channel_blocks*2,input_b));\n" -" b3=(remainb >= 1) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+channel_blocks*3,input_b));\n" -" if (remaina == 3) {\n" -" a.y=0;\n" -" a.z=0;\n" -" a.w=0;\n" -" } else if (remaina == 2) {\n" -" a.z=0;\n" -" a.w=0;\n" -" } else if (remaina == 1) {\n" -" a.w=0;\n" -" }\n" -" result0 += dot(a,b0);\n" -" result1 += dot(a,b1);\n" -" result2 += dot(a,b2);\n" -" result3 += dot(a,b3);\n" -" }\n" -" const int out_offset=height_idx*width_blocks+width_blocks_idx;\n" -" vstore4(CONVERT_FLOAT4((COMPUTE_FLOAT4)(result0,result1,result2,result3)),out_offset,output_c);\n" -"}\n" -"__kernel void matmul_transA_buf(GLOBAL_SIZE_2_DIMS __global const FLOAT* input_a,\n" -" __global const FLOAT* input_b,\n" -" #ifdef BIAS\n" -" __global const FLOAT* input_c,\n" +" const int loop_end=K4;\n" " #endif\n" -" __global FLOAT* output_c,\n" -" __private const int channels,\n" -" __private const int channel_blocks,\n" -" __private const int height,\n" -" __private const int height_blocks,\n" -" __private const int width_blocks,\n" -" __private const int width) {\n" -" const int width_blocks_idx=get_global_id(0);\n" -" const int height_blocks_idx=get_global_id(1);\n" -" DEAL_NON_UNIFORM_DIM2(width_blocks_idx,height_blocks_idx);\n" -" COMPUTE_FLOAT4 v_zero=(COMPUTE_FLOAT4)((COMPUTE_FLOAT)0.0);\n" -" #ifdef BIAS\n" -" COMPUTE_FLOAT4 result0=CONVERT_COMPUTE_FLOAT4(vload4(width_blocks_idx,input_c));\n" -" COMPUTE_FLOAT4 result1=result0;\n" -" COMPUTE_FLOAT4 result2=result0;\n" -" COMPUTE_FLOAT4 result3=result0;\n" +" \n" +" #ifdef TRANSPOSE_A\n" +" __global const FLOAT* input_a_offset=input_a+idm; // K x M\n" " #else\n" -" COMPUTE_FLOAT4 result0=0;\n" -" COMPUTE_FLOAT4 result1=0;\n" -" COMPUTE_FLOAT4 result2=0;\n" -" COMPUTE_FLOAT4 result3=0;\n" -" #endif\n" -" \n" -" const int remain=channel_blocks*4-channels;\n" -" for (short pos=0; pos= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset+height_blocks,input_a)));\n" -" COMPUTE_FLOAT4 a2=((remain >= 2) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset+height_blocks*2,input_a)));\n" -" COMPUTE_FLOAT4 a3=((remain >= 1) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset+height_blocks*3,input_a)));\n" -" const int inpb_offset=(4*(channel_blocks-1))*width_blocks+width_blocks_idx;\n" -" COMPUTE_FLOAT4 b0=CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset,input_b));\n" -" COMPUTE_FLOAT4 b1=((remain >= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+width_blocks,input_b)));\n" -" COMPUTE_FLOAT4 b2=((remain >= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+width_blocks*2,input_b)));\n" -" COMPUTE_FLOAT4 b3=((remain >= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+width_blocks*3,input_b)));\n" -" COMPUTE_FLOAT4 a0_trans=(COMPUTE_FLOAT4)(a0.x,a1.x,a2.x,a3.x);\n" -" COMPUTE_FLOAT4 a1_trans=(COMPUTE_FLOAT4)(a0.y,a1.y,a2.y,a3.y);\n" -" COMPUTE_FLOAT4 a2_trans=(COMPUTE_FLOAT4)(a0.z,a1.z,a2.z,a3.z);\n" -" COMPUTE_FLOAT4 a3_trans=(COMPUTE_FLOAT4)(a0.w,a1.w,a2.w,a3.w);\n" -" \n" -" COMPUTE_FLOAT4 b0_trans=(COMPUTE_FLOAT4)(b0.x,b1.x,b2.x,b3.x);\n" -" COMPUTE_FLOAT4 b1_trans=(COMPUTE_FLOAT4)(b0.y,b1.y,b2.y,b3.y);\n" -" COMPUTE_FLOAT4 b2_trans=(COMPUTE_FLOAT4)(b0.z,b1.z,b2.z,b3.z);\n" -" COMPUTE_FLOAT4 b3_trans=(COMPUTE_FLOAT4)(b0.w,b1.w,b2.w,b3.w);\n" -" //matmul\n" -" result0.x += dot(a0_trans,b0_trans);\n" -" result0.y += dot(a0_trans,b1_trans);\n" -" result0.z += dot(a0_trans,b2_trans);\n" -" result0.w += dot(a0_trans,b3_trans);\n" -" \n" -" result1.x += dot(a1_trans,b0_trans);\n" -" result1.y += dot(a1_trans,b1_trans);\n" -" result1.z += dot(a1_trans,b2_trans);\n" -" result1.w += dot(a1_trans,b3_trans);\n" -" \n" -" result2.x += dot(a2_trans,b0_trans);\n" -" result2.y += dot(a2_trans,b1_trans);\n" -" result2.z += dot(a2_trans,b2_trans);\n" -" result2.w += dot(a2_trans,b3_trans);\n" +" COMPUTE_FLOAT4 tmp0=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex));\n" +" COMPUTE_FLOAT4 tmp1=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex+K));\n" +" COMPUTE_FLOAT4 tmp2=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex+2*K));\n" +" COMPUTE_FLOAT4 tmp3=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex+3*K));\n" " \n" -" result3.x += dot(a3_trans,b0_trans);\n" -" result3.y += dot(a3_trans,b1_trans);\n" -" result3.z += dot(a3_trans,b2_trans);\n" -" result3.w += dot(a3_trans,b3_trans);\n" +" B[0]=(COMPUTE_FLOAT4)(tmp0.x,tmp1.x,tmp2.x,tmp3.x);\n" +" B[1]=(COMPUTE_FLOAT4)(tmp0.y,tmp1.y,tmp2.y,tmp3.y);\n" +" B[2]=(COMPUTE_FLOAT4)(tmp0.z,tmp1.z,tmp2.z,tmp3.z);\n" +" B[3]=(COMPUTE_FLOAT4)(tmp0.w,tmp1.w,tmp2.w,tmp3.w);\n" " }\n" -" \n" -" const int out_offset=(4*height_blocks_idx)*width_blocks+width_blocks_idx;\n" -" vstore4(CONVERT_FLOAT4(result0),out_offset,output_c);\n" -" if(4*height_blocks_idx+1 >= height) return;\n" -" vstore4(CONVERT_FLOAT4(result1),out_offset+width_blocks,output_c);\n" -" if(4*height_blocks_idx+2 >= height) return;\n" -" vstore4(CONVERT_FLOAT4(result2),out_offset+width_blocks*2,output_c);\n" -" if(4*height_blocks_idx+3 >= height) return;\n" -" vstore4(CONVERT_FLOAT4(result3),out_offset+width_blocks*3,output_c);\n" -"}\n" -"__kernel void matmul_transA_transB_buf(GLOBAL_SIZE_2_DIMS __global const FLOAT* input_a,\n" -" __global const FLOAT* input_b,\n" -" #ifdef BIAS\n" -" __global const FLOAT* input_c,\n" -" #endif\n" -" __global FLOAT* output_c,\n" -" __private const int channels,\n" -" __private const int channel_blocks,\n" -" __private const int height,\n" -" __private const int height_blocks,\n" -" __private const int width_blocks,\n" -" __private const int width) {\n" -" const int width_blocks_idx=get_global_id(0);\n" -" const int height_blocks_idx=get_global_id(1);\n" -" DEAL_NON_UNIFORM_DIM2(width_blocks_idx,height_blocks_idx);\n" -" COMPUTE_FLOAT4 v_zero=(COMPUTE_FLOAT4)((COMPUTE_FLOAT)0.0);\n" -" #ifdef BIAS\n" -" COMPUTE_FLOAT4 result0=CONVERT_COMPUTE_FLOAT4(vload4(width_blocks_idx,input_c));\n" -" COMPUTE_FLOAT4 result1=result0;\n" -" COMPUTE_FLOAT4 result2=result0;\n" -" COMPUTE_FLOAT4 result3=result0;\n" " #else\n" -" COMPUTE_FLOAT4 result0=0;\n" -" COMPUTE_FLOAT4 result1=0;\n" -" COMPUTE_FLOAT4 result2=0;\n" -" COMPUTE_FLOAT4 result3=0;\n" -" #endif\n" -" \n" -" const int remaina=channel_blocks*4-channels;\n" -" const int remainb=(width_blocks_idx+1)*4-width;\n" -" for (short pos=0; pos= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+channel_blocks,input_b)));\n" -" COMPUTE_FLOAT4 b2=((remainb >= 2) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+channel_blocks*2,input_b)));\n" -" COMPUTE_FLOAT4 b3=((remainb >= 1) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+channel_blocks*3,input_b)));\n" -" COMPUTE_FLOAT4 a0_trans=(COMPUTE_FLOAT4)(a0.x,a1.x,a2.x,a3.x);\n" -" COMPUTE_FLOAT4 a1_trans=(COMPUTE_FLOAT4)(a0.y,a1.y,a2.y,a3.y);\n" -" COMPUTE_FLOAT4 a2_trans=(COMPUTE_FLOAT4)(a0.z,a1.z,a2.z,a3.z);\n" -" COMPUTE_FLOAT4 a3_trans=(COMPUTE_FLOAT4)(a0.w,a1.w,a2.w,a3.w);\n" -" //matmul\n" -" result0.x += dot(a0_trans,b0);\n" -" result0.y += dot(a0_trans,b1);\n" -" result0.z += dot(a0_trans,b2);\n" -" result0.w += dot(a0_trans,b3);\n" -" \n" -" result1.x += dot(a1_trans,b0);\n" -" result1.y += dot(a1_trans,b1);\n" -" result1.z += dot(a1_trans,b2);\n" -" result1.w += dot(a1_trans,b3);\n" -" \n" -" result2.x += dot(a2_trans,b0);\n" -" result2.y += dot(a2_trans,b1);\n" -" result2.z += dot(a2_trans,b2);\n" -" result2.w += dot(a2_trans,b3);\n" +" B[0]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+kindex*N));\n" +" B[1]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+(kindex+1)*N));\n" +" B[2]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+(kindex+2)*N));\n" +" B[3]=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+(kindex+3)*N));\n" +" #endif\n" " \n" -" result3.x += dot(a3_trans,b0);\n" -" result3.y += dot(a3_trans,b1);\n" -" result3.z += dot(a3_trans,b2);\n" -" result3.w += dot(a3_trans,b3);\n" +" #pragma unroll\n" +" for (int vec_m=0; vec_m<4; ++vec_m){\n" +" out[vec_m]=mad((COMPUTE_FLOAT4)A[vec_m].x,B[0],out[vec_m]);\n" +" out[vec_m]=mad((COMPUTE_FLOAT4)A[vec_m].y,B[1],out[vec_m]);\n" +" out[vec_m]=mad((COMPUTE_FLOAT4)A[vec_m].z,B[2],out[vec_m]);\n" +" out[vec_m]=mad((COMPUTE_FLOAT4)A[vec_m].w,B[3],out[vec_m]);\n" " }\n" +" }\n" +" #ifdef K_LEAVE\n" +" for (int k=loop_end << 2; k= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset+height_blocks,input_a)));\n" -" COMPUTE_FLOAT4 a2=((remaina >= 2) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset+height_blocks*2,input_a)));\n" -" COMPUTE_FLOAT4 a3=((remaina >= 1) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpa_offset+height_blocks*3,input_a)));\n" -" const int inpb_offset=(4*width_blocks_idx)*channel_blocks+channel_blocks-1;\n" -" COMPUTE_FLOAT4 b0=CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset,input_b));\n" -" COMPUTE_FLOAT4 b1=((remainb >= 3) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+channel_blocks,input_b)));\n" -" COMPUTE_FLOAT4 b2=((remainb >= 2) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+channel_blocks*2,input_b)));\n" -" COMPUTE_FLOAT4 b3=((remainb >= 1) ? v_zero : CONVERT_COMPUTE_FLOAT4(vload4(inpb_offset+channel_blocks*3,input_b)));\n" -" COMPUTE_FLOAT4 a0_trans=(COMPUTE_FLOAT4)(a0.x,a1.x,a2.x,a3.x);\n" -" COMPUTE_FLOAT4 a1_trans=(COMPUTE_FLOAT4)(a0.y,a1.y,a2.y,a3.y);\n" -" COMPUTE_FLOAT4 a2_trans=(COMPUTE_FLOAT4)(a0.z,a1.z,a2.z,a3.z);\n" -" COMPUTE_FLOAT4 a3_trans=(COMPUTE_FLOAT4)(a0.w,a1.w,a2.w,a3.w);\n" -" //matmul\n" -" result0.x += dot(a0_trans,b0);\n" -" result0.y += dot(a0_trans,b1);\n" -" result0.z += dot(a0_trans,b2);\n" -" result0.w += dot(a0_trans,b3);\n" +" #ifdef TRANSPOSE_B\n" +" B.x=(COMPUTE_FLOAT)input_b_offset[k];\n" +" B.y=(COMPUTE_FLOAT)input_b_offset[k+K];\n" +" B.z=(COMPUTE_FLOAT)input_b_offset[k+2*K];\n" +" B.w=(COMPUTE_FLOAT)input_b_offset[k+3*K];\n" +" #else\n" +" B=CONVERT_COMPUTE_FLOAT4(vload4(0,input_b_offset+k*N));\n" +" #endif\n" +" out[0]=mad((COMPUTE_FLOAT4)A.x,B,out[0]);\n" +" out[1]=mad((COMPUTE_FLOAT4)A.y,B,out[1]);\n" +" out[2]=mad((COMPUTE_FLOAT4)A.z,B,out[2]);\n" +" out[3]=mad((COMPUTE_FLOAT4)A.w,B,out[3]);\n" +" }\n" +" #endif\n" " \n" -" result1.x += dot(a1_trans,b0);\n" -" result1.y += dot(a1_trans,b1);\n" -" result1.z += dot(a1_trans,b2);\n" -" result1.w += dot(a1_trans,b3);\n" " \n" -" result2.x += dot(a2_trans,b0);\n" -" result2.y += dot(a2_trans,b1);\n" -" result2.z += dot(a2_trans,b2);\n" -" result2.w += dot(a2_trans,b3);\n" +" const int out_offset=idm*N+idn;\n" +" #ifdef M_LEAVE\n" +" if(idm+3 >= M){\n" +" #ifdef N_LEAVE\n" +" if(idn+3 >= N){\n" +" for (int vec_m=0; vec_m= height) return;\n" -" vstore4(CONVERT_FLOAT4(result1),out_offset+width_blocks,output_c);\n" -" if(4*height_blocks_idx+2 >= height) return;\n" -" vstore4(CONVERT_FLOAT4(result2),out_offset+width_blocks*2,output_c);\n" -" if(4*height_blocks_idx+3 >= height) return;\n" -" vstore4(CONVERT_FLOAT4(result3),out_offset+width_blocks*3,output_c);\n" +" #endif\n" +" } else{\n" +" #endif\n" +" #ifdef N_LEAVE\n" +" if(idn+3 >= N){\n" +" #pragma unroll\n" +" for (int vec_m=0; vec_m<4; ++vec_m){\n" +" COMPUTE_FLOAT *out_ptr=(COMPUTE_FLOAT*)&out[vec_m];\n" +" for(int vec_n=0; vec_n= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n" +"#ifdef CONV_LOCAL_SIZE\n" +"__kernel\n" +"void conv_2d_1x1_local(__private const int out_w_blocks,\n" +" __global const FLOAT *input,\n" +" __global const FLOAT *kernel_ptr,\n" +" __global const FLOAT *bias_ptr,\n" +" __global FLOAT *output,\n" +" __private const int in_c_block,\n" +" __private const int batch,\n" +" __private const int out_h,\n" +" __private const int out_w,\n" +" __private const int out_c_block,\n" +" __private const int out_c_pack) {\n" +" const int lid=get_local_id(0);\n" +" const int out_c_w_idx=get_global_id(1); //c/4 w\n" +" const int out_b_h_idx=get_global_id(2); //b h\n" +" \n" +" COMPUTE_FLOAT4 local sum[CONV_LOCAL_SIZE];\n" +" \n" +" const int out_c_idx=out_c_w_idx/out_w_blocks;\n" +" const int out_w_idx=out_c_w_idx % out_w_blocks;\n" +" const int out_b_idx=out_b_h_idx/out_h; // equal to in_b_idx\n" +" const int out_h_idx=out_b_h_idx % out_h; // equal to in_h_idx\n" +" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias_ptr));\n" +" COMPUTE_FLOAT4 out0=(COMPUTE_FLOAT4)0;\n" +" int offset=out_c_idx*4;\n" +" int inp_offset=(((out_b_idx+in_c_block*batch)*out_h+out_h_idx)* out_w+out_w_idx) << 2;\n" +" \n" +" const int inp_add=batch*out_h*out_w*4;\n" +" for (ushort in_channel_block_idx=lid; in_channel_block_idx0; i /= 2){\n" +" if (lid= 4) {\n" @@ -14823,6 +13144,7 @@ const char* conv_2d_buf = " __private const int in_c_block,\n" " __private const int out_h,\n" " __private const int out_w,\n" +" __private const int out_b,\n" " __private const int out_c_block,\n" " __private const int out_c_pack) {\n" " const int out_c_w_idx=get_global_id(0); //c/8 w/4\n" @@ -14843,10 +13165,10 @@ const char* conv_2d_buf = " COMPUTE_FLOAT4 out6=out4;\n" " COMPUTE_FLOAT4 out7=out4;\n" " const int intput_width_idx0=out_w4_idx;\n" +" int inp_offset=((out_b_idx*out_h+out_h_idx)* out_w+intput_width_idx0)<<2;\n" +" int offset=out_c_idx*8;\n" +" const int inp_add=out_b*out_h*out_w*4;\n" " for (int in_channel_block_idx=0; in_channel_block_idx= 4) {\n" @@ -14972,6 +13297,7 @@ const char* conv_2d_buf = " __private const int in_c_block,\n" " __private const int out_h,\n" " __private const int out_w,\n" +" __private const int out_b,\n" " __private const int out_c_block,\n" " __private const int out_c_pack) {\n" " const int out_c_w_idx=get_global_id(0); //c/8 w/4\n" @@ -14989,10 +13315,10 @@ const char* conv_2d_buf = " COMPUTE_FLOAT4 out4=CONVERT_COMPUTE_FLOAT4(vload4((out_c_idx<<1)+1,bias_ptr));\n" " COMPUTE_FLOAT4 out5=out4;\n" " const int intput_width_idx0=out_w2_idx;\n" +" int inp_offset=((out_b_idx*out_h+out_h_idx)* out_w+intput_width_idx0)<<2;\n" +" int offset=out_c_idx*8;\n" +" const int inp_add=out_b*out_h*out_w*4;\n" " for (int in_channel_block_idx=0; in_channel_block_idx= 2) {\n" @@ -15075,6 +13404,7 @@ const char* conv_2d_buf = " __private const int in_c_block,\n" " __private const int out_h,\n" " __private const int out_w,\n" +" __private const int out_b,\n" " __private const int out_c_block,\n" " __private const int out_c_pack) {\n" " const int out_c_w_idx=get_global_id(0); //c/4 w\n" @@ -15086,12 +13416,12 @@ const char* conv_2d_buf = " const int out_h_idx=out_b_h_idx % out_h;//equal to in_h_idx\n" " COMPUTE_FLOAT4 out0=CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx,bias_ptr));\n" " const int intput_width_idx0=out_w_idx;\n" +" int offset=out_c_idx*4;\n" +" int inp_offset=((out_b_idx*out_h+out_h_idx)*out_w+intput_width_idx0)*4;\n" +" const int inp_add=out_b*out_h*out_w*4;\n" " \n" " for (int in_channel_block_idx=0; in_channel_block_idx= 2) {\n" @@ -15185,6 +13522,7 @@ const char* conv_2d_buf = " __private const int2 in_hw,\n" " __private const int inChannel,\n" " __private const int in_c_blocks,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 stride_hw,\n" @@ -15221,7 +13559,7 @@ const char* conv_2d_buf = " int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+kh_start)*filter_hw.y+kw_start)*4;\n" " for(int iy=in_h_idx_start; iy= out_hw.y) return;\n" @@ -15340,6 +13679,7 @@ const char* conv_2d_buf = " __private const int2 in_hw,\n" " __private const int inChannel,\n" " __private const int in_c_blocks,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 stride_hw,\n" @@ -15375,7 +13715,7 @@ const char* conv_2d_buf = " //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n" " int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+kh_start)*filter_hw.y+0)*4;\n" " for(int iy=in_h_idx_start; iy= 4) {\n" @@ -15451,6 +13791,7 @@ const char* conv_2d_buf = " __private const int2 in_hw,\n" " __private const int inChannel,\n" " __private const int in_c_blocks,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 stride_hw,\n" @@ -15486,7 +13827,7 @@ const char* conv_2d_buf = " for(ushort in_c_idx=0; in_c_idx= 4){\n" @@ -15573,6 +13914,7 @@ const char* conv_2d_buf = " __private const int2 in_hw,\n" " __private const int inChannel,\n" " __private const int in_c_blocks,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 stride_hw,\n" @@ -15613,7 +13955,7 @@ const char* conv_2d_buf = " for(ushort in_c_idx=0; in_c_idx= 4){\n" @@ -15715,12 +14057,12 @@ const char* conv_2d_buf = " }else if(remain == 1){\n" " vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n" " }\n" -"#ifdef CHANNEL_LEAVE\n" +" #ifdef CHANNEL_LEAVE\n" " if(out_c_idx+1 >= out_c_blocks){\n" " return;\n" " }\n" -"#endif\n" -" out_offset=(((out_b_idx*out_c_blocks+out_c_idx+1)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" #endif\n" +" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " if(remain >= 4){\n" " vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n" @@ -15741,12 +14083,12 @@ const char* conv_2d_buf = " vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out2),2*out_hw.y,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out3),3*out_hw.y,output+out_offset);\n" -"#ifdef CHANNEL_LEAVE\n" +" #ifdef CHANNEL_LEAVE\n" " if(out_c_idx+1 >= out_c_blocks){\n" " return;\n" " }\n" -"#endif\n" -" out_offset=(((out_b_idx*out_c_blocks+out_c_idx+1)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" #endif\n" +" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " vstore4(CONVERT_FLOAT4(out4),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out5),out_hw.y,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out6),2*out_hw.y,output+out_offset);\n" @@ -15762,6 +14104,7 @@ const char* conv_2d_buf = " __private const int2 in_hw,\n" " __private const int inChannel,\n" " __private const int in_c_blocks,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 stride_hw,\n" @@ -15797,7 +14140,7 @@ const char* conv_2d_buf = " for(ushort in_c_idx=0; in_c_idx= 2){\n" @@ -15860,12 +14203,12 @@ const char* conv_2d_buf = " }else if(remain == 1){\n" " vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n" " }\n" -"#ifdef CHANNEL_LEAVE\n" +" #ifdef CHANNEL_LEAVE\n" " if(out_c_idx+1 >= out_c_blocks){\n" " return;\n" " }\n" -"#endif\n" -" out_offset=(((out_b_idx*out_c_blocks+out_c_idx+1)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" #endif\n" +" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " if(remain >= 2){\n" " vstore4(CONVERT_FLOAT4(out2),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out3),out_hw.y,output+out_offset);\n" @@ -15875,12 +14218,12 @@ const char* conv_2d_buf = "#else\n" " vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out1),out_hw.y,output+out_offset);\n" -"#ifdef CHANNEL_LEAVE\n" +" #ifdef CHANNEL_LEAVE\n" " if(out_c_idx+1 >= out_c_blocks){\n" " return;\n" " }\n" -"#endif\n" -" out_offset=(((out_b_idx*out_c_blocks+out_c_idx+1)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" #endif\n" +" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " vstore4(CONVERT_FLOAT4(out2),0,output+out_offset);\n" " vstore4(CONVERT_FLOAT4(out3),out_hw.y,output+out_offset);\n" "#endif\n" @@ -15894,6 +14237,7 @@ const char* conv_2d_buf = " __private const int2 in_hw,\n" " __private const int inChannel,\n" " __private const int in_c_blocks,\n" +" __private const int batch,\n" " __private const int2 out_hw,\n" " __private const int2 filter_hw,\n" " __private const int2 stride_hw,\n" @@ -15936,7 +14280,7 @@ const char* conv_2d_buf = " //index: [0,4*in_c_idx,out_c_idx*kh*kw+kh_start*kw+kw_start,0]\n" " int weight_offset=((((4*in_c_idx+0)* out_c_blocks+out_c_idx) *filter_hw.x+kh_start)*filter_hw.y+0)*4;\n" " for(int iy=in_h_idx_start; iy= 4){\n" @@ -16032,10 +14376,10 @@ const char* conv_2d_buf = " }else if(remain == 1){\n" " vstore4(CONVERT_FLOAT4(out0),0,output+out_offset);\n" " }\n" -"#ifdef CHANNEL_LEAVE\n" +" #ifdef CHANNEL_LEAVE\n" " if(out_c_idx+1 >= out_c_blocks)return;\n" -"#endif\n" -" out_offset=(((out_b_idx*out_c_blocks+out_c_idx+1)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" #endif\n" +" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " if(remain >= 4){\n" " vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4,out5,out6,out7)),0,output+out_offset);\n" " }else if(remain == 3){\n" @@ -16048,10 +14392,10 @@ const char* conv_2d_buf = " }\n" "#else\n" " vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out0,out1,out2,out3)),0,output+out_offset);\n" -"#ifdef CHANNEL_LEAVE\n" +" #ifdef CHANNEL_LEAVE\n" " if(out_c_idx+1 >= out_c_blocks)return;\n" -"#endif\n" -" out_offset=(((out_b_idx*out_c_blocks+out_c_idx+1)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" +" #endif\n" +" out_offset=(((out_b_idx+(out_c_idx+1)*batch)*out_hw.x+out_h_idx)*out_hw.y+out_w_idx)*4;\n" " vstore16(CONVERT_FLOAT16((COMPUTE_FLOAT16)(out4,out5,out6,out7)),0,output+out_offset);\n" "#endif\n" "}\n" @@ -16573,312 +14917,94 @@ const char* winogradTransformDest2_3_1 = " res=max(res,(FLOAT4)(0));\n" "#endif\n" "#ifdef RELU6\n" -" res=clamp(res,(FLOAT4)(0),(FLOAT4)(6));\n" -"#endif\n" -" WI_F(uOutput,(int2)(imageOx,imageOy),res);\n" -" }\n" -" }\n" -" {\n" -" int ox=oxStart+0;\n" -" int oy=oyStart+1;\n" -" if (ox0; i /= 2){\n" -" if (lid0; i /= 2){\n" -" if (lid0; i /= 2){\n" -" if (lid0; i /= 2){\n" -" if (lid0; i /= 2){\n" -" if (lid0; i /= 2){\n" -" if (lid1\n" +" float local sum[LOCAL_SIZE];\n" +" if (pos.x> 2;\n" +" #ifdef PACK_LEAVE\n" +" const int loop=inside_v4-1;\n" " const int inside_remain=inside-((inside_v4-1) << 2);\n" -" COMPUTE_FLOAT4 in_sum=0;\n" +" #else\n" +" const int loop=inside_v4;\n" +" #endif\n" +" \n" +" float4 in_sum=0;\n" " int index=lid;\n" -" for(; index1) {\n" -" sum[lid]=sum[lid]+in_left.y;\n" -" }\n" -" if(inside_remain>2) {\n" -" sum[lid]=sum[lid]+in_left.z;\n" -" }\n" -" if(inside_remain>3) {\n" -" sum[lid]=sum[lid]+in_left.w;\n" +" for(int i=0; i0; i /= 2){\n" @@ -16887,47 +15013,86 @@ const char* layernorm_buf = " barrier(CLK_LOCAL_MEM_FENCE);\n" " }\n" " \n" -" COMPUTE_FLOAT4 mean=sum[0]/(COMPUTE_FLOAT4)inside;\n" +" float4 mean=sum[0]/(float4)inside;\n" +" #endif\n" " in_sum=0;\n" " index=lid;\n" -" for(; index1) {\n" -" sum[lid]=sum[lid]+in_sum.y;\n" -" }\n" -" if(inside_remain>2) {\n" -" sum[lid]=sum[lid]+in_sum.z;\n" -" }\n" -" if(inside_remain>3) {\n" -" sum[lid]=sum[lid]+in_sum.w;\n" +" for(int i=0; i0; i /= 2){\n" " if (lid= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" -"__kernel void softmax_channel(GLOBAL_SIZE_3_DIMS\n" +"__kernel void softmax_in1_buf(GLOBAL_SIZE_3_DIMS\n" " __global const FLOAT *input,\n" " __global FLOAT *output,\n" -" __private const int remain_channels,\n" -" __private const int4 shape) {//NCHW\n" +" __private const int inside,\n" +" __private const int outside,\n" +" __private const int dim) {\n" " const int x=get_global_id(0);\n" -" const int w=get_global_id(1);\n" -" const int bh=get_global_id(2);\n" -" DEAL_NON_UNIFORM_DIM3(x,w,bh);\n" +" const int y=get_global_id(1); // inside=1\n" +" const int z=get_global_id(2); // outside\n" +" DEAL_NON_UNIFORM_DIM3(x,y,z);\n" " \n" -" const int batch_idx=bh/shape.z;\n" -" const int height_idx=bh % shape.z;\n" -" const int offset=(((batch_idx*shape.y+0)*shape.z+height_idx)*shape.w+w)*4;\n" +" const int offset=z*dim+y;\n" +" const int dim4=(dim+3)/4;\n" +" const int loop_end=max(0,dim4-1);\n" "#if SOFTMAX_LOCAL_SIZE >= 4\n" " int lid=get_local_id(0);\n" -" COMPUTE_FLOAT4 local sum[SOFTMAX_LOCAL_SIZE];\n" +" COMPUTE_FLOAT local sum[SOFTMAX_LOCAL_SIZE];\n" +" // compute maxvalue\n" " COMPUTE_FLOAT4 maxValue=(COMPUTE_FLOAT4)-FLT_MAX;\n" -" for (int i=lid; i0; i /= 2){\n" " if (lid0; i /= 2){\n" " if (lid= 4\n" " int lid=get_local_id(0);\n" -" COMPUTE_FLOAT4 local sum[SOFTMAX_LOCAL_SIZE];\n" -" \n" -" /*Compute Max */\n" -" COMPUTE_FLOAT4 maxValue=(COMPUTE_FLOAT4)(-FLT_MAX);\n" -" for (int i=lid; i= 4\n" " int lid=get_local_id(0);\n" " COMPUTE_FLOAT4 local sum[SOFTMAX_LOCAL_SIZE];\n" -" \n" -" /*Compute Max */\n" -" COMPUTE_FLOAT4 maxValue=(COMPUTE_FLOAT4)(-FLT_MAX);\n" -" for (int i=lid; i> 2;\n" -" #ifdef GATHER_INPUT_NHWC\n" -" int off_c=offset_value % offset_dst_shape.z; offset_value /= offset_dst_shape.z;\n" -" int off_w=offset_value % offset_dst_shape.x; offset_value /= offset_dst_shape.x;\n" -" int off_h=offset_value % offset_dst_shape.y;\n" -" int off_b=offset_value/offset_dst_shape.y;\n" -" #else\n" -" int off_w=offset_value % offset_dst_shape.x; offset_value /= offset_dst_shape.x;\n" -" int off_h=offset_value % offset_dst_shape.y; offset_value /= offset_dst_shape.y;\n" -" int off_c=offset_value % offset_dst_shape.z;\n" -" int off_b=offset_value/offset_dst_shape.z;\n" -" #endif\n" -" int real_dst_offset=(((off_b*off_c4_size+off_c/4)*offset_dst_shape.y+off_h)*offset_dst_shape.x+off_w)*4+off_c % 4;\n" -" index.x=offset_dst_ptr[real_dst_offset];\n" -" }\n" -" #endif\n" -" \n" -" #ifdef OFFSET_SRC\n" -" {\n" -" int offset_value=pos.z;\n" -" int off_c4_size=(offset_src_shape.z+3) >> 2;\n" -" #ifdef GATHER_INPUT_NHWC\n" -" int off_c=offset_value % offset_src_shape.z; offset_value /= offset_src_shape.z;\n" -" int off_w=offset_value % offset_src_shape.x; offset_value /= offset_src_shape.x;\n" -" int off_h=offset_value % offset_src_shape.y;\n" -" int off_b=offset_value/offset_src_shape.y;\n" -" #else\n" -" int off_w=offset_value % offset_src_shape.x; offset_value /= offset_src_shape.x;\n" -" int off_h=offset_value % offset_src_shape.y; offset_value /= offset_src_shape.y;\n" -" int off_c=offset_value % offset_src_shape.z;\n" -" int off_b=offset_value/offset_src_shape.z;\n" -" #endif\n" -" int real_src_offset=(((off_b*off_c4_size+off_c/4)*offset_src_shape.y+off_h)*offset_src_shape.x+off_w)*4+off_c % 4;\n" -" index.y=offset_src_ptr[real_src_offset];\n" -" }\n" -" #endif\n" -" \n" +"#ifdef OFFSET_SRC\n" +" index.y=offset_src_ptr[pos.z];\n" +"#endif\n" " int2 offset=index*steps;\n" " int src_offset=offset.y+stride_src.w+x*stride_src.x+y*stride_src.y+pos.y*stride_src.z;\n" " int dst_offset=offset.x+stride_dst.w+x*stride_dst.x+y*stride_dst.y+pos.y*stride_dst.z;\n" -" int src_offsetC4,dst_offsetC4;\n" -" {\n" -"#ifdef GATHER_INPUT_NHWC\n" -" int c=src_offset % src_c4size.z; src_offset /= src_c4size.z;\n" -" int w=src_offset % src_c4size.x; src_offset /= src_c4size.x;\n" -" int h=src_offset % src_c4size.y;\n" -" int b=src_offset/src_c4size.y;\n" -" int c4_size=(src_c4size.z+3)/4;\n" -" src_offsetC4=(((b*c4_size+(c/4))*src_c4size.y+h)*src_c4size.x+w)*4+(c % 4);\n" -"#else\n" -" int w=src_offset % src_c4size.x; src_offset /= src_c4size.x;\n" -" int h=src_offset % src_c4size.y; src_offset /= src_c4size.y;\n" -" int c=src_offset % src_c4size.z;\n" -" int b=src_offset/src_c4size.z;\n" -" int c4_size=(src_c4size.z+3)/4;\n" -" src_offsetC4=(((b*c4_size+(c/4))*src_c4size.y+h)*src_c4size.x+w)*4+(c % 4);\n" -"#endif\n" -" }\n" -" {\n" -"#ifdef GATHER_OUTPUT_NHWC\n" -" int c=dst_offset % dst_c4size.z; dst_offset /= dst_c4size.z;\n" -" int w=dst_offset % dst_c4size.x; dst_offset /= dst_c4size.x;\n" -" int h=dst_offset % dst_c4size.y;\n" -" int b=dst_offset/dst_c4size.y;\n" -" int c4_size=(dst_c4size.z+3)/4;\n" -" dst_offsetC4=(((b*c4_size+(c/4))*dst_c4size.y+h)*dst_c4size.x+w)*4+(c % 4);\n" -"#else\n" -" int w=dst_offset % dst_c4size.x; dst_offset /= dst_c4size.x;\n" -" int h=dst_offset % dst_c4size.y; dst_offset /= dst_c4size.y;\n" -" int c=dst_offset % dst_c4size.z;\n" -" int b=dst_offset/dst_c4size.z;\n" -" int c4_size=(dst_c4size.z+3)/4;\n" -" dst_offsetC4=(((b*c4_size+(c/4))*dst_c4size.y+h)*dst_c4size.x+w)*4+(c % 4);\n" -"#endif\n" -" }\n" " if(offset.x >= 0){\n" " if(offset.y >= 0 && offset.y1\n" @@ -17576,6 +15598,7 @@ const char* conv_2d_c16_subgroup_buf = " __private const int output_width,\n" " __private const int output_height,\n" " __private const int output_channel,\n" +" __private const int batch,\n" " __private const int x_blocks,\n" " __private const int input_pad_left,\n" " __private const int input_pad_right,\n" @@ -17603,9 +15626,9 @@ const char* conv_2d_c16_subgroup_buf = " const uint output_x_pitch=4;\n" " const uint output_y_pitch=output_x_pitch*output_width;\n" " const uint output_fs_pitch=output_y_pitch*output_height;\n" -" const uint output_b_pitch=output_fs_pitch*((output_channel+3)/4);\n" -" const uint output_offset=b*output_b_pitch +\n" -" (feature_block << 2)*output_fs_pitch +\n" +" const uint output_b_pitch=output_fs_pitch*batch;\n" +" const uint output_offset=b*output_fs_pitch +\n" +" (feature_block << 2)*output_b_pitch +\n" " y*output_y_pitch +\n" " x*output_x_pitch;\n" " const uint filter_isv_pitch=16;\n" @@ -17746,13 +15769,13 @@ const char* conv_2d_c16_subgroup_buf = " if ((feature_block+1)*16 >= output_channel) {\n" " for (int i=0; i<4 && (x+i)1\n" @@ -17772,6 +15795,7 @@ const char* conv_2d_c16_subgroup_buf = " __private const int output_width,\n" " __private const int output_height,\n" " __private const int output_channel,\n" +" __private const int batch,\n" " __private const int x_blocks,\n" " __private const int input_pad_left,\n" " __private const int input_pad_right,\n" @@ -17799,9 +15823,9 @@ const char* conv_2d_c16_subgroup_buf = " const uint output_x_pitch=4;\n" " const uint output_y_pitch=output_x_pitch*output_width;\n" " const uint output_fs_pitch=output_y_pitch*output_height;\n" -" const uint output_b_pitch=output_fs_pitch*((output_channel+3)/4);\n" -" const uint output_offset=b*output_b_pitch +\n" -" (feature_block << 2)*output_fs_pitch +\n" +" const uint output_b_pitch=output_fs_pitch*batch;\n" +" const uint output_offset=b*output_fs_pitch +\n" +" (feature_block << 2)*output_b_pitch +\n" " y*output_y_pitch +\n" " x*output_x_pitch;\n" " const uint filter_isv_pitch=16;\n" @@ -17942,13 +15966,13 @@ const char* conv_2d_c16_subgroup_buf = " if ((feature_block+1)*16 >= output_channel) {\n" " for (int i=0; i<8 && (x+i)1\n" @@ -17968,6 +15992,7 @@ const char* conv_2d_c16_subgroup_buf = " __private const int output_width,\n" " __private const int output_height,\n" " __private const int output_channel,\n" +" __private const int batch,\n" " __private const int x_blocks,\n" " __private const int input_pad_left,\n" " __private const int input_pad_right,\n" @@ -18176,6 +16201,7 @@ const char* conv_2d_c16_subgroup_buf = " __private const int output_width,\n" " __private const int output_height,\n" " __private const int output_channel,\n" +" __private const int batch,\n" " __private const int x_blocks,\n" " __private const int input_pad_left,\n" " __private const int input_pad_right,\n" @@ -18383,6 +16409,7 @@ const char* conv_2d_c16_subgroup_buf = " __private const int output_width,\n" " __private const int output_height,\n" " __private const int output_channel,\n" +" __private const int batch,\n" " __private const int x_blocks,\n" " __private const int input_pad_left,\n" " __private const int input_pad_right,\n" @@ -18597,6 +16624,7 @@ const char* input_transe_buf = " __private const int input_width,\n" " __private const int input_height,\n" " __private const int input_channel,\n" +" __private const int batch,\n" " __private const int channel_blocks,\n" " __private const int input_pad_left,\n" " __private const int input_pad_right)\n" @@ -18613,9 +16641,9 @@ const char* input_transe_buf = " const uint input_x_pitch=4;\n" " const uint input_y_pitch=input_x_pitch*input_width;\n" " const uint input_f_pitch=input_y_pitch*input_height;\n" -" const uint input_b_pitch=input_f_pitch*channel_blocks;\n" -" const uint input_offset=b*input_b_pitch +\n" -" c*input_f_pitch +\n" +" const uint input_b_pitch=input_f_pitch*batch;\n" +" const uint input_offset=b*input_f_pitch +\n" +" c*input_b_pitch +\n" " h*input_y_pitch +\n" " w*input_x_pitch;\n" " // Output offset calculations:\n" @@ -18643,6 +16671,7 @@ const char* input_transe_buf = " int input_width,\n" " int input_height,\n" " int input_channel,\n" +" int batch,\n" " int channel_blocks,\n" " int input_pad_left,\n" " int input_pad_right)\n" @@ -18660,10 +16689,10 @@ const char* input_transe_buf = " const uint input_x_pitch=4;\n" " const uint input_y_pitch=input_x_pitch*input_width;\n" " const uint input_f_pitch=input_y_pitch*input_height;\n" -" const uint input_b_pitch=input_f_pitch*channel_blocks;\n" +" const uint input_b_pitch=input_f_pitch*batch;\n" " \n" -" const uint input_offset=b*input_b_pitch +\n" -" c*input_f_pitch +\n" +" const uint input_offset=b*input_f_pitch +\n" +" c*input_b_pitch +\n" " h*input_y_pitch +\n" " w*input_x_pitch;\n" " \n" @@ -18687,360 +16716,103 @@ const char* input_transe_buf = " }\n" " pad_offset += (input_pad_left+input_width)*output_x_pitch;\n" " for(int i=0; i= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" -"__kernel void reduct_width_buf(GLOBAL_SIZE_3_DIMS\n" -" __global const INPUT_TYPE* input,\n" -" __global OUTPUT_TYPE* output,\n" -" __private const int inputWidth,\n" -" __private const int inputHeight,\n" -" __private const int inputChannel,\n" -" __private const int inputBatch,\n" -" __private const int inputChannelBlock,\n" -" __private const int oututWidth,\n" -" __private const int outputHeight,\n" -" __private const int outputChannel,\n" -" __private const int outputChannelBlock\n" -" ) {\n" -" const int width_idx=get_global_id(0);\n" -" const int height_idx=get_global_id(1);\n" -" const int batch_channel_idx=get_global_id(2);\n" -" DEAL_NON_UNIFORM_DIM3(width_idx,height_idx,batch_channel_idx);\n" -" \n" -" const int batch_idx=batch_channel_idx/outputChannelBlock;\n" -" const int channel_idx=batch_channel_idx % outputChannelBlock;\n" -" const int offset=((((batch_idx*inputChannelBlock)+channel_idx)*inputHeight+height_idx)*inputWidth+0)*4;\n" -" const int outputOffset=((((batch_idx*outputChannelBlock)+channel_idx)*outputHeight+height_idx)*oututWidth+0)*4;\n" -" INPUT_TYPE4 out=(INPUT_TYPE4)VALUE;\n" -" \n" -"#if LOCAL_SIZE>0\n" -" const int lid=get_local_id(0);\n" -" INPUT_TYPE4 local sum[LOCAL_SIZE];\n" -" for(int i=lid; i0; i /= 2){\n" -" if (lid0\n" -" const int width_local_idx=get_global_id(0);\n" -" const int height_idx=get_global_id(1);\n" -" const int batch_channel_idx=get_global_id(2);\n" -" DEAL_NON_UNIFORM_DIM3(width_local_idx,height_idx,batch_channel_idx);\n" -" \n" -" const int width_idx=get_group_id(0);\n" -" const int batch_idx=batch_channel_idx/outputChannelBlock;\n" -" const int channel_idx=batch_channel_idx % outputChannelBlock;\n" -" \n" -" const int offset=((((batch_idx*inputChannelBlock)+channel_idx)*inputHeight+0)*inputWidth+width_idx)*4;\n" -" const int outputOffset=((((batch_idx*outputChannelBlock)+channel_idx)*outputHeight+0)*oututWidth+width_idx)*4;\n" -" const int lid=get_local_id(0);\n" -" INPUT_TYPE4 local sum[LOCAL_SIZE];\n" -" INPUT_TYPE4 out=(INPUT_TYPE4)VALUE;\n" -" for(int i=lid; i0; i /= 2){\n" -" if (lid0\n" -" const int width_local_idx=get_global_id(0);\n" -" const int height_idx=get_global_id(1);\n" -" const int batch_idx=get_global_id(2);\n" -" \n" -" DEAL_NON_UNIFORM_DIM3(width_local_idx,height_idx,batch_idx);\n" -" const int width_idx=get_group_id(0);\n" -" \n" -" const int offset=((((batch_idx*inputChannelBlock)+0)*inputHeight+height_idx)*inputWidth+width_idx)*4;\n" -" const int outputOffset=((((batch_idx*outputChannelBlock)+0)*outputHeight+height_idx)*oututWidth+width_idx)*4;\n" -" int remain=inputChannel-(inputChannelBlock-1)*4;\n" -" const int lid=get_local_id(0);\n" -" INPUT_TYPE local sum[LOCAL_SIZE];\n" -" INPUT_TYPE4 out=(INPUT_TYPE4)VALUE;\n" -" INPUT_TYPE4 in;\n" -" INPUT_TYPE *inPtr=(INPUT_TYPE*)∈\n" -" for(int i=lid; i0; i /= 2){\n" -" if (lid0\n" -" const int width_local_idx=get_global_id(0);\n" -" const int height_idx=get_global_id(1);\n" -" const int batch_idx=get_global_id(2);\n" +" vstore4((FLOAT4)0,0,output+pad_offset+i*output_x_pitch);\n" +" }\n" +" }\n" +"}\n" +; +#endif +#ifndef MNN_OPENCL_BUFFER_CLOSED +const char* reduction_buf = +"// TODO: use INIT_SCALAR_VALUE,OPERATOR,FINAL_OPERATOR_ON_CHANNEL macro abstract and simplify code\n" +"// TODO: support reduce dims include batch\n" +"// TODO: support keep_dim=False\n" +"// TODO: fix channel reduce result re-pack problem\n" +"#ifdef MNN_SUPPORT_FP16\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" +"#endif\n" +"#define GLOBAL_SIZE_2_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,\n" +"#define GLOBAL_SIZE_3_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n" +"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" +"__kernel void reduct_buf(GLOBAL_SIZE_3_DIMS\n" +" __global const INPUT_TYPE *input,\n" +" __global OUTPUT_TYPE *output,\n" +" __private const int inside,\n" +" __private const int outside,\n" +" __private const int dim) {\n" +" const int x=get_global_id(0);\n" +" const int y=get_global_id(1); // inside\n" +" const int z=get_global_id(2); // outside\n" +" DEAL_NON_UNIFORM_DIM3(x,y,z);\n" " \n" -" DEAL_NON_UNIFORM_DIM3(width_local_idx,height_idx,batch_idx);\n" -" const int width_idx=get_group_id(0);\n" +" INPUT_TYPE out=(INPUT_TYPE)VALUE;\n" +" const int offset=z*dim*inside+y;\n" " \n" -" const int offset=((((batch_idx*inputChannelBlock)+0)*inputHeight+height_idx)*inputWidth+width_idx)*4;\n" -" const int outputOffset=((batch_idx*outputHeight+height_idx)*oututWidth+width_idx);\n" -" int remain=inputChannel-(inputChannelBlock-1)*4;\n" +"#if REDUCT_LOCAL_SIZE>4\n" " const int lid=get_local_id(0);\n" -" INPUT_TYPE local sum[LOCAL_SIZE];\n" -" INPUT_TYPE4 out=(INPUT_TYPE4)VALUE;\n" -" INPUT_TYPE4 in;\n" -" INPUT_TYPE *inPtr=(INPUT_TYPE*)∈\n" -" for(int i=lid; i0; i /= 2){\n" +" for(int i=REDUCT_LOCAL_SIZE/2; i>0; i /= 2){\n" " if (lid0\n" -" const int width_local_idx=get_global_id(0);\n" -" const int height_idx=get_global_id(1);\n" -" const int channel_idx=get_global_id(2);\n" -" DEAL_NON_UNIFORM_DIM3(width_local_idx,height_idx,channel_idx);\n" -" const int width_idx=get_group_id(0);\n" +"__kernel void reduct_v4_buf(GLOBAL_SIZE_3_DIMS\n" +" __global const INPUT_TYPE *input,\n" +" __global OUTPUT_TYPE *output,\n" +" __private const int inside,\n" +" __private const int outside,\n" +" __private const int dim) {\n" +" const int x=get_global_id(0);\n" +" const int y=get_global_id(1); // inside\n" +" const int z=get_global_id(2); // outside\n" +" DEAL_NON_UNIFORM_DIM3(x,y,z);\n" " \n" -" const int offset=((((0*inputChannelBlock)+channel_idx)*inputHeight+height_idx)*inputWidth+width_idx)*4;\n" -" const int outputOffset=((((0*outputChannelBlock)+channel_idx)*outputHeight+height_idx)*oututWidth+width_idx)*4;\n" -" int batchOffset=inputChannelBlock*inputHeight*inputWidth;\n" -" const int lid=get_local_id(0);\n" -" INPUT_TYPE4 local sum[LOCAL_SIZE];\n" " INPUT_TYPE4 out=(INPUT_TYPE4)VALUE;\n" -" for(int i=lid; i4\n" +" const int lid=get_local_id(0);\n" +" INPUT_TYPE4 local sum[REDUCT_LOCAL_SIZE];\n" +" for(int i=lid; i0; i /= 2){\n" +" for(int i=REDUCT_LOCAL_SIZE/2; i>0; i /= 2){\n" " if (lid with bias (eltwise_sub) [M,N]\n" "// 4 -> with bias (eltwise_sub and get negative) [M,N]\n" +"// 5 -> with bias (mask 0 for invalid) [M,N]\n" "#ifndef BIAS_TYPE\n" " #define BIAS_TYPE 0\n" "#endif\n" @@ -19233,13 +17006,38 @@ const char* matmul_params_buf = "#define DEAL_BIAS(x,a) x=x-a\n" "#elif BIAS_TYPE == 4\n" "#define DEAL_BIAS(x,a) x=a-x\n" +"#elif BIAS_TYPE == 5\n" +"#define DEAL_BIAS(x,a) x=(a == 0 ? (FLOAT)(-FLT_MAX) : x)\n" "#endif\n" "// By default the workgroup size requirement is enabled. For Qualcomm devices the workgroup size\n" "// requirement results in worse performance and is disabled (src/utilities/compile.cpp)\n" "#ifndef RELAX_WORKGROUP_SIZE\n" " #define RELAX_WORKGROUP_SIZE 0\n" "#endif\n" -"#define ZERO (FLOAT)0.0f\n" +"typedef float real_arg;\n" +"#define GetRealArg(x) (FLOAT)x\n" +"typedef FLOAT real;\n" +"#ifndef PRECISION_COMPUTE\n" +"#define PRECISION_COMPUTE COMPUTE_FLOAT\n" +"#define CONVERT_PRECISION_COMPUTE(x) CONVERT_COMPUTE_FLOAT(x)\n" +"#endif\n" +"#ifndef PRECISION_COMPUTE2\n" +"#define PRECISION_COMPUTE2 COMPUTE_FLOAT2\n" +"#define CONVERT_PRECISION_COMPUTE2(x) CONVERT_COMPUTE_FLOAT2(x)\n" +"#endif\n" +"#ifndef PRECISION_COMPUTE4\n" +"#define PRECISION_COMPUTE4 COMPUTE_FLOAT4\n" +"#define CONVERT_PRECISION_COMPUTE4(x) CONVERT_COMPUTE_FLOAT4(x)\n" +"#endif\n" +"#ifndef PRECISION_COMPUTE8\n" +"#define PRECISION_COMPUTE8 COMPUTE_FLOAT8\n" +"#define CONVERT_PRECISION_COMPUTE8(x) CONVERT_COMPUTE_FLOAT8(x)\n" +"#endif\n" +"#ifndef PRECISION_COMPUTE16\n" +"#define PRECISION_COMPUTE16 COMPUTE_FLOAT16\n" +"#define CONVERT_PRECISION_COMPUTE16(x) CONVERT_COMPUTE_FLOAT16(x)\n" +"#endif\n" +"#define ZERO (PRECISION_COMPUTE)0.0f\n" "// Sets a variable to zero\n" "#define SetToZero(a) a=ZERO\n" "#define IsZero(a) (a == ZERO)\n" @@ -19259,38 +17057,69 @@ const char* matmul_params_buf = "INLINE_FUNC int GetGroupID1() { return get_group_id(1); }\n" "INLINE_FUNC int GetGroupID0() { return get_group_id(0); }\n" "// =================================================================================================\n" -"// End of the C++11 raw string literal\n" -"typedef float real_arg;\n" -"#define GetRealArg(x) (FLOAT)x\n" -"typedef FLOAT real;\n" "// Data-widths in dimension M\n" "#if VWM == 1\n" " typedef FLOAT realM;\n" +" #define COMPUTE_FLOATM PRECISION_COMPUTE\n" +" #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE(x)\n" +" #define CONVERT_FLOATM(x) CONVERT_FLOAT(x)\n" "#elif VWM == 2\n" " typedef FLOAT2 realM;\n" +" #define COMPUTE_FLOATM PRECISION_COMPUTE2\n" +" #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE2(x)\n" +" #define CONVERT_FLOATM(x) CONVERT_FLOAT2(x)\n" "#elif VWM == 4\n" " typedef FLOAT4 realM;\n" +" #define COMPUTE_FLOATM PRECISION_COMPUTE4\n" +" #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE4(x)\n" +" #define CONVERT_FLOATM(x) CONVERT_FLOAT4(x)\n" "#elif VWM == 8\n" " typedef FLOAT8 realM;\n" +" #define COMPUTE_FLOATM PRECISION_COMPUTE8\n" +" #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE8(x)\n" +" #define CONVERT_FLOATM(x) CONVERT_FLOAT8(x)\n" "#elif VWM == 16\n" " typedef FLOAT16 realM;\n" +" #define COMPUTE_FLOATM PRECISION_COMPUTE16\n" +" #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE16(x)\n" +" #define CONVERT_FLOATM(x) CONVERT_FLOAT16(x)\n" "#endif\n" "// Data-widths in dimension N\n" "#if VWN == 1\n" " typedef FLOAT realN;\n" +" typedef int intN;\n" +" #define COMPUTE_FLOATN PRECISION_COMPUTE\n" +" #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE(x)\n" +" #define CONVERT_FLOATN(x) CONVERT_FLOAT(x)\n" "#elif VWN == 2\n" " typedef FLOAT2 realN;\n" +" typedef int2 intN;\n" +" #define COMPUTE_FLOATN PRECISION_COMPUTE2\n" +" #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE2(x)\n" +" #define CONVERT_FLOATN(x) CONVERT_FLOAT2(x)\n" "#elif VWN == 4\n" " typedef FLOAT4 realN;\n" +" typedef int4 intN;\n" +" #define COMPUTE_FLOATN PRECISION_COMPUTE4\n" +" #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE4(x)\n" +" #define CONVERT_FLOATN(x) CONVERT_FLOAT4(x)\n" "#elif VWN == 8\n" " typedef FLOAT8 realN;\n" +" typedef int8 intN;\n" +" #define COMPUTE_FLOATN PRECISION_COMPUTE8\n" +" #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE8(x)\n" +" #define CONVERT_FLOATN(x) CONVERT_FLOAT8(x)\n" "#elif VWN == 16\n" " typedef FLOAT16 realN;\n" +" typedef int16 intN;\n" +" #define COMPUTE_FLOATN PRECISION_COMPUTE16\n" +" #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE16(x)\n" +" #define CONVERT_FLOATN(x) CONVERT_FLOAT16(x)\n" "#endif\n" "// =================================================================================================\n" "// Initializes the accumulation registers to zero\n" -"INLINE_FUNC realM InitAccRegisters() {\n" -" realM result;\n" +"INLINE_FUNC COMPUTE_FLOATM InitAccRegisters() {\n" +" COMPUTE_FLOATM result;\n" " #if VWM == 1\n" " SetToZero(result);\n" " #elif VWM == 2\n" @@ -19330,8 +17159,8 @@ const char* matmul_params_buf = " #endif\n" " return result;\n" "}\n" -"INLINE_FUNC realN InitAccRegistersN() {\n" -" realN result;\n" +"INLINE_FUNC COMPUTE_FLOATN InitAccRegistersN() {\n" +" COMPUTE_FLOATN result;\n" " #if VWN == 1\n" " SetToZero(result);\n" " #elif VWN == 2\n" @@ -19538,10 +17367,10 @@ const char* matmul_params_buf = "}\n" "#endif\n" "// The vectorised multiply-add function\n" -"INLINE_FUNC realM MultiplyAddVector(realM cvec,const realM avec,const real bval) {\n" +"INLINE_FUNC COMPUTE_FLOATM MultiplyAddVector(COMPUTE_FLOATM cvec,COMPUTE_FLOATM avec,PRECISION_COMPUTE bval) {\n" " #if USE_VECTOR_MAD == 1\n" " #if USE_CL_MAD == 1\n" -" cvec=mad(avec,(realM)bval,cvec);\n" +" cvec=mad(avec,(COMPUTE_FLOATM)bval,cvec);\n" " #else\n" " cvec += avec*bval;\n" " #endif\n" @@ -19587,10 +17416,10 @@ const char* matmul_params_buf = " return cvec;\n" "}\n" "// The vectorised multiply-add function\n" -"INLINE_FUNC realN MultiplyAddVectorN(realN cvec,const real avec,const realN bval) {\n" +"INLINE_FUNC COMPUTE_FLOATN MultiplyAddVectorN(COMPUTE_FLOATN cvec,PRECISION_COMPUTE avec,COMPUTE_FLOATN bval) {\n" " #if USE_VECTOR_MAD == 1\n" " #if USE_CL_MAD == 1\n" -" cvec=mad((realN)avec,bval,cvec);\n" +" cvec=mad((COMPUTE_FLOATN)avec,bval,cvec);\n" " #else\n" " cvec += avec*bval;\n" " #endif\n" @@ -19660,8 +17489,8 @@ const char* matmul_params_buf = " return res;\n" "}\n" "// layout : [N,M]\n" -"INLINE_FUNC void StoreResultsM(__global realM* cgm,realM c_value,const INT2 baseOffset,const int _mi,const int _ni,\n" -" const int kSizeM,const real alpha,const real beta) {\n" +"INLINE_FUNC void StoreResultsM(__global realM* cgm,COMPUTE_FLOATM c_value,const INT2 baseOffset,const int _mi,const int _ni,\n" +" const int kSizeM,const PRECISION_COMPUTE alpha,const PRECISION_COMPUTE beta) {\n" " #if STRM == 0\n" " int idm=_mi+baseOffset.index[0];\n" " #elif STRM == 1\n" @@ -19674,10 +17503,10 @@ const char* matmul_params_buf = " #endif\n" " \n" " int index=idn*(kSizeM/VWM)+idm;\n" -" realM result=c_value;\n" +" COMPUTE_FLOATM result=c_value;\n" " // The final multiplication with alpha (in case beta == 0)\n" " #ifdef ONLY_HAVE_ALPHA\n" -" realM xval=c_value;\n" +" COMPUTE_FLOATM xval=c_value;\n" " #if VWM == 1\n" " Multiply(result,alpha,xval);\n" " #elif VWM == 2\n" @@ -19718,8 +17547,8 @@ const char* matmul_params_buf = " #endif\n" " // The final multiplication with alpha and the addition with beta*C\n" " #ifdef HAVE_ALPHA_BETA\n" -" realM xval=c_value;\n" -" realM yval=cgm[index];\n" +" COMPUTE_FLOATM xval=c_value;\n" +" COMPUTE_FLOATM yval=CONVERT_COMPUTE_FLOATM(cgm[index]);\n" " #if VWM == 1\n" " AXPBY(result,alpha,xval,beta,yval);\n" " #elif VWM == 2\n" @@ -19758,7 +17587,7 @@ const char* matmul_params_buf = " AXPBY(result.sF,alpha,xval.sF,beta,yval.sF);\n" " #endif\n" " #endif\n" -" cgm[index]=result;\n" +" cgm[index]=CONVERT_FLOATM(result);\n" "}\n" "INLINE_FUNC INT2 StoreIndexN() {\n" " INT2 res;\n" @@ -19780,7 +17609,7 @@ const char* matmul_params_buf = " return res;\n" "}\n" "// layout : [M,N]\n" -"INLINE_FUNC void StoreResultsN(__global realN* cgn,realN c_value,\n" +"INLINE_FUNC void StoreResultsN(__global realN* cgn,COMPUTE_FLOATN c_value,\n" " const INT2 baseOffset,\n" " #if BIAS_TYPE>0\n" " #if BIAS_TYPE>1\n" @@ -19790,7 +17619,7 @@ const char* matmul_params_buf = " #endif\n" " #endif\n" " const int _mi,const int _ni,\n" -" const int cstride/*kSizeN*/,const int dstride/*kSizeN*/,const real alpha,const real beta) {\n" +" const int cstride/*kSizeN*/,const int dstride/*kSizeN*/,const PRECISION_COMPUTE alpha,const PRECISION_COMPUTE beta) {\n" " #if STRM == 0\n" " int idm=_mi+baseOffset.index[0];\n" " #elif STRM == 1\n" @@ -19803,11 +17632,11 @@ const char* matmul_params_buf = " #endif\n" " int index=idm*(cstride/VWN)+idn;\n" " \n" -" realN result=c_value;\n" +" COMPUTE_FLOATN result=c_value;\n" " \n" " // The final multiplication with alpha (in case beta == 0)\n" " #ifdef ONLY_HAVE_ALPHA\n" -" realN xval=c_value;\n" +" COMPUTE_FLOATN xval=c_value;\n" " #if VWN == 1\n" " Multiply(result,alpha,xval);\n" " #elif VWN == 2\n" @@ -19848,8 +17677,8 @@ const char* matmul_params_buf = " #endif\n" " // The final multiplication with alpha and the addition with beta*C\n" " #ifdef HAVE_ALPHA_BETA\n" -" realN xval=c_value;\n" -" realN yval=cgn[index];\n" +" COMPUTE_FLOATN xval=c_value;\n" +" COMPUTE_FLOATN yval=CONVERT_COMPUTE_FLOATN(cgn[index]);\n" " #if VWN == 1\n" " AXPBY(result,alpha,xval,beta,yval);\n" " #elif VWN == 2\n" @@ -19892,29 +17721,31 @@ const char* matmul_params_buf = " \n" "#if BIAS_TYPE>0\n" " #if BIAS_TYPE == 1\n" -" realN eval=epm[_ni];\n" +" COMPUTE_FLOATN eval=CONVERT_COMPUTE_FLOATN(epm[_ni]);\n" +" #elif BIAS_TYPE == 5\n" +" int index_bias=idm*(dstride/VWN)+idn;\n" +" intN eval=((__global intN*)egm)[index_bias];\n" " #else\n" -" \n" " int index_bias=idm*(dstride/VWN)+idn;\n" -" realN eval=egm[index_bias];\n" +" COMPUTE_FLOATN eval=CONVERT_COMPUTE_FLOATN(egm[index_bias]);\n" " #endif\n" " \n" " #if VWN == 1\n" " DEAL_BIAS(result,eval);\n" " #ifdef RELU\n" -" result=fmax(result,(FLOAT)0);\n" +" result=fmax(result,(COMPUTE_FLOATN)0);\n" " #endif\n" " #ifdef RELU6\n" -" result=clamp(result,(FLOAT)0,(FLOAT)6);\n" +" result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n" " #endif\n" " #elif VWN == 2\n" " DEAL_BIAS(result.x,eval.x);\n" " DEAL_BIAS(result.y,eval.y);\n" " #ifdef RELU\n" -" result=fmax(result,(FLOAT2)0);\n" +" result=fmax(result,(COMPUTE_FLOATN)0);\n" " #endif\n" " #ifdef RELU6\n" -" result=clamp(result,(FLOAT2)0,(FLOAT2)6);\n" +" result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n" " #endif\n" " #elif VWN == 4\n" " DEAL_BIAS(result.x,eval.x);\n" @@ -19922,10 +17753,10 @@ const char* matmul_params_buf = " DEAL_BIAS(result.z,eval.z);\n" " DEAL_BIAS(result.w,eval.w);\n" " #ifdef RELU\n" -" result=fmax(result,(FLOAT4)0);\n" +" result=fmax(result,(COMPUTE_FLOATN)0);\n" " #endif\n" " #ifdef RELU6\n" -" result=clamp(result,(FLOAT4)0,(FLOAT4)6);\n" +" result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n" " #endif\n" " #elif VWN == 8\n" " DEAL_BIAS(result.s0,eval.s0);\n" @@ -19937,10 +17768,10 @@ const char* matmul_params_buf = " DEAL_BIAS(result.s6,eval.s6);\n" " DEAL_BIAS(result.s7,eval.s7);\n" " #ifdef RELU\n" -" result=fmax(result,(FLOAT8)0);\n" +" result=fmax(result,(COMPUTE_FLOATN)0);\n" " #endif\n" " #ifdef RELU6\n" -" result=clamp(result,(FLOAT8)0,(FLOAT8)6);\n" +" result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n" " #endif\n" " #elif VWN == 16\n" " DEAL_BIAS(result.s0,eval.s0);\n" @@ -19960,14 +17791,14 @@ const char* matmul_params_buf = " DEAL_BIAS(result.sE,eval.sE);\n" " DEAL_BIAS(result.sF,eval.sF);\n" " #ifdef RELU\n" -" result=fmax(result,(FLOAT16)0);\n" +" result=fmax(result,(COMPUTE_FLOATN)0);\n" " #endif\n" " #ifdef RELU6\n" -" result=clamp(result,(FLOAT16)0,(FLOAT16)6);\n" +" result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n" " #endif\n" " #endif\n" "#endif\n" -" cgn[index]=result;\n" +" cgn[index]=CONVERT_FLOATN(result);\n" "}\n" "// Main body of the matrix-multiplication algorithm. It calls various (inlined) functions.\n" "INLINE_FUNC void XgemmBody(const int kSizeM,const int kSizeN,const int kSizeK,const int4 stride,\n" @@ -19975,7 +17806,7 @@ const char* matmul_params_buf = " #if BIAS_TYPE>0\n" " __global realN* restrict egm,\n" " #endif\n" -" __global realM* cgm,const real alpha,const real beta\n" +" __global realM* cgm,const real_arg alpha,const real_arg beta\n" " #if SA == 1 && SB == 1\n" " ,LOCAL_PTR realM* alm,LOCAL_PTR realN* blm\n" " #elif SA == 1\n" @@ -19986,10 +17817,10 @@ const char* matmul_params_buf = " ) {\n" " #ifdef OUTPUTMN\n" " #pragma promote_to_registers\n" -" realN cpn[MWI*(NWI/VWN)]; // MWI*NWI\n" +" COMPUTE_FLOATN cpn[MWI*(NWI/VWN)]; // MWI*NWI\n" " #else\n" " #pragma promote_to_registers\n" -" realM cpm[NWI*(MWI/VWM)]; // NWI*MWI\n" +" COMPUTE_FLOATM cpm[NWI*(MWI/VWM)]; // NWI*MWI\n" " #endif\n" " // Combined thread identifier (volatile to disable caching)\n" " #if SA == 1 || SB == 1\n" @@ -20017,9 +17848,9 @@ const char* matmul_params_buf = " #if SA == 1 || SB == 1\n" " // Allocates workitem-private memory (registers)\n" " #pragma promote_to_registers\n" -" realM apm[MWI/VWM]; // MWI*1\n" +" COMPUTE_FLOATM apm[MWI/VWM]; // MWI*1\n" " #pragma promote_to_registers\n" -" realN bpm[NWI/VWN]; // 1*NWI\n" +" COMPUTE_FLOATN bpm[NWI/VWN]; // 1*NWI\n" " \n" " for (int kwg=0; kwg local (matrix A)\n" @@ -20044,10 +17875,10 @@ const char* matmul_params_buf = " for (int _mi=0; _mi private (matrix A)\n" " #if SA == 1\n" -" apm[_mi]=LocalToPrivateA(alm,_mi,kg);\n" +" apm[_mi]=CONVERT_COMPUTE_FLOATM(LocalToPrivateA(alm,_mi,kg));\n" " // Loads data: off-chip --> private (matrix A)\n" " #elif SA == 0\n" -" apm[_mi]=GlobalToPrivateA(agm,_mi,kSizeM,idk);\n" +" apm[_mi]=CONVERT_COMPUTE_FLOATM(GlobalToPrivateA(agm,_mi,kSizeM,idk));\n" " #endif\n" " }\n" " // Loads matrix B (kernel 0) or matrix A (kernel 1)\n" @@ -20055,10 +17886,10 @@ const char* matmul_params_buf = " for (int _ni=0; _ni private (matrix B)\n" " #if SB == 1\n" -" bpm[_ni]=LocalToPrivateB(blm,_ni,kg);\n" +" bpm[_ni]=CONVERT_COMPUTE_FLOATN(LocalToPrivateB(blm,_ni,kg));\n" " // Loads data: off-chip --> private (matrix B)\n" " #else\n" -" bpm[_ni]=GlobalToPrivateB(bgm,_ni,kSizeN,idk);\n" +" bpm[_ni]=CONVERT_COMPUTE_FLOATN(GlobalToPrivateB(bgm,_ni,kSizeN,idk));\n" " #endif\n" " }\n" " // Performs the accumulation (Cpm += Apm*Bpm)\n" @@ -20067,7 +17898,7 @@ const char* matmul_params_buf = " for (int _mi=0; _mi private (matrix B)\n" -" bpm[_ni]=GlobalToPrivateOptB(bgm,baseIndexB,_ni,stride.s1/*kSizeN*/,idk);\n" +" bpm[_ni]=CONVERT_COMPUTE_FLOATN(GlobalToPrivateOptB(bgm,baseIndexB,_ni,stride.s1/*kSizeN*/,idk));\n" " }\n" " #pragma unroll\n" " for (int _mi=0; _mi private (matrix B)\n" -" apm[_mi]=GlobalToPrivateOptA(agm,baseIndexA,_mi,stride.s0/*kSizeM*/,idk);\n" +" apm[_mi]=CONVERT_COMPUTE_FLOATM(GlobalToPrivateOptA(agm,baseIndexA,_mi,stride.s0/*kSizeM*/,idk));\n" " }\n" " #pragma unroll\n" " for (int _ni=0; _ni0\n" " egm,\n" " #endif\n" -" cgm,alpha,beta,alm);\n" +" cgm,arg_alpha,arg_beta,alm);\n" " #elif SB == 1\n" " XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm,bgm,\n" " #if BIAS_TYPE>0\n" " egm,\n" " #endif\n" -" cgm,alpha,beta,blm);\n" +" cgm,arg_alpha,arg_beta,blm);\n" " #else\n" " XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm,bgm,\n" " #if BIAS_TYPE>0\n" " egm,\n" " #endif\n" -" cgm,alpha,beta);\n" +" cgm,arg_alpha,arg_beta);\n" " #endif\n" "}\n" "#if RELAX_WORKGROUP_SIZE == 1\n" @@ -20408,29 +18237,32 @@ const char* matmul_params_buf = " const real_arg arg_alpha,\n" " const real_arg arg_beta,\n" " const __global realM* restrict agm,\n" -" const int batch_offset_a,\n" " const __global realN* restrict bgm,\n" -" const int batch_offset_b,\n" " #if BIAS_TYPE>0\n" " __global realN* restrict egm,\n" -" const int batch_offset_e,\n" " #endif\n" " __global realM* cgm,\n" -" const int batch_offset_c) {\n" +" const int4 batch_offset,// [batch_offset_a,batch_offset_b,batch_offset_c,batch_offset_e]\n" +" const int4 stride,// [stride_a,stride_b,stride_c,stride_e]\n" +" /*\n" +" total_batch -> [loop_y,loop_x]\n" +" with group batch -> [loop_y,loop_x/group_num]\n" +" group_size == loop_x/group_num\n" +" */\n" +" const int4 group // [group_num_a,group_num_b,group_num_e,loop_x]\n" +") {\n" " const int batch=get_group_id(2);\n" -" const real alpha=GetRealArg(arg_alpha);\n" -" const real beta=GetRealArg(arg_beta);\n" " \n" " // Sets the offsets\n" -" const int a_offset=batch*batch_offset_a;\n" -" const int b_offset=batch*batch_offset_b;\n" -" const int c_offset=batch*batch_offset_c;\n" +" const int a_offset=((batch/group.w)*group.x+(batch % group.w)/group.x)*batch_offset.x;\n" +" const int b_offset=((batch/group.w)*group.y+(batch % group.w)/group.y)*batch_offset.y;\n" +" const int c_offset=batch*batch_offset.z;\n" " const __global realM* restrict agm_=&agm[a_offset/VWM];\n" " const __global realN* restrict bgm_=&bgm[b_offset/VWN];\n" " __global realM* restrict cgm_=&cgm[c_offset/VWM];\n" " \n" " #if BIAS_TYPE>0\n" -" const int e_offset=batch*batch_offset_e;\n" +" const int e_offset=((batch/group.w)*group.z+(batch % group.w)/group.z)*batch_offset.w;\n" " __global realN* restrict egm_=&egm[e_offset/VWN];\n" " #endif\n" " \n" @@ -20441,40 +18273,31 @@ const char* matmul_params_buf = " #if SB == 1\n" " __local realN blm[KWG*NWG/VWN];\n" " #endif\n" -" int4 stride;\n" -" stride.s0=kSizeM;\n" -" stride.s1=kSizeN;\n" -" #ifdef OUTPUTMN\n" -" stride.s2=kSizeN;\n" -" #else\n" -" stride.s2=kSizeM;\n" -" #endif\n" -" stride.s3=kSizeN;\n" " // Computes the matrix-multiplication and stores the result in global memory\n" " #if SA == 1 && SB == 1\n" " XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n" " #if BIAS_TYPE>0\n" " egm_,\n" " #endif\n" -" cgm_,alpha,beta,alm,blm);\n" +" cgm_,arg_alpha,arg_beta,alm,blm);\n" " #elif SA == 1\n" " XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n" " #if BIAS_TYPE>0\n" " egm_,\n" " #endif\n" -" cgm_,alpha,beta,alm);\n" +" cgm_,arg_alpha,arg_beta,alm);\n" " #elif SB == 1\n" " XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n" " #if BIAS_TYPE>0\n" " egm_,\n" " #endif\n" -" cgm_,alpha,beta,blm);\n" +" cgm_,arg_alpha,arg_beta,blm);\n" " #else\n" " XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n" " #if BIAS_TYPE>0\n" " egm_,\n" " #endif\n" -" cgm_,alpha,beta);\n" +" cgm_,arg_alpha,arg_beta);\n" " #endif\n" "}\n" ; @@ -20495,228 +18318,83 @@ const char* cast = " ) {\n" " const int width_idx=get_global_id(0);\n" " const int height_idx=get_global_id(1);\n" -" const int batch_channel_idx=get_global_id(2);\n" -" DEAL_NON_UNIFORM_DIM3(width_idx,height_idx,batch_channel_idx);\n" -" \n" -" const int batch_idx=batch_channel_idx/channelBlock;\n" -" const int channel_idx=batch_channel_idx % channelBlock;\n" -" \n" -"#ifdef TO_BOOL\n" -" int4 value=convert_int4(RI_DATA(input,SAMPLER,(int2)(channel_idx*width+width_idx,batch_idx*height+height_idx)));\n" -" value=value == (int4)0 ? (int4)0 : (int4)1;\n" -" WI_DATA(output,(int2)(channel_idx*width+width_idx,batch_idx*height+height_idx),CONVERT_OUTPUT_I4(value));\n" -"#else\n" -" INPUT_TYPE_I4 value=RI_DATA(input,SAMPLER,(int2)(channel_idx*width+width_idx,batch_idx*height+height_idx));\n" -" WI_DATA(output,(int2)(channel_idx*width+width_idx,batch_idx*height+height_idx),CONVERT_OUTPUT_I4(value));\n" -"#endif\n" -"}\n" -; -#ifndef MNN_OPENCL_BUFFER_CLOSED -const char* buffer_convert_buf = -"#ifdef MNN_SUPPORT_FP16\n" -"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" -"#endif\n" -"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n" -"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n" -"// convert data from buffer(nhwc) to buffer(nc4hw4)\n" -"__kernel void nhwc_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS\n" -" __global const INPUT_TYPE *input_ptr,\n" -" __private const int height,\n" -" __private const int width,__private const int channels,\n" -" __global OUTPUT_TYPE *output) {\n" -" int image_width_idx=get_global_id(0);\n" -" int image_height_idx=get_global_id(1);\n" -" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n" -" const int batch_idx=image_height_idx/height;\n" -" const int height_idx=image_height_idx % height;\n" -" const int width_idx=image_width_idx % width;\n" -" const int channel_4_idx=(image_width_idx/width) << 2;\n" -" const int buffer_offset=((batch_idx*height+height_idx)*width+width_idx)*channels+channel_4_idx;\n" -" const int remain_channel=channels-channel_4_idx;\n" -" float4 values=convert_float4(vload4(0,input_ptr+buffer_offset));\n" -" if (remain_channel == 3) {\n" -" values.w=0;\n" -" } else if (remain_channel == 2) {\n" -" values.z=0;\n" -" values.w=0;\n" -" } else if (remain_channel == 1) {\n" -" values.y=0;\n" -" values.z=0;\n" -" values.w=0;\n" -" }\n" -" const int out_offset=(((batch_idx*((channels+3)/4)+channel_4_idx/4)*height+height_idx)*width+width_idx)*4;\n" -" vstore4(CONVERT_OUTPUT4(values),0,output+out_offset);\n" -"}\n" -"// convert data from buffer(nchw) to buffer(nc4hw4)\n" -"__kernel void nchw_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS\n" -" __global const INPUT_TYPE *input_ptr,\n" -" __private const int height,__private const int width,__private const int channels,\n" -" __global OUTPUT_TYPE *output) {\n" -" int image_width_idx=get_global_id(0);\n" -" int image_height_idx=get_global_id(1);\n" -" \n" -" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n" -" const int batch_idx=image_height_idx/height;\n" -" const int height_idx=image_height_idx % height;\n" -" const int width_idx=image_width_idx % width;\n" -" const int channel_4_idx=image_width_idx/width << 2;\n" -" const int buffer_offset=((batch_idx*channels+channel_4_idx)*height+height_idx)*width+width_idx;\n" -" const int remain_channel=channels-channel_4_idx;\n" -" const int height_width_size=height*width;\n" -" float4 output_values=0;\n" -" if (remain_channel >= 4) {\n" -" int offset=buffer_offset;\n" -" output_values.x=(float)*(input_ptr+offset);\n" -" offset += height_width_size;\n" -" output_values.y=(float)*(input_ptr+offset);\n" -" offset += height_width_size;\n" -" output_values.z=(float)*(input_ptr+offset);\n" -" offset += height_width_size;\n" -" output_values.w=(float)*(input_ptr+offset);\n" -" } else if (remain_channel == 3) {\n" -" int offset=buffer_offset;\n" -" output_values.x=(float)*(input_ptr+offset);\n" -" offset += height_width_size;\n" -" output_values.y=(float)*(input_ptr+offset);\n" -" offset += height_width_size;\n" -" output_values.z=(float)*(input_ptr+offset);\n" -" } else if (remain_channel == 2) {\n" -" int offset=buffer_offset;\n" -" output_values.x=(float)*(input_ptr+offset);\n" -" offset += height_width_size;\n" -" output_values.y=(float)*(input_ptr+offset);\n" -" } else if (remain_channel == 1) {\n" -" int offset=buffer_offset;\n" -" output_values.x=(float)*(input_ptr+offset);\n" -" }\n" -" const int out_offset=(((batch_idx*((channels+3)/4)+channel_4_idx/4)*height+height_idx)*width+width_idx)*4;\n" -" vstore4(CONVERT_OUTPUT4(output_values),0,output+out_offset);\n" -"}\n" -"__kernel void nchw_buffer_to_nchw_buffer(GLOBAL_SIZE_2_DIMS\n" -" __global INPUT_TYPE *input_ptr,\n" -" __private const int height,__private const int width,__private const int channels,\n" -" __private const int input_pad_left,__private const int input_pad_right,\n" -" __private const int output_pad_left,__private const int output_pad_right,\n" -" __global OUTPUT_TYPE *output) {\n" -" int image_width_idx=get_global_id(0);\n" -" int image_height_idx=get_global_id(1);\n" -" \n" -" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n" -" const int src_width=width+input_pad_left+input_pad_right;\n" -" const int dst_width=width+output_pad_left+output_pad_right;\n" -" const int batch_idx=image_height_idx/height;\n" -" const int height_idx=image_height_idx % height;\n" -" const int width_idx=image_width_idx % width;\n" -" const int channel_idx=image_width_idx/width;\n" -" const int in_offset=((batch_idx*channels+channel_idx)*height+height_idx)*src_width+width_idx+input_pad_left;\n" -" const int out_offset=((batch_idx*channels+channel_idx)*height+height_idx)*dst_width+width_idx+output_pad_left;\n" -" output[out_offset]=(OUTPUT_TYPE)input_ptr[in_offset];\n" -"}\n" -"// convert data from image(b h,ic/4 w ic4) to buffer(nhwc)\n" -"__kernel void nc4hw4_buffer_to_nhwc_buffer(GLOBAL_SIZE_2_DIMS\n" -" __global OUTPUT_TYPE *output,\n" -" __private const int height,__private const int width,\n" -" __private const int channels,\n" -" __global INPUT_TYPE *input_ptr) {\n" -" int image_width_idx=get_global_id(0);\n" -" int image_height_idx=get_global_id(1);\n" -" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n" -" const int batch_idx=image_height_idx/height;\n" -" const int height_idx=image_height_idx % height;\n" -" const int width_idx=image_width_idx % width;\n" -" const int channel_4_idx=(image_width_idx/width) << 2;\n" -" const int buffer_offset=((batch_idx*height+height_idx)*width+width_idx)*channels+channel_4_idx;\n" -" const int in_offset=(((batch_idx*((channels+3)/4)+channel_4_idx/4)*height+height_idx)*width+width_idx)*4;\n" -" \n" -" float4 values=convert_float4(vload4(0,input_ptr+in_offset));\n" -" const int remain_channel=channels-channel_4_idx;\n" -" if (remain_channel >= 4) {\n" -" vstore4(CONVERT_OUTPUT4(values),0,output+buffer_offset);\n" -" } else if (remain_channel == 3) {\n" -" int offset=buffer_offset;\n" -" output[offset]=(OUTPUT_TYPE)values.x;\n" -" offset++;\n" -" output[offset]=(OUTPUT_TYPE)values.y;\n" -" offset++;\n" -" output[offset]=(OUTPUT_TYPE)values.z;\n" -" } else if (remain_channel == 2) {\n" -" int offset=buffer_offset;\n" -" output[offset]=(OUTPUT_TYPE)values.x;\n" -" offset++;\n" -" output[offset]=(OUTPUT_TYPE)values.y;\n" -" } else if (remain_channel == 1) {\n" -" int offset=buffer_offset;\n" -" output[offset]=(OUTPUT_TYPE)values.x;\n" -" }\n" -"}\n" -"// convert data from buffer(nc4hw4) to buffer(nchw)\n" -"__kernel void nc4hw4_buffer_to_nchw_buffer(GLOBAL_SIZE_2_DIMS\n" -" __global OUTPUT_TYPE *output,\n" -" __private const int height,__private const int width,\n" -" __private const int channels,\n" -" __global INPUT_TYPE *input_ptr) {\n" -" int image_width_idx=get_global_id(0);\n" -" int image_height_idx=get_global_id(1);\n" -" \n" -" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n" -" const int batch_idx=image_height_idx/height;\n" -" const int height_idx=image_height_idx % height;\n" -" const int width_idx=image_width_idx % width;\n" -" int channel_4_idx=(image_width_idx/width)*4;\n" -" int buffer_offset=((batch_idx*channels+channel_4_idx)*height+height_idx)*width+width_idx;\n" +" const int batch_channel_idx=get_global_id(2);\n" +" DEAL_NON_UNIFORM_DIM3(width_idx,height_idx,batch_channel_idx);\n" " \n" -" const int in_offset=(((batch_idx*((channels+3)/4)+channel_4_idx/4)*height+height_idx)*width+width_idx)*4;\n" -" float4 values=convert_float4(vload4(0,input_ptr+in_offset));\n" -" const int height_width_size=height*width;\n" -" const int remain_channel=channels-channel_4_idx;\n" -" if (remain_channel >= 4) {\n" -" int offset=buffer_offset;\n" -" output[offset]=(OUTPUT_TYPE)values.x;\n" -" offset += height_width_size;\n" -" output[offset]=(OUTPUT_TYPE)values.y;\n" -" offset += height_width_size;\n" -" output[offset]=(OUTPUT_TYPE)values.z;\n" -" offset += height_width_size;\n" -" output[offset]=(OUTPUT_TYPE)values.w;\n" -" } else if (remain_channel == 3) {\n" -" int offset=buffer_offset;\n" -" output[offset]=(OUTPUT_TYPE)values.x;\n" -" offset += height_width_size;\n" -" output[offset]=(OUTPUT_TYPE)values.y;\n" -" offset += height_width_size;\n" -" output[offset]=(OUTPUT_TYPE)values.z;\n" -" } else if (remain_channel == 2) {\n" -" int offset=buffer_offset;\n" -" output[offset]=(OUTPUT_TYPE)values.x;\n" -" offset += height_width_size;\n" -" output[offset]=(OUTPUT_TYPE)values.y;\n" -" } else if (remain_channel == 1) {\n" -" int offset=buffer_offset;\n" -" output[offset]=(OUTPUT_TYPE)values.x;\n" -" }\n" +" const int batch_idx=batch_channel_idx/channelBlock;\n" +" const int channel_idx=batch_channel_idx % channelBlock;\n" +" \n" +"#ifdef TO_BOOL\n" +" int4 value=convert_int4(RI_DATA(input,SAMPLER,(int2)(channel_idx*width+width_idx,batch_idx*height+height_idx)));\n" +" value=value == (int4)0 ? (int4)0 : (int4)1;\n" +" WI_DATA(output,(int2)(channel_idx*width+width_idx,batch_idx*height+height_idx),CONVERT_OUTPUT_I4(value));\n" +"#else\n" +" INPUT_TYPE_I4 value=RI_DATA(input,SAMPLER,(int2)(channel_idx*width+width_idx,batch_idx*height+height_idx));\n" +" WI_DATA(output,(int2)(channel_idx*width+width_idx,batch_idx*height+height_idx),CONVERT_OUTPUT_I4(value));\n" +"#endif\n" "}\n" -"__kernel void nc4hw4_buffer_to_nc4hw4_buffer(GLOBAL_SIZE_2_DIMS\n" +; +#ifndef MNN_OPENCL_BUFFER_CLOSED +const char* buffer_convert_buf = +"#ifdef MNN_SUPPORT_FP16\n" +"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" +"#endif\n" +"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n" +"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n" +"#define GLOBAL_SIZE_3_DIMS __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n" +"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" +"#define MNN_DATA_FORMAT_NCHW 0\n" +"#define MNN_DATA_FORMAT_NHWC 1\n" +"#define MNN_DATA_FORMAT_NC4HW4 2\n" +"#define MNN_DATA_FORMAT_C4NHW4 3\n" +"__kernel void buffer_convert_to_buffer(GLOBAL_SIZE_3_DIMS\n" " __global const INPUT_TYPE *input_ptr,\n" -" __private const int2 output_shape,\n" -" __private const int2 src_stride,\n" -" __private const int2 dst_stride,\n" -" __global OUTPUT_TYPE *output\n" +" __private const int4 shape,// N C H W\n" +" __global OUTPUT_TYPE *output_ptr\n" ") {\n" -" int image_width_idx=get_global_id(0);\n" -" int image_height_idx=get_global_id(1);\n" -" DEAL_NON_UNIFORM_DIM2(image_width_idx,image_height_idx);\n" -" const int batch_idx=image_height_idx/output_shape.x;\n" -" const int height_idx=image_height_idx % output_shape.x;\n" -" const int width_idx=image_width_idx % output_shape.y;\n" -" const int channel_block_idx=image_width_idx/output_shape.y;\n" -" int2 src_bc_offset=src_stride*(int2)(batch_idx,channel_block_idx);\n" -" int2 dst_bc_offset=dst_stride*(int2)(batch_idx,channel_block_idx);\n" -" int src_buffer_offset =\n" -" (((src_bc_offset.x+src_bc_offset.y)*output_shape.x+height_idx)*output_shape.y+width_idx)*4;\n" -" int dst_buffer_offset =\n" -" (((dst_bc_offset.x+dst_bc_offset.y)*output_shape.x+height_idx)*output_shape.y+width_idx)*4;\n" -" \n" -" vstore4(CONVERT_OUTPUT4(vload4(0,input_ptr+src_buffer_offset)),0,output+dst_buffer_offset);\n" +" int wh=get_global_id(0);\n" +" int c=get_global_id(1);\n" +" int n=get_global_id(2);\n" +" DEAL_NON_UNIFORM_DIM3(wh,c,n);\n" +" int w=wh % shape.w;\n" +" int h=wh/shape.w;\n" +" \n" +"#if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n" +" int input_offset=((n*shape.y+c)*shape.z+h)*shape.w+w;\n" +"#elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n" +" int input_offset=((n*shape.z+h)*shape.w+w)*shape.y+c;\n" +"#elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n" +" int input_offset=((((c/4)*shape.x+n)*shape.z+h)*shape.w+w)*4+(c % 4);\n" +"#endif\n" +"#if OUTPUT_FORMAT == MNN_DATA_FORMAT_NCHW\n" +" int output_offset=((n*shape.y+c)*shape.z+h)*shape.w+w;\n" +"#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NHWC\n" +" int output_offset=((n*shape.z+h)*shape.w+w)*shape.y+c;\n" +"#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4\n" +" int output_offset=((((c/4)*shape.x+n)*shape.z+h)*shape.w+w)*4+(c % 4);\n" +"#endif\n" +" output_ptr[output_offset]=input_ptr[input_offset];\n" +"}\n" +"__kernel void buffer_copy_to_buffer(GLOBAL_SIZE_2_DIMS\n" +" __global const INPUT_TYPE *input_ptr,\n" +" __global OUTPUT_TYPE *output_ptr,\n" +" __private const int size // N C H W\n" +") {\n" +" const int x=get_global_id(0);\n" +" const int y=get_global_id(1);\n" +" DEAL_NON_UNIFORM_DIM2(x,y);\n" +" const int offset=x << 2;\n" +"#ifdef PACK_LEAVE\n" +" if(offset+3 >= size){\n" +" for(int i=0; i [N H C 1]\n" +"// [C4 N H 1 4] -> [N H C 1]\n" "__kernel void tile_trans_3d_buf(__global INPUT_TYPE* input,\n" " __global OUTPUT_TYPE* output,\n" " __private const int widthPad,\n" @@ -21281,7 +18959,6 @@ const char* loop_buf = " // group id\n" " const int c=get_group_id(0)*WGSC;\n" " const int h=get_group_id(1)*WGSH;\n" -" const int channel_4=(channel+3) >> 2;\n" " int jc=lidc;\n" " int ih=lidh;\n" " \n" @@ -21294,7 +18971,7 @@ const char* loop_buf = " int offset_h=i*WGSH/TSH+ih;\n" " int offset_c=j*WGSC/TSC+jc ;\n" " // [TSH,WGSH/TSH] [TSC/4,WGSC/TSC,4]\n" -" localData[offset_h][offset_c]=(h+offset_h >= height || c+4*offset_c >= channel) ? (INPUT_TYPE4)0 : vload4(0,input+((b*channel_4+(c/4+offset_c))*height+(h+offset_h))*4);\n" +" localData[offset_h][offset_c]=(h+offset_h >= height || c+4*offset_c >= channel) ? (INPUT_TYPE4)0 : vload4(0,input+((b+(c/4+offset_c)*batch)*height+(h+offset_h))*4);\n" " }\n" " }\n" " \n" @@ -21316,7 +18993,7 @@ const char* loop_buf = " }\n" " }\n" "}\n" -"// [N C4 H W 4] -> [N C W H]\n" +"// [C4 N H W 4] -> [N C W H]\n" "__kernel void tile_trans_4d_buf(__global INPUT_TYPE* input,\n" " __global OUTPUT_TYPE* output,\n" " __private const int widthPad,\n" @@ -21337,7 +19014,6 @@ const char* loop_buf = " // group id\n" " const int w=get_group_id(0)*WGSW;\n" " const int h=get_group_id(1)*WGSH;\n" -" const int channel_4=(channel+3) >> 2;\n" " int jw=lidw;\n" " int ih=lidh;\n" " \n" @@ -21349,7 +19025,7 @@ const char* loop_buf = " for(int j=0; j= height || offset_w >= width) ? (INPUT_TYPE4)0 : vload4(0,input+(((b*channel_4+c4)*height+offset_h)*width+offset_w)*4);\n" +" localData[ih+i*WGSH/TSH][jw+j*WGSW/TSW]=(offset_h >= height || offset_w >= width) ? (INPUT_TYPE4)0 : vload4(0,input+(((b+c4*batch)*height+offset_h)*width+offset_w)*4);\n" " }\n" " }\n" " \n" @@ -21469,8 +19145,8 @@ const char* loop_buf = " const int c=c_4 << 2;\n" " const int x_src_pitch=4;\n" " const int y_src_pitch=x_src_pitch*width;\n" -" const int c_src_pitch=y_src_pitch*height;\n" -" const int b_src_pitch=c_src_pitch*((channel+3)/4);\n" +" const int b_src_pitch=y_src_pitch*height;\n" +" const int c_src_pitch=b_src_pitch*batch;\n" " \n" " bool outBound=(w >= width || h >= height || c >= channel);\n" "#ifdef MNN_NHWC\n" @@ -21621,154 +19297,32 @@ const char* loop_buf = " }\n" "}\n" "#ifdef LOOP_BINARY_OPERATOR\n" -"__kernel void broadcast_binary_buf(__private int global_dim0,__private int global_dim1,__private int global_dim2,\n" -" __global OUTPUT_TYPE* output,__global INPUT_TYPE* input0,__global INPUT_TYPE* input1,\n" -" __private const int8 src0_size,//(batch,channel,height,width)\n" -" __private const int4 src0C4_size,// nc4hw4\n" -" __private const int8 src1_size,\n" -" __private const int4 src1C4_size,\n" -" __private const int8 dst_size,\n" -" __private const int dst_width,\n" -" __private const int dst_height,\n" -" __private const int dst_channel,\n" -" __private const int channel_block) {\n" -" int3 pos=(int3)(get_global_id(0),get_global_id(1),get_global_id(2));\n" -" \n" -" if (pos.x0; i /= 2){\n" +" if (lid0; i /= 2){\n" +" if (lid0; i /= 2){\n" " if (lid0; i /= 2){\n" " if (lid= global_size_dim0 || index1 >= global_size_dim1) { "" return; "" }\n" +"#define GLOBAL_SIZE_DIM3 "" __private int global_size_dim0,__private int global_size_dim1,__private int global_size_dim2,\n" +"#define UNIFORM_BOUNDRY_CHECK3(index0, index1, index2) "" if(index0 >= global_size_dim0 || index1 >= global_size_dim1 || index2 >= global_size_dim2) { "" return; "" }\n" +"#define UCHAR16_TO_2CHAR16(a, b, c) "" a.s0 = (c.s0 >> 4) - 8; a.s1 = (c.s0 & 15) - 8; a.s2 = (c.s1 >> 4) - 8; a.s3 = (c.s1 & 15) - 8; a.s4 = (c.s2 >> 4) - 8; a.s5 = (c.s2 & 15) - 8; a.s6 = (c.s3 >> 4) - 8; a.s7 = (c.s3 & 15) - 8; "" a.s8 = (c.s4 >> 4) - 8; a.s9 = (c.s4 & 15) - 8; a.sa = (c.s5 >> 4) - 8; a.sb = (c.s5 & 15) - 8; a.sc = (c.s6 >> 4) - 8; a.sd = (c.s6 & 15) - 8; a.se = (c.s7 >> 4) - 8; a.sf = (c.s7 & 15) - 8; "" b.s0 = (c.s8 >> 4) - 8; b.s1 = (c.s8 & 15) - 8; b.s2 = (c.s9 >> 4) - 8; b.s3 = (c.s9 & 15) - 8; b.s4 = (c.sa >> 4) - 8; b.s5 = (c.sa & 15) - 8; b.s6 = (c.sb >> 4) - 8; b.s7 = (c.sb & 15) - 8; "" b.s8=(c.sc >> 4)-8; b.s9=(c.sc & 15)-8; b.sa=(c.sd >> 4)-8; b.sb=(c.sd & 15)-8; b.sc=(c.se >> 4)-8; b.sd=(c.se & 15)-8; b.se=(c.sf >> 4)-8; b.sf=(c.sf & 15)-8;\n" +"#define UCHAR8_TO_CHAR16(a, c) "" a.s0 = (c.s0 >> 4) - 8; a.s1 = (c.s0 & 15) - 8; a.s2 = (c.s1 >> 4) - 8; a.s3 = (c.s1 & 15) - 8; a.s4 = (c.s2 >> 4) - 8; a.s5 = (c.s2 & 15) - 8; a.s6 = (c.s3 >> 4) - 8; a.s7 = (c.s3 & 15) - 8; "" a.s8=(c.s4 >> 4)-8; a.s9=(c.s4 & 15)-8; a.sa=(c.s5 >> 4)-8; a.sb=(c.s5 & 15)-8; a.sc=(c.s6 >> 4)-8; a.sd=(c.s6 & 15)-8; a.se=(c.s7 >> 4)-8; a.sf=(c.s7 & 15)-8;\n" +"#define DOT16X16(a, b, c) "" c += dot(a.s0123, b.s0123); "" c += dot(a.s4567, b.s4567); "" c += dot(a.s89ab, b.s89ab); "" c += dot(a.scdef,b.scdef);\n" +"#if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE)\n" +"#define CHANNEL_PACK 32\n" +"#else\n" +"#define CHANNEL_PACK 16\n" +"#endif\n" +"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +"#define WEIGHT_STRIDE 16\n" +"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" +"#define WEIGHT_STRIDE 8\n" +"#endif\n" +"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" +"#ifdef USE_IMAGE\n" +"inline COMPUTE_FLOAT16 readWeight(__read_only image2d_t weight,int ix,int iy,COMPUTE_FLOAT scale,COMPUTE_FLOAT offset){\n" +" return CONVERT_COMPUTE_FLOAT16(as_char16(read_imagei(weight,SAMPLER,(int2)(ix,iy))))*scale+offset;\n" +"}\n" +"#else\n" +"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +"inline COMPUTE_FLOAT16 readWeight(__global const char *weight,int ix,int iy,COMPUTE_FLOAT scale,COMPUTE_FLOAT offset){\n" +" return CONVERT_COMPUTE_FLOAT16(vload16(0,weight))*scale+offset;\n" +"}\n" +"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" +"inline COMPUTE_FLOAT16 readWeight(__global const uchar *weight,int ix,int iy,COMPUTE_FLOAT scale,COMPUTE_FLOAT offset){\n" +" uchar16 charWeightsInt40=vload16(0,weight);\n" +" uchar8 charWeightsInt4=vload8(0,weight);\n" +" char16 charWeights=0;\n" +" UCHAR8_TO_CHAR16(charWeights,charWeightsInt4);\n" +" return CONVERT_COMPUTE_FLOAT16(charWeights)*scale+offset;\n" +"}\n" +"#endif\n" +"#endif\n" +"__kernel void inverse_quant_weight(GLOBAL_SIZE_DIM2\n" +" #ifdef USE_IMAGE\n" +" __read_only image2d_t weight,\n" +" #else\n" +" #if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +" __global const char *weight,\n" +" #elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" +" __global const uchar *weight,\n" +" #endif\n" +" #endif\n" +" __global const float *dequantScaleOffset,\n" +" __global FLOAT* output,\n" +" __private const int outputChannelAlign,\n" +" __private const int outputChannel4Align,\n" +" __private const int blockDim){\n" +" const int x=get_global_id(0); //ic\n" +" const int y=get_global_id(1); //oc\n" +" UNIFORM_BOUNDRY_CHECK(x,y);\n" +" #if defined(USE_LOW_BIT_WEIGHT_INT4) && defined(USE_IMAGE)\n" +" \n" +" const int ic=x << 5;\n" +" const int oc=y << 2;\n" +" const int output_offset=ic*outputChannelAlign+oc;\n" +" int kindex=(ic/blockDim)*outputChannel4Align*2;\n" +" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(0,dequantScaleOffset+kindex+oc*2));\n" +" COMPUTE_FLOAT16 weights00,weights01,weights10,weights11,weights20,weights21,weights30,weights31;\n" +" {\n" +" uchar16 charWeightsInt40=as_uchar16(read_imagei(weight,SAMPLER,(int2)(oc,x)));\n" +" uchar16 charWeightsInt41=as_uchar16(read_imagei(weight,SAMPLER,(int2)(oc+1,x)));\n" +" uchar16 charWeightsInt42=as_uchar16(read_imagei(weight,SAMPLER,(int2)(oc+2,x)));\n" +" uchar16 charWeightsInt43=as_uchar16(read_imagei(weight,SAMPLER,(int2)(oc+3,x)));\n" +" char16 charWeights0,charWeights1;\n" +" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt40);\n" +" weights00=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s0+ScaleOffset.s1;\n" +" weights01=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s0+ScaleOffset.s1;\n" +" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt41);\n" +" weights10=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s2+ScaleOffset.s3;\n" +" weights11=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s2+ScaleOffset.s3;\n" +" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt42);\n" +" weights20=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s4+ScaleOffset.s5;\n" +" weights21=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s4+ScaleOffset.s5;\n" +" UCHAR16_TO_2CHAR16(charWeights0,charWeights1,charWeightsInt43);\n" +" weights30=CONVERT_COMPUTE_FLOAT16(charWeights0)*ScaleOffset.s6+ScaleOffset.s7;\n" +" weights31=CONVERT_COMPUTE_FLOAT16(charWeights1)*ScaleOffset.s6+ScaleOffset.s7;\n" +" }\n" +" COMPUTE_FLOAT *weights00_ptr=(COMPUTE_FLOAT *)&weights00;\n" +" COMPUTE_FLOAT *weights10_ptr=(COMPUTE_FLOAT *)&weights10;\n" +" COMPUTE_FLOAT *weights20_ptr=(COMPUTE_FLOAT *)&weights20;\n" +" COMPUTE_FLOAT *weights30_ptr=(COMPUTE_FLOAT *)&weights30;\n" +" COMPUTE_FLOAT *weights01_ptr=(COMPUTE_FLOAT *)&weights01;\n" +" COMPUTE_FLOAT *weights11_ptr=(COMPUTE_FLOAT *)&weights11;\n" +" COMPUTE_FLOAT *weights21_ptr=(COMPUTE_FLOAT *)&weights21;\n" +" COMPUTE_FLOAT *weights31_ptr=(COMPUTE_FLOAT *)&weights31;\n" +" #pragma unroll\n" +" for (int i=0; i<16; ++i){\n" +" FLOAT4 out=CONVERT_FLOAT4((COMPUTE_FLOAT4)(weights00_ptr[i],weights10_ptr[i],weights20_ptr[i],weights30_ptr[i]));\n" +" vstore4(out,0,output+output_offset+i*outputChannelAlign);\n" +" }\n" +" #pragma unroll\n" +" for (int i=0; i<16; ++i){\n" +" FLOAT4 out=CONVERT_FLOAT4((COMPUTE_FLOAT4)(weights01_ptr[i],weights11_ptr[i],weights21_ptr[i],weights31_ptr[i]));\n" +" vstore4(out,0,output+output_offset+(i+16)*outputChannelAlign);\n" +" }\n" +" #else\n" +" const int ic=x << 4;\n" +" const int oc=y << 2;\n" +"#ifndef USE_IMAGE\n" +" #if (defined USE_LOW_BIT_WEIGHT_INT4)\n" +" int weight_offset=oc*8;\n" +" int weight_oc_offset=outputChannel4Align*8;\n" +" int weight_stride=8;\n" +" #else\n" +" int weight_offset=oc*16;\n" +" int weight_oc_offset=outputChannel4Align*16;\n" +" int weight_stride=16;\n" +" #endif\n" +"#endif\n" +" const int output_offset=ic*outputChannelAlign+oc;\n" +" int kindex=(ic/blockDim)*outputChannel4Align*2;\n" +" COMPUTE_FLOAT8 ScaleOffset=CONVERT_COMPUTE_FLOAT8(vload8(0,dequantScaleOffset+kindex+oc*2));\n" +" #ifdef USE_IMAGE\n" +" COMPUTE_FLOAT16 weights0=readWeight(weight,oc,x,ScaleOffset.s0,ScaleOffset.s1);\n" +" COMPUTE_FLOAT16 weights1=readWeight(weight,oc+1,x,ScaleOffset.s2,ScaleOffset.s3);\n" +" COMPUTE_FLOAT16 weights2=readWeight(weight,oc+2,x,ScaleOffset.s4,ScaleOffset.s5);\n" +" COMPUTE_FLOAT16 weights3=readWeight(weight,oc+3,x,ScaleOffset.s6,ScaleOffset.s7);\n" +" #else\n" +" COMPUTE_FLOAT16 weights0=readWeight(weight+weight_offset+x*weight_oc_offset,0,0,ScaleOffset.s0,ScaleOffset.s1);\n" +" COMPUTE_FLOAT16 weights1=readWeight(weight+weight_offset+x*weight_oc_offset+weight_stride,0,0,ScaleOffset.s2,ScaleOffset.s3);\n" +" COMPUTE_FLOAT16 weights2=readWeight(weight+weight_offset+x*weight_oc_offset+2*weight_stride,0,0,ScaleOffset.s4,ScaleOffset.s5);\n" +" COMPUTE_FLOAT16 weights3=readWeight(weight+weight_offset+x*weight_oc_offset+3*weight_stride,0,0,ScaleOffset.s6,ScaleOffset.s7);\n" +" #endif\n" +" COMPUTE_FLOAT *weights0_ptr=(COMPUTE_FLOAT*)&weights0;\n" +" COMPUTE_FLOAT *weights1_ptr=(COMPUTE_FLOAT*)&weights1;\n" +" COMPUTE_FLOAT *weights2_ptr=(COMPUTE_FLOAT*)&weights2;\n" +" COMPUTE_FLOAT *weights3_ptr=(COMPUTE_FLOAT*)&weights3;\n" +" #pragma unroll\n" +" for (int i=0; i<16; ++i){\n" +" FLOAT4 out=CONVERT_FLOAT4((COMPUTE_FLOAT4)(weights0_ptr[i],weights1_ptr[i],weights2_ptr[i],weights3_ptr[i]));\n" +" vstore4(out,0,output+output_offset+i*outputChannelAlign);\n" +" }\n" +" #endif\n" +"}\n" +"__kernel void reshape_nchw4_nhwc4(GLOBAL_SIZE_DIM2\n" +"__global const FLOAT* input,\n" +"__global FLOAT* output,\n" +"__private const int bhw,\n" +"__private const int channel,\n" +"__private const int channelAlign){\n" +" const int x=get_global_id(0); //c\n" +" const int y=get_global_id(1); //bhw\n" +" UNIFORM_BOUNDRY_CHECK(x,y);\n" +" \n" +" const int x4=x << 2;\n" +" const int y4=y << 2;\n" +" const int input_offset=(x*bhw+y4)*4;\n" +" FLOAT4 in0=vload4(0,input+input_offset);\n" +" FLOAT4 in1=(y4+1= channel){\n" +" FLOAT *in0_ptr=(FLOAT*)&in0;\n" +" FLOAT *in1_ptr=(FLOAT*)&in1;\n" +" FLOAT *in2_ptr=(FLOAT*)&in2;\n" +" FLOAT *in3_ptr=(FLOAT*)&in3;\n" +" int remain=x4+3-channel;\n" +" for(int i=remain; i >= 0; i--){\n" +" in0_ptr[3-i]=0;\n" +" in1_ptr[3-i]=0;\n" +" in2_ptr[3-i]=0;\n" +" in3_ptr[3-i]=0;\n" +" }\n" +" }\n" +"#endif\n" +" \n" +"#ifdef FORMAT_CNHW\n" +" int idx=x/4;\n" +" int idy=x % 4;\n" +" const int bhw4=(bhw+3)/4*4;\n" +" int output_offset=((idx*bhw4+y4)*4+idy)*4; // [c/16 b 4 4]\n" +" vstore4(in0,0,output+output_offset);\n" +" vstore4(in1,0,output+output_offset+16);\n" +" vstore4(in2,0,output+output_offset+32);\n" +" vstore4(in3,0,output+output_offset+48);\n" +"#else\n" +" FLOAT16 out=(FLOAT16)(in0.s0,in1.s0,in2.s0,in3.s0,in0.s1,in1.s1,in2.s1,in3.s1,in0.s2,in1.s2,in2.s2,in3.s2,in0.s3,in1.s3,in2.s3,in3.s3);\n" +" const int output_offset=(y*channelAlign+x4)*4;\n" +" vstore16(out,0,output+output_offset);\n" +"#endif\n" +"}\n" +"__kernel void reshape_nhwc4_nchw4(GLOBAL_SIZE_DIM2\n" +"__global const FLOAT* input,\n" +"__global FLOAT* output,\n" +"__private const int bhw,\n" +"__private const int channelAlign){\n" +" const int x=get_global_id(0); //c\n" +" const int y=get_global_id(1); //bhw\n" +" UNIFORM_BOUNDRY_CHECK(x,y);\n" +" \n" +" const int x4=x << 2;\n" +" const int y4=y << 2;\n" +" const int output_offset=(x*bhw+y4)*4;\n" +" \n" +" const int input_offset=(y*channelAlign+x4)*4;\n" +" FLOAT16 in=vload16(0,input+input_offset);\n" +" \n" +" FLOAT4 out0=(FLOAT4)(in.s0,in.s4,in.s8,in.sc);\n" +" FLOAT4 out1=(FLOAT4)(in.s1,in.s5,in.s9,in.sd);\n" +" FLOAT4 out2=(FLOAT4)(in.s2,in.s6,in.sa,in.se);\n" +" FLOAT4 out3=(FLOAT4)(in.s3,in.s7,in.sb,in.sf);\n" +" vstore4(out0,0,output+output_offset);\n" +" if(y4+1 >= bhw) return;\n" +" vstore4(out1,0,output+output_offset+4);\n" +" if(y4+2 >= bhw) return;\n" +" vstore4(out2,0,output+output_offset+8);\n" +" if(y4+3 >= bhw) return;\n" +" vstore4(out3,0,output+output_offset+12);\n" +"}\n" +"__kernel void gemm_b4_c4_buf(GLOBAL_SIZE_DIM2\n" +" __global const FLOAT* input,\n" +"#ifdef USE_IMAGE\n" +" __read_only image2d_t weight,\n" +"#else\n" +"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" +" __global const char *weight,\n" +"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" +" __global const uchar *weight,\n" +"#endif\n" +"#endif\n" +" __global const float *dequantScaleOffset,\n" +" __global const FLOAT *bias,\n" +" __global FLOAT* output,\n" +" __private const int bhw4,\n" +" __private const int dstChannelAlign,\n" +" __private const int srcChannelAlign,\n" +" __private const int blockNum,\n" +" __private const int blockDim) {\n" +" const int x=get_global_id(0); //c\n" +" const int y=get_global_id(1); //b\n" +" UNIFORM_BOUNDRY_CHECK(x,y);\n" +" const int out_c_idx=x << 2;\n" +" const int out_b_idx=y << 2;\n" +" COMPUTE_FLOAT4 bias0=CONVERT_COMPUTE_FLOAT4(vload4(0,bias+out_c_idx));\n" +" COMPUTE_FLOAT4 out=(COMPUTE_FLOAT4)bias0.s0;\n" +" COMPUTE_FLOAT4 out1=(COMPUTE_FLOAT4)bias0.s1,out2=(COMPUTE_FLOAT4)bias0.s2,out3=(COMPUTE_FLOAT4)bias0.s3;\n" +"#ifdef FORMAT_CNHW\n" +" int input_offset=out_b_idx*16;\n" +"#else\n" +" int input_offset=out_b_idx*srcChannelAlign;\n" +"#endif\n" +" int out_offset=out_b_idx*dstChannelAlign+out_c_idx*4;\n" +" \n" +"#ifndef USE_IMAGE\n" +" int weight_offset=out_c_idx*WEIGHT_STRIDE;\n" +" int weight_oc_offset=dstChannelAlign*WEIGHT_STRIDE;\n" +"#endif\n" +" const int loop=(blockDim+CHANNEL_PACK-1)/CHANNEL_PACK;\n" +" \n" +" for (int i=0; i0; i /= 2){\n" -" if (lid0; i /= 2){\n" -" if (lid= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" -"__kernel void cast_buf(GLOBAL_SIZE_3_DIMS\n" +"#define GLOBAL_SIZE_2_DIMS ""__private const int global_size_dim0,__private const int global_size_dim1,\n" +"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n" +"__kernel void cast_buf(GLOBAL_SIZE_2_DIMS\n" " __global INPUT_TYPE* input,\n" " __global OUTPUT_TYPE* output,\n" -" __private const int width,\n" -" __private const int height,\n" -" __private const int channelBlock\n" +" __private const int size\n" " ) {\n" -" const int width_idx=get_global_id(0);\n" -" const int height_idx=get_global_id(1);\n" -" const int batch_channel_idx=get_global_id(2);\n" -" DEAL_NON_UNIFORM_DIM3(width_idx,height_idx,batch_channel_idx);\n" -" \n" -" const int batch_idx=batch_channel_idx/channelBlock;\n" -" const int channel_idx=batch_channel_idx % channelBlock;\n" -" \n" -" const int inp_offset=((((batch_idx*channelBlock)+channel_idx)*height+height_idx)*width+width_idx)*4;\n" -"#ifdef TO_BOOL\n" +" const int idx=get_global_id(0);\n" +" const int idy=get_global_id(1);\n" +" DEAL_NON_UNIFORM_DIM2(idx,idy);\n" +" const int inp_offset=idx*4;\n" +"#ifdef PACK_LEAVE\n" +" if(inp_offset+3 >= size){\n" +" int remain=size-inp_offset;\n" +" for(int i=0; i OpenCLProgramMap = #ifndef MNN_OPENCL_BUFFER_CLOSED { "binary_buf", binary_buf }, #endif -#ifndef MNN_OPENCL_BUFFER_CLOSED - { "gemm_quant_batch_buf", gemm_quant_batch_buf }, -#endif #ifndef MNN_OPENCL_BUFFER_CLOSED { "raster_buf", raster_buf }, #endif @@ -338,6 +335,9 @@ const std::map OpenCLProgramMap = { "roi_pooling", roi_pooling }, { "depthwise_conv2d", depthwise_conv2d }, { "layernorm", layernorm }, +#ifndef MNN_OPENCL_BUFFER_CLOSED + { "gemm_conv1x1_buf", gemm_conv1x1_buf }, +#endif { "winogradTransformDest2_5_1", winogradTransformDest2_5_1 }, #ifndef MNN_OPENCL_BUFFER_CLOSED { "cast_buf", cast_buf }, diff --git a/source/backend/opencl/execution/cl/pooling_buf.cl b/source/backend/opencl/execution/cl/pooling_buf.cl index 1340973d1..300e25185 100644 --- a/source/backend/opencl/execution/cl/pooling_buf.cl +++ b/source/backend/opencl/execution/cl/pooling_buf.cl @@ -16,7 +16,7 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, __private const int2 kernel_shape, __global FLOAT *output, __global FLOAT *rediceOutput, - __private const int channel_block) { + __private const int batch) { const int ow_idx = get_global_id(0); const int b_oh_idx = get_global_id(1); @@ -31,7 +31,7 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, #ifdef POOL_AVG COMPUTE_FLOAT4 result = (COMPUTE_FLOAT4)(0); - const int inp_offset = (((b_idx*channel_block+c_idx)*input_shape.x+ih_start)*input_shape.y+iw_start)*4; + const int inp_offset = (((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start)*4; #ifdef COUNT_INCLUDE_PADDING int total_count = (min(ih_start + kernel_shape.x, input_shape.x + pad_shape.x) - ih_start) * (min(iw_start + kernel_shape.y, input_shape.y + pad_shape.y) - iw_start); #else @@ -60,7 +60,7 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, #if RETURN_REDICE int4 redice = (int4)0; #endif - const int inp_offset = (((b_idx*channel_block+c_idx)*input_shape.x+ih_start)*input_shape.y+iw_start)*4; + const int inp_offset = (((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start)*4; for(int kh=0; kh= input_shape.x) { @@ -80,7 +80,7 @@ __kernel void pooling(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, } #endif - const int out_offset = (((b_idx*channel_block + c_idx)*output_shape.x + oh_idx)* output_shape.y + ow_idx)*4; + const int out_offset = (((b_idx + c_idx*batch)*output_shape.x + oh_idx)* output_shape.y + ow_idx)*4; vstore4(CONVERT_FLOAT4(result), 0, output+out_offset); #if RETURN_REDICE vstore4(CONVERT_FLOAT4(redice), 0, rediceOutput+out_offset); @@ -96,7 +96,7 @@ __kernel void global_pooling_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, __private const int2 kernel_shape, __global FLOAT *output, __global FLOAT *rediceOutput, - __private const int channel_block) { + __private const int batch) { const int local_id = get_local_id(0); const int output_channel_idx = get_global_id(1); const int output_batch_idx = get_global_id(2); @@ -112,7 +112,7 @@ __kernel void global_pooling_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, #endif COMPUTE_FLOAT4 local sum[LOCAL_SIZE]; - const int inp_offset = ((output_batch_idx*channel_block+output_channel_idx)*input_shape.x)*input_shape.y*4; + const int inp_offset = ((output_batch_idx+output_channel_idx*batch)*input_shape.x)*input_shape.y*4; const int size = input_shape.x * input_shape.y; for(int i = local_id; i < size; i+=LOCAL_SIZE){ int w = i % input_shape.y;; @@ -152,7 +152,7 @@ __kernel void global_pooling_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, output_result /= (input_shape.x * input_shape.y); #endif - const int out_offset = (output_batch_idx*channel_block + output_channel_idx)*4; + const int out_offset = (output_batch_idx + output_channel_idx*batch)*4; vstore4(CONVERT_FLOAT4(output_result), 0, output+out_offset); #if RETURN_REDICE redice = rediceId[0]; diff --git a/source/backend/opencl/execution/cl/pooling_subgroup_buf.cl b/source/backend/opencl/execution/cl/pooling_subgroup_buf.cl index 304c3b903..6311116ff 100644 --- a/source/backend/opencl/execution/cl/pooling_subgroup_buf.cl +++ b/source/backend/opencl/execution/cl/pooling_subgroup_buf.cl @@ -15,9 +15,10 @@ __kernel void pooling_c4_c4(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, __global FLOAT *output, __global FLOAT *rediceOutput, __private const int channel, + __private const int batch, __private const int in_channel_block, __private const int out_channel_block, - __private const int input_pad_left, + __private const int input_pad_left, __private const int input_pad_right, __private const int output_pad_left, __private const int output_pad_right) { @@ -35,7 +36,7 @@ __kernel void pooling_c4_c4(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, #ifdef POOL_AVG COMPUTE_FLOAT4 result = (COMPUTE_FLOAT4)(0); - const int inp_offset = (((b_idx*in_channel_block+c_idx)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4; + const int inp_offset = (((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4; #ifdef COUNT_INCLUDE_PADDING int total_count = (min(ih_start + KERNEL_Y, input_shape.x + pad_shape.x) - ih_start) * (min(iw_start + KERNEL_X, input_shape.y + pad_shape.y) - iw_start); #else @@ -64,7 +65,7 @@ __kernel void pooling_c4_c4(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, #if RETURN_REDICE int4 redice = (int4)0; #endif - const int inp_offset = (((b_idx*in_channel_block+c_idx)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4; + const int inp_offset = (((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4; for(int kh=0; kh= input_shape.x) { @@ -84,10 +85,10 @@ __kernel void pooling_c4_c4(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, } #endif - const int out_offset = (((b_idx*in_channel_block + c_idx)*output_shape.x + oh_idx)* output_shape.y + ow_idx + output_pad_left)*4; + const int out_offset = (((b_idx + c_idx*batch)*output_shape.x + oh_idx)* output_shape.y + ow_idx + output_pad_left)*4; vstore4(CONVERT_FLOAT4(result), 0, output+out_offset); #if RETURN_REDICE - vstore4(CONVERT_FLOAT4(redice), 0, rediceOutput+(((b_idx*in_channel_block + c_idx)*output_shape.x + oh_idx)* output_shape.y + ow_idx)*4); + vstore4(CONVERT_FLOAT4(redice), 0, rediceOutput+(((b_idx + c_idx*batch)*output_shape.x + oh_idx)* output_shape.y + ow_idx)*4); #endif } @@ -98,6 +99,7 @@ __kernel void pooling_c4_c16(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, __global FLOAT *output, __global FLOAT *rediceOutput, __private const int channel, + __private const int batch, __private const int in_channel_block, __private const int out_channel_block, __private const int input_pad_left, @@ -119,7 +121,7 @@ __kernel void pooling_c4_c16(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, #ifdef POOL_AVG COMPUTE_FLOAT4 result = (COMPUTE_FLOAT4)(0); - const int inp_offset = (((b_idx*in_channel_block+c_idx)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4; + const int inp_offset = (((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4; #ifdef COUNT_INCLUDE_PADDING int total_count = (min(ih_start + KERNEL_Y, input_shape.x + pad_shape.x) - ih_start) * (min(iw_start + KERNEL_X, input_shape.y + pad_shape.y) - iw_start); #else @@ -148,7 +150,7 @@ __kernel void pooling_c4_c16(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, #if RETURN_REDICE int4 redice = (int4)0; #endif - const int inp_offset = (((b_idx*in_channel_block+c_idx)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4; + const int inp_offset = (((b_idx+c_idx*batch)*input_shape.x+ih_start)*input_shape.y+iw_start+input_pad_left)*4; for(int kh=0; kh= input_shape.x) { @@ -194,6 +196,7 @@ __kernel void pooling_c16_c16(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, __global FLOAT *output, __global FLOAT *rediceOutput, __private const int channel, + __private const int batch, __private const int in_channel_block, __private const int out_channel_block, __private const int input_pad_left, @@ -343,6 +346,7 @@ __kernel void pooling_c16_c4(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, __global FLOAT *output, __global FLOAT *rediceOutput, __private const int channel, + __private const int batch, __private const int in_channel_block, __private const int out_channel_block, __private const int input_pad_left, @@ -429,18 +433,18 @@ __kernel void pooling_c16_c4(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, const uint lid_x = sglid % 4; const uint lid_y = sglid / 4; - const int out_offset = (((b_idx*out_channel_block + c_idx * 4)*output_shape.x + oh_idx)* output_shape.y + ow_idx + output_pad_left)*4; - const int width_height = output_shape.y * output_shape.x * 4; + const int out_offset = (((b_idx + c_idx * 4 * batch)*output_shape.x + oh_idx)* output_shape.y + ow_idx + output_pad_left)*4; + const int batch_width_height = batch * output_shape.y * output_shape.x * 4; #if RETURN_REDICE - const int redice_offset = (((b_idx*out_channel_block + c_idx * 4)*output_shape.x + oh_idx)* output_shape.y + ow_idx)*4; + const int redice_offset = (((b_idx + c_idx * 4 * batch)*output_shape.x + oh_idx)* output_shape.y + ow_idx)*4; #endif #if OUTPUT_LEFTOVERS if ((c_idx+1)*16 >= channel) { for (int i = 0; i < 8; i++) { if ((c_idx*16 + lid_y * 4 + lid_x < channel) && (ow_idx + i) < output_shape.y) - output[out_offset + lid_y * width_height + i * 4 + lid_x] = result[i]; + output[out_offset + lid_y * batch_width_height + i * 4 + lid_x] = result[i]; #if RETURN_REDICE - rediceOutput[redice_offset + lid_y * width_height + i * 4 + lid_x] = redice[i]; + rediceOutput[redice_offset + lid_y * batch_width_height + i * 4 + lid_x] = redice[i]; #endif } } @@ -448,9 +452,9 @@ __kernel void pooling_c16_c4(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, #endif { for (int i = 0; i < 8 && (ow_idx + i) < output_shape.y; i++) { - output[out_offset + lid_y * width_height + i * 4 + lid_x] = result[i]; + output[out_offset + lid_y * batch_width_height + i * 4 + lid_x] = result[i]; #if RETURN_REDICE - rediceOutput[redice_offset + lid_y * width_height + i * 4 + lid_x] = redice[i]; + rediceOutput[redice_offset + lid_y * batch_width_height + i * 4 + lid_x] = redice[i]; #endif } } diff --git a/source/backend/opencl/execution/cl/range_buf.cl b/source/backend/opencl/execution/cl/range_buf.cl index 79ea69a69..fbadf98e7 100644 --- a/source/backend/opencl/execution/cl/range_buf.cl +++ b/source/backend/opencl/execution/cl/range_buf.cl @@ -2,39 +2,40 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable #endif -#define GLOBAL_SIZE_3_DIMS \ -__private const int global_size_dim0, __private const int global_size_dim1, __private const int global_size_dim2, +#define GLOBAL_SIZE_2_DIMS \ +__private const int global_size_dim0, __private const int global_size_dim1, -#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) \ - if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { \ +#define DEAL_NON_UNIFORM_DIM2(input1, input2) \ + if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { \ return; \ } -__kernel void range_buf(GLOBAL_SIZE_3_DIMS +__kernel void range_buf(GLOBAL_SIZE_2_DIMS __global const INPUT_TYPE* input0, __global const INPUT_TYPE* input2, __global OUTPUT_TYPE* output, - __private const int width, - __private const int height, - __private const int channel, - __private const int channelBlock + __private const int size ) { - const int width_idx = get_global_id(0); - const int height_idx = get_global_id(1); - const int batch_channel_idx = get_global_id(2); + const int x = get_global_id(0); + const int y = get_global_id(1); - DEAL_NON_UNIFORM_DIM3(width_idx, height_idx, batch_channel_idx); + DEAL_NON_UNIFORM_DIM2(x, y); - const int batch_idx = batch_channel_idx / channelBlock; - const int channel_idx = batch_channel_idx % channelBlock; - - const int offset = ((((batch_idx * channelBlock) + channel_idx) * height + height_idx) * width + width_idx)*4; - const int channel4 = channel_idx << 2; - int index = (((batch_idx * channel) + channel4) * height + height_idx) * width + width_idx; - int size = height * width; - int4 index4 = (int4)(index, index + size, index + size * 2, index + size * 3); + int index = x << 2; + int4 index4 = (int4)(index, index + 1, index + 2, index + 3); INPUT_TYPE start = input0[0]; INPUT_TYPE step = input2[0]; OUTPUT_TYPE4 value = (OUTPUT_TYPE4)start + CONVERT_OUTPUT4(index4) * (OUTPUT_TYPE4)step; - vstore4(value, 0, output + offset); +#ifdef PACK_LEAVE + if(index + 3 >= size){ + OUTPUT_TYPE* value_ptr = (OUTPUT_TYPE*)&value; + for(int i = 0; i < size - index; ++i){ + output[index + i] = value_ptr[i]; + } + }else{ +#endif + vstore4(value, 0, output + index); +#ifdef PACK_LEAVE + } +#endif } diff --git a/source/backend/opencl/execution/cl/raster_buf.cl b/source/backend/opencl/execution/cl/raster_buf.cl index 7770f09e2..947910084 100644 --- a/source/backend/opencl/execution/cl/raster_buf.cl +++ b/source/backend/opencl/execution/cl/raster_buf.cl @@ -32,31 +32,69 @@ __kernel void buffer_set_zero( output[y*global_size_dim0 + x] = (OUTPUT_TYPE)(0.0f); } -__kernel void raster_buffer( +#define MNN_DATA_FORMAT_NCHW 0 +#define MNN_DATA_FORMAT_NHWC 1 +#define MNN_DATA_FORMAT_NC4HW4 2 +__kernel void raster_direct_buffer( GLOBAL_SIZE_3_DIMS + __private const int size_x, __global INPUT_TYPE *input, __private const int inputOffset, + __private const int combineSrcOffset, __private const int inputStride0, __private const int inputStride1, __private const int inputStride2, + __private const int src_width, + __private const int src_height, + __private const int src_channel, + __private const int src_batch, __global OUTPUT_TYPE *output, __private const int outputOffset, + __private const int combineDstOffset, __private const int outputStride0, __private const int outputStride1, - __private const int outputStride2 + __private const int outputStride2, + __private const int dst_width, + __private const int dst_height, + __private const int dst_channel, + __private const int dst_batch ) { - const int x = get_global_id(0); + const int idx = get_global_id(0); const int y = get_global_id(1); const int z = get_global_id(2); - DEAL_NON_UNIFORM_DIM3(x, y, z); + DEAL_NON_UNIFORM_DIM3(idx, y, z); + const int x = idx % size_x; + const int id = idx / size_x; - int inputIndex = inputOffset + z * inputStride0 + y * inputStride1 + x * inputStride2; - int outputIndex = outputOffset + z * outputStride0 + y * outputStride1 + x * outputStride2; - output[outputIndex] = (OUTPUT_TYPE)input[inputIndex]; + int inputIndex = inputOffset + id * combineSrcOffset + z * inputStride0 + y * inputStride1 + x * inputStride2; + int outputIndex = outputOffset + id * combineDstOffset + z * outputStride0 + y * outputStride1 + x * outputStride2; +#if INPUT_FORMAT == MNN_DATA_FORMAT_NCHW + int inputIndexReal = inputIndex; +#elif INPUT_FORMAT == MNN_DATA_FORMAT_NHWC + int inputIndexReal = inputIndex; +#elif INPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4 + int in_w = inputIndex % src_width; inputIndex /= src_width; + int in_h = inputIndex % src_height; inputIndex /= src_height; + int in_c = inputIndex % src_channel; + int in_b = inputIndex / src_channel; + int inputIndexReal = (((in_b + (in_c / 4) * src_batch) * src_height + in_h) * src_width + in_w) * 4 + (in_c % 4); +#endif + +#if OUTPUT_FORMAT == MNN_DATA_FORMAT_NCHW + int outputIndexReal = outputIndex; +#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NHWC + int outputIndexReal = outputIndex; +#elif OUTPUT_FORMAT == MNN_DATA_FORMAT_NC4HW4 + int out_w = outputIndex % dst_width; outputIndex /= dst_width; + int out_h = outputIndex % dst_height; outputIndex /= dst_height; + int out_c = outputIndex % dst_channel; + int out_b = outputIndex / dst_channel; + int outputIndexReal = (((out_b + (out_c / 4) * dst_batch) * dst_height + out_h) * dst_width + out_w) * 4 + (out_c % 4); +#endif + output[outputIndexReal] = (OUTPUT_TYPE)input[inputIndexReal]; } - __kernel void raster_nc4hw4_buffer( GLOBAL_SIZE_3_DIMS __global INPUT_TYPE *input, @@ -85,72 +123,6 @@ __kernel void raster_nc4hw4_buffer( int inputIndex = inputOffset + (z * inputStride0 + y * inputStride1 + x * inputStride2) * 4; int outputIndex = outputOffset + (z * outputStride0 + y * outputStride1 + x * outputStride2) * 4; - vstore4(CONVERT_OUTPUT4(vload4(0, input+inputIndex)), 0, output+outputIndex); -} - -__kernel void raster_direct_buffer( - GLOBAL_SIZE_3_DIMS - __private const int size_x, - __global INPUT_TYPE *input, - __private const int inputOffset, - __private const int combineSrcOffset, - __private const int inputStride0, - __private const int inputStride1, - __private const int inputStride2, - __private const int src_width, - __private const int src_height, - __private const int src_channel, - __global OUTPUT_TYPE *output, - __private const int outputOffset, - __private const int combineDstOffset, - __private const int outputStride0, - __private const int outputStride1, - __private const int outputStride2, - __private const int dst_width, - __private const int dst_height, - __private const int dst_channel - ) { - const int idx = get_global_id(0); - const int y = get_global_id(1); - const int z = get_global_id(2); - - DEAL_NON_UNIFORM_DIM3(idx, y, z); - const int x = idx % size_x; - const int id = idx / size_x; - - int inputIndex = inputOffset + id * combineSrcOffset + z * inputStride0 + y * inputStride1 + x * inputStride2; - int outputIndex = outputOffset + id * combineDstOffset + z * outputStride0 + y * outputStride1 + x * outputStride2; -#ifdef INPUT_DATA_FORMAT_NHWC - int in_c = inputIndex % src_channel; inputIndex /= src_channel; - int in_w = inputIndex % src_width; inputIndex /= src_width; - int in_h = inputIndex % src_height; - int in_b = inputIndex / src_height; - int src_channel4 = (src_channel + 3) / 4; - int inputIndexC4 = (((in_b * src_channel4 + (in_c / 4)) * src_height + in_h) * src_width + in_w) * 4 + (in_c % 4); -#else - int in_w = inputIndex % src_width; inputIndex /= src_width; - int in_h = inputIndex % src_height; inputIndex /= src_height; - int in_c = inputIndex % src_channel; - int in_b = inputIndex / src_channel; - int src_channel4 = (src_channel + 3) / 4; - int inputIndexC4 = (((in_b * src_channel4 + (in_c / 4)) * src_height + in_h) * src_width + in_w) * 4 + (in_c % 4); -#endif - -#ifdef OUTPUT_DATA_FORMAT_NHWC - int out_c = outputIndex % dst_channel; outputIndex /= dst_channel; - int out_w = outputIndex % dst_width; outputIndex /= dst_width; - int out_h = outputIndex % dst_height; - int out_b = outputIndex / dst_height; - int dst_channel4 = (dst_channel + 3) / 4; - int outputIndexC4 = (((out_b * dst_channel4 + (out_c / 4)) * dst_height + out_h) * dst_width + out_w) * 4 + (out_c % 4); -#else - int out_w = outputIndex % dst_width; outputIndex /= dst_width; - int out_h = outputIndex % dst_height; outputIndex /= dst_height; - int out_c = outputIndex % dst_channel; - int out_b = outputIndex / dst_channel; - int dst_channel4 = (dst_channel + 3) / 4; - int outputIndexC4 = (((out_b * dst_channel4 + (out_c / 4)) * dst_height + out_h) * dst_width + out_w) * 4 + (out_c % 4); -#endif - - output[outputIndexC4] = (OUTPUT_TYPE)input[inputIndexC4]; + OUTPUT_TYPE4 values = CONVERT_OUTPUT4(vload4(0, (__global INPUT_TYPE *)(input+inputIndex))); + vstore4(values, 0, (__global OUTPUT_TYPE *)(output+outputIndex)); } diff --git a/source/backend/opencl/execution/cl/reduction_buf.cl b/source/backend/opencl/execution/cl/reduction_buf.cl index aa5b00960..daf033545 100644 --- a/source/backend/opencl/execution/cl/reduction_buf.cl +++ b/source/backend/opencl/execution/cl/reduction_buf.cl @@ -17,355 +17,88 @@ __private const int global_size_dim0, __private const int global_size_dim1, __pr return; \ } -__kernel void reduct_width_buf(GLOBAL_SIZE_3_DIMS - __global const INPUT_TYPE* input, - __global OUTPUT_TYPE* output, - __private const int inputWidth, - __private const int inputHeight, - __private const int inputChannel, - __private const int inputBatch, - __private const int inputChannelBlock, - __private const int oututWidth, - __private const int outputHeight, - __private const int outputChannel, - __private const int outputChannelBlock - ) { - const int width_idx = get_global_id(0); - const int height_idx = get_global_id(1); - const int batch_channel_idx = get_global_id(2); - - DEAL_NON_UNIFORM_DIM3(width_idx, height_idx, batch_channel_idx); - - const int batch_idx = batch_channel_idx / outputChannelBlock; - const int channel_idx = batch_channel_idx % outputChannelBlock; - const int offset = ((((batch_idx * inputChannelBlock) + channel_idx) * inputHeight + height_idx) * inputWidth + 0)*4; - const int outputOffset = ((((batch_idx * outputChannelBlock) + channel_idx) * outputHeight + height_idx) * oututWidth + 0)*4; - INPUT_TYPE4 out = (INPUT_TYPE4)VALUE; +__kernel void reduct_buf(GLOBAL_SIZE_3_DIMS + __global const INPUT_TYPE *input, + __global OUTPUT_TYPE *output, + __private const int inside, + __private const int outside, + __private const int dim) { + + const int x = get_global_id(0); + const int y = get_global_id(1); // inside + const int z = get_global_id(2); // outside + DEAL_NON_UNIFORM_DIM3(x, y, z); -#if LOCAL_SIZE > 0 - const int lid = get_local_id(0); - INPUT_TYPE4 local sum[LOCAL_SIZE]; - for(int i = lid; i < inputWidth; i+=LOCAL_SIZE){ - INPUT_TYPE4 in = vload4(i, input + offset); - out = OPERATE(out, in); - } - sum[lid] = out; - barrier(CLK_LOCAL_MEM_FENCE); - for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ - if (lid < i) - sum[lid] = OPERATE(sum[lid], sum[lid + i]); - barrier(CLK_LOCAL_MEM_FENCE); - } - out = sum[0]; -#else - for(int i = 0; i < inputWidth; ++i){ - INPUT_TYPE4 in = vload4(i, input + offset); - out = OPERATE(out, in); - } -#endif - -#ifdef GET_AVG - out = out / inputWidth; -#endif - vstore4(CONVERT_OUTPUT4(out), 0, output + outputOffset); -} - - -__kernel void reduct_height_buf(GLOBAL_SIZE_3_DIMS - __global const INPUT_TYPE* input, - __global OUTPUT_TYPE* output, - __private const int inputWidth, - __private const int inputHeight, - __private const int inputChannel, - __private const int inputBatch, - __private const int inputChannelBlock, - __private const int oututWidth, - __private const int outputHeight, - __private const int outputChannel, - __private const int outputChannelBlock - ) { -#if LOCAL_SIZE > 0 - const int width_local_idx = get_global_id(0); - const int height_idx = get_global_id(1); - const int batch_channel_idx = get_global_id(2); - - DEAL_NON_UNIFORM_DIM3(width_local_idx, height_idx, batch_channel_idx); - - const int width_idx = get_group_id(0); - const int batch_idx = batch_channel_idx / outputChannelBlock; - const int channel_idx = batch_channel_idx % outputChannelBlock; + INPUT_TYPE out = (INPUT_TYPE)VALUE; + const int offset = z * dim * inside + y; - const int offset = ((((batch_idx * inputChannelBlock) + channel_idx) * inputHeight + 0) * inputWidth + width_idx)*4; - const int outputOffset = ((((batch_idx * outputChannelBlock) + channel_idx) * outputHeight + 0) * oututWidth + width_idx)*4; +#if REDUCT_LOCAL_SIZE > 4 const int lid = get_local_id(0); - INPUT_TYPE4 local sum[LOCAL_SIZE]; - INPUT_TYPE4 out = (INPUT_TYPE4)VALUE; - for(int i = lid; i < inputHeight; i+=LOCAL_SIZE){ - INPUT_TYPE4 in = vload4(i * inputWidth, input + offset); + INPUT_TYPE local sum[REDUCT_LOCAL_SIZE]; + for(int i = lid; i < dim; i+=REDUCT_LOCAL_SIZE){ + INPUT_TYPE in = (INPUT_TYPE)input[offset + i * inside]; out = OPERATE(out, in); } sum[lid] = out; barrier(CLK_LOCAL_MEM_FENCE); - for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ + for(int i = REDUCT_LOCAL_SIZE/2; i > 0; i /= 2){ if (lid < i) sum[lid] = OPERATE(sum[lid], sum[lid + i]); barrier(CLK_LOCAL_MEM_FENCE); } out = sum[0]; #else - - const int width_idx = get_global_id(0); - const int height_idx = get_global_id(1); - const int batch_channel_idx = get_global_id(2); - - DEAL_NON_UNIFORM_DIM3(width_idx, height_idx, batch_channel_idx); - - const int batch_idx = batch_channel_idx / outputChannelBlock; - const int channel_idx = batch_channel_idx % outputChannelBlock; - - const int offset = ((((batch_idx * inputChannelBlock) + channel_idx) * inputHeight + 0) * inputWidth + width_idx)*4; - const int outputOffset = ((((batch_idx * outputChannelBlock) + channel_idx) * outputHeight + 0) * oututWidth + width_idx)*4; - INPUT_TYPE4 out = (INPUT_TYPE4)VALUE; - for(int i = 0; i < inputHeight; ++i){ - INPUT_TYPE4 in = vload4(i * inputWidth, input + offset); + for(int i = 0; i < dim; ++i){ + INPUT_TYPE in = (INPUT_TYPE)input[offset + i * inside]; out = OPERATE(out, in); } #endif - -#ifdef GET_AVG - out = out / inputHeight; -#endif - vstore4(CONVERT_OUTPUT4(out), 0, output + outputOffset); -} -__kernel void reduct_channel_buf(GLOBAL_SIZE_3_DIMS - __global const INPUT_TYPE* input, - __global OUTPUT_TYPE* output, - __private const int inputWidth, - __private const int inputHeight, - __private const int inputChannel, - __private const int inputBatch, - __private const int inputChannelBlock, - __private const int oututWidth, - __private const int outputHeight, - __private const int outputChannel, - __private const int outputChannelBlock - ) { -#if LOCAL_SIZE > 0 - const int width_local_idx = get_global_id(0); - const int height_idx = get_global_id(1); - const int batch_idx = get_global_id(2); - - DEAL_NON_UNIFORM_DIM3(width_local_idx, height_idx, batch_idx); - const int width_idx = get_group_id(0); - - const int offset = ((((batch_idx * inputChannelBlock) + 0) * inputHeight + height_idx) * inputWidth + width_idx)*4; - const int outputOffset = ((((batch_idx * outputChannelBlock) + 0) * outputHeight + height_idx) * oututWidth + width_idx)*4; - int remain = inputChannel - (inputChannelBlock - 1) * 4; - const int lid = get_local_id(0); - INPUT_TYPE local sum[LOCAL_SIZE]; - INPUT_TYPE4 out = (INPUT_TYPE4)VALUE; - INPUT_TYPE4 in; - INPUT_TYPE *inPtr = (INPUT_TYPE*)∈ - for(int i = lid; i < inputChannelBlock - 1; i += LOCAL_SIZE){ - in = vload4(i * inputWidth * inputHeight, input + offset); - out = OPERATE(out, in); - } - out.x = OPERATE(out.x, out.y); - out.x = OPERATE(out.x, out.z); - out.x = OPERATE(out.x, out.w); - sum[lid] = out.x; - barrier(CLK_LOCAL_MEM_FENCE); - for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ - if (lid < i) - sum[lid] = OPERATE(sum[lid], sum[lid + i]); - barrier(CLK_LOCAL_MEM_FENCE); - } - out.x = sum[0]; - in = vload4((inputChannelBlock - 1) * inputWidth * inputHeight, input + offset); - for(int j = 0; j < remain; ++j){ - out.x = OPERATE(out.x, inPtr[j]); - } -#ifdef GET_AVG - out.x = out.x / inputChannel; -#endif - output[outputOffset] = (OUTPUT_TYPE)out.x; - -#else - const int width_idx = get_global_id(0); - const int height_idx = get_global_id(1); - const int batch_idx = get_global_id(2); - - DEAL_NON_UNIFORM_DIM3(width_idx, height_idx, batch_idx); - - const int offset = ((((batch_idx * inputChannelBlock) + 0) * inputHeight + height_idx) * inputWidth + width_idx)*4; - const int outputOffset = ((((batch_idx * outputChannelBlock) + 0) * outputHeight + height_idx) * oututWidth + width_idx)*4; - int remain = inputChannel - (inputChannelBlock - 1) * 4; - - INPUT_TYPE out = (INPUT_TYPE)VALUE; - INPUT_TYPE4 in; - INPUT_TYPE *inPtr = (INPUT_TYPE*)∈ - for(int i = 0; i < inputChannelBlock - 1; ++i){ - in = vload4(i * inputWidth * inputHeight, input + offset); - for(int j = 0; j < 4; ++j){ - out = OPERATE(out, inPtr[j]); - } - } - in = vload4((inputChannelBlock - 1) * inputWidth * inputHeight, input + offset); - for(int j = 0; j < remain; ++j){ - out = OPERATE(out, inPtr[j]); - } #ifdef GET_AVG - out = out / inputChannel; -#endif - output[outputOffset] = (OUTPUT_TYPE)out; + out = out / dim; #endif + output[z * inside + y] = (OUTPUT_TYPE)out; } -__kernel void reduct_channel_dim1_buf(GLOBAL_SIZE_3_DIMS - __global const INPUT_TYPE* input, - __global OUTPUT_TYPE* output, - __private const int inputWidth, - __private const int inputHeight, - __private const int inputChannel, - __private const int inputBatch, - __private const int inputChannelBlock, - __private const int oututWidth, - __private const int outputHeight, - __private const int outputChannel, - __private const int outputChannelBlock - ) { -#if LOCAL_SIZE > 0 - const int width_local_idx = get_global_id(0); - const int height_idx = get_global_id(1); - const int batch_idx = get_global_id(2); - - DEAL_NON_UNIFORM_DIM3(width_local_idx, height_idx, batch_idx); - const int width_idx = get_group_id(0); +__kernel void reduct_v4_buf(GLOBAL_SIZE_3_DIMS + __global const INPUT_TYPE *input, + __global OUTPUT_TYPE *output, + __private const int inside, + __private const int outside, + __private const int dim) { + + const int x = get_global_id(0); + const int y = get_global_id(1); // inside + const int z = get_global_id(2); // outside + DEAL_NON_UNIFORM_DIM3(x, y, z); - const int offset = ((((batch_idx * inputChannelBlock) + 0) * inputHeight + height_idx) * inputWidth + width_idx)*4; - const int outputOffset = ((batch_idx * outputHeight + height_idx) * oututWidth + width_idx); - int remain = inputChannel - (inputChannelBlock - 1) * 4; - const int lid = get_local_id(0); - INPUT_TYPE local sum[LOCAL_SIZE]; INPUT_TYPE4 out = (INPUT_TYPE4)VALUE; - INPUT_TYPE4 in; - INPUT_TYPE *inPtr = (INPUT_TYPE*)∈ - for(int i = lid; i < inputChannelBlock - 1; i += LOCAL_SIZE){ - in = vload4(i * inputWidth * inputHeight, input + offset); - out = OPERATE(out, in); - } - out.x = OPERATE(out.x, out.y); - out.x = OPERATE(out.x, out.z); - out.x = OPERATE(out.x, out.w); - sum[lid] = out.x; - barrier(CLK_LOCAL_MEM_FENCE); - for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ - if (lid < i) - sum[lid] = OPERATE(sum[lid], sum[lid + i]); - barrier(CLK_LOCAL_MEM_FENCE); - } - out.x = sum[0]; - in = vload4((inputChannelBlock - 1) * inputWidth * inputHeight, input + offset); - for(int j = 0; j < remain; ++j){ - out.x = OPERATE(out.x, inPtr[j]); - } -#ifdef GET_AVG - out.x = out.x / inputChannel; -#endif - output[outputOffset] = (OUTPUT_TYPE)out.x; + const int offset = z * dim * inside + (y << 2); -#else - const int width_idx = get_global_id(0); - const int height_idx = get_global_id(1); - const int batch_idx = get_global_id(2); - - DEAL_NON_UNIFORM_DIM3(width_idx, height_idx, batch_idx); - const int offset = ((((batch_idx * inputChannelBlock) + 0) * inputHeight + height_idx) * inputWidth + width_idx)*4; - const int outputOffset = ((batch_idx * outputHeight + height_idx) * oututWidth + width_idx); - int remain = inputChannel - (inputChannelBlock - 1) * 4; - INPUT_TYPE out = (INPUT_TYPE)VALUE; - INPUT_TYPE4 in; - INPUT_TYPE *inPtr = (INPUT_TYPE*)∈ - for(int i = 0; i < inputChannelBlock - 1; ++i){ - in = vload4(i * inputWidth * inputHeight, input + offset); - for(int j = 0; j < 4; ++j){ - out = OPERATE(out, inPtr[j]); - } - } - in = vload4((inputChannelBlock - 1) * inputWidth * inputHeight, input + offset); - for(int j = 0; j < remain; ++j){ - out = OPERATE(out, inPtr[j]); - } -#ifdef GET_AVG - out = out / inputChannel; -#endif - output[outputOffset] = (OUTPUT_TYPE)out; -#endif -} - - -__kernel void reduct_batch_buf(GLOBAL_SIZE_3_DIMS - __global const INPUT_TYPE* input, - __global OUTPUT_TYPE* output, - __private const int inputWidth, - __private const int inputHeight, - __private const int inputChannel, - __private const int inputBatch, - __private const int inputChannelBlock, - __private const int oututWidth, - __private const int outputHeight, - __private const int outputChannel, - __private const int outputChannelBlock - ) { -#if LOCAL_SIZE > 0 - const int width_local_idx = get_global_id(0); - const int height_idx = get_global_id(1); - const int channel_idx = get_global_id(2); - - DEAL_NON_UNIFORM_DIM3(width_local_idx, height_idx, channel_idx); - const int width_idx = get_group_id(0); - - const int offset = ((((0 * inputChannelBlock) + channel_idx) * inputHeight + height_idx) * inputWidth + width_idx)*4; - const int outputOffset = ((((0 * outputChannelBlock) + channel_idx) * outputHeight + height_idx) * oututWidth + width_idx)*4; - int batchOffset = inputChannelBlock * inputHeight * inputWidth; +#if REDUCT_LOCAL_SIZE > 4 const int lid = get_local_id(0); - INPUT_TYPE4 local sum[LOCAL_SIZE]; - INPUT_TYPE4 out = (INPUT_TYPE4)VALUE; - for(int i = lid; i < inputBatch; i+=LOCAL_SIZE){ - INPUT_TYPE4 in = vload4(i * batchOffset, input + offset); + INPUT_TYPE4 local sum[REDUCT_LOCAL_SIZE]; + for(int i = lid; i < dim; i+=REDUCT_LOCAL_SIZE){ + INPUT_TYPE4 in = vload4(0, input + offset + i * inside); out = OPERATE(out, in); } sum[lid] = out; barrier(CLK_LOCAL_MEM_FENCE); - for(int i = LOCAL_SIZE/2; i > 0; i /= 2){ + for(int i = REDUCT_LOCAL_SIZE/2; i > 0; i /= 2){ if (lid < i) sum[lid] = OPERATE(sum[lid], sum[lid + i]); barrier(CLK_LOCAL_MEM_FENCE); } out = sum[0]; -#ifdef GET_AVG - out = out / inputBatch; -#endif - vstore4(CONVERT_OUTPUT4(out), 0, output + outputOffset); #else - const int width_idx = get_global_id(0); - const int height_idx = get_global_id(1); - const int channel_idx = get_global_id(2); - - DEAL_NON_UNIFORM_DIM3(width_idx, height_idx, channel_idx); - - const int offset = ((((0 * inputChannelBlock) + channel_idx) * inputHeight + height_idx) * inputWidth + width_idx)*4; - const int outputOffset = ((((0 * outputChannelBlock) + channel_idx) * outputHeight + height_idx) * oututWidth + width_idx)*4; - int batchOffset = inputChannelBlock * inputHeight * inputWidth; - INPUT_TYPE4 out = (INPUT_TYPE4)VALUE; - for(int i = 0; i < inputBatch; ++i){ - INPUT_TYPE4 in = vload4(i * batchOffset, input + offset); + for(int i = 0; i < dim; ++i){ + INPUT_TYPE4 in = vload4(0, input + offset + i * inside); out = OPERATE(out, in); } -#ifdef GET_AVG - out = out / inputBatch; #endif - vstore4(CONVERT_OUTPUT4(out), 0, output + outputOffset); + +#ifdef GET_AVG + out = out / (INPUT_TYPE4)dim; #endif + vstore4(CONVERT_OUTPUT4(out), 0, output + z * inside + (y << 2)); } diff --git a/source/backend/opencl/execution/cl/scale_buf.cl b/source/backend/opencl/execution/cl/scale_buf.cl index 72d3b90fd..f1d722d36 100644 --- a/source/backend/opencl/execution/cl/scale_buf.cl +++ b/source/backend/opencl/execution/cl/scale_buf.cl @@ -17,26 +17,25 @@ __kernel void scale_buf(GLOBAL_SIZE_2_DIMS __global const FLOAT* bias, #endif __global FLOAT* output, - __private const int4 shape) {//N, H, W, C4 + __private const int channelBlock, + __private const int batch, + __private const int inside) { - const int out_w_c_idx = get_global_id(0); - const int out_h_b_idx = get_global_id(1); + const int x = get_global_id(0); // inside(width * height) + const int y = get_global_id(1); // channelBlock * batch - DEAL_NON_UNIFORM_DIM2(out_w_c_idx, out_h_b_idx); + DEAL_NON_UNIFORM_DIM2(x, y); - const int out_b_idx = out_h_b_idx / shape.y; - const int out_h_idx = out_h_b_idx % shape.y; - const int out_c_idx = out_w_c_idx / shape.z; - const int out_w_idx = out_w_c_idx % shape.z; - - const int offset = (((out_b_idx * shape.w + out_c_idx) * shape.y + out_h_idx) * shape.z + out_w_idx) * 4; + const int out_c_idx = y % channelBlock; + const int out_b_idx = y / channelBlock; + const int offset = ((out_b_idx + out_c_idx * batch) * inside + x) * 4; COMPUTE_FLOAT4 in_value = CONVERT_COMPUTE_FLOAT4(vload4(0, input+offset)); COMPUTE_FLOAT4 scale_value = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, scale)); -#ifdef BIAS + #ifdef BIAS COMPUTE_FLOAT4 bias_value = CONVERT_COMPUTE_FLOAT4(vload4(out_c_idx, bias)); COMPUTE_FLOAT4 out_value = in_value * scale_value + bias_value; -#else + #else COMPUTE_FLOAT4 out_value = in_value * scale_value; -#endif + #endif vstore4(CONVERT_FLOAT4(out_value), 0, output+offset); } diff --git a/source/backend/opencl/execution/cl/self_attention_buf.cl b/source/backend/opencl/execution/cl/self_attention_buf.cl index 8dc4f8b78..2b6cf9d9d 100644 --- a/source/backend/opencl/execution/cl/self_attention_buf.cl +++ b/source/backend/opencl/execution/cl/self_attention_buf.cl @@ -53,6 +53,7 @@ __kernel void split_transpose_qkv(GLOBAL_SIZE_3_DIMS __private const int seq_len, __private const int head_num, __private const int head_dim, + __private const int batch, __private const int seq_index ) { const int sl = get_global_id(0); // seqLen_4 @@ -80,8 +81,8 @@ __kernel void split_transpose_qkv(GLOBAL_SIZE_3_DIMS return; } - const int offset_inp = (((b * seq_len_4 + seq_index * seq_len_piece / 4 + sl) * head_num + hn) * 3 * head_dim + 4 * hd) * 4; - + const int offset_inp = ((((seq_index * seq_len_piece / 4 + sl) * batch + b) * head_num + hn) * 3 * head_dim + 4 * hd) * 4; + if(sl * 4 < seq_len_piece) { FLOAT4 temp_0 = vload4(0, input + offset_inp); FLOAT4 temp_1 = vload4(0, input + offset_inp + 4); @@ -125,7 +126,8 @@ __kernel void split_transpose_qkv(GLOBAL_SIZE_3_DIMS } - const int offset_inp = (((b * seq_len_4 + sl) * head_num + hn) * 3 * head_dim + 4 * hd) * 4; + const int offset_inp = (((sl * batch + b) * head_num + hn) * 3 * head_dim + 4 * hd) * 4; + if(sl * 4 < seq_len_piece) { FLOAT4 temp_0 = vload4(0, input + offset_inp); @@ -238,7 +240,7 @@ __kernel void softmax_inside(GLOBAL_SIZE_3_DIMS const int out_offset = (outside * shape.z + 0) * shape.y + axis; #endif /*Compute Result */ - for (int i=lid; i inside_len){ + for(int i = lid + inside_len; i < shape.z; i+=SOFTMAX_LOCAL_SIZE){ + #ifdef OUTPUT_TRANSPOSE + output[out_offset+ i*shape.y] = (FLOAT)0; + #else + output[offset+ i] = (FLOAT)0; + #endif + } + } } // [N X Y4 4] -> [N Y X] -__kernel void trans_3d_buf(__global const FLOAT* input, +__kernel void trans_3d_buf(GLOBAL_SIZE_3_DIMS + __global const FLOAT* input, __global FLOAT* output, __private const int batch, __private const int width, __private const int height ) { int b = get_global_id(2); - - const int w = get_global_id(0) << 3; - const int h = get_global_id(1) << 3; + int w = get_global_id(0); + int h = get_global_id(1); + DEAL_NON_UNIFORM_DIM3(w, h, b); + + w = w << 3; + h = h << 3; const int inp_offset = (b * width + w) * height + h; const int out_offset = (b * height + h) * width + w; @@ -290,6 +305,7 @@ __kernel void clip_transpose_qkv(GLOBAL_SIZE_3_DIMS __private const int seq_len_piece, __private const int head_num, __private const int head_dim, + __private const int batch, __private const int seq_index ) { @@ -311,8 +327,8 @@ __kernel void clip_transpose_qkv(GLOBAL_SIZE_3_DIMS const int offset_inp = ((b * head_num + hn) * head_dim_pack + 4 * hd) * seq_len_pack + 4 * sl; - const int offset_out = (((b * seq_len_4 + seq_index * seq_len_piece / 4 + sl) * head_num + hn) * head_dim + 4 * hd) * 4; - + const int offset_out = ((((seq_index * seq_len_piece / 4 + sl) * batch + b) * head_num + hn) * head_dim + 4 * hd) * 4; + // Q FLOAT4 temp_0 = vload4(0, input + offset_inp); FLOAT4 temp_1 = vload4(0, input + offset_inp + seq_len_pack); diff --git a/source/backend/opencl/execution/cl/softmax_buf.cl b/source/backend/opencl/execution/cl/softmax_buf.cl index 52dd91c61..fa30bf5e2 100644 --- a/source/backend/opencl/execution/cl/softmax_buf.cl +++ b/source/backend/opencl/execution/cl/softmax_buf.cl @@ -12,173 +12,120 @@ } -__kernel void softmax_channel(GLOBAL_SIZE_3_DIMS +__kernel void softmax_in1_buf(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, __global FLOAT *output, - __private const int remain_channels, - __private const int4 shape) {//NCHW + __private const int inside, + __private const int outside, + __private const int dim) { const int x = get_global_id(0); - const int w = get_global_id(1); - const int bh = get_global_id(2); - DEAL_NON_UNIFORM_DIM3(x, w, bh); + const int y = get_global_id(1); // inside = 1 + const int z = get_global_id(2); // outside + DEAL_NON_UNIFORM_DIM3(x, y, z); - const int batch_idx = bh / shape.z; - const int height_idx = bh % shape.z; - const int offset = (((batch_idx*shape.y+0)*shape.z+height_idx)*shape.w+w)*4; + const int offset = z * dim + y; + const int dim4 = (dim + 3) / 4; + const int loop_end = max(0, dim4 - 1); #if SOFTMAX_LOCAL_SIZE >= 4 int lid = get_local_id(0); - COMPUTE_FLOAT4 local sum[SOFTMAX_LOCAL_SIZE]; + COMPUTE_FLOAT local sum[SOFTMAX_LOCAL_SIZE]; + // compute maxvalue COMPUTE_FLOAT4 maxValue = (COMPUTE_FLOAT4)-FLT_MAX; - for (int i = lid; i < shape.y - 1; i+=SOFTMAX_LOCAL_SIZE) { - maxValue = fmax(maxValue, CONVERT_COMPUTE_FLOAT4(vload4(i*shape.z*shape.w, input+offset))); + for (int i = lid; i < loop_end; i+=SOFTMAX_LOCAL_SIZE) { + maxValue = fmax(maxValue, CONVERT_COMPUTE_FLOAT4(vload4(i, input+offset))); } - sum[lid] = maxValue; + sum[lid] = fmax(fmax(fmax(maxValue.x, maxValue.y), maxValue.z), maxValue.w); barrier(CLK_LOCAL_MEM_FENCE); for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i /= 2){ if (lid < i) sum[lid] = fmax(sum[lid], sum[lid + i]); barrier(CLK_LOCAL_MEM_FENCE); } - maxValue = sum[0]; - - maxValue.x = fmax(maxValue.x, maxValue.y); - maxValue.x = fmax(maxValue.x, maxValue.z); - maxValue.x = fmax(maxValue.x, maxValue.w); - - COMPUTE_FLOAT4 input_data = CONVERT_COMPUTE_FLOAT4(vload4((shape.y - 1) *shape.z*shape.w, input+offset)); - if (remain_channels == 0) { - maxValue.x = fmax(maxValue.x, input_data.x); - maxValue.x = fmax(maxValue.x, input_data.y); - maxValue.x = fmax(maxValue.x, input_data.z); - maxValue.x = fmax(maxValue.x, input_data.w); - } else if (remain_channels == 1) { - maxValue.x = fmax(maxValue.x, input_data.z); - maxValue.x = fmax(maxValue.x, input_data.y); - maxValue.x = fmax(maxValue.x, input_data.x); - } else if (remain_channels == 2) { - maxValue.x = fmax(maxValue.x, input_data.y); - maxValue.x = fmax(maxValue.x, input_data.x); - } else if (remain_channels == 3) { - maxValue.x = fmax(maxValue.x, input_data.x); + maxValue.x = sum[0]; + for(int i = loop_end << 2; i < dim; ++i){ + maxValue.x = fmax(maxValue.x, (COMPUTE_FLOAT)(input[offset+i])); } + // compute sumvalue COMPUTE_FLOAT4 sumValue = (COMPUTE_FLOAT4)0; - for (int i = lid; i < shape.y - 1; i+=SOFTMAX_LOCAL_SIZE) { - sumValue += exp(CONVERT_COMPUTE_FLOAT4(vload4(i*shape.z*shape.w, input+offset)) - (COMPUTE_FLOAT4)maxValue.x); + for (int i = lid; i < loop_end; i+=SOFTMAX_LOCAL_SIZE) { + sumValue += exp(CONVERT_COMPUTE_FLOAT4(vload4(i, input+offset)) - (COMPUTE_FLOAT4)maxValue.x); } - sum[lid] = sumValue; + sum[lid] = sumValue.x + sumValue.y + sumValue.z + sumValue.w; barrier(CLK_LOCAL_MEM_FENCE); for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i /= 2){ if (lid < i) sum[lid] = sum[lid] + sum[lid + i]; barrier(CLK_LOCAL_MEM_FENCE); } - sumValue = sum[0]; - sumValue.x = sumValue.x + sumValue.y + sumValue.z + sumValue.w; - + sumValue.x = sum[0]; + for(int i = loop_end << 2; i < dim; ++i){ + sumValue.x += exp((COMPUTE_FLOAT)(input[offset+i]) - maxValue.x); + } - input_data -= maxValue.x; - if (remain_channels == 0) { - sumValue.x += exp(input_data.w); - sumValue.x += exp(input_data.z); - sumValue.x += exp(input_data.y); - sumValue.x += exp(input_data.x); - } else if (remain_channels == 1) { - sumValue.x += exp(input_data.z); - sumValue.x += exp(input_data.y); - sumValue.x += exp(input_data.x); - } else if (remain_channels == 2) { - sumValue.x += exp(input_data.y); - sumValue.x += exp(input_data.x); - } else if (remain_channels == 3) { - sumValue.x += exp(input_data.x); + // store result + for(int i = lid; i < loop_end; i+=SOFTMAX_LOCAL_SIZE){ + vstore4(CONVERT_FLOAT4(exp(CONVERT_COMPUTE_FLOAT4(vload4(i, input+offset)) - (COMPUTE_FLOAT4)maxValue.x) / (COMPUTE_FLOAT4)sumValue.x), 0, output + offset + i * 4); } - for(int i = lid; i < shape.y; i+=SOFTMAX_LOCAL_SIZE){ - COMPUTE_FLOAT4 value = exp(CONVERT_COMPUTE_FLOAT4(vload4(i*shape.z*shape.w, input+offset)) - maxValue.x) / sumValue.x; - vstore4(CONVERT_FLOAT4(value), i*shape.z*shape.w, output+offset); + for(int i = loop_end << 2; i < dim; ++i){ + output[offset + i] = (FLOAT)exp((COMPUTE_FLOAT)(input[offset + i]) - maxValue.x) / sumValue.x; } #else + // compute maxvalue COMPUTE_FLOAT4 maxValue = (COMPUTE_FLOAT4)-FLT_MAX; - for (int i = 0; i < shape.y - 1; i++) { - maxValue = fmax(maxValue, CONVERT_COMPUTE_FLOAT4(vload4(i*shape.z*shape.w, input+offset))); + for (int i = 0; i < loop_end; i++) { + maxValue = fmax(maxValue, CONVERT_COMPUTE_FLOAT4(vload4(i, input+offset))); } - - maxValue.x = fmax(maxValue.x, maxValue.y); - maxValue.x = fmax(maxValue.x, maxValue.z); - maxValue.x = fmax(maxValue.x, maxValue.w); - - COMPUTE_FLOAT4 input_data = CONVERT_COMPUTE_FLOAT4(vload4((shape.y - 1) *shape.z*shape.w, input+offset)); - if (remain_channels == 0) { - maxValue.x = fmax(maxValue.x, input_data.x); - maxValue.x = fmax(maxValue.x, input_data.y); - maxValue.x = fmax(maxValue.x, input_data.z); - maxValue.x = fmax(maxValue.x, input_data.w); - } else if (remain_channels == 1) { - maxValue.x = fmax(maxValue.x, input_data.z); - maxValue.x = fmax(maxValue.x, input_data.y); - maxValue.x = fmax(maxValue.x, input_data.x); - } else if (remain_channels == 2) { - maxValue.x = fmax(maxValue.x, input_data.y); - maxValue.x = fmax(maxValue.x, input_data.x); - } else if (remain_channels == 3) { - maxValue.x = fmax(maxValue.x, input_data.x); + maxValue.x = fmax(fmax(fmax(maxValue.x, maxValue.y), maxValue.z), maxValue.w); + for(int i = loop_end << 2; i < dim; ++i){ + maxValue.x = fmax(maxValue.x, (COMPUTE_FLOAT)(input[offset+i])); } - + + // compute sumvalue COMPUTE_FLOAT4 sumValue = (COMPUTE_FLOAT4)0; - for (int i = 0; i < shape.y - 1; i++) { - sumValue += exp(CONVERT_COMPUTE_FLOAT4(vload4(i*shape.z*shape.w, input+offset)) - (COMPUTE_FLOAT4)maxValue.x); + for (int i = 0; i < loop_end; i++) { + sumValue += exp(CONVERT_COMPUTE_FLOAT4(vload4(i, input+offset)) - (COMPUTE_FLOAT4)maxValue.x); } sumValue.x = sumValue.x + sumValue.y + sumValue.z + sumValue.w; - input_data -= maxValue.x; - if (remain_channels == 0) { - sumValue.x += exp(input_data.w); - sumValue.x += exp(input_data.z); - sumValue.x += exp(input_data.y); - sumValue.x += exp(input_data.x); - } else if (remain_channels == 1) { - sumValue.x += exp(input_data.z); - sumValue.x += exp(input_data.y); - sumValue.x += exp(input_data.x); - } else if (remain_channels == 2) { - sumValue.x += exp(input_data.y); - sumValue.x += exp(input_data.x); - } else if (remain_channels == 3) { - sumValue.x += exp(input_data.x); + for(int i = loop_end << 2; i < dim; ++i){ + sumValue.x += exp((COMPUTE_FLOAT)(input[offset+i]) - maxValue.x); + } + + // store result + for(int i = 0; i < loop_end; i++){ + vstore4(CONVERT_FLOAT4(exp(CONVERT_COMPUTE_FLOAT4(vload4(i, input+offset)) - (COMPUTE_FLOAT4)maxValue.x) / (COMPUTE_FLOAT4)sumValue.x), 0, output + offset + i * 4); } - for(int i = 0; i < shape.y; i++){ - COMPUTE_FLOAT4 value = exp(CONVERT_COMPUTE_FLOAT4(vload4(i*shape.z*shape.w, input+offset)) - maxValue.x) / sumValue.x; - vstore4(CONVERT_FLOAT4(value), i*shape.z*shape.w, output+offset); + for(int i = loop_end << 2; i < dim; ++i){ + output[offset + i] = (FLOAT)exp((COMPUTE_FLOAT)(input[offset + i]) - maxValue.x) / sumValue.x; } #endif } +__kernel void softmax_buf(GLOBAL_SIZE_3_DIMS + __global const FLOAT *input, + __global FLOAT *output, + __private const int inside, + __private const int outside, + __private const int dim) { -__kernel void softmax_height(GLOBAL_SIZE_3_DIMS - __global const FLOAT *input, - __global FLOAT *output, - __private const int remain_channels, - __private const int4 shape // NCHW - ) { const int x = get_global_id(0); - const int wc = get_global_id(1); - const int b = get_global_id(2); - DEAL_NON_UNIFORM_DIM3(x, wc, b); + const int y = get_global_id(1); // inside + const int z = get_global_id(2); // outside + DEAL_NON_UNIFORM_DIM3(x, y, z); - const int c = wc / shape.w; - const int w = wc % shape.w; - const int offset = (((b*shape.y+c)*shape.z+0)*shape.w+w)*4; + const int offset = z * dim * inside + y; #if SOFTMAX_LOCAL_SIZE >= 4 int lid = get_local_id(0); - COMPUTE_FLOAT4 local sum[SOFTMAX_LOCAL_SIZE]; - - /*Compute Max */ - COMPUTE_FLOAT4 maxValue = (COMPUTE_FLOAT4)(-FLT_MAX); - for (int i=lid; i 0; i /= 2){ @@ -187,11 +134,10 @@ __kernel void softmax_height(GLOBAL_SIZE_3_DIMS barrier(CLK_LOCAL_MEM_FENCE); } maxValue = sum[0]; - - /*Compute Exp Sum*/ - COMPUTE_FLOAT4 sumValue = (COMPUTE_FLOAT4)0; - for (int i=lid; i= 4 int lid = get_local_id(0); COMPUTE_FLOAT4 local sum[SOFTMAX_LOCAL_SIZE]; - - /*Compute Max */ - COMPUTE_FLOAT4 maxValue = (COMPUTE_FLOAT4)(-FLT_MAX); - for (int i=lid; i 0; i /= 2){ @@ -259,11 +196,10 @@ __kernel void softmax_width(GLOBAL_SIZE_3_DIMS barrier(CLK_LOCAL_MEM_FENCE); } maxValue = sum[0]; - - /*Compute Exp Sum*/ + COMPUTE_FLOAT4 sumValue = (COMPUTE_FLOAT4)0; - for (int i=lid; i> 2; - const int area_4 = (shape.z + 3) >> 2; - const int in_offset = ((b * channel_4 + c_4) * area_4 * 2 + hw_4) * 16; - const int out_offset = ((b * channel_4 + c_4) * area_4 + hw_4) * 16; - - float16 valueL = convert_float16(vload16(0, input + in_offset)); - float16 valueR = convert_float16(vload16(area_4, input + in_offset)); - - #ifdef DOUBLE_INPUTS - float4 valueConstL = convert_float4(vload4(hw, input1)); - float4 valueConstR = convert_float4(vload4(area_4+hw, input1)); - valueL += (float16)((float4)valueConstL.x, (float4)valueConstL.y, (float4)valueConstL.z, (float4)valueConstL.w); - valueR += (float16)((float4)valueConstR.x, (float4)valueConstR.y, (float4)valueConstR.z, (float4)valueConstR.w); - #endif - float16 out = (erf(valueR * (float16)0.7071067932881648) + (float16)1.0) * valueR * (float16)0.5; - out *= valueL; - vstore16(CONVERT_FLOAT16(out), 0, output + out_offset); +#elif defined (WH_4) + + const int in_offset = bc * shape.z * 2 + h * 4; + const int out_offset = bc * shape.z + h * 4; + + float4 valueL = convert_float4(vload4(0, input + in_offset)); + float4 valueR = convert_float4(vload4(0, input + in_offset + shape.z)); + + #ifdef DOUBLE_INPUTS + float4 valueConstL = convert_float4(vload4(h, input1)); + float4 valueConstR = convert_float4(vload4(h, input1 + shape.z)); + valueL += valueConstL; + valueR += valueConstR; + #endif + float4 out = (erf(valueR * (float4)0.7071067932881648) + (float4)1.0) * valueR * (float4)0.5; + out *= valueL; + vstore4(CONVERT_FLOAT4(out), 0, output + out_offset); #else - const int hw = pos.z; - - const int channel_4 = (shape.y + 3) >> 2; - const int in_offset = ((b * channel_4 + c_4) * shape.z * 2 + hw) * 4; - const int out_offset = ((b * channel_4 + c_4) * shape.z + hw) * 4; - - float4 valueL = convert_float4(vload4(0, input + in_offset)); - float4 valueR = convert_float4(vload4(shape.z, input + in_offset)); - - #ifdef DOUBLE_INPUTS - float valueConstL = input1[hw]; - float valueConstR = input1[shape.z+hw]; - valueL += (float4)valueConstL; - valueR += (float4)valueConstR; - #endif - float4 out = (erf(valueR * (float4)0.7071067932881648) + (float4)1.0) * valueR * (float4)0.5; - out *= valueL; - vstore4(CONVERT_FLOAT4(out), 0, output + out_offset); + const int in_offset = bc * shape.z * 2 + h; + const int out_offset = bc * shape.z + h; + + float valueL = (float)input[in_offset]; + float valueR = (float)input[in_offset + shape.z]; + + #ifdef DOUBLE_INPUTS + float valueConstL = input1[h]; + float valueConstR = input1[shape.z+h]; + valueL += valueConstL; + valueR += valueConstR; + #endif + float out = (erf(valueR * 0.7071067932881648) + 1.0) * valueR * 0.5; + out *= valueL; + output[out_offset] = out; #endif } } diff --git a/source/backend/opencl/execution/cl/unary_buf.cl b/source/backend/opencl/execution/cl/unary_buf.cl index 67565b1b3..9a93d83ce 100644 --- a/source/backend/opencl/execution/cl/unary_buf.cl +++ b/source/backend/opencl/execution/cl/unary_buf.cl @@ -2,11 +2,11 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable #endif -#define GLOBAL_SIZE_3_DIMS \ - __private const int global_size_dim0, __private const int global_size_dim1, __private const int global_size_dim2, +#define GLOBAL_SIZE_2_DIMS \ + __private const int global_size_dim0, __private const int global_size_dim1, -#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) \ - if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { \ +#define DEAL_NON_UNIFORM_DIM2(input1, input2) \ + if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { \ return; \ } inline float4 gelu(float4 in){ @@ -17,22 +17,35 @@ inline float4 gelu(float4 in){ return (1.0f + dst) * in * 0.5f; } -__kernel void unary_buf(GLOBAL_SIZE_3_DIMS +__kernel void unary_buf(GLOBAL_SIZE_2_DIMS __global const INPUT_TYPE *input, __global OUTPUT_TYPE *output, - __private const int height) { - const int channel_block_idx = get_global_id(0); - const int w = get_global_id(1); - const int hb = get_global_id(2); + __private const int size) { + const int x = get_global_id(0); + const int y = get_global_id(1); - DEAL_NON_UNIFORM_DIM3(channel_block_idx, w, hb); - - const int batch_idx = hb / height; - const int height_idx = hb % height; - - const int offset = (((batch_idx*global_size_dim0+channel_block_idx)*height+height_idx)*global_size_dim1+w) * 4; - float4 in = convert_float4(vload4(0, input+offset)); - float4 out = OPERATOR; - vstore4(CONVERT_OUTPUT4(out), 0, output+offset); + DEAL_NON_UNIFORM_DIM2(x, y); + const int offset = x << 2; +#ifdef PACK_LEAVE + if(offset + 3 >= size){ + int remain = size - offset; + float4 in; + float* in_ptr = (float*)∈ + for(int i = 0; i < remain; ++i){ + in_ptr[i] = (float)input[offset + i]; + } + float4 out = OPERATOR; + float* out_ptr = (float*)&out; + for(int i = 0; i < remain; ++i){ + output[offset + i] = (OUTPUT_TYPE)out_ptr[i]; + } + }else { +#endif + float4 in = convert_float4(vload4(0, input + offset)); + float4 out = OPERATOR; + vstore4(CONVERT_OUTPUT4(out), 0, output + offset); +#ifdef PACK_LEAVE + } +#endif } diff --git a/source/backend/opencl/execution/cl/unary_subgroup_buf.cl b/source/backend/opencl/execution/cl/unary_subgroup_buf.cl index d2d8b9528..ffdc8f8f3 100644 --- a/source/backend/opencl/execution/cl/unary_subgroup_buf.cl +++ b/source/backend/opencl/execution/cl/unary_subgroup_buf.cl @@ -23,6 +23,7 @@ __kernel void unary_buf_c4_c4(GLOBAL_SIZE_3_DIMS __private const int width, __private const int height, __private const int channel, + __private const int batch, __private const int input_pad_left, __private const int input_pad_right, __private const int output_pad_left, __private const int output_pad_right) { const int channel_block_idx = get_global_id(0); @@ -33,9 +34,8 @@ __kernel void unary_buf_c4_c4(GLOBAL_SIZE_3_DIMS const int batch_idx = hb / height; const int height_idx = hb % height; - const int channel4 = (channel + 3) / 4; - const int offset = (((batch_idx*channel4+channel_block_idx)*height+height_idx)*width+w) * 4; + const int offset = (((batch_idx+channel_block_idx*batch)*height+height_idx)*width+w) * 4; float4 in = convert_float4(vload4(0, input+offset)); float4 out = OPERATOR; vstore4(CONVERT_OUTPUT4(out), 0, output+offset); @@ -47,6 +47,7 @@ __kernel void unary_buf_c4_c16(GLOBAL_SIZE_3_DIMS __private const int width, __private const int height, __private const int channel, + __private const int batch, __private const int input_pad_left, __private const int input_pad_right, __private const int output_pad_left, __private const int output_pad_right) { const int channel_block_idx = get_global_id(0); @@ -58,11 +59,10 @@ __kernel void unary_buf_c4_c16(GLOBAL_SIZE_3_DIMS const int batch_idx = hb / height; const int height_idx = hb % height; const int dst_width = output_pad_left+width+output_pad_right; - const int channel4 = (channel + 3) / 4; const int channel16 = (channel + 15) / 16; const int channe_out_idx = channel_block_idx >> 2; - const int offset = (((batch_idx*channel4+channel_block_idx)*height+height_idx)*width+w) * 4; + const int offset = (((batch_idx+channel_block_idx*batch)*height+height_idx)*width+w) * 4; const int dst_offset = (((batch_idx*channel16+channe_out_idx)*height+height_idx)*dst_width+w+output_pad_left) * 16 + (channel_block_idx % 4) * 4; float4 in = convert_float4(vload4(0, input+offset)); float4 out = OPERATOR; @@ -86,6 +86,7 @@ __kernel void unary_buf_c16_c16(GLOBAL_SIZE_3_DIMS __private const int width, __private const int height, __private const int channel, + __private const int batch, __private const int input_pad_left, __private const int input_pad_right, __private const int output_pad_left, __private const int output_pad_right) { const int channel_idx = get_group_id(0); @@ -132,6 +133,7 @@ __kernel void unary_buf_c16_c4(GLOBAL_SIZE_3_DIMS __private const int width, __private const int height, __private const int channel, + __private const int batch, __private const int input_pad_left, __private const int input_pad_right, __private const int output_pad_left, __private const int output_pad_right) { const int channel_idx = get_group_id(0); @@ -142,12 +144,11 @@ __kernel void unary_buf_c16_c4(GLOBAL_SIZE_3_DIMS const int batch_idx = hb / height; const int height_idx = hb % height; const int src_width = width + input_pad_left + input_pad_right; - const int channel4 = (channel + 3) / 4; const int channel16 = (channel + 15) / 16; const int src_offset = (((batch_idx*channel16+channel_idx)*height+height_idx)*src_width+w+input_pad_left) * 16; - const int dst_offset = (((batch_idx*channel4+(channel_idx<<2))*height+height_idx)*width+w) * 4; + const int dst_offset = (((batch_idx+(channel_idx<<2)*batch)*height+height_idx)*width+w) * 4; const int height_width = height * width * 4; float4 in = convert_float4(AS_INPUT_DATA4(INTEL_SUB_GROUP_READ4((__global INTEL_DATA*)(input + src_offset)))); diff --git a/source/backend/opencl/execution/cl/winogradTransform_buf.cl b/source/backend/opencl/execution/cl/winogradTransform_buf.cl index 0caf484b0..87424d940 100644 --- a/source/backend/opencl/execution/cl/winogradTransform_buf.cl +++ b/source/backend/opencl/execution/cl/winogradTransform_buf.cl @@ -97,6 +97,7 @@ __kernel void winoTransSrcBuf2_3_1(GLOBAL_SIZE_DIM2 __private const int srcWidth, // 6 __private const int srcHeight, __private const int srcChannelC4, __private const int dstHeightPad, __private const int srcChannelPad, + __private const int batch, __private const int batchOffset) { int2 pos = (int2)(get_global_id(0), get_global_id(1)); UNIFORM_BOUNDRY_CHECK(pos.x, pos.y); @@ -133,7 +134,7 @@ __kernel void winoTransSrcBuf2_3_1(GLOBAL_SIZE_DIM2 FLOAT4 S23; FLOAT4 S33; - int inp_offset = (((batchIndex * srcChannelC4 + srcZ) * srcHeight + syStart) * srcWidth + sxStart) * 4; + int inp_offset = (((batchIndex + srcZ * batch) * srcHeight + syStart) * srcWidth + sxStart) * 4; { int sx = 0 + sxStart; int sy = 0 + syStart; @@ -395,6 +396,7 @@ __kernel void winoTransDstBuf2_3_1(GLOBAL_SIZE_DIM2 __private const int dstChannelC4, __private const int srcWidthPad, __private const int dstChannelPad, + __private const int batch, __private const int batchOffset) { int2 pos = (int2)(get_global_id(0), get_global_id(1)); UNIFORM_BOUNDRY_CHECK(pos.x, pos.y); @@ -447,7 +449,7 @@ __kernel void winoTransDstBuf2_3_1(GLOBAL_SIZE_DIM2 //NC4HW4 [batch, dstChannelC4, dstHeight, dstWidth] //index: [batchIndex, oz, oyStart, oxStart] - int out_offset = (((batchIndex * dstChannelC4+ oz) * dstHeight + oyStart) * dstWidth + oxStart)*4; + int out_offset = (((batchIndex + oz * batch) * dstHeight + oyStart) * dstWidth + oxStart)*4; { int ox = oxStart + 0; int oy = oyStart + 0; diff --git a/source/backend/opencl/execution/cl/winogradTransform_subgroup_buf.cl b/source/backend/opencl/execution/cl/winogradTransform_subgroup_buf.cl index 4a7d903b3..05833ba41 100644 --- a/source/backend/opencl/execution/cl/winogradTransform_subgroup_buf.cl +++ b/source/backend/opencl/execution/cl/winogradTransform_subgroup_buf.cl @@ -20,6 +20,7 @@ __kernel void winoTransSrcBuf2_3_1_c16_c16(GLOBAL_SIZE_DIM2 __private const int srcWidth, // 6 __private const int srcHeight, __private const int srcChannelC4, __private const int srcChannelC16, __private const int dstHeight, __private const int batchOffset, + __private const int batch, __private const int input_pad_left, __private const int input_pad_right) { int2 pos = (int2)(get_global_id(0), get_global_id(1)); UNIFORM_BOUNDRY_CHECK(pos.x, pos.y); @@ -101,6 +102,7 @@ __kernel void winoTransDstBuf2_3_1_c16_c16(GLOBAL_SIZE_DIM2 __private const int dstHeight, __private const int dstChannelC4,__private const int dstChannelC16,__private const int srcWidth, __private const int batchOffset, + __private const int batch, __private const int output_pad_left, __private const int output_pad_right) { int2 pos = (int2)(get_global_id(0), get_global_id(1)); UNIFORM_BOUNDRY_CHECK(pos.x, pos.y); @@ -225,6 +227,7 @@ __kernel void winoTransSrcBuf2_3_1_c4_c16(GLOBAL_SIZE_DIM2 __private const int srcWidth, // 6 __private const int srcHeight, __private const int srcChannelC4, __private const int srcChannelC16, __private const int dstHeight, __private const int batchOffset, + __private const int batch, __private const int input_pad_left, __private const int input_pad_right) { int2 pos = (int2)(get_global_id(0), get_global_id(1)); UNIFORM_BOUNDRY_CHECK(pos.x, pos.y); @@ -253,7 +256,7 @@ __kernel void winoTransSrcBuf2_3_1_c4_c16(GLOBAL_SIZE_DIM2 FLOAT4 S23; FLOAT4 S33; - int inp_offset = (((batchOffset * srcChannelC4 + pos.y) * srcHeight + syStart) * srcWidth + sxStart) * 4; + int inp_offset = (((batchOffset + pos.y * batch) * srcHeight + syStart) * srcWidth + sxStart) * 4; { int sx = 0 + sxStart; int sy = 0 + syStart; @@ -417,6 +420,7 @@ __kernel void winoTransDstBuf2_3_1_c16_c4(GLOBAL_SIZE_DIM2 __private const int dstHeight, __private const int dstChannelC4,__private const int dstChannelC16,__private const int srcWidth, __private const int batchOffset, + __private const int batch, __private const int output_pad_left, __private const int output_pad_right) { int2 pos = (int2)(get_global_id(0), get_global_id(1)); UNIFORM_BOUNDRY_CHECK(pos.x, pos.y); @@ -463,7 +467,7 @@ __kernel void winoTransDstBuf2_3_1_c16_c4(GLOBAL_SIZE_DIM2 //NC4HW4 [batch, dstChannelC4, dstHeight, dstWidth] //index: [batchOffset, pos.y, oyStart, oxStart] - int out_offset = (((batchOffset * dstChannelC4+ pos.y) * dstHeight + oyStart) * dstWidth + oxStart)*4; + int out_offset = (((batchOffset+ pos.y * batch) * dstHeight + oyStart) * dstWidth + oxStart)*4; { int ox = oxStart + 0; int oy = oyStart + 0; diff --git a/source/backend/opencl/execution/image/ConvExecution.cpp b/source/backend/opencl/execution/image/ConvExecution.cpp index d5315ffee..d2f6d288a 100644 --- a/source/backend/opencl/execution/image/ConvExecution.cpp +++ b/source/backend/opencl/execution/image/ConvExecution.cpp @@ -491,7 +491,7 @@ class ConvolutionCreator : public OpenCLBackend::Creator { std::vector inputShape = tensorShapeFormat(inputs[0]); const int inputChannels = inputShape.at(3); #if defined(MNN_LOW_MEMORY) && not defined(MNN_OPENCL_BUFFER_CLOSED) - { + if (static_cast(backend)->getMemory() == BackendConfig::Memory_Low){ auto conv2dParams = op->main_as_Convolution2D(); if (conv2dParams->quanParameter() != nullptr) { if (((conv2dParams->quanParameter()->type() == 4) || diff --git a/source/backend/opencl/execution/image/ConvLowMemoryExecution.cpp b/source/backend/opencl/execution/image/ConvLowMemoryExecution.cpp index f40f3d644..717bab14a 100644 --- a/source/backend/opencl/execution/image/ConvLowMemoryExecution.cpp +++ b/source/backend/opencl/execution/image/ConvLowMemoryExecution.cpp @@ -86,9 +86,6 @@ bool ConvLowMemoryExecution::convertToQuantWeight1x1Buffer(cl::Buffer input, int // int4 case buildOptions.emplace("-DUSE_LOW_BIT_WEIGHT_INT4"); } else {/* More types to be supported. */} - if(mResource->mInputChannel % icPack != 0){ - buildOptions.emplace("-DCHANNEL_LEAVE"); - } mBufferToConv1x1Kernel = runtime->buildKernelWithCache("buffer_convert_quant", kernelName, buildOptions); auto kernel = mBufferToConv1x1Kernel->get(); @@ -495,6 +492,9 @@ ConvLowMemoryExecution::ConvLowMemoryExecution(const std::vector &inpu setGeneralWeightLowMemory(mFilterDataPtr, quanCommon); } // Create Kernel + if (mResource->mStrides[0] == 1 && mResource->mStrides[1] == 1 && mResource->mDilations[0] == 1 && mResource->mDilations[1] == 1) { + mResource->mBuildOptions.emplace("-DMNN_CONV_S1D1"); + } mResource->mBuildOptions.emplace("-DBIAS"); if (conv2dCommonParams->relu()) { mResource->mBuildOptions.emplace("-DRELU"); diff --git a/source/backend/opengl/GLBackend.cpp b/source/backend/opengl/GLBackend.cpp index d0d9ba2c7..c8c460407 100644 --- a/source/backend/opengl/GLBackend.cpp +++ b/source/backend/opengl/GLBackend.cpp @@ -439,7 +439,7 @@ bool GLBackend::isCreateError() const { } -Backend* GLRuntime::onCreate(const BackendConfig* config) const { +Backend* GLRuntime::onCreate(const BackendConfig* config, Backend* origin) const { BackendConfig::PrecisionMode precision = BackendConfig::Precision_Normal; BackendConfig::PowerMode power = BackendConfig::Power_Normal; if (nullptr != mInfo.user) { @@ -477,7 +477,7 @@ class GLRuntimeCreator : public RuntimeCreator { public: virtual Runtime *onCreate(const Backend::Info &info) const override { auto rt = new GLRuntime(info); - auto bn = (GLBackend*)(rt->onCreate(nullptr)); + auto bn = (GLBackend*)(rt->onCreate(nullptr, nullptr)); if (bn->isCreateError()) { delete bn; delete rt; diff --git a/source/backend/opengl/GLBackend.hpp b/source/backend/opengl/GLBackend.hpp index b36140258..2c0307faa 100644 --- a/source/backend/opengl/GLBackend.hpp +++ b/source/backend/opengl/GLBackend.hpp @@ -35,7 +35,7 @@ class GLRuntime : public Runtime { @brief create backend @return created backend */ - virtual Backend* onCreate(const BackendConfig* config) const override; + virtual Backend* onCreate(const BackendConfig* config, Backend* origin) const override; /** @brief clear unuseful resource diff --git a/source/backend/tensorrt/backend/TRTBackend.cpp b/source/backend/tensorrt/backend/TRTBackend.cpp index 49d954b10..66fde8932 100755 --- a/source/backend/tensorrt/backend/TRTBackend.cpp +++ b/source/backend/tensorrt/backend/TRTBackend.cpp @@ -54,7 +54,7 @@ TRTRuntime::TRTRuntime(const Backend::Info& info) { TRTRuntime::~TRTRuntime() { } -Backend* TRTRuntime::onCreate(const BackendConfig* config) const { +Backend* TRTRuntime::onCreate(const BackendConfig* config, Backend* origin) const { return new TRTBackend(this); } diff --git a/source/backend/tensorrt/backend/TRTBackend.hpp b/source/backend/tensorrt/backend/TRTBackend.hpp index c7e14fa6c..adde390c2 100644 --- a/source/backend/tensorrt/backend/TRTBackend.hpp +++ b/source/backend/tensorrt/backend/TRTBackend.hpp @@ -34,7 +34,7 @@ class TRTRuntime : public Runtime { TRTRuntime(const Backend::Info& info); virtual ~TRTRuntime(); - virtual Backend* onCreate(const BackendConfig* config) const override; + virtual Backend* onCreate(const BackendConfig* config, Backend* origin) const override; virtual void onGabageCollect(int level) override; // If buffer is not nullptr, try copy cache, else delete cache virtual bool onSetCache(const void* buffer, size_t size) override { diff --git a/source/backend/vulkan/component/VulkanPipeline.cpp b/source/backend/vulkan/component/VulkanPipeline.cpp index e0da6bcdd..5b26ca094 100644 --- a/source/backend/vulkan/component/VulkanPipeline.cpp +++ b/source/backend/vulkan/component/VulkanPipeline.cpp @@ -128,7 +128,7 @@ VulkanLayout::DescriptorSet* VulkanPipeline::createSet() const { } void VulkanPipeline::changePipeline(const std::vector& localSize) const{ - VkPipeline pipeline = VK_NULL_HANDLE; + mDevice.destroyPipeline(mPipeline); /*for localSize_x_id = 0,localSize_y_id = 1,localSize_z_id = 2*/ std::vector specializationMapEntry; /*localSize data description*/ std::shared_ptr specializationInfo = std::make_shared(); @@ -145,11 +145,10 @@ void VulkanPipeline::changePipeline(const std::vector& localSize) cons specializationInfo->mapEntryCount = specializationMapEntry.size(); } - auto res = mDevice.createComputePipeline(pipeline, mShader->get(), mLayout->get(), mCache->get(), specializationInfo.get()); + auto res = mDevice.createComputePipeline(mPipeline, mShader->get(), mLayout->get(), mCache->get(), specializationInfo.get()); if (VK_SUCCESS != res) { FUNC_PRINT(1); } - mPipeline = pipeline; } VulkanLayout::DescriptorSet* VulkanLayout::createSet() const { diff --git a/source/backend/vulkan/runtime/VulkanRuntime.cpp b/source/backend/vulkan/runtime/VulkanRuntime.cpp index 795c24f99..191158113 100644 --- a/source/backend/vulkan/runtime/VulkanRuntime.cpp +++ b/source/backend/vulkan/runtime/VulkanRuntime.cpp @@ -165,7 +165,7 @@ void VulkanRuntime::onGabageCollect(int level) { mPipelineFactory->reset(); } -Backend* VulkanRuntime::onCreate(const BackendConfig* config) const { +Backend* VulkanRuntime::onCreate(const BackendConfig* config, Backend* origin) const { // FIXME: Use config return new VulkanBackend(this, mInfo); } diff --git a/source/backend/vulkan/runtime/VulkanRuntime.hpp b/source/backend/vulkan/runtime/VulkanRuntime.hpp index c8dfa56ac..3c04c9808 100644 --- a/source/backend/vulkan/runtime/VulkanRuntime.hpp +++ b/source/backend/vulkan/runtime/VulkanRuntime.hpp @@ -26,7 +26,7 @@ class VulkanRuntime : public Runtime { public: virtual ~ VulkanRuntime(); - virtual Backend* onCreate(const BackendConfig* config) const override; + virtual Backend* onCreate(const BackendConfig* config, Backend* origin) const override; enum GPUType { ADRENO = 0, MALI = 1, OTHER = 2 }; virtual void onGabageCollect(int level) override; virtual float onGetMemoryInMB() override; diff --git a/source/core/Backend.hpp b/source/core/Backend.hpp index 0d199bd90..2e0b2548b 100644 --- a/source/core/Backend.hpp +++ b/source/core/Backend.hpp @@ -34,11 +34,12 @@ struct RuntimeHint { int cpuDecreaseRate = 50; int dynamicQuantOption = 0; - // 0: Do not quantize kvcache, just store float - // 1: Only quantize key cache, use int8 asymmetric quantization - // 2: Only quantize value cache, use fp8 quantization - // 3: quantize both key and value cache as described above - int kvcacheQuantOption = 0; + // 0: Do not quantize + // 1: Only quantize key, use int8 asymmetric quantization + // 2: Only quantize value, use fp8 quantization + // 3: quantize both key and value + // 4: quantize query, key and value, and use gemm int8 kernel to compute K*V + int qkvQuantOption = 0; // the kvcache size limit of each layer // if the size of kvcache in memory exceeds the limit @@ -48,6 +49,9 @@ struct RuntimeHint { // path of the kvcache directory std::string kvcacheDirPath = "/tmp"; + + std::string midMemoryPath; + std::string weightMemoryPath; }; /** abstract backend */ class Backend : public NonCopyable { @@ -267,7 +271,7 @@ class Runtime : public NonCopyable { @brief create backend @return created backend */ - virtual Backend* onCreate(const BackendConfig* config = nullptr) const = 0; + virtual Backend* onCreate(const BackendConfig* config = nullptr, Backend* origin = nullptr) const = 0; /** @brief reset runtime diff --git a/source/core/BufferAllocator.cpp b/source/core/BufferAllocator.cpp index 43104da80..676cf97fc 100644 --- a/source/core/BufferAllocator.cpp +++ b/source/core/BufferAllocator.cpp @@ -6,8 +6,10 @@ // Copyright © 2018, Alibaba Group Holding Limited // +#include #include "core/BufferAllocator.hpp" #include "core/Macro.h" +#include "MNNFileUtils.h" // #define DUMP_USAGE //#define MNN_DEBUG_MEMORY @@ -54,6 +56,62 @@ class DefaultAllocator : public BufferAllocator::Allocator { MNNMemoryFreeAlign(chunk.first); } }; +class MmapAllocator : public BufferAllocator::Allocator { +private: + std::map> mCache; + std::string mFileName; + std::string mPosfix; + int mAllocTimes = 0; + bool mRemove; +public: + MmapAllocator(const char* dirName, const char* posfix, bool autoRemove) { + if (nullptr != dirName) { + mFileName = dirName; + if (!MNNDirExist(dirName)) { + MNN_ERROR("%s not exist\n", dirName); + } + } + if (nullptr != posfix) { + mPosfix = posfix; + } + mRemove = autoRemove; + } + virtual ~ MmapAllocator() { + for (auto& iter : mCache) { + MNNUnmapFile(iter.first, std::get<1>(iter.second)); + MNNCloseFile(std::get<0>(iter.second)); + if (mRemove) { + MNNRemoveFile(std::get<2>(iter.second).c_str()); + } + } + } + virtual MemChunk onAlloc(size_t size, size_t align) { + MNN_ASSERT(size > 0); + std::string fileName = MNNFilePathConcat(mFileName, std::to_string(mAllocTimes) + "." + mPosfix); + auto file = MNNCreateFile(fileName.c_str()); + size = UP_DIV(size, align) * align; + MNNSetFileSize(file, size); + void* ptr = MNNMmapFile(file, size); + mCache.insert(std::make_pair(ptr, std::make_tuple(file, size, fileName))); + mAllocTimes++; + return MemChunk(ptr, 0); + } + virtual void onRelease(MemChunk chunk) { + MNN_ASSERT(chunk.second == 0); + auto iter = mCache.find(chunk.first); + if (iter == mCache.end()) { + MNN_ASSERT(false); + MNN_ERROR("Invalid free for MMAPAllocator\n"); + return; + } + MNNUnmapFile(iter->first, std::get<1>(iter->second)); + MNNCloseFile(std::get<0>(iter->second)); + if (mRemove) { + MNNRemoveFile(std::get<2>(iter->second).c_str()); + } + mCache.erase(iter); + } +}; class RecurseAllocator : public BufferAllocator::Allocator { public: RecurseAllocator(BufferAllocator* parent) { @@ -72,14 +130,17 @@ class RecurseAllocator : public BufferAllocator::Allocator { BufferAllocator* mParent; }; -ErrorCode BufferAllocator::compute() { - return NO_ERROR; -} std::shared_ptr BufferAllocator::Allocator::createDefault() { std::shared_ptr _res; _res.reset(new DefaultAllocator); return _res; } +std::shared_ptr BufferAllocator::Allocator::createMmap(const char* dirName, const char* posfix, bool autoRemove) { + std::shared_ptr _res; + _res.reset(new MmapAllocator(dirName, posfix, autoRemove)); + return _res; +} + std::shared_ptr BufferAllocator::Allocator::createRecurse(BufferAllocator* parent) { std::shared_ptr _res; _res.reset(new RecurseAllocator(parent)); @@ -113,23 +174,48 @@ MemChunk EagerBufferAllocator::alloc(size_t size, bool separate, size_t align) { return MemChunk(pointer); } } + auto allocSize = size; + if (mMinAllocSize != 0) { + allocSize = ALIMAX(mMinAllocSize, size); + } // alloc otherwise - auto chunk = mAllocator->onAlloc(size, align); + auto chunk = mAllocator->onAlloc(allocSize, align); pointer.first = chunk.first; pointer.second = chunk.second; if (nullptr == pointer.first) { return chunk; } - mTotalSize += size; + mTotalSize += allocSize; // save node SharedPtr node(new Node); - node->size = size; + node->size = allocSize; node->pointer = pointer; - mUsedList[pointer] = node; node->outside = mAllocator.get(); MNN_ASSERT(pointer.second % align == 0); + if (allocSize > size) { + // Split + SharedPtr first(new Node); + first->parent = node; + first->size = size; + first->pointer = pointer; + mUsedList.insert(std::make_pair(pointer, first)); + node->useCount = 1; + + SharedPtr second(new Node); + second->parent = node; + second->size = allocSize - size; + second->pointer.first = pointer.first; + second->pointer.second = pointer.second + size; + if (nullptr != mCurrentFreeList) { + mCurrentFreeList->insert(std::make_pair(second->size, second)); + } else { + mFreeList.insert(std::make_pair(second->size, second)); + } + } else { + mUsedList[pointer] = node; + } #ifdef DUMP_USAGE MNN_PRINT("mTotalSize: %f\n", mTotalSize / 1024.0f / 1024.0f); #endif @@ -290,13 +376,40 @@ std::pair EagerBufferAllocator::getFromFreeList(FREELIST* list, s static void _CPUMemChunkApplyToTensor(uint8_t* ptr, size_t offset, Tensor* t) { t->buffer().host = ptr + offset; } +SingleBufferWithAllocator::~ SingleBufferWithAllocator() { + release(); +} +void SingleBufferWithAllocator::release() { + if (current.first != nullptr) { + root->onRelease(current); + current.first = nullptr; + current.second = 0; + currentSize = 0; + } +} + +ErrorCode SingleBufferWithAllocator::realloc(size_t size, size_t align) { + if (currentSize < size) { + if (nullptr != current.first) { + root->onRelease(current); + } + current = root->onAlloc(size, align); + if (current.first == nullptr) { + return OUT_OF_MEMORY; + } + currentSize = size; + } + return NO_ERROR; +} + -DeferBufferAllocator::DeferBufferAllocator(std::shared_ptr parent, size_t align, MemChunkApplyToTensor func) : mAllocator(parent), mAlign(align) { +DeferBufferAllocator::DeferBufferAllocator(SingleBufferWithAllocator* root, size_t align, MemChunkApplyToTensor func) : mAlign(align) { if (nullptr == func) { mApplyFunction = _CPUMemChunkApplyToTensor; } else { mApplyFunction = func; } + mParent = root; } //------------------------------- DeferBufferAllocator -----------------------------------// @@ -371,10 +484,6 @@ void DeferBufferAllocator::release(bool allRelease) { } } -size_t DeferBufferAllocator::totalSize() const { - return mTotalSize; -} - void DeferBufferAllocator::barrierBegin() { MNN_ASSERT(!mBarrrier); mBarrrier = true; @@ -398,12 +507,8 @@ void DeferBufferAllocator::reset() { mTotalSize = 0; mChunks.clear(); mFreeList.clear(); - // mPtr.reset(nullptr); - if (mPtr.ptr()) { - mAllocator->onRelease(mPtr); - mPtr.first = nullptr; - mPtr.second = 0; - } + mPtr.first = nullptr; + mPtr.second = 0; mHead = nullptr; mTail = nullptr; mBarrrier = false; @@ -411,7 +516,7 @@ void DeferBufferAllocator::reset() { } ErrorCode DeferBufferAllocator::compute() { - if (mPtr.ptr()) { + if (mTotalSize > 0) { return NO_ERROR; } mTotalSize = 0; @@ -431,10 +536,28 @@ ErrorCode DeferBufferAllocator::compute() { mTotalSize += chunk->size; chunk = chunk->right; } - mPtr = mAllocator->onAlloc(mTotalSize, mAlign); - if (mPtr.ptr() == nullptr) { - return OUT_OF_MEMORY; + return apply(); +} +ErrorCode DeferBufferAllocator::apply() { + if (mFreeList.empty()) { + // Not alloc + return NO_ERROR; + } + auto& chunk = mParent->current; + bool needApply = false; + if (mParent->currentSize < mTotalSize) { + needApply = true; + auto code = mParent->realloc(mTotalSize, mAlign); + if (NO_ERROR != code) { + return code; + } + } else if (mPtr.first != chunk.first || mPtr.second != chunk.second) { + needApply = true; + } + if (!needApply) { + return NO_ERROR; } + mPtr = chunk; for (auto& chunk : mChunks) { chunk->base = mPtr.ptr(); for (auto t : chunk->tensors) { @@ -474,7 +597,7 @@ void DeferBufferAllocator::erase_node(MemNode* chunk) { } if (right) { right->left = nullptr; - mTail = right; + mHead = right; return; } mHead = mTail = nullptr; diff --git a/source/core/BufferAllocator.hpp b/source/core/BufferAllocator.hpp index 78e0b8ee7..b5c406a4f 100644 --- a/source/core/BufferAllocator.hpp +++ b/source/core/BufferAllocator.hpp @@ -85,6 +85,7 @@ class MNN_PUBLIC BufferAllocator : public NonCopyable { virtual MemChunk onAlloc(size_t size, size_t align) = 0; virtual void onRelease(MemChunk chunk) = 0; static std::shared_ptr createDefault(); + static std::shared_ptr createMmap(const char* dirName, const char* posfix, bool autoRemove = true); static std::shared_ptr createRecurse(BufferAllocator* parent); }; BufferAllocator() = default; @@ -92,13 +93,22 @@ class MNN_PUBLIC BufferAllocator : public NonCopyable { virtual MemChunk alloc(size_t size, bool separate = false, size_t align = 0) = 0; virtual bool free(MemChunk chunk) = 0; virtual void release(bool allRelease = true) = 0; - virtual size_t totalSize() const = 0; + size_t totalSize() const { + return mTotalSize; + } virtual void barrierBegin() {} virtual void barrierEnd() {} virtual void beginGroup() {} virtual void endGroup() {} virtual void reset() {} - virtual ErrorCode compute(); + virtual ErrorCode compute() { + return NO_ERROR; + } + virtual ErrorCode apply() { + return NO_ERROR; + } +protected: + size_t mTotalSize = 0; }; @@ -108,7 +118,7 @@ class MNN_PUBLIC EagerBufferAllocator : public BufferAllocator { * @brief init buffer allocator with pointer alignment. * @param align given pointer alignment. */ - EagerBufferAllocator(std::shared_ptr parent, size_t align = MNN_MEMORY_ALIGN_DEFAULT) : mAllocator(parent), mAlign(align) { + EagerBufferAllocator(std::shared_ptr parent, size_t align = MNN_MEMORY_ALIGN_DEFAULT, size_t minAllocSize = 0) : mAllocator(parent), mAlign(align), mMinAllocSize(minAllocSize) { // nothing to do } /** @@ -145,14 +155,6 @@ class MNN_PUBLIC EagerBufferAllocator : public BufferAllocator { */ void release(bool allRelease = true) override; - /** - * @brief query total size allocated indeed. - * @return total size allocated indeed. - */ - size_t totalSize() const override { - return mTotalSize; - } - /* For multi thread case, we must assume that the memory use by different thread don't conflict @@ -184,40 +186,47 @@ class MNN_PUBLIC EagerBufferAllocator : public BufferAllocator { std::map, SharedPtr> mUsedList; FREELIST mFreeList; - size_t mTotalSize = 0; FREELIST* mCurrentFreeList = nullptr; std::vector> mGroups; std::shared_ptr mAllocator; size_t mAlign; + size_t mMinAllocSize = 0; }; typedef void(*MemChunkApplyToTensor)(uint8_t* ptr, size_t offset, Tensor* tensor); +class MNN_PUBLIC SingleBufferWithAllocator { +public: + ~ SingleBufferWithAllocator(); + ErrorCode realloc(size_t size, size_t align); + void release(); + std::shared_ptr root; + MemChunk current; + size_t currentSize = 0; +}; class MNN_PUBLIC DeferBufferAllocator : public BufferAllocator { public: - DeferBufferAllocator(std::shared_ptr parent, size_t align = MNN_MEMORY_ALIGN_DEFAULT, MemChunkApplyToTensor func = nullptr); - ~DeferBufferAllocator() { - reset(); + DeferBufferAllocator(SingleBufferWithAllocator* parent, size_t align = MNN_MEMORY_ALIGN_DEFAULT, MemChunkApplyToTensor func = nullptr); + virtual ~DeferBufferAllocator() { + // Donothing } public: MemChunk alloc(size_t size, bool separate = false, size_t align = 0) override; bool free(MemChunk chunk) override; void release(bool allRelease = true) override; - size_t totalSize() const override; void barrierBegin() override; void barrierEnd() override; void beginGroup() override; void endGroup() override; void reset() override; ErrorCode compute() override; + ErrorCode apply() override; private: std::vector> mChunks; MemNode *mHead = nullptr, *mTail = nullptr; std::multiset mFreeList; // std::unique_ptr mPtr; MemChunk mPtr; - size_t mTotalSize = 0; - std::shared_ptr mAllocator; size_t mAlign; // barrier bool mBarrrier = false; @@ -231,6 +240,7 @@ class MNN_PUBLIC DeferBufferAllocator : public BufferAllocator { void eraseFree(MemNode* chunk); void visiChildren(MemNode* chunk); MemChunkApplyToTensor mApplyFunction; + SingleBufferWithAllocator* mParent; }; } // namespace MNN #endif diff --git a/source/core/OpCommonUtils.cpp b/source/core/OpCommonUtils.cpp index 8c5596312..4a62fb4db 100644 --- a/source/core/OpCommonUtils.cpp +++ b/source/core/OpCommonUtils.cpp @@ -647,9 +647,13 @@ static bool _RebuildExternalOp(FileLoader* external, const MNN::Op* origin, flat external->offset(param->external[0] + param->external[1] + param->external[2]); } if (param->bias.empty() && param->external.size() > 3) { - param->bias.resize(param->external[3]/sizeof(float)); - external->read((char*)param->bias.data(), param->external[3]); - } + if (param->external[3] > 0) { + param->bias.resize(param->external[3]/sizeof(float)); + external->read((char*)param->bias.data(), param->external[3]); + } else { + param->bias.resize(param->common->outputCount); + } + } if (param->quanParameter->index.empty() && param->external.size() > 4) { param->quanParameter->index.resize(param->external[4]/sizeof(uint32_t)); external->read((char*)param->quanParameter->index.data(), param->external[4]); diff --git a/source/core/Pipeline.cpp b/source/core/Pipeline.cpp index 30266df05..0108aa5e6 100644 --- a/source/core/Pipeline.cpp +++ b/source/core/Pipeline.cpp @@ -269,6 +269,8 @@ ErrorCode Pipeline::encode(bool supportDebug, bool permitCodegen) { } } else { #ifndef MNN_BUILD_MINI + mBackend->onClearBuffer(); + mBackupBackend->onClearBuffer(); mContext.clear(); mContext.mNeedRelease = mGeometryNeedRelease; FileLoader l(mExternalFile.c_str()); @@ -897,6 +899,10 @@ ErrorCode Pipeline::fixResizeCache() { info.cacheBuffer.extras.clear(); } } + mInfo.first.cache.first->onResizeBegin(); + mInfo.first.cache.first->onResizeEnd(); + mInfo.first.cache.second->onResizeBegin(); + mInfo.first.cache.second->onResizeEnd(); auto res = mInfo.first.cache.first->onSelectDynamicAllocator(1, 2); res = res && mInfo.first.cache.second->onSelectDynamicAllocator(1, 2); if (!res) { @@ -1094,8 +1100,6 @@ ErrorCode Pipeline::allocMemory(bool firstMalloc, bool forbidReplace) { /* Create Execution Begin */ auto& mBackend = mInfo.first.cache.first; auto& mBackupBackend = mInfo.first.cache.second; - mBackend->onClearBuffer(); - mBackupBackend->onClearBuffer(); // Check If we need a lone time for init if (mBackend->type() != MNN_FORWARD_CPU && mBackend->type() != MNN_FORWARD_CPU_EXTENSION && mTuneAttr.autoSetOpType) { Runtime::OpInfo dstInfo; @@ -1144,10 +1148,12 @@ ErrorCode Pipeline::allocMemory(bool firstMalloc, bool forbidReplace) { } } /* Create Execution End */ - + mBackend->onClearBuffer(); + mBackupBackend->onClearBuffer(); _SetTensorBackend(mInfo, mAllocInput); // Insert Wrap If needed { + // Reset memory allocator for backend auto insertCode = _InsertCopy(mInfo, mCacheConstTensors, mWrapTensors, mAllocInput, forbidReplace); if (NO_ERROR != insertCode) { return insertCode; diff --git a/source/core/Session.cpp b/source/core/Session.cpp index a424898ba..48148ab28 100644 --- a/source/core/Session.cpp +++ b/source/core/Session.cpp @@ -18,10 +18,8 @@ #include "core/TensorUtils.hpp" #include "utils/InitNet.hpp" -using namespace std; - namespace MNN { -static void _createPipelineBackend(Schedule::PipelineInfo& iter, RuntimeInfo& runtime) { +void Session::createPipelineBackend(Schedule::PipelineInfo& iter, RuntimeInfo& runtime) { if (iter.first.cache.first != nullptr) { return; } @@ -41,7 +39,16 @@ static void _createPipelineBackend(Schedule::PipelineInfo& iter, RuntimeInfo& ru // We need create a new backend to do size compute / not support op compute BackendConfig defaultConfig; defaultConfig.flags = 4; - iter.first.cache.second.reset(cpuRuntime->onCreate(&defaultConfig)); + if (iter.first.info.user != nullptr) { + // Don't change default Precision + defaultConfig.memory = iter.first.info.user->memory; + defaultConfig.power = iter.first.info.user->power; + } + Backend* origin = nullptr; + if (cpuRuntime.get() == rt) { + origin = iter.first.cache.first.get(); + } + iter.first.cache.second.reset(cpuRuntime->onCreate(&defaultConfig, origin)); } } void Session::ModeGroup::setMode(Interpreter::SessionMode mode) { @@ -84,8 +91,8 @@ void Session::ModeGroup::setHint(Interpreter::HintMode mode, int hint) { case Interpreter::DYNAMIC_QUANT_OPTIONS: runtimeHint.dynamicQuantOption = hint; break; - case Interpreter::KVCACHE_QUANT_OPTIONS: - runtimeHint.kvcacheQuantOption = hint; + case Interpreter::QKV_QUANT_OPTIONS: + runtimeHint.qkvQuantOption = hint; break; case Interpreter::KVCACHE_SIZE_LIMIT: runtimeHint.kvcacheSizeLimit = hint; @@ -100,6 +107,12 @@ void Session::ModeGroup::setExternalPath(std::string path, int type) { case MNN::Interpreter::EXTERNAL_PATH_KVCACHE_DIR: runtimeHint.kvcacheDirPath = path; break; + case MNN::Interpreter::EXTERNAL_FEATUREMAP_DIR: + runtimeHint.midMemoryPath = path; + break; + case MNN::Interpreter::EXTERNAL_WEIGHT_DIR: + runtimeHint.weightMemoryPath = path; + break; default: break; } @@ -114,7 +127,7 @@ Session::Session(Schedule::ScheduleInfo&& info, const ModeGroup& mode, RuntimeIn } mInfo = std::move(info); for (auto& iter : mInfo.pipelineInfo) { - _createPipelineBackend(iter, mRuntime); + createPipelineBackend(iter, mRuntime); Pipeline::TuningAttr attr; attr.maxTuningNumber = mode.maxTuningNumber; attr.autoSetOpType = mode.backendMode == Interpreter::Session_Backend_Auto; @@ -473,7 +486,7 @@ Session* Session::clone(RuntimeInfo&& runtime, std::shared_ptr sharedConst); + static void createPipelineBackend(Schedule::PipelineInfo& iter, RuntimeInfo& runtime); + public: /** * @brief infer. diff --git a/source/cv/ImageProcess.cpp b/source/cv/ImageProcess.cpp index 7d57a7200..d6592e1d8 100644 --- a/source/cv/ImageProcess.cpp +++ b/source/cv/ImageProcess.cpp @@ -28,7 +28,6 @@ #include "backend/cpu/x86_x64/cpu_id.h" #endif -#define CACHE_SIZE 256 namespace MNN { void registerBackend(); diff --git a/source/cv/ImageProcessUtils.cpp b/source/cv/ImageProcessUtils.cpp index c8662ed36..daf8ffd9d 100644 --- a/source/cv/ImageProcessUtils.cpp +++ b/source/cv/ImageProcessUtils.cpp @@ -28,7 +28,6 @@ #include "backend/cpu/x86_x64/cpu_id.h" #endif -#define CACHE_SIZE 256 namespace MNN { using namespace CV; #define CHECKFORMAT(src, dst, func) if (source == src && dest == dst) return func @@ -240,9 +239,14 @@ ErrorCode ImageProcessUtils::selectImageProcer(bool identity, bool hasBackend, b return NO_ERROR; } // Choose sampler. - mInside->mSampler = choose(mInside->config.sourceFormat, mInside->config.filterType, identity); - if (nullptr == mInside->mSampler) { - return INPUT_DATA_ERROR; + if (false == identity || mInside->config.sourceFormat == YUV_NV12 || mInside->config.sourceFormat == YUV_NV21 || mInside->config.sourceFormat == YUV_I420) { + mInside->mSampler = choose(mInside->config.sourceFormat, mInside->config.filterType, identity); + if (nullptr == mInside->mSampler) { + MNN_ERROR("Do not support resize convert.\n"); + return INPUT_DATA_ERROR; + } + } else { + mInside->mSampler = nullptr; } // Choose blitter. if ((ImageFormatType)mInside->config.sourceFormat != (ImageFormatType)mInside->config.destFormat) { @@ -366,11 +370,17 @@ static std::pair _computeClip(CV::Point* points, int iw, int ih, const return std::make_pair(sta, end); } +static inline float __clamp(float v, float minV, float maxV) { + return std::max(std::min(v, maxV), minV); +} + ErrorCode ImageProcessUtils::transformImage(const uint8_t* source, uint8_t* dst, uint8_t* samplerDest, uint8_t* blitDest, int tileCount, int destBytes, const int32_t* regions) { CV::Point points[2]; if (mInside->mStride == 0) { mInside->mStride = mInside->iw * mInside->ic; } + float xMax = mInside->iw - 1; + float yMax = mInside->ih - 1; for (int i = 0; i < mInside->oh; ++i) { int dy = mInside->mDraw ? regions[3 * i] : i; auto dstY = (uint8_t*)dst + dy * destBytes * mInside->ow * mInside->oc; @@ -390,7 +400,9 @@ ErrorCode ImageProcessUtils::transformImage(const uint8_t* source, uint8_t* dst, samplerDest = blitDest; } + const uint8_t* blitSrc = samplerDest; // For draw // Sample + const uint8_t* sourcePos = nullptr; // for sampler is null. if (!mInside->mDraw) { // Compute position points[0].fX = xStart; @@ -432,16 +444,28 @@ ErrorCode ImageProcessUtils::transformImage(const uint8_t* source, uint8_t* dst, } points[1].fX = (deltaX) / (float)(count); points[1].fY = (deltaY) / (float)(count); - - mInside->mSampler(source, samplerDest, points, sta, end - sta, count, mInside->iw, mInside->ih, mInside->mStride); + + if (mInside->mSampler) { + mInside->mSampler(source, samplerDest, points, sta, end - sta, count, mInside->iw, mInside->ih, mInside->mStride); + blitSrc = samplerDest; + } else { + int y = (int)roundf(__clamp(points[0].fY, 0, yMax)); + int x = (int)roundf(__clamp(points[0].fX, 0, xMax)); + sourcePos = source + (y * mInside->mStride + mInside->ic* x); + blitSrc = sourcePos; // update blitSrc when not draw. + } } // Convert format if (mInside->mBlitter) { - mInside->mBlitter(samplerDest, blitDest, count); + mInside->mBlitter(blitSrc, blitDest, count); } // Turn float if (mInside->mBlitFloat) { - mInside->mBlitFloat(blitDest, (float*)dstStart, mInside->config.mean, mInside->config.normal, count); + if (mInside->mSampler) { + mInside->mBlitFloat(blitDest, (float*)dstStart, mInside->config.mean, mInside->config.normal, count); + } else { + mInside->mBlitFloat(sourcePos, (float*)dstStart, mInside->config.mean, mInside->config.normal, count); + } } } } @@ -493,10 +517,10 @@ static CV::ImageFormat _correctImageFormat(int outputBpp, halide_type_t type, CV } ErrorCode ImageProcessUtils::execFunc(const uint8_t *source, int stride, void *dest) { - uint8_t sampleDest[4 * 256]; - uint8_t blitDest[4 * 256]; + uint8_t sampleDest[4 * CACHE_SIZE]; + uint8_t blitDest[4 * CACHE_SIZE]; int destBytes = mInside->mDtype.bytes(); - int tileCount = UP_DIV(mInside->ow, 256); + int tileCount = UP_DIV(mInside->ow, CACHE_SIZE); if (mInside->mDraw) { tileCount = 1; } @@ -512,7 +536,7 @@ void ImageProcessUtils::setDraw() { } void ImageProcessUtils::draw(uint8_t* img, int w, int h, int c, const int* regions, int num, uint8_t* color) { - uint8_t blitDest[4 * 256]; + uint8_t blitDest[4 * CACHE_SIZE]; int destBytes = mInside->mDtype.bytes(); mInside->oh = num; transformImage(img, img, color, blitDest, 1, destBytes, regions); diff --git a/source/cv/ImageProcessUtils.hpp b/source/cv/ImageProcessUtils.hpp index e8e901bd3..baced1bc8 100644 --- a/source/cv/ImageProcessUtils.hpp +++ b/source/cv/ImageProcessUtils.hpp @@ -15,6 +15,7 @@ #include "backend/cpu/compute/CommonOptFunction.h" +#define CACHE_SIZE 512 namespace MNN { typedef void (*BLITTER)(const unsigned char* source, unsigned char* dest, size_t count); typedef void (*BLIT_FLOAT)(const unsigned char* source, float* dest, const float* mean, const float* normal, size_t count); diff --git a/source/geometry/GeometryComputerUtils.cpp b/source/geometry/GeometryComputerUtils.cpp index 207d29e5d..f2aae2a6d 100644 --- a/source/geometry/GeometryComputerUtils.cpp +++ b/source/geometry/GeometryComputerUtils.cpp @@ -265,14 +265,20 @@ ErrorCode GeometryComputerUtils::shapeComputeAndGeometryTransform( auto& c = *cp; std::shared_ptr tmpStorge; if (nullptr == c.execution) { - auto exe = OpCommonUtils::createExecutionWithExternal(backupBackend.get(), c.inputs, c.outputs, c.op, external, tmpStorge); - c.execution.reset(exe); + auto opIter = info.executionCache.find(c.op); + if (opIter != info.executionCache.end()) { + c.execution = opIter->second; + } else { + auto exe = OpCommonUtils::createExecutionWithExternal(backupBackend.get(), c.inputs, c.outputs, c.op, external, tmpStorge); + c.execution.reset(exe); + } } auto exe = c.execution; if (nullptr == exe.get()) { MNN_ERROR("Const Folder Error for %s\n", info.op->name()->c_str()); return NO_EXECUTION; } + backupBackend->onResizeBegin(); for (auto t : c.outputs) { auto des = TensorUtils::getDescribeOrigin(t); TensorUtils::setLinearLayout(t); @@ -282,7 +288,6 @@ ErrorCode GeometryComputerUtils::shapeComputeAndGeometryTransform( } des->setBackend(backupBackend.get()); } - backupBackend->onResizeBegin(); auto code = exe->onResize(c.inputs, c.outputs); if (NO_ERROR != code) { return NOT_SUPPORT; diff --git a/test.sh b/test.sh index 81ef7c647..52bd6c6d3 100755 --- a/test.sh +++ b/test.sh @@ -175,6 +175,7 @@ android_static_build() { -DMNN_OPENCL=true \ -DMNN_SUPPORT_BF16=true \ -DMNN_OPENCL=true -DMNN_ARM82=true \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ -DNATIVE_LIBRARY_OUTPUT=. -DNATIVE_INCLUDE_OUTPUT=. $1 $2 $3 make -j16 android_build_wrong=$[$? > 0] @@ -205,7 +206,8 @@ android_static_build() { -DMNN_OPENCL=true \ -DMNN_BUILD_MINI=true \ -DMNN_SUPPORT_BF16=true \ - -DMNN_OPENCL=true\ + -DMNN_OPENCL=true \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ -DNATIVE_LIBRARY_OUTPUT=. -DNATIVE_INCLUDE_OUTPUT=. make -j16 android_build_wrong=$[$? > 0] @@ -249,6 +251,7 @@ linux_build() { -DMNN_BUILD_OPENCV=ON \ -DMNN_LOW_MEMORY=ON \ -DMNN_IMGCODECS=ON \ + -DMNN_SUPPORT_TRANSFORMER_FUSE=ON \ -DMNN_ENABLE_COVERAGE=$COVERAGE make -j16 @@ -477,33 +480,34 @@ coverage_report() { # # ############################################################################################# android_unit_test() { - adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out all 0 0 1 $1" + memory_mode=$2 + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out all 0 0 1 $1 $memory_mode" if [ $? -ne 0 ]; then echo '### Android单元测试失败,测试终止!' failed fi - adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op 0 0 4 multi$1" + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op 0 0 4 multi$1 $memory_mode" if [ $? -ne 0 ]; then echo '### Android单元测试多线程失败,测试终止!' failed fi - adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/convolution 0 2 4 fp16multi$1" + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/convolution 0 2 4 fp16multi$1 $memory_mode" if [ $? -ne 0 ]; then echo '### Android单元测试卷积FP16多线程失败,测试终止!' failed fi - adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/col2im 0 2 4 fp16col2im$1" + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/col2im 0 2 4 fp16col2im$1 $memory_mode" if [ $? -ne 0 ]; then echo '### Android单元测试FP16-col2im多线程失败,测试终止!' failed fi - adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/R 0 2 4 fp16roipooling$1" + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/R 0 2 4 fp16roipooling$1 $memory_mode" if [ $? -ne 0 ]; then echo '### Android单元测试FP16-roipooling多线程失败,测试终止!' failed fi if [ "$OPENCL_CHANGE" ]; then - adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op 3 1 4 $1" + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op 3 1 4 $1 $memory_mode" if [ $? -ne 0 ]; then echo '### Android单元测试OpenCL失败,测试终止!' failed @@ -592,25 +596,58 @@ android_model_test() { fi fi } -android_unit_test_low_memory() { +android_unit_test_low_memory_armv8() { adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/lowMemory 0 1 1 $1 2" if [ $? -ne 0 ]; then - echo '### Android 64位Low Memory, precision=1 单元测试失败,测试终止!' + echo '### Android 64位Low Memory,动态量化, precision=1, thread=1 单元测试失败,测试终止!' failed fi adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/lowMemory 0 2 1 $1 2" if [ $? -ne 0 ]; then - echo '### Android 64位Low Memory, precision=2 单元测试失败,测试终止!' + echo '### Android 64位Low Memory,动态量化, precision=2, thread=1 单元测试失败,测试终止!' + failed + fi + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/lowMemory 0 1 4 $1 2" + if [ $? -ne 0 ]; then + echo '### Android 64位Low Memory,动态量化, precision=1, thread=4 单元测试失败,测试终止!' + failed + fi + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/lowMemory 0 2 4 $1 2" + if [ $? -ne 0 ]; then + echo '### Android 64位Low Memory,动态量化, precision=2, thread=4 单元测试失败,测试终止!' failed fi adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/lowMemory 0 1 1 $1" if [ $? -ne 0 ]; then - echo '### Android 64位 权值量化调用1x1Strassen, precision=1 单元测试失败,测试终止!' + echo '### Android 64位Low Memory 权重反量化, precision=1 单元测试失败,测试终止!' failed fi adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/lowMemory 0 2 1 $1" if [ $? -ne 0 ]; then - echo '### Android 64位 权值量化调用1x1Strassen, precision=2 单元测试失败,测试终止!' + echo '### Android 64位Low Memory 权重反量化, precision=2 单元测试失败,测试终止!' + failed + fi +} + +android_unit_test_low_memory_armv7() { + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/lowMemory 0 1 1 $1 2" + if [ $? -ne 0 ]; then + echo '### Android 32位Low Memory,动态量化, precision=1, thread=1 单元测试失败,测试终止!' + failed + fi + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/lowMemory 0 2 1 $1 2" + if [ $? -ne 0 ]; then + echo '### Android 32位Low Memory,动态量化, precision=2, thread=1 单元测试失败,测试终止!' + failed + fi + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/lowMemory 0 1 4 $1 2" + if [ $? -ne 0 ]; then + echo '### Android 32位Low Memory,动态量化, precision=1, thread=4 单元测试失败,测试终止!' + failed + fi + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./run_test.out op/lowMemory 0 2 4 $1 2" + if [ $? -ne 0 ]; then + echo '### Android 32位Low Memory,动态量化, precision=2, thread=4 单元测试失败,测试终止!' failed fi } @@ -620,7 +657,7 @@ android_test() { # 1. build Android32 mkdir build_32 pushd build_32 - ../build_32.sh -DMNN_BUILD_TRAIN=OFF -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DMNN_OPENCL=true + ../build_32.sh -DMNN_BUILD_TRAIN=OFF -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DMNN_OPENCL=true -DMNN_LOW_MEMORY=ON -DMNN_SUPPORT_TRANSFORMER_FUSE=ON android32_build_wrong=$[$? > 0] mnn32_size=$(ls -lh libMNN.so | awk '{print $5}') expr32_size=$(ls -lh libMNN_Express.so | awk '{print $5}') @@ -631,14 +668,15 @@ android_test() { failed fi ../updateTest.sh - android_unit_test 32 + android_unit_test 32bit 1 + android_unit_test_low_memory_armv7 32bit android_model_test 32 popd # 3. build Android64 mkdir build_64 pushd build_64 - ../build_64.sh -DMNN_BUILD_TRAIN=OFF -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DMNN_ARM82=true -DMNN_OPENCL=true -DMNN_LOW_MEMORY=true + ../build_64.sh -DMNN_BUILD_TRAIN=OFF -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DMNN_ARM82=true -DMNN_OPENCL=true -DMNN_LOW_MEMORY=true -DMNN_SUPPORT_TRANSFORMER_FUSE=ON android64_build_wrong=$[$? > 0] mnn64_size=$(ls -lh libMNN.so | awk '{print $5}') expr64_size=$(ls -lh libMNN_Express.so | awk '{print $5}') @@ -651,8 +689,8 @@ android_test() { # 4. test Android64 ../updateTest.sh - android_unit_test 64 - android_unit_test_low_memory 64 + android_unit_test 64 0 + android_unit_test_low_memory_armv8 64 android_model_test 64 popd diff --git a/test/MNNTestSuite.cpp b/test/MNNTestSuite.cpp index f37f1c038..804b52543 100644 --- a/test/MNNTestSuite.cpp +++ b/test/MNNTestSuite.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include "MNNTestSuite.h" MNNTestSuite* MNNTestSuite::gInstance = NULL; @@ -39,7 +40,7 @@ static void printTestResult(int wrong, int right, const char* flag) { int MNNTestSuite::run(const char* key, int precision, const char* flag) { if (key == NULL || strlen(key) == 0) return 0; - std::map runTimes; + std::vector> runTimes; auto suite = MNNTestSuite::get(); std::string prefix = key; std::vector wrongs; @@ -51,12 +52,15 @@ int MNNTestSuite::run(const char* key, int precision, const char* flag) { MNN_PRINT("\trunning %s.\n", test->name.c_str()); MNN::Timer _t; auto res = test->run(precision); - runTimes.insert(std::make_pair(test->name, _t.durationInUs() / 1000.0f)); + runTimes.emplace_back(std::make_pair(test->name, _t.durationInUs() / 1000.0f)); if (!res) { wrongs.emplace_back(test->name); } } } + std::sort(runTimes.begin(), runTimes.end(), [](const std::pair& left, const std::pair& right) { + return left.second < right.second; + }); for (auto& iter : runTimes) { MNN_PRINT("%s cost time: %.3f ms\n", iter.first.c_str(), iter.second); } @@ -73,7 +77,7 @@ int MNNTestSuite::run(const char* key, int precision, const char* flag) { int MNNTestSuite::runAll(int precision, const char* flag) { auto suite = MNNTestSuite::get(); std::vector wrongs; - std::map runTimes; + std::vector> runTimes; for (int i = 0; i < suite->mTests.size(); ++i) { MNNTestCase* test = suite->mTests[i]; if (test->name.find("speed") != std::string::npos) { @@ -87,11 +91,14 @@ int MNNTestSuite::runAll(int precision, const char* flag) { MNN_PRINT("\trunning %s.\n", test->name.c_str()); MNN::Timer _t; auto res = test->run(precision); - runTimes.insert(std::make_pair(test->name, _t.durationInUs() / 1000.0f)); + runTimes.emplace_back(std::make_pair(test->name, _t.durationInUs() / 1000.0f)); if (!res) { wrongs.emplace_back(test->name); } } + std::sort(runTimes.begin(), runTimes.end(), [](const std::pair& left, const std::pair& right) { + return left.second < right.second; + }); for (auto& iter : runTimes) { MNN_PRINT("%s cost time: %.3f ms\n", iter.first.c_str(), iter.second); } diff --git a/test/core/BufferAllocatorTest.cpp b/test/core/BufferAllocatorTest.cpp index 968d83646..423b5e011 100644 --- a/test/core/BufferAllocatorTest.cpp +++ b/test/core/BufferAllocatorTest.cpp @@ -30,7 +30,9 @@ class BufferAllocatorTest : public MNNTestCase { printf("BufferAllocator total size : %lu B, %f M\n", allocator.totalSize(), allocator.totalSize() / 1024.f / 1024.f); } static void defer_allocator_test(const std::vector& seqs) { - DeferBufferAllocator allocator(BufferAllocator::Allocator::createDefault()); + SingleBufferWithAllocator root; + root.root = BufferAllocator::Allocator::createDefault(); + DeferBufferAllocator allocator(&root); std::vector allocs; int usage_num = 0; for (int i = 0; i < seqs.size(); i++) { diff --git a/test/cv/ImageProcessTest.cpp b/test/cv/ImageProcessTest.cpp index 7f689fac5..959b7d5a6 100644 --- a/test/cv/ImageProcessTest.cpp +++ b/test/cv/ImageProcessTest.cpp @@ -11,9 +11,12 @@ #include #include #include "MNNTestSuite.h" +#include +#include using namespace MNN; using namespace MNN::CV; +using namespace MNN::Express; static std::vector genSourceData(int h, int w, int bpp) { std::vector source(h * w * bpp); @@ -148,7 +151,7 @@ class ImageProcessGrayToGrayBilinearTransformTest : public MNNTestCase { ImageProcess::Config config; config.sourceFormat = GRAY; config.destFormat = GRAY; - config.filterType = BILINEAR; + config.filterType = MNN::CV::Filter::BILINEAR; config.wrap = CLAMP_TO_EDGE; std::shared_ptr process(ImageProcess::create(config)); @@ -189,7 +192,7 @@ class ImageProcessGrayToGrayNearestTransformTest : public MNNTestCase { ImageProcess::Config config; config.sourceFormat = GRAY; config.destFormat = GRAY; - config.filterType = NEAREST; + config.filterType = MNN::CV::Filter::NEAREST; config.wrap = ZERO; std::shared_ptr process(ImageProcess::create(config)); @@ -444,7 +447,7 @@ class ImageProcessRGBAToGrayBilinearTransformTest : public MNNTestCase { ImageProcess::Config config; config.sourceFormat = RGBA; config.destFormat = GRAY; - config.filterType = BILINEAR; + config.filterType = MNN::CV::Filter::BILINEAR; config.wrap = CLAMP_TO_EDGE; std::shared_ptr process(ImageProcess::create(config)); @@ -483,7 +486,7 @@ class ImageProcessRGBAToGrayNearestTransformTest : public MNNTestCase { ImageProcess::Config config; config.sourceFormat = RGBA; config.destFormat = GRAY; - config.filterType = NEAREST; + config.filterType = MNN::CV::Filter::NEAREST; config.wrap = CLAMP_TO_EDGE; std::shared_ptr process(ImageProcess::create(config)); @@ -772,7 +775,7 @@ class ImageProcessColorResizeTest: public MNNTestCase { // Test: first color then resize and first resize then color, these two results are same. virtual ~ImageProcessColorResizeTest() = default; virtual bool run(int precison) { - std::vector filters(NEAREST, BILINEAR); + std::vector filters = {MNN::CV::Filter::NEAREST, MNN::CV::Filter::BILINEAR}; for (int iw = 2; iw < 200; iw += 17) { for (int ih = 7; ih < 200; ih += 19) { for (int ow = 2; ow < 200; ow += 17) { @@ -802,5 +805,472 @@ class ImageProcessColorResizeTest: public MNNTestCase { return true; } }; -MNNTestSuiteRegister(ImageProcessColorResizeTest, "cv/image_process/color_resize_test"); +// MNNTestSuiteRegister(ImageProcessColorResizeTest, "cv/image_process/color_resize_test"); +static int format2Channel(CV::ImageFormat format) { + switch (format) { + case CV::RGB: + case CV::BGR: + case CV::YCrCb: + case CV::YUV: + case CV::HSV: + case CV::XYZ: + case CV::YUV_NV21: + case CV::YUV_NV12: + case CV::YUV_I420: + return 3; + case CV::BGR555: + case CV::BGR565: + return 2; + case CV::GRAY: + return 1; + case CV::RGBA: + case CV::BGRA: + return 4; + default: + return 3; + } +} + +static VARP cvtImpl(VARP src, ImageFormat srcformat, ImageFormat dstformat,int h, int w) { + int oc = format2Channel(dstformat); + auto type = halide_type_of(); + auto dest = Tensor::create({1, h, w, oc}, type); + std::unique_ptr process(CV::ImageProcess::create(srcformat, dstformat)); + process->convert(src->readMap(), w, h, 0, dest); + auto res = Express::Variable::create(Express::Expr::create(dest, true), 0); + return _Squeeze(res, {0}); +} + +static void getVARPSize(VARP var, int* height, int* width, int* channel) { + auto info = var->getInfo(); + auto dims = info->dim; + int num = dims.size(); + if (num < 2) return; + if (num == 2) { + *height = dims[0]; + *width = dims[1]; + *channel = 1; + } else if (num == 3) { + *height = dims[0]; + *width = dims[1]; + *channel = dims[2]; + } else if (info->order == NHWC) { + *channel = dims[num - 1]; + *width = dims[num - 2]; + *height = dims[num - 3]; + } else { // NCHW + *width = dims[num - 1]; + *height = dims[num - 2]; + *channel = dims[num - 3]; + } +} + +static VARP cvtColor(VARP src, ImageFormat srcformat, ImageFormat dstformat) { + int h, w, c; + getVARPSize(src, &h, &w, &c); + return cvtImpl(src, srcformat, dstformat, h, w); +} + +class ImageProcessSpeed: public MNNTestCase { + virtual ~ImageProcessSpeed() = default; + virtual bool run(int precison) { + int LOOP = 10000; + int warmup = 2; + int ih = 240, iw = 240; + { + int ic = 4; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, RGBA, BGR); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, RGBA, BGR); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("RGBA->BGR: cost time=%.3f ms\n", duration); + } + { + int ic = 4; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, RGBA, BGRA); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, RGBA, BGRA); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("RGBA->BGRA: cost time=%.3f ms\n", duration); + } + + { + int ic = 3; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, RGB, BGR); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, RGB, BGR); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("RGB->BGR: cost time=%.3f ms\n", duration); + } + + { + int ic = 3; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, RGB, RGBA); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, RGB, RGBA); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("RGB->RGBA: cost time=%.3f ms\n", duration); + } + + { + int ic = 4; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, BGRA, BGR); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, BGRA, BGR); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("BRGA->BGR: cost time=%.3f ms\n", duration); + } + + { + int ic = 3; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, RGB, GRAY); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, RGB, GRAY); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("RGB->GRAY: cost time=%.3f ms\n", duration); + } + + { + int ic = 3; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, BGR, GRAY); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, BGR, GRAY); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("BGR->GRAY: cost time=%.3f ms\n", duration); + } + + { + int ic = 4; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, BGRA, GRAY); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, BGRA, GRAY); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("BGRA->GRAY: cost time=%.3f ms\n", duration); + } + + { + int ic = 4; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, RGBA, GRAY); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, RGBA, GRAY); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("RGBA->GRAY: cost time=%.3f ms\n", duration); + } + + { + int ic = 1; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, GRAY, RGBA); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, GRAY, RGBA); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("GRAY->RGBA: cost time=%.3f ms\n", duration); + } + + { + int ic = 1; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, GRAY, RGB); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, GRAY, RGB); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("GRAY->RGB: cost time=%.3f ms\n", duration); + } + + { + int ic = 3; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, RGB, YUV); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, RGB, YUV); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("RGB->YUV: cost time=%.3f ms\n", duration); + } + { + int ic = 3; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, RGB, XYZ); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, RGB, XYZ); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("RGB->XYZ: cost time=%.3f ms\n", duration); + } + + { + int ic = 3; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, RGB, HSV); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, RGB, HSV); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("RGB->HSV: cost time=%.3f ms\n", duration); + } + { + int ic = 3; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, RGB, BGR555); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, RGB, BGR555); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("RGB->BGR555: cost time=%.3f ms\n", duration); + } + + { + int ic = 3; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, BGR, BGR555); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, BGR, BGR555); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("BGR->BGR555: cost time=%.3f ms\n", duration); + } + + { + int ic = 3; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, BGR, BGR565); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, BGR, BGR565); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("BGR->BGR565: cost time=%.3f ms\n", duration); + } + + { + int ic = 3; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, RGB, BGR565); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, RGB, BGR565); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("RGB->BGR565: cost time=%.3f ms\n", duration); + } + + { + int ic = 3; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, YUV_NV21, RGB); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, YUV_NV21, RGB); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("YUV_NV21->RGB: cost time=%.3f ms\n", duration); + } + + { + int ic = 3; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, YUV_NV21, BGR); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, YUV_NV21, BGR); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("YUV_NV21->BGR: cost time=%.3f ms\n", duration); + } + + { + int ic = 3; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, YUV_NV21, BGRA); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, YUV_NV21, BGRA); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("YUV_NV21->BGRA: cost time=%.3f ms\n", duration); + } + + { + int ic = 3; + auto srcvec = genSourceData(ih, iw, ic); + auto srcVar = _Input({ih, iw, ic}, NHWC, halide_type_of()); + auto inputPtr = srcVar->writeMap(); + memcpy(inputPtr, srcvec.data(), srcVar->getInfo()->size * sizeof(uint8_t)); + + for (int i = 0; i < warmup; ++i) { + cvtColor(srcVar, YUV_NV21, RGBA); + } + Timer l_; + for (int i = 0; i < LOOP; ++i) { + cvtColor(srcVar, YUV_NV21, RGBA); + } + auto duration = (float)l_.durationInUs() / 1000.f / LOOP; + printf("YUV_NV21->RGBA: cost time=%.3f ms\n", duration); + } + return true; + } +}; +// MNNTestSuiteRegister(ImageProcessSpeed, "cv/image_process/speed"); diff --git a/test/expr/ModuleTest.cpp b/test/expr/ModuleTest.cpp index 56664bfe7..233711fda 100644 --- a/test/expr/ModuleTest.cpp +++ b/test/expr/ModuleTest.cpp @@ -33,7 +33,7 @@ static VARP convBlock(VARP x, INTS channels, int stride) { static VARP convBlocTemp(VARP x, INTS channels, int stride) { int inputChannel = channels[0], outputChannel = channels[1]; int group = inputChannel; - x = _Conv(0.002f, 1.0f, x, {inputChannel, inputChannel}, {3, 3}, SAME, {stride, stride}, {1, 1}); + x = _Conv(0.002f, 1.0f, x, {inputChannel, inputChannel}, {3, 3}, SAME, {stride, stride}, {1, 1}, inputChannel); x = _Conv(0.05f, -2.0f, x, {inputChannel, outputChannel}, {1, 1}, SAME, {1, 1}, {1, 1}, 1); return x; } @@ -1190,3 +1190,144 @@ class WinogradMemoryTest : public MNNTestCase { } }; MNNTestSuiteRegister(WinogradMemoryTest, "expr/WinogradMemoryTest"); + + +class SequenceMemoryTest : public MNNTestCase { +public: + virtual bool run(int precision) { + auto res = _run(precision, false); + if (!res) { + FUNC_PRINT(1); + return false; + } + return _run(precision, true); + } + virtual bool _run(int precision, bool shapeMultable) { + BackendConfig bnConfig; + auto exe = Executor::newExecutor(MNN_FORWARD_CPU, bnConfig, 1); + ExecutorScope scope(exe); + Module::Config config; + config.shapeMutable = shapeMultable; + config.rearrange = true; + std::vector buffer; + { + // Make Buffer + auto x0 = _Input({1, 3, -1, -1}, NCHW, halide_type_of()); + x0->setName("x0"); + auto y0 = _mobileNetV1Expr(_Convert(x0, NC4HW4), false); + y0->setName("y0"); + buffer = Variable::save({y0}); + } + auto rtInfo = Express::ExecutorScope::Current()->getRuntime(); + auto rt = rtInfo.first.begin()->second; + MNN::ScheduleConfig sconfig; + std::vector sconfigs = {sconfig}; + std::shared_ptr rtMgr(Executor::RuntimeManager::createRuntimeManager(sconfigs)); + rtMgr->setMode(Interpreter::Session_Memory_Collect); + std::shared_ptr m0(Module::load({"x0"}, {"y0"}, (const unsigned char*)buffer.data(), buffer.size(), rtMgr, &config), Module::destroy); + std::shared_ptr m1(Module::load({"x0"}, {"y0"}, (const unsigned char*)buffer.data(), buffer.size(), rtMgr, &config), Module::destroy); + float memoryInit = 0.0f; + rtMgr->getInfo(Interpreter::MEMORY, &memoryInit); + FUNC_PRINT_ALL(memoryInit, f); + auto x = _Input({1, 3, 96, 96}, NCHW, halide_type_of()); + x->writeMap(); + x->unMap(); + auto x1 = _Input({1, 3, 97, 97}, NCHW, halide_type_of()); + x1->writeMap(); + x1->unMap(); + auto x2 = _Input({1, 3, 95, 95}, NCHW, halide_type_of()); + x2->writeMap(); + x2->unMap(); + float memoryCurrent = 0.0f; + auto compute = [&](){ + m0->onForward({x}); + rtMgr->getInfo(Interpreter::MEMORY, &memoryCurrent); + auto dynamic0 = memoryCurrent - memoryInit; + FUNC_PRINT_ALL(dynamic0, f); + m1->onForward({x1}); + rtMgr->getInfo(Interpreter::MEMORY, &memoryCurrent); + auto dynamic1 = memoryCurrent - memoryInit; + + FUNC_PRINT_ALL(dynamic1, f); + m1->onForward({x2}); + rtMgr->getInfo(Interpreter::MEMORY, &memoryCurrent); + auto dynamic2 = memoryCurrent - memoryInit; + FUNC_PRINT_ALL(dynamic2, f); + + if (dynamic1 > dynamic0 * 1.1f || dynamic2 > dynamic1) { + MNN_ERROR("Dynamic Memory reuse error\n"); + return false; + } + return true; + }; + bool res = compute(); + if (!res) { + return false; + } + exe->gc(MNN::Express::Executor::FULL); + rtMgr->getInfo(Interpreter::MEMORY, &memoryCurrent); + auto dynamic3 = memoryCurrent - memoryInit; + FUNC_PRINT_ALL(dynamic3, f); + if (dynamic3 > 0.2) { + MNN_ERROR("Dynamic Memory GC error\n"); + return false; + } + res = compute(); + if (!res) { + return false; + } + return true; + } +}; +MNNTestSuiteRegister(SequenceMemoryTest, "expr/SequenceMemoryTest"); + +class PrearrangeTest : public MNNTestCase { +public: + virtual bool run(int precision) { + // Make Model include convolution in shape compute and content compute + auto x = _Input({1, 3, 24, 24}, NCHW, halide_type_of()); + x->setName("x"); + auto xs = _Convert(_Reshape(_Cast(_Shape(x, NCHW)), {1, 1, 2, 2}), NC4HW4); + xs = _Convert(_Conv(1.0f, 0.0f, xs, {1, 1}, {2, 2}), NCHW); + auto y = _Conv(0.1f, 0.0f, _Convert(x, NC4HW4), {3, 1}, {3, 3}); + y = _Convert(y, NCHW); + y = _ReduceMean(y); + y = y * _Reciprocal(xs); + auto info = y->getInfo(); + y->setName("y"); + auto buffer = Variable::save({y}); + MNN::ScheduleConfig sconfig; + BackendConfig bnConfig; + bnConfig.precision = MNN::BackendConfig::Precision_Low; + sconfig.backendConfig = &bnConfig; + auto exe = Executor::newExecutor(MNN_FORWARD_CPU, bnConfig, 4); + ExecutorScope scope(exe); + std::vector sconfigs = {sconfig}; + std::shared_ptr rtMgr(Executor::RuntimeManager::createRuntimeManager(sconfigs)); + rtMgr->setMode(Interpreter::Session_Memory_Collect); + Module::Config config; + config.rearrange = false; + std::shared_ptr m0(Module::load({"x"}, {"y"}, (const unsigned char*)buffer.data(), buffer.size(), rtMgr, &config), Module::destroy); + config.rearrange = true; + std::shared_ptr m1(Module::load({"x"}, {"y"}, (const unsigned char*)buffer.data(), buffer.size(), rtMgr, &config), Module::destroy); + auto size = x->getInfo()->size; + auto xPtr = x->writeMap(); + for (int v=0; vonForward({x})[0]->readMap()[0]; + auto y1 = m1->onForward({x})[0]->readMap()[0]; + if (fabsf(y0 - y1) > 0.000001f) { + return false; + } + rtMgr->setExternalPath(".", Interpreter::EXTERNAL_FEATUREMAP_DIR); + std::shared_ptr m2(Module::load({"x"}, {"y"}, (const unsigned char*)buffer.data(), buffer.size(), rtMgr, &config), Module::destroy); + auto y2 = m2->onForward({x})[0]->readMap()[0]; + if (fabsf(y0 - y2) > 0.000001f) { + return false; + } + return true; + } +}; +MNNTestSuiteRegister(PrearrangeTest, "expr/PrearrangeTest"); + diff --git a/test/op/AttentionTest.cpp b/test/op/AttentionTest.cpp new file mode 100644 index 000000000..402391952 --- /dev/null +++ b/test/op/AttentionTest.cpp @@ -0,0 +1,241 @@ +// +// AttentionTest.cpp +// MNNTests +// +// Created by MNN on 2024/07/23. +// Copyright © 2018, Alibaba Group Holding Limited +// +#ifdef MNN_SUPPORT_TRANSFORMER_FUSE +#include +#include +#include "MNNTestSuite.h" +#include "TestUtils.h" +#include +#include + +using namespace MNN::Express; + +int NumHead = 16; +int KvNumHead = 2; +int HeadDim = 128; +const float diff_threshold = 0.001; +const float diff_percent_threshold = 0.1; + +static std::vector< std::vector< std::vector > > generateRandTensor(int C, int H, int W) { + std::vector< std::vector< std::vector > > a; + a.resize(C); + for (int i = 0; i < C; i++) { + a[i].resize(H); + for (int j = 0; j < H; j++) { + a[i][j].resize(W); + for (int k = 0; k < W; k++) { + a[i][j][k] = (float)rand() / (float)RAND_MAX * 10.0 * (rand() % 2 ? 1 : -1); + } + } + } + return a; +} + +VARP vector_to_var(std::vector< std::vector< std::vector > > & a) { + int C = a.size(); + int H = a[0].size(); + int W = a[0][0].size(); + VARP var = _Input({1, C, H, W}, NCHW, halide_type_of()); + float * ptr = var->writeMap(); + for (int i = 0; i < C; i++) { + for (int j = 0; j < H; j++) { + for (int k = 0; k < W; k++) { + ptr[i * H * W + j * W + k] = a[i][j][k]; + } + } + } + var->unMap(); + return var; +} + +VARP vector_to_var(std::vector< std::vector > & a) { + int H = a.size(); + int W = a[0].size(); + VARP var = _Input({1, 1, H, W}, NCHW, halide_type_of()); + int * ptr = var->writeMap(); + for (int i = 0; i < H; i++) { + for (int j = 0; j < W; j++) { + ptr[i * W + j] = a[i][j]; + } + } + var->unMap(); + return var; +} + +static std::vector< std::vector< std::vector > > +computeAttention ( + std::vector< std::vector< std::vector > > & query, + std::vector< std::vector< std::vector > > & key, + std::vector< std::vector< std::vector > > & value, + std::vector< std::vector > & mask, + int seq_len, int kv_seq_len ) +{ + int group_size = NumHead / KvNumHead; + std::vector< std::vector< std::vector > > output(seq_len); + for (int i = 0; i < seq_len; i++) { + output[i].resize(NumHead); + for (int j = 0; j < NumHead; j++) { + output[i][j].resize(HeadDim); + } + } + for (int h = 0; h < NumHead; h++) { + int kv_h = h / group_size; + /*---- Q * K ----*/ + std::vector< std::vector > qk(seq_len, std::vector(kv_seq_len, 0.0f)); + for (int i = 0; i < seq_len; i++) { + for (int j = 0; j < kv_seq_len; j++) { + qk[i][j] = 0.0f; + for (int k = 0; k < HeadDim; k++) { + qk[i][j] += query[i][h][k] * key[j][kv_h][k]; + } + } + } + /*---- Mask QK ----*/ + float scale = 1.0 / sqrt(HeadDim); + for (int i = 0; i < seq_len; i++) { + for (int j = 0; j < kv_seq_len; j++) { + if (mask[i][j] == 1) { + qk[i][j] *= scale; + } else { + qk[i][j] = std::numeric_limits::lowest(); + } + } + } + /*---- Softmax QK ----*/ + for (int i = 0; i < seq_len; i++) { + float maxValue = qk[i][0]; + for (int j = 1; j < kv_seq_len; j++) { + maxValue = ALIMAX(maxValue, qk[i][j]); + } + for (int j = 0; j < kv_seq_len; j++) { + qk[i][j] -= maxValue; + } + float sum = 0.0f; + for (int j = 0; j < kv_seq_len; j++) { + sum += exp(qk[i][j]); + } + for (int j = 0; j < kv_seq_len; j++) { + qk[i][j] = exp(qk[i][j]) / sum; + } + } + /*---- QK * V ----*/ + for (int i = 0; i < seq_len; i++) { + for (int j = 0; j < HeadDim; j++) { + output[i][h][j] = 0.0f; + for (int k = 0; k < kv_seq_len; k++) { + output[i][h][j] += qk[i][k] * value[k][kv_h][j]; + } + } + } + } + return output; +} + +class NaiveAttention { + private: + std::vector< std::vector< std::vector > > mPastKey, mPastValue; + int mPastLen; + public: + NaiveAttention() : mPastLen(0) {} + ~NaiveAttention() = default; + std::vector< std::vector< std::vector > > onExecute ( + std::vector< std::vector< std::vector > > & query, + std::vector< std::vector< std::vector > > & key, + std::vector< std::vector< std::vector > > & value, + std::vector< std::vector > & mask, + int seq_len ) + { + for (int i = 0; i < seq_len; i++) { + mPastKey.push_back(key[i]); + mPastValue.push_back(value[i]); + } + mPastLen += seq_len; + return computeAttention(query, mPastKey, mPastValue, mask, seq_len, mPastLen); + } +}; + +class AttentionTest : public MNNTestCase { + protected: + std::vector< std::vector< std::vector > > query; + std::vector< std::vector< std::vector > > key; + std::vector< std::vector< std::vector > > value; + std::vector< std::vector > mask; + std::vector< std::vector< std::vector > > expected_result; + VARP Query, Key, Value, Mask, Output; +public: + AttentionTest() = default; + virtual ~AttentionTest() = default; + + void generateInput(int seq_len) { + query = generateRandTensor(seq_len, NumHead, HeadDim); + key = generateRandTensor(seq_len, KvNumHead, HeadDim); + value = generateRandTensor(seq_len, KvNumHead, HeadDim); + Query = vector_to_var(query); + Key = vector_to_var(key); + Value = vector_to_var(value); + } + + void generateMask(int seq_len, int kv_seq_len) { + mask.resize(seq_len); + for (int i = 0; i < seq_len; i++) { + mask[i].resize(kv_seq_len); + for (int j = 0; j < kv_seq_len; j++) { + if (j - i <= kv_seq_len - seq_len) { + mask[i][j] = 1; + } else { + mask[i][j] = 0; + } + } + } + Mask = vector_to_var(mask); + } + + bool compareResult(int seq_len) { + const float * resultPtr = Output->readMap(); + for (int i = 0; i < seq_len; i++) { + for (int j = 0; j < NumHead; j++) { + for (int k = 0; k < HeadDim; k++) { + float diff = fabs(resultPtr[i * NumHead * HeadDim + j * HeadDim + k] - expected_result[i][j][k]); + float diff_percent = fabs(diff / expected_result[i][j][k]); + if (diff > diff_threshold && diff_percent > diff_percent_threshold) { + printf("Result Mismatch: expected %lf but got %lf in CPU Attention Test\n", expected_result[i][j][k], resultPtr[i * NumHead * HeadDim + j * HeadDim + k]); + printf("Error Position: Output[%d][%d][%d]\n", i, j, k); + return false; + } + } + } + } + Output->unMap(); + return true; + } + + virtual bool run(int precision) { + srand(2024); + std::shared_ptr naiveAttention(new NaiveAttention); + std::shared_ptr attention(new MNN::OpT); + attention->type = MNN::OpType_Attention; + attention->main.type = MNN::OpParameter_AttentionParam; + attention->main.value = new MNN::AttentionParamT; + attention->main.AsAttentionParam()->kv_cache = true; + int seq_len = 10; + generateInput(seq_len); + generateMask(seq_len, seq_len); + expected_result = naiveAttention->onExecute(query, key, value, mask, seq_len); + Output = Variable::create(Expr::create(attention.get(), {Query, Key, Value, Mask})); + bool pass = compareResult(seq_len); + if (pass) { + printf("CPU attention unit test passed!\n"); + } else { + printf("Error: CPU attention unit test failed!\n"); + } + return pass; + } +}; + +MNNTestSuiteRegister(AttentionTest, "op/cpu_attention"); +#endif diff --git a/test/op/RasterTest.cpp b/test/op/RasterTest.cpp index 517da5374..2dd10eb73 100644 --- a/test/op/RasterTest.cpp +++ b/test/op/RasterTest.cpp @@ -8,6 +8,7 @@ #include #include +#include "RuntimeAttr.hpp" #include "MNNTestSuite.h" #include "TestUtils.h" @@ -211,6 +212,12 @@ class ReduceBlitTest : public MNNTestCase { return true; } virtual bool run(int precision) { + // TODO: Other Backend Support Reduce Blit + auto attr = ExecutorScope::Current()->getAttr(); + if (attr->firstType != MNN_FORWARD_CPU) { + MNN_ERROR("Currently only cpu backend support reduce blit\n"); + return true; + } ExecutorScope::Current()->lazyEval = false; auto res = _run(precision, false); if (!res) { diff --git a/test/speed/HybridConvSpeedTest.cpp b/test/speed/HybridConvSpeedTest.cpp index 2354c4c58..42968330d 100644 --- a/test/speed/HybridConvSpeedTest.cpp +++ b/test/speed/HybridConvSpeedTest.cpp @@ -132,20 +132,32 @@ class HybridConvSpeedInt8Test : public HybridConvSpeedTestCommon { class HybridConvInt8Test : public HybridConvSpeedTestCommon { public: virtual bool run(int precision) { - std::vector< std::vector> channels = {{7, 9}, {2048, 6144}, {1, 10}, {20, 153}, {9, 18}}; + INTS strides = {1, 1}, dilate = {1, 1}, pad = {0, 0}, inputShape = {1, 1}; // {w, h} int testBatchCount = 5; // std::vector batch(testBatchCount); std::vector batch = {1, 23, 1479, 38, 29}; std::vector kernels = {1, 1}; - std::vector weightBits = {8}; bool lowmemory = true; - for (auto& bits : weightBits) { + { + std::vector< std::vector> channels = {{7, 9}, {2048, 6144}, {1, 10}, {20, 153}, {9, 18}}; + for (int i = 0; i < channels.size(); ++i) { + for (int n = 0; n < 5; ++n) { + auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channels[i], pad, strides, dilate, batch[n], 8, precision); + if (!res) { + MNN_ERROR("Error: low memory hybridConv when bits=8, n=%d, ic=%d, oc=%d\n", batch[n], channels[i][0], channels[i][1]); + return false; + } + } + } + } + { + std::vector< std::vector> channels = {{2048, 6144}, {8, 8}, {8, 9}, {8, 16}}; for (int i = 0; i < channels.size(); ++i) { - for (int n = 0; n < batch.size(); ++n) { - auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channels[i], pad, strides, dilate, batch[n], bits, precision); + for (int n = 0; n < 5; ++n) { + auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channels[i], pad, strides, dilate, batch[n], 4, precision); if (!res) { - MNN_ERROR("Error: low memory hybridConv when n=%d, ic=%d, oc=%d\n", batch[n], channels[i][0], channels[i][1]); + MNN_ERROR("Error: low memory hybridConv when bits=4, n=%d, ic=%d, oc=%d\n", batch[n], channels[i][0], channels[i][1]); return false; } } diff --git a/tools/converter/include/config.hpp b/tools/converter/include/config.hpp index befc9d51f..63e72052d 100644 --- a/tools/converter/include/config.hpp +++ b/tools/converter/include/config.hpp @@ -53,6 +53,7 @@ class MNN_PUBLIC modelConfig { bool detectSparseSpeedUp = true; bool convertMatmulToConv = true; bool transformerFuse = false; + bool allowCustomOp = false; std::string customOpLibs = ""; std::string authCode = ""; std::string testDir = ""; diff --git a/tools/converter/source/common/cli.cpp b/tools/converter/source/common/cli.cpp index 89e83ab26..bc2399b36 100644 --- a/tools/converter/source/common/cli.cpp +++ b/tools/converter/source/common/cli.cpp @@ -295,6 +295,11 @@ bool Cli::initializeMNNConvertArgs(modelConfig &modelPath, int argc, char **argv "transformerFuse", "fuse attention op, like fmhaV2/fmhca/splitGelu/groupNorm. default: false", cxxopts::value() + ) + ( + "allowCustomOp", + "allow custom op when convert. default: false", + cxxopts::value() ); auto result = options.parse(argc, argv); @@ -489,6 +494,9 @@ bool Cli::initializeMNNConvertArgs(modelConfig &modelPath, int argc, char **argv if (result.count("transformerFuse")) { modelPath.transformerFuse = true; } + if (result.count("allowCustomOp")) { + modelPath.allowCustomOp = true; + } return true; } @@ -595,7 +603,7 @@ static void computeUnaryBuffer(MNN::NetT* net) { auto inputId = op->inputIndexes[0]; if (describes.find(inputId) == describes.end()) { auto iter = describes.find(outputId); - + } unaryDes = describes.find(inputId)->second; float inpScale = unaryDes->quantInfo->scale; @@ -704,7 +712,7 @@ bool Cli::convertModel(modelConfig& modelPath) { MNN_PRINT("MNN net has tensor quant info\n"); computeUnaryBuffer(newNet.get()); } - + error = writeFb(newNet, modelPath.MNNModel, modelPath); } else { error = writeFb(netT, modelPath.MNNModel, modelPath); diff --git a/tools/converter/source/common/writeFb.cpp b/tools/converter/source/common/writeFb.cpp index e46b50a7e..4132a2770 100644 --- a/tools/converter/source/common/writeFb.cpp +++ b/tools/converter/source/common/writeFb.cpp @@ -60,7 +60,7 @@ static float _computeOpExternalSizeInMB(const MNN::OpT* op) { } return blob->external[1] / 1024.0f / 1024.0f; } - + default: break; } @@ -166,7 +166,7 @@ int writeFb(std::unique_ptr& netT, const std::string& MNNModelFile, c } std::ostringstream notSupportInfo; - if (!notSupportOps.empty()) { + if (!notSupportOps.empty() && !config.allowCustomOp) { for (auto name : notSupportOps) { notSupportInfo << name << " | "; } diff --git a/tools/converter/source/optimizer/postconvert/AddTensorFormatConverter.cpp b/tools/converter/source/optimizer/postconvert/AddTensorFormatConverter.cpp index f52931fd8..4e714a1e8 100644 --- a/tools/converter/source/optimizer/postconvert/AddTensorFormatConverter.cpp +++ b/tools/converter/source/optimizer/postconvert/AddTensorFormatConverter.cpp @@ -28,8 +28,8 @@ static FormatSetType _getFormatType(const OpT* op, MNN_DATA_FORMAT originFormat) switch (op->type) { // NC4HW4 Ops with multi-input case MNN::OpType_SeqLen2Spatial: - case MNN::OpType_GroupNorm: - case MNN::OpType_Convolution: + case MNN::OpType_FmhaV2: + case MNN::OpType_Convolution: case MNN::OpType_Convolution3D: case MNN::OpType_ConvolutionDepthwise: case MNN::OpType_Deconvolution: diff --git a/tools/cpp/CMakeLists.txt b/tools/cpp/CMakeLists.txt index 134340106..c560fb401 100644 --- a/tools/cpp/CMakeLists.txt +++ b/tools/cpp/CMakeLists.txt @@ -1,4 +1,9 @@ set(MNN_CPP_TOOLS "") +if(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv7" OR ARCHS MATCHES "^armv7(;armv7s)?") + add_definitions(-DMNN_USE_NEON) +elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64") + add_definitions(-DMNN_USE_NEON) +endif() add_executable(GetMNNInfo ${CMAKE_CURRENT_LIST_DIR}/GetMNNInfo.cpp) list(APPEND MNN_CPP_TOOLS GetMNNInfo) add_executable(ModuleBasic.out ${CMAKE_CURRENT_LIST_DIR}/ModuleBasic.cpp) diff --git a/tools/cpp/ModuleBasic.cpp b/tools/cpp/ModuleBasic.cpp index 04954e7a4..6f0eb2538 100644 --- a/tools/cpp/ModuleBasic.cpp +++ b/tools/cpp/ModuleBasic.cpp @@ -8,6 +8,7 @@ #include "MNN_generated.h" #include +#include #include #include #define MNN_OPEN_TIME_TRACE @@ -93,6 +94,10 @@ int main(int argc, char *argv[]) { MNN_ERROR("Usage: ./ModuleBasic.out ${test.mnn} ${Dir} [runMask] [forwardType] [runLoops] [numberThread] [precision | memory] [cacheFile]\n"); return 0; } + BackendConfig backendConfigTmp; + auto _executor = Executor::newExecutor(MNN_FORWARD_CPU, backendConfigTmp, 1); + ExecutorScope _s(_executor); + std::string modelName = argv[1]; std::string directName = argv[2]; MNN_PRINT("Test %s from input info: %s\n", modelName.c_str(), directName.c_str()); @@ -277,6 +282,9 @@ int main(int argc, char *argv[]) { if (runMask & 1024) { rtmgr->setHint(Interpreter::DYNAMIC_QUANT_OPTIONS, 1); } + if (runMask & 2048) { + rtmgr->setExternalPath("tmp", Interpreter::EXTERNAL_FEATUREMAP_DIR); + } std::shared_ptr net; { AUTOTIME; @@ -419,6 +427,7 @@ int main(int argc, char *argv[]) { for (int i = 0; i < t; ++i) { Timer _l; auto out = net->onForward(inputs); + Variable::compute(out); for (auto o : out) { ((MNN::Tensor*)o->getTensor())->wait(MNN::Tensor::MAP_TENSOR_READ, true); } diff --git a/tools/cpp/getPerformance.cpp b/tools/cpp/getPerformance.cpp index ff3b5dfc4..c0b5d6f01 100644 --- a/tools/cpp/getPerformance.cpp +++ b/tools/cpp/getPerformance.cpp @@ -207,18 +207,17 @@ static void _testMemcpy() { int size = 1024 * 1024; int loop = 10000; std::vector threads; + int threadNumber = 2; + std::vector> tmp(threadNumber); + for (int i=0; i tmp0(size); - std::vector tmp1(size); - auto t0 = tmp0.data(); - auto t1 = tmp1.data(); + for (int i=0; i using namespace MNN; using namespace Express; @@ -30,7 +31,7 @@ constexpr const char* path = "./imgs/cat.jpg"; template VARP cv2mnn(const cv::Mat& src) { - VARP dst = _Input({ src.rows, src.cols, src.channels() }, NHWC, halide_type_of()); + VARP dst = _Input({ 1, src.rows, src.cols, src.channels() }, NHWC, halide_type_of()); auto inputPtr = dst->writeMap(); memcpy(inputPtr, src.ptr(0), dst->getInfo()->size * sizeof(T)); return dst; @@ -46,12 +47,12 @@ VARP cv2mnn(const cv::Mat& src) { #define arg_switch(COND, CASE0, CASE1, CASE2, CASE3) arg_concat(arg_switch_, COND)(CASE0, CASE1, CASE2, CASE3) #define BENCH_IMPL(mode, func, ...)\ - auto t1 = std::chrono::high_resolution_clock::now();\ +arg_switch(mode, cv::func(__VA_ARGS__);, auto dst = func(__VA_ARGS__);dst->readMap();, auto dst = func(__VA_ARGS__);dst[0]->readMap();, func(__VA_ARGS__);)\ + Timer l_;\ for (int i = 0; i < LOOP; i++) {\ arg_switch(mode, cv::func(__VA_ARGS__);, auto dst = func(__VA_ARGS__);dst->readMap();, auto dst = func(__VA_ARGS__);dst[0]->readMap();, func(__VA_ARGS__);)\ }\ - auto t2 = std::chrono::high_resolution_clock::now();\ - auto duration = std::chrono::duration_cast(t2 - t1).count() / (1000. * LOOP);\ + auto duration = (float)l_.durationInUs() / 1000.f / LOOP;\ times.push_back(duration); \ #define BENCHMARK_NAME(mode, name, func, ...) \ @@ -73,17 +74,18 @@ void color(cv::Mat cvimg, VARP mnnimg) { cv::Mat dst; #define CVTCOLOR(code)\ BENCHMARK_NAME(0, code, cvtColor, cvimg, dst, cv::COLOR_##code)\ - BENCHMARK_NAME(1, code, cvtColor, mnnimg, COLOR_##code) + BENCHMARK_NAME(3, code, cvtColor, mnnimg, COLOR_##code) + CVTCOLOR(RGB2BGR) CVTCOLOR(RGB2GRAY) CVTCOLOR(RGB2RGBA) - CVTCOLOR(RGB2BGRA) CVTCOLOR(RGB2YUV) CVTCOLOR(RGB2XYZ) CVTCOLOR(RGB2HSV) CVTCOLOR(RGB2HSV_FULL) CVTCOLOR(RGB2BGR555) CVTCOLOR(RGB2BGR565) + } void filter(cv::Mat cvimg, VARP mnnimg) { diff --git a/tools/cv/source/imgproc/filter.cpp b/tools/cv/source/imgproc/filter.cpp index 5436b0170..b9e6c9204 100644 --- a/tools/cv/source/imgproc/filter.cpp +++ b/tools/cv/source/imgproc/filter.cpp @@ -32,7 +32,7 @@ static halide_type_t formatInput(VARP& src, bool fp = true) { src = _Convert(_Reshape(src, {1, channel, height, width}), NHWC); } } - if (fp) { + if (fp && src->getInfo() && src->getInfo()->type.code != halide_type_float) { src = _Cast(src, halide_type_of()); } return info->type; @@ -46,14 +46,17 @@ static VARP formatOutput(VARP src, halide_type_t type) { if (channel == 1) { squeeze_dims.push_back(-1); } - if (!squeeze_dims.empty()) { - src = _Squeeze(src, squeeze_dims); - } if (type == halide_type_of()) { src = _Minimum(src, _Scalar(255)); src = _Maximum(src, _Scalar(0)); } - return _Cast(src, type); + if (src->getInfo()) { + auto srctype = src->getInfo()->type; + if (srctype.code == type.code && srctype.bits == type.bits) { + return src; + } + } + return _Cast(src, type); // if same type, do not need. } template diff --git a/tools/train/source/nn/NN.cpp b/tools/train/source/nn/NN.cpp index a49c6afaf..8d49f6ada 100644 --- a/tools/train/source/nn/NN.cpp +++ b/tools/train/source/nn/NN.cpp @@ -710,7 +710,7 @@ class ConvBNReluFusedModule : public Module { int threadNumber = 1, ePack = 12; int unit2 = UP_DIV(outH * outW, ePack * threadNumber); int maxUnit = (int)::sqrtf((float)unit2); - const int MAX_UNIT = 4, MIN_UNIT = 2; + const int MAX_UNIT = 6, MIN_UNIT = 2; maxUnit = std::max(std::min(maxUnit, MAX_UNIT), MIN_UNIT); auto units = std::pair({0, 0}); diff --git a/transformers/diffusion/export/convert_mnn.py b/transformers/diffusion/export/convert_mnn.py new file mode 100644 index 000000000..b29b9ca38 --- /dev/null +++ b/transformers/diffusion/export/convert_mnn.py @@ -0,0 +1,21 @@ +import os +def convert(onnx_path, mnn_path, extra): + print('Onnx path: ', onnx_path) + print('MNN path: ', mnn_path) + print('Extra: ', extra) + convert_path = '../../../build/MNNConvert' + if not os.path.exists(convert_path): + print(convert_path + " not exist, use pymnn instead") + convert_path = 'mnnconvert' + models = ['text_encoder', 'unet', 'vae_decoder'] + for model in models: + cmd = convert_path + ' -f ONNX --modelFile ' + os.path.join(onnx_path, model, 'model.onnx') + ' --MNNModel ' + os.path.join(mnn_path, model + '.mnn') + ' --saveExternalData=1 ' + extra + print(cmd) + print(os.popen(cmd).read()) + +if __name__ == '__main__': + import sys + extra = "" + if len(sys.argv) > 3: + extra = sys.argv[3] + convert(sys.argv[1], sys.argv[2], extra) diff --git a/transformers/diffusion/pipeline.cpp b/transformers/diffusion/pipeline.cpp index ed35ba705..7ac441f04 100644 --- a/transformers/diffusion/pipeline.cpp +++ b/transformers/diffusion/pipeline.cpp @@ -113,6 +113,7 @@ bool Pipeline::load_modules() { // load text_encoder model { std::string model_path = mModelPath + "/text_encoder.mnn"; + MNN_PRINT("Load %s\n", model_path.c_str()); mModules[0].reset(Module::load( {"input_ids"}, {"last_hidden_state", "pooler_output"}, model_path.c_str(), runtime_manager_, &module_config)); @@ -125,6 +126,7 @@ bool Pipeline::load_modules() { // load unet model { std::string model_path = mModelPath + "/unet.mnn"; + MNN_PRINT("Load %s\n", model_path.c_str()); mModules[1].reset(Module::load( {"sample", "timestep", "encoder_hidden_states"}, {"out_sample"}, model_path.c_str(), runtime_manager_, &module_config)); @@ -137,6 +139,7 @@ bool Pipeline::load_modules() { // load vae_decoder model { std::string model_path = mModelPath + "/vae_decoder.mnn"; + MNN_PRINT("Load %s\n", model_path.c_str()); mModules[2].reset(Module::load( {"latent_sample"}, {"sample"}, model_path.c_str(), runtime_manager_, &module_config)); diff --git a/transformers/llm/config.json b/transformers/llm/config.json index f34f70063..d508b467d 100755 --- a/transformers/llm/config.json +++ b/transformers/llm/config.json @@ -6,10 +6,12 @@ "thread_num": 4, "precision": "low", "memory": "low", + "power":"normal", + "use_mmap":"false", "is_batch_quant": 1, "reuse_kv": false, "quant_kv": 0, "kvcache_limit": -1 -} \ No newline at end of file +} diff --git a/transformers/llm/engine/include/llm/llm.hpp b/transformers/llm/engine/include/llm/llm.hpp index 72fcaa742..4ebba43d6 100644 --- a/transformers/llm/engine/include/llm/llm.hpp +++ b/transformers/llm/engine/include/llm/llm.hpp @@ -112,10 +112,11 @@ class MNN_PUBLIC Llm { class Embedding : public Llm { public: Embedding(std::shared_ptr config); - static Embedding* createEmbedding(const std::string& config_path); + static Embedding* createEmbedding(const std::string& config_path, bool load = true); static float dist(MNN::Express::VARP var0, MNN::Express::VARP var1); virtual void load() override; - MNN::Express::VARP embedding(const std::string& txt); + MNN::Express::VARP ids_embedding(const std::vector& ids); + MNN::Express::VARP txt_embedding(const std::string& txt); int dim() const; private: virtual std::vector tokenizer(const std::string& query) override; diff --git a/transformers/llm/engine/ios/README.md b/transformers/llm/engine/ios/README.md new file mode 100644 index 000000000..4a682eb73 --- /dev/null +++ b/transformers/llm/engine/ios/README.md @@ -0,0 +1,44 @@ +# mnn-llm ios demo + +🚀 本示例代码全部由`ChatGPT-4`生成。 + +## 速度 + +[旧版测试prompt](../resource/prompt.txt) +- Qwen-1.8b-chat 4bit + - iPhone 11 : pefill 52.00 tok/s, decode 16.23 tok/s + - iPhone 14 Pro: pefill 102.63 tok/s, decode 33.53 tok/s +- Qwen-1.8b-chat 8bit + - iPhone 11 : pefill 61.90 tok/s, decode 14.75 tok/s + - iPhone 14 Pro: pefill 105.41 tok/s, decode 25.45 tok/s + +--- + +[新版测试prompt](../resource/bench.txt) +- Qwen1.5-0.5b-chat 4bit + - iPhone 15 Pro: pefill 282.73 tok/s, decode 51.68 tok/s +- Qwen2-0.5b-instruct 4bit + - iPhone 15 Pro: pefill 234.51 tok/s, decode 51.36 tok/s +- Qwen2-1.5b-instruct 4bit + - iPhone 15 Pro: pefill 107.64 tok/s, decode 25.57 tok/s + +## 编译 +1. 编译 MNN iOS Framework: 在 MNN 根目录下执行 +``` +sh package_scripts/ios/buildiOS.sh "-DMNN_ARM82=true -DMNN_LOW_MEMORY=true -DMNN_SUPPORT_TRANSFORMER_FUSE=true -DMNN_BUILD_LLM=true -DMNN_CPU_WEIGHT_DEQUANT_GEMM=true" +mv MNN-iOS-CPU-GPU/Static/MNN.framework transformers/llm/engine/ios/MNN.framework +``` +2. 下载模型文件: [Qwen1.5-0.5B-Chat-MNN](https://modelscope.cn/models/zhaode/Qwen1.5-0.5B-Chat-MNN/files) ,或者使用 export 下面的脚本导出模型 +3. 将模型文件拷贝到`${MNN根目录}/transformers/llm/engine/model/`目录下 +4. 在xcode项目属性中`Signing & Capabilities` > `Team`输入自己的账号;`Bundle Identifier`可以重新命名; +5. 连接iPhone并编译执行,需要在手机端打开开发者模式,并在安装完成后在:`设置` > `通用` > `VPN与设备管理`中选择信任该账号; + +备注:如测试其他模型,可以将`ios/mnn-llm/model/`替换为其他模型的文件夹;同时修改`LLMInferenceEngineWrapper.m +38`的模型路径; + +## 性能 +等待模型加载完成后,发送:`benchmark`,即可进行benchmark测试; + +## 测试 +等待模型加载完成后即可发送信息,如下图所示: + +![ios-app](./ios_app.jpg) diff --git a/transformers/llm/engine/ios/ios_app.jpg b/transformers/llm/engine/ios/ios_app.jpg new file mode 100644 index 000000000..72a9d54d9 Binary files /dev/null and b/transformers/llm/engine/ios/ios_app.jpg differ diff --git a/transformers/llm/engine/ios/mnn-llm/icon.png b/transformers/llm/engine/ios/mnn-llm/icon.png new file mode 100644 index 000000000..824ebb7ab Binary files /dev/null and b/transformers/llm/engine/ios/mnn-llm/icon.png differ diff --git a/transformers/llm/engine/ios/mnn-llm/mnn-llm.xcodeproj/project.pbxproj b/transformers/llm/engine/ios/mnn-llm/mnn-llm.xcodeproj/project.pbxproj new file mode 100644 index 000000000..7672178ca --- /dev/null +++ b/transformers/llm/engine/ios/mnn-llm/mnn-llm.xcodeproj/project.pbxproj @@ -0,0 +1,453 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 56; + objects = { + +/* Begin PBXBuildFile section */ + 4D5B978C2B2B21D3003AF2F1 /* mnn_llmApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4D5B978B2B2B21D3003AF2F1 /* mnn_llmApp.swift */; }; + 4D5B978E2B2B21D3003AF2F1 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4D5B978D2B2B21D3003AF2F1 /* ContentView.swift */; }; + 4D5B97902B2B21D5003AF2F1 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 4D5B978F2B2B21D5003AF2F1 /* Assets.xcassets */; }; + 4D5B97932B2B21D5003AF2F1 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 4D5B97922B2B21D5003AF2F1 /* Preview Assets.xcassets */; }; + 4D5B97C42B2B29CF003AF2F1 /* LLMInferenceEngineWrapper.mm in Sources */ = {isa = PBXBuildFile; fileRef = 4D5B97C32B2B29CF003AF2F1 /* LLMInferenceEngineWrapper.mm */; }; + CE1A4A5D2C8596D900A62A4F /* MNN.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = CE1A4A5C2C8596D900A62A4F /* MNN.framework */; }; + CE1A4A7C2C85B69800A62A4F /* config.json in Resources */ = {isa = PBXBuildFile; fileRef = CE1A4A722C85B69800A62A4F /* config.json */; }; + CE1A4A7D2C85B69800A62A4F /* embeddings_bf16.bin in Resources */ = {isa = PBXBuildFile; fileRef = CE1A4A732C85B69800A62A4F /* embeddings_bf16.bin */; }; + CE1A4A7E2C85B69800A62A4F /* llm_config.json in Resources */ = {isa = PBXBuildFile; fileRef = CE1A4A742C85B69800A62A4F /* llm_config.json */; }; + CE1A4A7F2C85B69800A62A4F /* llm.mnn in Resources */ = {isa = PBXBuildFile; fileRef = CE1A4A752C85B69800A62A4F /* llm.mnn */; }; + CE1A4A802C85B69800A62A4F /* llm.mnn.weight in Resources */ = {isa = PBXBuildFile; fileRef = CE1A4A762C85B69800A62A4F /* llm.mnn.weight */; }; + CE1A4A842C85B69800A62A4F /* tokenizer.txt in Resources */ = {isa = PBXBuildFile; fileRef = CE1A4A7A2C85B69800A62A4F /* tokenizer.txt */; }; + CE1A4A862C85D43E00A62A4F /* bench.txt in Resources */ = {isa = PBXBuildFile; fileRef = CE1A4A852C85D43E00A62A4F /* bench.txt */; }; +/* End PBXBuildFile section */ + +/* Begin PBXCopyFilesBuildPhase section */ + 4D7E1C0A2C40C6530004DA17 /* Embed Watch Content */ = { + isa = PBXCopyFilesBuildPhase; + buildActionMask = 2147483647; + dstPath = "$(CONTENTS_FOLDER_PATH)/Watch"; + dstSubfolderSpec = 16; + files = ( + ); + name = "Embed Watch Content"; + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXCopyFilesBuildPhase section */ + +/* Begin PBXFileReference section */ + 4D5B97882B2B21D3003AF2F1 /* mnn-llm.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "mnn-llm.app"; sourceTree = BUILT_PRODUCTS_DIR; }; + 4D5B978B2B2B21D3003AF2F1 /* mnn_llmApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = mnn_llmApp.swift; sourceTree = ""; }; + 4D5B978D2B2B21D3003AF2F1 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; }; + 4D5B978F2B2B21D5003AF2F1 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; + 4D5B97922B2B21D5003AF2F1 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = ""; }; + 4D5B97992B2B263D003AF2F1 /* LLMInferenceEngineWrapper.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = LLMInferenceEngineWrapper.h; sourceTree = ""; }; + 4D5B979A2B2B2677003AF2F1 /* mnn-llm-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "mnn-llm-Bridging-Header.h"; sourceTree = ""; }; + 4D5B97C32B2B29CF003AF2F1 /* LLMInferenceEngineWrapper.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = LLMInferenceEngineWrapper.mm; sourceTree = ""; }; + CE1A4A5C2C8596D900A62A4F /* MNN.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MNN.framework; path = ../MNN.framework; sourceTree = ""; }; + CE1A4A722C85B69800A62A4F /* config.json */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.json; path = config.json; sourceTree = ""; }; + CE1A4A732C85B69800A62A4F /* embeddings_bf16.bin */ = {isa = PBXFileReference; lastKnownFileType = archive.macbinary; path = embeddings_bf16.bin; sourceTree = ""; }; + CE1A4A742C85B69800A62A4F /* llm_config.json */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.json; path = llm_config.json; sourceTree = ""; }; + CE1A4A752C85B69800A62A4F /* llm.mnn */ = {isa = PBXFileReference; lastKnownFileType = file; path = llm.mnn; sourceTree = ""; }; + CE1A4A762C85B69800A62A4F /* llm.mnn.weight */ = {isa = PBXFileReference; lastKnownFileType = file; path = llm.mnn.weight; sourceTree = ""; }; + CE1A4A7A2C85B69800A62A4F /* tokenizer.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = tokenizer.txt; sourceTree = ""; }; + CE1A4A852C85D43E00A62A4F /* bench.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = bench.txt; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 4D5B97852B2B21D3003AF2F1 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + CE1A4A5D2C8596D900A62A4F /* MNN.framework in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 4D5B977F2B2B21D3003AF2F1 = { + isa = PBXGroup; + children = ( + CE1A4A7B2C85B69800A62A4F /* model */, + 4D5B978A2B2B21D3003AF2F1 /* mnn-llm */, + 4D5B97892B2B21D3003AF2F1 /* Products */, + 4D5B97C52B2B2C26003AF2F1 /* Frameworks */, + ); + sourceTree = ""; + }; + 4D5B97892B2B21D3003AF2F1 /* Products */ = { + isa = PBXGroup; + children = ( + 4D5B97882B2B21D3003AF2F1 /* mnn-llm.app */, + ); + name = Products; + sourceTree = ""; + }; + 4D5B978A2B2B21D3003AF2F1 /* mnn-llm */ = { + isa = PBXGroup; + children = ( + 4D5B978B2B2B21D3003AF2F1 /* mnn_llmApp.swift */, + 4D5B978D2B2B21D3003AF2F1 /* ContentView.swift */, + 4D5B978F2B2B21D5003AF2F1 /* Assets.xcassets */, + 4D5B97912B2B21D5003AF2F1 /* Preview Content */, + 4D5B97992B2B263D003AF2F1 /* LLMInferenceEngineWrapper.h */, + 4D5B97C32B2B29CF003AF2F1 /* LLMInferenceEngineWrapper.mm */, + 4D5B979A2B2B2677003AF2F1 /* mnn-llm-Bridging-Header.h */, + ); + path = "mnn-llm"; + sourceTree = ""; + }; + 4D5B97912B2B21D5003AF2F1 /* Preview Content */ = { + isa = PBXGroup; + children = ( + 4D5B97922B2B21D5003AF2F1 /* Preview Assets.xcassets */, + ); + path = "Preview Content"; + sourceTree = ""; + }; + 4D5B97C52B2B2C26003AF2F1 /* Frameworks */ = { + isa = PBXGroup; + children = ( + CE1A4A5C2C8596D900A62A4F /* MNN.framework */, + ); + name = Frameworks; + sourceTree = ""; + }; + CE1A4A7B2C85B69800A62A4F /* model */ = { + isa = PBXGroup; + children = ( + CE1A4A852C85D43E00A62A4F /* bench.txt */, + CE1A4A722C85B69800A62A4F /* config.json */, + CE1A4A732C85B69800A62A4F /* embeddings_bf16.bin */, + CE1A4A742C85B69800A62A4F /* llm_config.json */, + CE1A4A752C85B69800A62A4F /* llm.mnn */, + CE1A4A762C85B69800A62A4F /* llm.mnn.weight */, + CE1A4A7A2C85B69800A62A4F /* tokenizer.txt */, + ); + name = model; + path = ../../model; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 4D5B97872B2B21D3003AF2F1 /* mnn-llm */ = { + isa = PBXNativeTarget; + buildConfigurationList = 4D5B97962B2B21D5003AF2F1 /* Build configuration list for PBXNativeTarget "mnn-llm" */; + buildPhases = ( + 4D5B97842B2B21D3003AF2F1 /* Sources */, + 4D5B97852B2B21D3003AF2F1 /* Frameworks */, + 4D5B97862B2B21D3003AF2F1 /* Resources */, + 4D7E1C0A2C40C6530004DA17 /* Embed Watch Content */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = "mnn-llm"; + productName = "mnn-llm"; + productReference = 4D5B97882B2B21D3003AF2F1 /* mnn-llm.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 4D5B97802B2B21D3003AF2F1 /* Project object */ = { + isa = PBXProject; + attributes = { + BuildIndependentTargetsInParallel = 1; + LastSwiftUpdateCheck = 1540; + LastUpgradeCheck = 1410; + TargetAttributes = { + 4D5B97872B2B21D3003AF2F1 = { + CreatedOnToolsVersion = 14.1; + LastSwiftMigration = 1410; + }; + }; + }; + buildConfigurationList = 4D5B97832B2B21D3003AF2F1 /* Build configuration list for PBXProject "mnn-llm" */; + compatibilityVersion = "Xcode 14.0"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 4D5B977F2B2B21D3003AF2F1; + productRefGroup = 4D5B97892B2B21D3003AF2F1 /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 4D5B97872B2B21D3003AF2F1 /* mnn-llm */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 4D5B97862B2B21D3003AF2F1 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + CE1A4A862C85D43E00A62A4F /* bench.txt in Resources */, + CE1A4A842C85B69800A62A4F /* tokenizer.txt in Resources */, + 4D5B97932B2B21D5003AF2F1 /* Preview Assets.xcassets in Resources */, + 4D5B97902B2B21D5003AF2F1 /* Assets.xcassets in Resources */, + CE1A4A7E2C85B69800A62A4F /* llm_config.json in Resources */, + CE1A4A802C85B69800A62A4F /* llm.mnn.weight in Resources */, + CE1A4A7F2C85B69800A62A4F /* llm.mnn in Resources */, + CE1A4A7D2C85B69800A62A4F /* embeddings_bf16.bin in Resources */, + CE1A4A7C2C85B69800A62A4F /* config.json in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 4D5B97842B2B21D3003AF2F1 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 4D5B97C42B2B29CF003AF2F1 /* LLMInferenceEngineWrapper.mm in Sources */, + 4D5B978E2B2B21D3003AF2F1 /* ContentView.swift in Sources */, + 4D5B978C2B2B21D3003AF2F1 /* mnn_llmApp.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin XCBuildConfiguration section */ + 4D5B97942B2B21D5003AF2F1 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + "USING_DISK_EMBED=1", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 16.1; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + }; + name = Debug; + }; + 4D5B97952B2B21D5003AF2F1 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_NO_COMMON_BLOCKS = YES; + GCC_PREPROCESSOR_DEFINITIONS = " USING_DISK_EMBED=1"; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 16.1; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + SDKROOT = iphoneos; + SWIFT_COMPILATION_MODE = wholemodule; + SWIFT_OPTIMIZATION_LEVEL = "-O"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + 4D5B97972B2B21D5003AF2F1 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_EMBED_SWIFT_STANDARD_LIBRARIES = YES; + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; + ASSETCATALOG_COMPILER_INCLUDE_ALL_APPICON_ASSETS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "c++17"; + CLANG_ENABLE_MODULES = YES; + CODE_SIGN_STYLE = Automatic; + CURRENT_PROJECT_VERSION = 1; + DEVELOPMENT_ASSET_PATHS = "\"mnn-llm/Preview Content\""; + DEVELOPMENT_TEAM = 6G7464HHUS; + ENABLE_PREVIEWS = YES; + FRAMEWORK_SEARCH_PATHS = ( + "$(inherited)", + "$(PROJECT_DIR)/mnn-llm", + "$(PROJECT_DIR)/../", + "$(PROJECT_DIR)/../../", + ); + GCC_PREPROCESSOR_DEFINITIONS = ( + "MNN_ARM82=1", + "MNN_SUPPORT_TRANSFORMER_FUSE=1", + "MNN_LOW_MEMORY=1", + ); + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; + INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; + INFOPLIST_KEY_UILaunchScreen_Generation = YES; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + IPHONEOS_DEPLOYMENT_TARGET = 16.0; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + MARKETING_VERSION = 1.0; + PRODUCT_BUNDLE_IDENTIFIER = "com.zhaode.mnn-llm1"; + PRODUCT_NAME = "$(TARGET_NAME)"; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_OBJC_BRIDGING_HEADER = "mnn-llm/mnn-llm-Bridging-Header.h"; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 4D5B97982B2B21D5003AF2F1 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_EMBED_SWIFT_STANDARD_LIBRARIES = YES; + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; + ASSETCATALOG_COMPILER_INCLUDE_ALL_APPICON_ASSETS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "c++17"; + CLANG_ENABLE_MODULES = YES; + CODE_SIGN_STYLE = Automatic; + CURRENT_PROJECT_VERSION = 1; + DEVELOPMENT_ASSET_PATHS = "\"mnn-llm/Preview Content\""; + DEVELOPMENT_TEAM = 6G7464HHUS; + ENABLE_PREVIEWS = YES; + FRAMEWORK_SEARCH_PATHS = ( + "$(inherited)", + "$(PROJECT_DIR)/mnn-llm", + "$(PROJECT_DIR)/../", + "$(PROJECT_DIR)/../../", + ); + GCC_PREPROCESSOR_DEFINITIONS = ( + "MNN_ARM82=1", + "MNN_SUPPORT_TRANSFORMER_FUSE=1", + "MNN_LOW_MEMORY=1", + ); + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; + INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; + INFOPLIST_KEY_UILaunchScreen_Generation = YES; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + IPHONEOS_DEPLOYMENT_TARGET = 16.0; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + MARKETING_VERSION = 1.0; + PRODUCT_BUNDLE_IDENTIFIER = "com.zhaode.mnn-llm1"; + PRODUCT_NAME = "$(TARGET_NAME)"; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_OBJC_BRIDGING_HEADER = "mnn-llm/mnn-llm-Bridging-Header.h"; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 4D5B97832B2B21D3003AF2F1 /* Build configuration list for PBXProject "mnn-llm" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 4D5B97942B2B21D5003AF2F1 /* Debug */, + 4D5B97952B2B21D5003AF2F1 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 4D5B97962B2B21D5003AF2F1 /* Build configuration list for PBXNativeTarget "mnn-llm" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 4D5B97972B2B21D5003AF2F1 /* Debug */, + 4D5B97982B2B21D5003AF2F1 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 4D5B97802B2B21D3003AF2F1 /* Project object */; +} diff --git a/transformers/llm/engine/ios/mnn-llm/mnn-llm.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/transformers/llm/engine/ios/mnn-llm/mnn-llm.xcodeproj/project.xcworkspace/contents.xcworkspacedata new file mode 100644 index 000000000..919434a62 --- /dev/null +++ b/transformers/llm/engine/ios/mnn-llm/mnn-llm.xcodeproj/project.xcworkspace/contents.xcworkspacedata @@ -0,0 +1,7 @@ + + + + + diff --git a/transformers/llm/engine/ios/mnn-llm/mnn-llm.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist b/transformers/llm/engine/ios/mnn-llm/mnn-llm.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist new file mode 100644 index 000000000..18d981003 --- /dev/null +++ b/transformers/llm/engine/ios/mnn-llm/mnn-llm.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist @@ -0,0 +1,8 @@ + + + + + IDEDidComputeMac32BitWarning + + + diff --git a/transformers/llm/engine/ios/mnn-llm/mnn-llm/Assets.xcassets/AccentColor.colorset/Contents.json b/transformers/llm/engine/ios/mnn-llm/mnn-llm/Assets.xcassets/AccentColor.colorset/Contents.json new file mode 100644 index 000000000..eb8789700 --- /dev/null +++ b/transformers/llm/engine/ios/mnn-llm/mnn-llm/Assets.xcassets/AccentColor.colorset/Contents.json @@ -0,0 +1,11 @@ +{ + "colors" : [ + { + "idiom" : "universal" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/transformers/llm/engine/ios/mnn-llm/mnn-llm/Assets.xcassets/AppIcon.appiconset/Contents.json b/transformers/llm/engine/ios/mnn-llm/mnn-llm/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 000000000..a657e3367 --- /dev/null +++ b/transformers/llm/engine/ios/mnn-llm/mnn-llm/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,14 @@ +{ + "images" : [ + { + "filename" : "icon.png", + "idiom" : "universal", + "platform" : "ios", + "size" : "1024x1024" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/transformers/llm/engine/ios/mnn-llm/mnn-llm/Assets.xcassets/AppIcon.appiconset/icon.png b/transformers/llm/engine/ios/mnn-llm/mnn-llm/Assets.xcassets/AppIcon.appiconset/icon.png new file mode 100644 index 000000000..824ebb7ab Binary files /dev/null and b/transformers/llm/engine/ios/mnn-llm/mnn-llm/Assets.xcassets/AppIcon.appiconset/icon.png differ diff --git a/transformers/llm/engine/ios/mnn-llm/mnn-llm/Assets.xcassets/Contents.json b/transformers/llm/engine/ios/mnn-llm/mnn-llm/Assets.xcassets/Contents.json new file mode 100644 index 000000000..73c00596a --- /dev/null +++ b/transformers/llm/engine/ios/mnn-llm/mnn-llm/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/transformers/llm/engine/ios/mnn-llm/mnn-llm/ContentView.swift b/transformers/llm/engine/ios/mnn-llm/mnn-llm/ContentView.swift new file mode 100644 index 000000000..11ac631b1 --- /dev/null +++ b/transformers/llm/engine/ios/mnn-llm/mnn-llm/ContentView.swift @@ -0,0 +1,152 @@ +// +// ContentView.swift +// mnn-llm +// +// Created by wangzhaode on 2023/12/14. +// + +import Combine +import SwiftUI + +class ChatViewModel: ObservableObject { + @Published var messages: [Message] = [] + @Published var isModelLoaded = false // 模型是否加载完成 + @Published var isProcessing: Bool = false // 标志表示是否有正在处理的LLM响应 + private var llm: LLMInferenceEngineWrapper? + + init() { + self.messages.append(Message(id: UUID(), text: " 模型加载中, 请稍等 ...", isUser: false)) + llm = LLMInferenceEngineWrapper { [weak self] success in + DispatchQueue.main.async { + self?.isModelLoaded = success + var loadresult = "模型加载完毕!" + if !success { + loadresult = "模型加载失败!" + } + self?.messages.append(Message(id: UUID(), text: loadresult, isUser: false)) + } + } + } + + func sendInput(_ input: String) { + // 将用户输入作为新消息添加 + let userMessage = Message(id: UUID(), text: input, isUser: true) + DispatchQueue.main.async { + self.messages.append(userMessage) + } + isProcessing = true + // 在后台线程处理耗时的输入 + DispatchQueue.global(qos: .userInitiated).async { + self.llm?.processInput(input) { [weak self] output in + // 切换回主线程来更新UI + DispatchQueue.main.async { + if (output.contains("")) { + self?.isProcessing = false + } else { + self?.appendResponse(output) + } + } + } + } + } + + private func appendResponse(_ output: String) { + if let lastMessage = messages.last, !lastMessage.isUser { + // 创建一个更新后的消息 + var updatedMessage = messages[messages.count - 1] + updatedMessage.text += output + // 替换数组中的旧消息 + self.messages[messages.count - 1] = updatedMessage + } else { + let newMessage = Message(id: UUID(), text: output, isUser: false) + self.messages.append(newMessage) + } + } +} + + +struct Message: Identifiable, Equatable { + let id: UUID + var text: String + let isUser: Bool +} + +struct ChatBubble: View { + let message: Message + + var body: some View { + HStack { + if message.isUser { + Spacer() + } + + Text(message.text) + .padding(10) + .foregroundColor(message.isUser ? .white : .black) + .background(message.isUser ? Color.blue : Color.gray.opacity(0.2)) + .cornerRadius(10) + .frame(maxWidth: 400, alignment: message.isUser ? .trailing : .leading) + + if !message.isUser { + Spacer() + } + } + .transition(.scale(scale: 0, anchor: message.isUser ? .bottomTrailing : .bottomLeading)) + } +} + +struct ChatView: View { + @StateObject var viewModel = ChatViewModel() + @State private var inputText: String = "" + + var body: some View { + NavigationView { // 包裹在 NavigationView 中 + VStack { + ScrollView { + ScrollViewReader { scrollView in + VStack(alignment: .leading, spacing: 10) { + ForEach(viewModel.messages) { message in + ChatBubble(message: message) + } + } + .padding(.horizontal) + .onChange(of: viewModel.messages) { _ in + scrollView.scrollTo(viewModel.messages.last?.id, anchor: .bottom) + } + } + } + + HStack { + TextField("Type a message...", text: $inputText) + .textFieldStyle(RoundedBorderTextFieldStyle()) + .frame(minHeight: 44) + + Button(action: { + viewModel.sendInput(inputText) + inputText = "" + }) { + Image(systemName: "arrow.up.circle.fill") + .resizable() + .aspectRatio(contentMode: .fit) + .frame(width: 44, height: 44) + } + .disabled(inputText.isEmpty || viewModel.isProcessing || !viewModel.isModelLoaded) + } + .padding() + } + .navigationBarTitle("mnn-llm", displayMode: .inline) // 设置标题 + } + } +} + +extension String { + var isBlank: Bool { + return allSatisfy({ $0.isWhitespace }) + } +} + +struct ChatView_Previews: PreviewProvider { + static var previews: some View { + ChatView() + } +} diff --git a/transformers/llm/engine/ios/mnn-llm/mnn-llm/LLMInferenceEngineWrapper.h b/transformers/llm/engine/ios/mnn-llm/mnn-llm/LLMInferenceEngineWrapper.h new file mode 100644 index 000000000..28374c06d --- /dev/null +++ b/transformers/llm/engine/ios/mnn-llm/mnn-llm/LLMInferenceEngineWrapper.h @@ -0,0 +1,29 @@ +// +// LLMInferenceEngineWrapper.h +// mnn-llm +// +// Created by wangzhaode on 2023/12/14. +// + +#ifndef LLMInferenceEngineWrapper_h +#define LLMInferenceEngineWrapper_h + + +// LLMInferenceEngineWrapper.h +#import + +NS_ASSUME_NONNULL_BEGIN + +typedef void(^ModelLoadingCompletionHandler)(BOOL success); +typedef void (^StreamOutputHandler)(NSString * _Nonnull output); + +@interface LLMInferenceEngineWrapper : NSObject + +- (instancetype)initWithCompletionHandler:(ModelLoadingCompletionHandler)completionHandler; +- (void)processInput:(NSString *)input withStreamHandler:(StreamOutputHandler)handler; + +@end + +NS_ASSUME_NONNULL_END + +#endif /* LLMInferenceEngineWrapper_h */ diff --git a/transformers/llm/engine/ios/mnn-llm/mnn-llm/LLMInferenceEngineWrapper.mm b/transformers/llm/engine/ios/mnn-llm/mnn-llm/LLMInferenceEngineWrapper.mm new file mode 100644 index 000000000..4d05379a4 --- /dev/null +++ b/transformers/llm/engine/ios/mnn-llm/mnn-llm/LLMInferenceEngineWrapper.mm @@ -0,0 +1,106 @@ +// +// LLMInferenceEngineWrapper.m +// mnn-llm +// +// Created by wangzhaode on 2023/12/14. +// + +#import "LLMInferenceEngineWrapper.h" +#include +using namespace MNN::Transformer; + +const char* GetMainBundleDirectory() { + NSString *bundleDirectory = [[NSBundle mainBundle] bundlePath]; + return [bundleDirectory UTF8String]; +} + +@implementation LLMInferenceEngineWrapper { + std::shared_ptr llm; +} + +- (instancetype)initWithCompletionHandler:(ModelLoadingCompletionHandler)completionHandler { + self = [super init]; + if (self) { + // 在后台线程异步加载模型 + dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{ + BOOL success = [self loadModel]; // 假设loadModel方法加载模型并返回加载的成功或失败 + // 切回主线程回调 + dispatch_async(dispatch_get_main_queue(), ^{ + completionHandler(success); + }); + }); + } + return self; +} + +- (BOOL)loadModel { + if (!llm) { + std::string model_dir = GetMainBundleDirectory(); + std::string config_path = model_dir + "/config.json"; + llm.reset(Llm::createLLM(config_path)); + NSString *tempDirectory = NSTemporaryDirectory(); + llm->set_config("{\"tmp_path\":\"" + std::string([tempDirectory UTF8String]) + "\", \"use_mmap\":true}"); + llm->load(); + } + return YES; +} + +- (void)processInput:(NSString *)input withStreamHandler:(StreamOutputHandler)handler { + LlmStreamBuffer::CallBack callback = [handler](const char* str, size_t len) { + if (handler) { + NSString *nsOutput = [NSString stringWithUTF8String:str]; + handler(nsOutput); + } + }; + LlmStreamBuffer streambuf(callback); + std::ostream os(&streambuf); + if (std::string([input UTF8String]) == "benchmark") { + // do benchmark + std::string model_dir = GetMainBundleDirectory(); + std::string prompt_file = model_dir + "/bench.txt"; + std::ifstream prompt_fs(prompt_file); + std::vector prompts; + std::string prompt; + while (std::getline(prompt_fs, prompt)) { + // prompt start with '#' will be ignored + if (prompt.substr(0, 1) == "#") { + continue; + } + std::string::size_type pos = 0; + while ((pos = prompt.find("\\n", pos)) != std::string::npos) { + prompt.replace(pos, 2, "\n"); + pos += 1; + } + prompts.push_back(prompt); + } + int prompt_len = 0; + int decode_len = 0; + int64_t prefill_time = 0; + int64_t decode_time = 0; + for (int i = 0; i < prompts.size(); i++) { + llm->response(prompts[i], &os, "\n"); + prompt_len += llm->prompt_len_; + decode_len += llm->gen_seq_len_; + prefill_time += llm->prefill_us_; + decode_time += llm->decode_us_; + } + float prefill_s = prefill_time / 1e6; + float decode_s = decode_time / 1e6; + os << "\n#################################\n" + << "prompt tokens num = " << prompt_len << "\n" + << "decode tokens num = " << decode_len << "\n" + << "prefill time = " << std::fixed << std::setprecision(2) << prefill_s << " s\n" + << " decode time = " << std::fixed << std::setprecision(2) << decode_s << " s\n" + << "prefill speed = " << std::fixed << std::setprecision(2) << prompt_len / prefill_s << " tok/s\n" + << " decode speed = " << std::fixed << std::setprecision(2) << decode_len / decode_s << " tok/s\n" + << "##################################\n"; + os << ""; + } else { + llm->response([input UTF8String], &os, ""); + } +} + +- (void)dealloc { + llm.reset(); +} +@end diff --git a/transformers/llm/engine/ios/mnn-llm/mnn-llm/Preview Content/Preview Assets.xcassets/Contents.json b/transformers/llm/engine/ios/mnn-llm/mnn-llm/Preview Content/Preview Assets.xcassets/Contents.json new file mode 100644 index 000000000..73c00596a --- /dev/null +++ b/transformers/llm/engine/ios/mnn-llm/mnn-llm/Preview Content/Preview Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/transformers/llm/engine/ios/mnn-llm/mnn-llm/mnn-llm-Bridging-Header.h b/transformers/llm/engine/ios/mnn-llm/mnn-llm/mnn-llm-Bridging-Header.h new file mode 100644 index 000000000..208d3edbb --- /dev/null +++ b/transformers/llm/engine/ios/mnn-llm/mnn-llm/mnn-llm-Bridging-Header.h @@ -0,0 +1,5 @@ +// +// Use this file to import your target's public headers that you would like to expose to Swift. +// + +#import "LLMInferenceEngineWrapper.h" diff --git a/transformers/llm/engine/ios/mnn-llm/mnn-llm/mnn_llmApp.swift b/transformers/llm/engine/ios/mnn-llm/mnn-llm/mnn_llmApp.swift new file mode 100644 index 000000000..c0585da9f --- /dev/null +++ b/transformers/llm/engine/ios/mnn-llm/mnn-llm/mnn_llmApp.swift @@ -0,0 +1,17 @@ +// +// mnn_llmApp.swift +// mnn-llm +// +// Created by wangzhaode on 2023/12/14. +// + +import SwiftUI + +@main +struct mnn_llmApp: App { + var body: some Scene { + WindowGroup { + ChatView() + } + } +} diff --git a/transformers/llm/engine/llm_demo.cpp b/transformers/llm/engine/llm_demo.cpp index 3e41b2eb0..1200957c0 100644 --- a/transformers/llm/engine/llm_demo.cpp +++ b/transformers/llm/engine/llm_demo.cpp @@ -8,6 +8,7 @@ #include "llm/llm.hpp" #define MNN_OPEN_TIME_TRACE #include +#include #include #include #include @@ -160,14 +161,19 @@ int main(int argc, const char* argv[]) { std::cout << "Usage: " << argv[0] << " config.json " << std::endl; return 0; } + MNN::BackendConfig backendConfig; + auto executor = MNN::Express::Executor::newExecutor(MNN_FORWARD_CPU, backendConfig, 1); + MNN::Express::ExecutorScope s(executor); + std::string config_path = argv[1]; std::cout << "config path is " << config_path << std::endl; std::unique_ptr llm(Llm::createLLM(config_path)); + llm->set_config("{\"tmp_path\":\"tmp\"}"); { AUTOTIME; llm->load(); } - if (false) { + if (true) { AUTOTIME; trace_prepare(llm.get()); } diff --git a/transformers/llm/engine/model/bench.txt b/transformers/llm/engine/model/bench.txt new file mode 100644 index 000000000..87e49b7f1 --- /dev/null +++ b/transformers/llm/engine/model/bench.txt @@ -0,0 +1,4 @@ +计算8乘以12 +将下面的句子翻译成中文:It's a beautiful day to learn something new. +描述优秀的领导者应具备的五个特质,并解释每个特质为什么重要 +近年来,随着技术的快速发展和全球化的深入推进,数字经济已成为推动世界经济增长的新引擎。数字经济不仅改变了人们的生活方式,促进了信息和资源的快速流通,还重塑了传统行业的业务模式和竞争格局。尽管数字经济的发展为全球经济增长提供了新的动能,但同时也带来了数据安全、隐私保护、数字鸿沟和市场垄断等一系列挑战。考虑到这些背景,请详细分析数字经济在促进世界经济增长方面的作用,包括但不限于数字经济对提高生产效率、创造就业机会和促进可持续发展的贡献。同时,探讨如何应对数字经济发展过程中出现的挑战,具体包括如何保护个人数据安全和隐私、缩小数字鸿沟以确保数字经济的包容性和公平性,以及如何制定有效政策以避免市场垄断情况的出现,最终实现数字经济的健康和可持续发展。 \ No newline at end of file diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index e01350eb9..d11254110 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -85,10 +85,19 @@ void Llm::init_runtime() { BackendConfig cpuBackendConfig; config.type = backend_type_convert(config_->backend_type()); config.numThread = config_->thread_num(); - if (config_->memory() == "low") { + if (config_->power() == "high") { + cpuBackendConfig.power = BackendConfig::Power_High; + } else if (config_->power() == "low") { + cpuBackendConfig.power = BackendConfig::Power_Low; + } + if (config_->memory() == "high") { + cpuBackendConfig.memory = BackendConfig::Memory_High; + } else if (config_->memory() == "low") { cpuBackendConfig.memory = BackendConfig::Memory_Low; } - if (config_->precision() == "low") { + if (config_->precision() == "high") { + cpuBackendConfig.precision = BackendConfig::Precision_High; + } else if (config_->precision() == "low") { cpuBackendConfig.precision = BackendConfig::Precision_Low; } config.backendConfig = &cpuBackendConfig; @@ -97,10 +106,16 @@ void Llm::init_runtime() { runtime_manager_.reset(Executor::RuntimeManager::createRuntimeManager(config)); runtime_manager_->setHint(MNN::Interpreter::MEM_ALLOCATOR_TYPE, 0); runtime_manager_->setHint(MNN::Interpreter::DYNAMIC_QUANT_OPTIONS, 1); // 1: per batch quant, 2: per tensor quant - runtime_manager_->setHint(MNN::Interpreter::KVCACHE_QUANT_OPTIONS, config_->quant_kv()); + runtime_manager_->setHint(MNN::Interpreter::QKV_QUANT_OPTIONS, config_->quant_qkv()); runtime_manager_->setHint(MNN::Interpreter::KVCACHE_SIZE_LIMIT, config_->kvcache_limit()); - runtime_manager_->setExternalPath("/tmp/.kvcache", MNN::Interpreter::EXTERNAL_PATH_KVCACHE_DIR); - + std::string tmpPath = config_->tmp_path(); + if (config_->kvcache_mmap()) { + runtime_manager_->setExternalPath(tmpPath, MNN::Interpreter::EXTERNAL_PATH_KVCACHE_DIR); + } + if (config_->use_mmap()) { + runtime_manager_->setExternalPath(tmpPath, MNN::Interpreter::EXTERNAL_WEIGHT_DIR); + } + #if DEBUG_MODE==1 runtime_manager_->setMode(MNN::Interpreter::Session_Debug); _initTimeTrace(); @@ -154,7 +169,7 @@ void Llm::load() { {"input_ids", "attention_mask", "position_ids", "past_key_values"}, {"logits", "presents"}, model_path.c_str(), runtime_manager_, &module_config)); } - MNN_PRINT("Done!\n"); + MNN_PRINT("Load Module Done!\n"); } else { MNN_ERROR("Split version is depercerate\n"); } @@ -162,6 +177,8 @@ void Llm::load() { for (int v=0; vreuse_kv()) { + response(user_str); + } else { + history.emplace_back(std::make_pair("user", user_str)); + auto assistant_str = response(history); + history.emplace_back(std::make_pair("assistant", assistant_str)); + } std::cout << std::endl; } } @@ -777,10 +798,12 @@ float Embedding::dist(VARP var0, VARP var1) { return dist; } -Embedding* Embedding::createEmbedding(const std::string& config_path) { +Embedding* Embedding::createEmbedding(const std::string& config_path, bool load) { std::shared_ptr config(new LlmConfig(config_path)); Embedding* embedding = new Embedding(config); - embedding->load(); + if (load) { + embedding->load(); + } return embedding; } @@ -808,10 +831,9 @@ void Embedding::load() { MNN_PRINT("Done!\n"); } -VARP Embedding::embedding(const std::string& txt) { - auto ids = tokenizer(txt); +VARP Embedding::ids_embedding(const std::vector& ids) { int prompt_len = ids.size(); - auto inputs_ids = _Const(ids.data(), {prompt_len}, NCHW, halide_type_of()); + auto inputs_ids = embedding(ids); auto attention_mask = gen_attention_mask(prompt_len); auto position_ids = gen_position_ids(prompt_len); auto outputs = modules_[0]->onForward({inputs_ids, attention_mask, position_ids}); @@ -819,12 +841,12 @@ VARP Embedding::embedding(const std::string& txt) { return sentence_embeddings; } +VARP Embedding::txt_embedding(const std::string& txt) { + return ids_embedding(tokenizer(txt)); +} + std::vector Embedding::tokenizer(const std::string& query) { - auto prompt = query; - if (query.size() <= 256) { - prompt = "为这个句子生成表示以用于检索相关文章:" + query; - } - prompt = apply_prompt_template(prompt); + auto prompt = apply_prompt_template(query); auto ids = tokenizer_->encode(prompt); return ids; } diff --git a/transformers/llm/engine/src/llmconfig.hpp b/transformers/llm/engine/src/llmconfig.hpp index 57cc924a8..22b66c895 100644 --- a/transformers/llm/engine/src/llmconfig.hpp +++ b/transformers/llm/engine/src/llmconfig.hpp @@ -241,13 +241,16 @@ class LlmConfig { std::string precision() const { return config_.value("precision", "low"); } + std::string power() const { + return config_.value("power", "normal"); + } std::string memory() const { return config_.value("memory", "low"); } - int quant_kv() const { - return config_.value("quant_kv", 0); + int quant_qkv() const { + return config_.value("quant_qkv", 0); } int kvcache_limit() const { @@ -264,6 +267,16 @@ class LlmConfig { return llm_config_.value("is_visual", false); } + bool use_mmap() const { + return config_.value("use_mmap", false); + } + bool kvcache_mmap() const { + return config_.value("kvcache_mmap", false); + } + std::string tmp_path() const { + return config_.value("tmp_path", ""); + } + int hidden_size() const { return llm_config_.value("hidden_size", 4096); } diff --git a/transformers/llm/export/README.md b/transformers/llm/export/README.md index bc72b39da..bdd38a9de 100644 --- a/transformers/llm/export/README.md +++ b/transformers/llm/export/README.md @@ -4,156 +4,85 @@ llm-export是一个llm模型导出工具,能够将llm模型导出为onnx和mnn模型。 -- 🚀 均完成`onnxruntime`正确性测试 - 🚀 优化原始代码,支持动态形状 - 🚀 优化原始代码,减少常量部分 -- 🚀 使用[OnnxSlim](https://github.com/WeLoveAI/OnnxSlim)优化onnx模型,性能提升约5%; by [@inisis](https://github.com/inisis) +- 🚀 使用[OnnxSlim](https://github.com/inisis/OnnxSlim)优化onnx模型,性能提升约5%; by [@inisis](https://github.com/inisis) - 🚀 支持将lora权重导出为onnx和mnn +- 🚀 Onnx推理代码[OnnxLLM](https://github.com/inisis/OnnxLLM) -## 模型支持与下载 -- [![Download][download-chatglm-6b-onnx]][release-chatglm-6b-onnx] -- [![Download][download-chatglm2-6b-onnx]][release-chatglm2-6b-onnx] -- [![Download][download-chatglm3-6b-onnx]][release-chatglm3-6b-onnx] -- [![Download][download-codegeex2-6b-onnx]][release-codegeex2-6b-onnx] -- [![Download][download-qwen-7b-chat-onnx]][release-qwen-7b-chat-onnx] -- [![Download][download-baichuan2-7b-chat-onnx]][release-baichuan2-7b-chat-onnx] -- [![Download][download-llama2-7b-chat-onnx]][release-llama2-7b-chat-onnx] -- [![Download][download-qwen-1.8b-chat-onnx]][release-qwen-1.8b-chat-onnx] -- [![Download][download-phi-2-onnx]][release-phi-2-onnx] -- [![Download][download-internlm-7b-onnx]][release-internlm-7b-onnx] -- [![Download][download-qwen-vl-onnx]][release-qwen-vl-onnx] -- [![Download][download-bge-large-zh-onnx]][release-bge-large-zh-onnx] -- [![Download][download-tinyllama-1.1b-chat-onnx]][release-tinyllama-1.1b-chat-onnx] -- [![Download][download-yi-6b-chat-onnx]][release-yi-6b-chat-onnx] -- [![Download][download-deepseek-7b-chat-onnx]][release-deepseek-7b-chat-onnx] -- [![Download][download-qwen1.5-0.5b-chat-onnx]][release-qwen1.5-0.5b-chat-onnx] -- [![Download][download-qwen1.5-1.8b-chat-onnx]][release-qwen1.5-1.8b-chat-onnx] -- [![Download][download-qwen1.5-4b-chat-onnx]][release-qwen1.5-4b-chat-onnx] -- [![Download][download-qwen1.5-7b-chat-onnx]][release-qwen1.5-7b-chat-onnx] -- [![Download][download-llama3-8b-instruct-onnx]][release-llama3-8b-instruct-onnx] +## 安装 +```sh +# pip install +pip install llmexport -[download-chatglm-6b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/chatglm-6b-onnx/total -[download-chatglm2-6b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/chatglm2-6b-onnx/total -[download-chatglm3-6b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/chatglm3-6b-onnx/total -[download-codegeex2-6b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/codegeex2-6b-onnx/total -[download-qwen-7b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/qwen-7b-chat-onnx/total -[download-baichuan2-7b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/baichuan2-7b-chat-onnx/total -[download-llama2-7b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/llama2-7b-chat-onnx/total -[download-qwen-1.8b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/qwen-1.8b-onnx/total -[download-phi-2-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/phi-2-onnx/total -[download-internlm-7b-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/internlm-7b-onnx/total -[download-qwen-vl-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/qwen-vl-onnx/total -[download-bge-large-zh-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/bge-large-zh-onnx/total -[download-tinyllama-1.1b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/tinyllama-1.1b-chat-onnx/total -[download-yi-6b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/yi-6b-chat-onnx/total -[download-deepseek-7b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/deepseek-7b-chat-onnx/total -[download-qwen1.5-0.5b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/qwen1.5-0.5b-chat-onnx/total -[download-qwen1.5-1.8b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/qwen1.5-1.8b-chat-onnx/total -[download-qwen1.5-4b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/qwen1.5-4b-chat-onnx/total -[download-qwen1.5-7b-chat-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/qwen1.5-7b-chat-onnx/total -[download-llama3-8b-instruct-onnx]: https://img.shields.io/github/downloads/wangzhaode/llm-export/llama3-8b-instruct-onnx/total -[release-chatglm-6b-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/chatglm-6b-onnx -[release-chatglm2-6b-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/chatglm2-6b-onnx -[release-chatglm3-6b-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/chatglm3-6b-onnx -[release-codegeex2-6b-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/codegeex2-6b-onnx -[release-qwen-7b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/qwen-7b-chat-onnx -[release-baichuan2-7b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/baichuan2-7b-chat-onnx -[release-llama2-7b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/llama2-7b-chat-onnx -[release-qwen-1.8b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/qwen-1.8b-onnx -[release-phi-2-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/phi-2-onnx -[release-internlm-7b-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/internlm-7b-onnx -[release-qwen-vl-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/qwen-vl-onnx -[release-bge-large-zh-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/bge-large-zh-onnx -[release-tinyllama-1.1b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/tinyllama-1.1b-chat-onnx -[release-yi-6b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/yi-6b-chat-onnx -[release-deepseek-7b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/deepseek-7b-chat-onnx -[release-qwen1.5-0.5b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/qwen1.5-0.5b-chat-onnx -[release-qwen1.5-1.8b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/qwen1.5-1.8b-chat-onnx -[release-qwen1.5-4b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/qwen1.5-4b-chat-onnx -[release-qwen1.5-7b-chat-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/qwen1.5-7b-chat-onnx -[release-llama3-8b-instruct-onnx]: https://github.com/wangzhaode/llm-export/releases/tag/llama3-8b-instruct-onnx +# git install +pip install git+https://github.com/wangzhaode/llm-export@master -## 用法 -1. 将该项目clone到本地 -```sh -git clone git@github.com:wangzhaode/llm-export.git +# local install +git clone https://github.com/wangzhaode/llm-export && cd llm-export/ +pip install . ``` -2. 将需要导出的LLM项目clone到本地,如:chatglm2-6b + +## 用法 + +1. 将需要导出的LLM项目clone到本地,如:chatglm2-6b ```sh git clone https://huggingface.co/THUDM/chatglm2-6b # 如果huggingface下载慢可以使用modelscope git clone https://modelscope.cn/ZhipuAI/chatglm2-6b.git ``` -3. 执行LLMExporter导出模型 +2. 导出模型 ```sh -cd mnn-llm -# 将chatglm2-6b分为embedding, blocks, lm分别导出为onnx并转换为mnn, 并导出tokenizer.txt -python llm_export.py \ - --path ../chatglm2-6b \ - --export_split \ - --export_token \ - --export_mnn \ - --onnx_path ./chatglm2-6b-onnx \ - --mnn_path ./chatglm2-6b-mnn +# 将chatglm2-6b导出为onnx模型 +llmexport --path ../chatglm2-6b --export onnx +# 将chatglm2-6b导出为mnn模型, 量化参数为4bit, blokc-wise = 128 +llmexport --path ../chatglm2-6b --export mnn --quant_bit 4 --quant_block 128 ``` ## 功能 -- 支持将模型完整导出为一个onnx模型,使用`--export` -- 支持将模型分段导出为多个模型,使用`--export_split` -- 支持导出模型的词表到一个文本文件,每行代表一个token;其中token使用base64编码;使用`--export_verbose` -- 支持导出模型的Embedding层为一个onnx模型,使用`--export_embed`,同时支持bf16格式,使用`--embed_bf16` -- 支持分层导出模型的block,使用`--export_blocks`导出全部层;使用`--export_block $id`导出指定层 -- 支持导出模型的lm_head层为一个onnx模型,使用`--export_lm` -- 支持导出多模态模型的visual模型为一个onnx模型,使用`--export_visual` - 支持对模型进行对话测试,使用`--test $query`会返回llm的回复内容 -- 支持在导出onnx模型后使用onnxruntime对结果一致性进行校验,使用`--export_test` -- 支持将tokenizer导出为文本文件,使用`--export_token` -- 支持将导出的onnx模型转换为mnn模型,默认转换为非对称4bit量化,使用`--export_mnn` -- 指定导出路径使用`--onnx_path`和`--mnn_path` - 默认会使用onnx-slim对onnx模型进行优化,跳过该步骤使用`--skip_slim` - 支持合并lora权重后导出,指定lora权重的目录使用`--lora_path` +- 制定量化bit数使用`--quant_bit`;量化的block大小使用`--quant_block` +- 使用`--lm_quant_bit`来制定lm_head层权重的量化bit数,不指定则使用`--quant_bit`的量化bit数 +- 支持使用自己编译的`MNNConvert`,使用`--mnnconvert` ## 参数 ``` -usage: llm_export.py [-h] --path PATH - [--type {chatglm-6b,chatglm2-6b,chatglm3-6b,codegeex2-6b,Qwen-7B-Chat,Qwen-1_8B-Chat,Qwen-1_8B,Qwen-VL-Chat,Qwen1_5-0_5B-Chat,Qwen1_5-1_8B-Chat,Qwen1_5-4B-Chat,Qwen1_5-7B-Chat,Baichuan2-7B-Chat,Llama-2-7b-chat-ms,Llama-3-8B-Instruct,internlm-chat-7b,TinyLlama-1_1B-Chat,Yi-6B-Chat,deepseek-llm-7b-chat,phi-2,bge-large-zh,lora}] - [--lora_path LORA_PATH] [--onnx_path ONNX_PATH] [--mnn_path MNN_PATH] [--export_mnn] [--export_verbose] [--export_test] [--test TEST] [--export] - [--export_split] [--export_token] [--export_embed] [--export_visual] [--export_lm] [--export_block EXPORT_BLOCK] [--export_blocks] [--embed_bin] - [--embed_bf16] [--skip_slim] +usage: llmexport.py [-h] --path PATH [--type TYPE] [--lora_path LORA_PATH] [--dst_path DST_PATH] [--test TEST] [--export EXPORT] + [--skip_slim] [--quant_bit QUANT_BIT] [--quant_block QUANT_BLOCK] [--lm_quant_bit LM_QUANT_BIT] + [--mnnconvert MNNCONVERT] llm_exporter -optional arguments: +options: -h, --help show this help message and exit --path PATH path(`str` or `os.PathLike`): Can be either: - A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO] - A path to a *directory* clone from repo like `../chatglm-6b`. - --type {chatglm-6b,chatglm2-6b,chatglm3-6b,codegeex2-6b,Qwen-7B-Chat,Qwen-1_8B-Chat,Qwen-1_8B,Qwen-VL-Chat,Qwen1_5-0_5B-Chat,Qwen1_5-1_8B-Chat,Qwen1_5-4B-Chat,Qwen1_5-7B-Chat,Baichuan2-7B-Chat,Llama-2-7b-chat-ms,Llama-3-8B-Instruct,internlm-chat-7b,TinyLlama-1_1B-Chat,Yi-6B-Chat,deepseek-llm-7b-chat,phi-2,bge-large-zh,lora} - type(`str`, *optional*): + --type TYPE type(`str`, *optional*): The pretrain llm model type. --lora_path LORA_PATH lora path, defaut is `None` mean not apply lora. - --onnx_path ONNX_PATH - export onnx model path, defaut is `./onnx`. - --mnn_path MNN_PATH export mnn model path, defaut is `./mnn`. - --export_mnn Whether or not to export mnn model after onnx. - --export_verbose Whether or not to export onnx with verbose. - --export_test Whether or not to export onnx with test using onnxruntime. + --dst_path DST_PATH export onnx/mnn model to path, defaut is `./model`. --test TEST test model inference with query `TEST`. - --export export model to an `onnx` model. - --export_split export model split to some `onnx` models: - - embedding model. - - block models. - - lm_head model. - --export_token export llm tokenizer to a txt file. - --export_embed export llm embedding to an `onnx` model. - --export_visual export llm visual model to an `onnx` model. - --export_lm export llm lm_head to an `onnx` model. - --export_block EXPORT_BLOCK - export llm block [id] to an `onnx` model. - --export_blocks export llm all blocks to `onnx` models. - --embed_bin export embedding weight as bin file with dtype `bfloat16` - --embed_bf16 using `bfloat16` replace `float32` in embedding. + --export EXPORT export model to an onnx/mnn model. --skip_slim Whether or not to skip onnx-slim. + --quant_bit QUANT_BIT + mnn quant bit, 4 or 8, default is 4. + --quant_block QUANT_BLOCK + mnn quant block, default is 0 mean channle-wise. + --lm_quant_bit LM_QUANT_BIT + mnn lm_head quant bit, 4 or 8, default is `quant_bit`. + --mnnconvert MNNCONVERT + local mnnconvert path, if invalid, using pymnn. ``` + +## 支持模型 + +- llama/llama2/llama3/tinyllama +- qwen/qwen1.5/qwen2/qwen-vl +- baichuan2/phi-2/internlm/yi/deepseek +- chatglm/codegeex/chatglm2/chatglm3 +- phi-2/gemma-2 \ No newline at end of file diff --git a/transformers/llm/export/README_en.md b/transformers/llm/export/README_en.md deleted file mode 100644 index 9942c23f1..000000000 --- a/transformers/llm/export/README_en.md +++ /dev/null @@ -1,92 +0,0 @@ -# llm-export - -[中文](./README_en.md) - -llm-export is a tool for exporting llm models, capable of converting llm models into ONNX or MNN models. -- 🚀 All passed `onnxruntime` correctness tests -- 🚀 Optimized the original code to support dynamic shapes -- 🚀 Optimized the original code to reduce the constant portion -- 🚀 Using [OnnxSlim](https://github.com/WeLoveAI/OnnxSlim) slim onnx model,speed up 5%; by [@inisis](https://github.com/inisis) -- 🚀 Support export lora weight to onnx or MNN model - -## Model Support and Downloads - -## Usage -1. Clone this project locally -```sh -git clnoe git@github.com:wangzhaode/llm-export.git -``` -2. Clone the LLM project that you want to export locally, such as: chatglm2-6b -```sh -git clone https://huggingface.co/THUDM/chatglm2-6b -# If downloading from Hugging Face is slow, you can use ModelScope -git clone https://modelscope.cn/ZhipuAI/chatglm2-6b.git -``` -3. Execute LLMExporter to export the model -```sh -cd mnn-llm -# Divide chatglm2-6b into embedding, blocks, lm, export each as ONNX and convert to MNN, and also export tokenizer.txt -python llm_export.py \ - --path ../chatglm2-6b \ - --export_split \ - --export_token \ - --export_mnn \ - --onnx_path ./chatglm2-6b-onnx \ - --mnn_path ./chatglm2-6b-mnn -``` - -## Features -- Supports exporting the entire model as a single ONNX model, use --export -- Supports exporting the model in segments as multiple models, use --export_split -- Supports exporting the model's vocabulary to a text file, each line representing a token; tokens are encoded using base64, use --export_verbose -- Supports exporting the model's Embedding layer as an ONNX model, use --export_embed, also supports bf16 format, use --embed_bf16 -- Supports layered export of the model's blocks, use --export_blocks to export all layers; use --export_block $id to export a specified layer -- Supports exporting the model's lm_head layer as an ONNX model, use --export_lm -- Supports exporting the VL model's visual model as an ONNX model, use --export_visual -- Supports conducting a dialogue test on the model, using --test $query will return the llm's response -- Supports verifying the consistency of results using onnxruntime after exporting the ONNX model, use --export_test -- Supports exporting the tokenizer as a text file, use --export_token -- Supports converting the exported ONNX model to an MNN model, with default conversion to non-symmetric 4bit quantization, use --export_mnn -- Specify export paths using --onnx_path and --mnn_path -- Default using onnx-slim, skip using --skip_slim - -## Commad Args -``` -usage: llm_export.py [-h] --path PATH - [--type {chatglm-6b,chatglm2-6b,chatglm3-6b,codegeex2-6b,Qwen-7B-Chat,Qwen-1_8B-Chat,Qwen-VL-Chat,Baichuan2-7B-Chat,Llama-2-7b-chat-ms,internlm-chat-7b,TinyLlama-1_1B-Chat,Yi-6B-Chat,deepseek-llm-7b-chat,phi-2,bge-large-zh}] - [--onnx_path ONNX_PATH] [--mnn_path MNN_PATH] [--export_mnn] [--export_verbose] [--export_test] [--test TEST] [--export] [--export_split] [--export_token] [--export_embed] [--export_visual] [--export_lm] - [--export_block EXPORT_BLOCK] [--export_blocks] [--embed_bf16] [--skip_slim] - -llm_exporter - -optional arguments: - -h, --help show this help message and exit - --path PATH path(`str` or `os.PathLike`): - Can be either: - - A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO] - - A path to a *directory* clone from repo like `../chatglm-6b`. - --type {chatglm-6b,chatglm2-6b,chatglm3-6b,codegeex2-6b,Qwen-7B-Chat,Qwen-1_8B-Chat,Qwen-VL-Chat,Baichuan2-7B-Chat,Llama-2-7b-chat-ms,internlm-chat-7b,TinyLlama-1_1B-Chat,Yi-6B-Chat,deepseek-llm-7b-chat,phi-2,bge-large-zh} - type(`str`, *optional*): - The pretrain llm model type. - --onnx_path ONNX_PATH - export onnx model path, defaut is `./onnx`. - --mnn_path MNN_PATH export mnn model path, defaut is `./mnn`. - --export_mnn Whether or not to export mnn model after onnx. - --export_verbose Whether or not to export onnx with verbose. - --export_test Whether or not to export onnx with test using onnxruntime. - --test TEST test model inference with query `TEST`. - --export export model to an `onnx` model. - --export_split export model split to some `onnx` models: - - embedding model. - - block models. - - lm_head model. - --export_token export llm tokenizer to a txt file. - --export_embed export llm embedding to an `onnx` model. - --export_visual export llm visual model to an `onnx` model. - --export_lm export llm lm_head to an `onnx` model. - --export_block EXPORT_BLOCK - export llm block [id] to an `onnx` model. - --export_blocks export llm all blocks to `onnx` models. - --embed_bf16 using `bfloat16` replace `float32` in embedding. - --skip_slim Whether or not to skip onnx-slim. -``` diff --git a/transformers/llm/export/llm_export.py b/transformers/llm/export/llm_export.py deleted file mode 100644 index 4b541b247..000000000 --- a/transformers/llm/export/llm_export.py +++ /dev/null @@ -1,1430 +0,0 @@ -import os -import base64 -import glob -import json -import shutil -import argparse -import torch -import numpy as np -from onnxslim import slim -import onnxruntime as ort -import sentencepiece as spm -from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer -from peft import LoraConfig, TaskType, get_peft_model, PeftModel -try: - import _tools as MNNTools -except: - MNNTools = None - -def onnx2mnn(onnx_path, mnn_dir, quant_bit = 4, asymmetric = True, external_data = False, bizCode : str= None): - model_name, model_extension = os.path.splitext(os.path.basename(onnx_path)) - if model_extension != '.onnx': - return - mnn_name = model_name + '.mnn' - mnn_path = os.path.join(mnn_dir, mnn_name) - convert_args = [ - '', - '-f', - 'ONNX', - '--modelFile', - str(onnx_path), - '--MNNModel', - str(mnn_path), - '--weightQuantBits', - str(quant_bit), - ] - if asymmetric: - convert_args.append("--weightQuantAsymmetric") - if external_data: - convert_args.append("--saveExternalData") - if bizCode is not None: - convert_args.append("--bizCode") - convert_args.append(str(bizCode)) - MNNTools.mnnconvert(convert_args) - -# some wrapper class for export -class Embedding(torch.nn.Module): - def __init__(self, embed, using_bf16: bool = False): - super().__init__() - self.bf16 = using_bf16 - self.embed_dim = embed.weight.shape[-1] - if using_bf16: - # using bf16 embedding weight - self.embed = embed.bfloat16() - else: - self.embed = embed - - def forward(self, input_ids): - res = self.embed(input_ids) - if self.bf16: - res = res.float() - return res.view(-1, 1, self.embed_dim) - -class Lm(torch.nn.Module): - def __init__(self, lm): - super().__init__() - self.lm = lm - - def forward(self, hidden_states): - m_logits = self.lm(hidden_states) - # token = torch.argmax(m_logits) - return m_logits - -class LLM(torch.nn.Module): - ''' - Base class for all llm model. Inherits from [`torch.nn.Module`]. - ''' - - def __init__(self, args): - super().__init__() - self.quant_bit = 4 - self.asymmetric = True - self.onnx_path = args.onnx_path - self.mnn_path = args.mnn_path - if not os.path.exists(self.onnx_path): - os.makedirs(self.onnx_path) - if not os.path.exists(self.mnn_path): - os.makedirs(self.mnn_path) - self.export_mnn = args.export_mnn - self.export_verbose = args.export_verbose - self.export_test = args.export_test - # default is False, just set True when using below command: - # `python llm_export ../path --export --embed_bin` to export single model without embedding - self.without_embed = False - self.embed_bin = True - self.embed_bf16 = args.embed_bf16 - self.skip_slim = args.skip_slim - tokenizer_model = os.path.join(args.path, 'tokenizer.model') - ice_text_model = os.path.join(args.path, 'ice_text.model') - try: - if os.path.exists(tokenizer_model): - self.sp_model = spm.SentencePieceProcessor(tokenizer_model) - elif os.path.exists(ice_text_model): - self.sp_model = spm.SentencePieceProcessor(ice_text_model) - else: - self.sp_model = None - except: - self.sp_model = None - merge_file = os.path.join(args.path, 'merges.txt') - if os.path.exists(merge_file): - self.merge_txt = merge_file - else: - self.merge_txt = None - self.stop_ids = [] - self.max_length = 1024 - self.hidden_size = 4096 - self.visual = None # defualt is not visual - self.lora_path = args.lora_path - self.load_hf(args.path) - self.load_model() - self.llm_config = { - 'hidden_size' : self.hidden_size, - 'layer_nums' : self.block_nums, - 'attention_mask': self.attention_mask_type, - 'key_value_shape': self.past_kv_shape[1:], - "prompt_template": self.build_prompt('%s'), - 'is_visual': False - } - - def load_hf(self, model_path: str): - self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - try: - self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).float().eval() - except: - self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).float().eval() - self.config = self.model.config - if self.lora_path is not None: - adapter = PeftModel.from_pretrained(self.model, model_id=self.lora_path) - self.model = adapter.merge_and_unload(progressbar=True) - - def load_model(self): - raise NotImplementedError - - def get_attention_mask(self) -> torch.Tensor: - raise NotImplementedError - - def get_position_ids(self) -> torch.Tensor: - raise NotImplementedError - - def export_vocab(self): - raise NotImplementedError - - def visual_embed(self, input_ids): - raise NotImplementedError - - def __embedding(self, input_ids): - if self.visual is not None and self.token_len == 0: - input_embeds = self.visual_embed(input_ids) - else: - input_embeds = self.embed(input_ids) - return input_embeds - - def __decode(self, hidden_states, attention_mask, position_ids, past_key_values): - presents = [] - for i in range(self.block_nums): - hidden_states, kv = self.blocks[i](hidden_states, attention_mask, position_ids, past_key_values[i]) - presents.append(kv) - logits = self.lm(hidden_states).reshape(-1) - presents = torch.stack(presents) - self.seq_len += 1 - self.token_len += 1 - return logits, presents - - def forward(self, input_ids, attention_mask, position_ids, past_key_values): - if self.without_embed: - return self.__decode(input_ids, attention_mask, position_ids, past_key_values) - return self.__decode(self.__embedding(input_ids), attention_mask, position_ids, past_key_values) - - # some test functions - def build_prompt(self, query): - if hasattr(self.tokenizer, 'build_prompt'): - prompt = self.tokenizer.build_prompt(query) - else: - prompt = query - return prompt - - def str_to_ids(self, prompt): - input_ids = self.tokenizer(prompt, return_tensors="pt")['input_ids'] - return input_ids - - def id_to_str(self, token_id): - word = self.tokenizer._convert_id_to_token(int(token_id)) - word = self.tokenizer.convert_tokens_to_string([word]) - return word - - def response(self, query): - prompt = self.build_prompt(query) - input_ids = self.str_to_ids(prompt) - self.seq_len = input_ids.numel() - self.context_len = self.seq_len - 2 - self.token_len = 0 - past_key_values = [None for i in range(self.block_nums)] - token_id = input_ids - while self.token_len < self.max_length: - attention_mask = self.get_attention_mask() - position_ids = self.get_position_ids() - logits, past_key_values = self.forward(token_id, attention_mask, position_ids, past_key_values) - token_id = torch.argmax(logits) - if token_id in self.stop_ids: - print("", end='\n') - break - word = self.id_to_str(token_id) - print(word, end="", flush=True) - - # some export functions - def assert_equal(self, torch_outs, onnx_outs): - if type(torch_outs) not in (list, tuple): - torch_outs = (torch_outs, ) - onnx_outs = (onnx_outs, ) - same = True - for orig, onnx in zip(torch_outs, onnx_outs): - orig = orig.detach().numpy() - if not np.allclose(orig, onnx, rtol=1e-3, atol=1e-3): - print('Error: onnx outputs dont match original. [shape = {}] onnx: {}, original: {}'.format(onnx.shape, onnx, orig)) - same = False - break - if same: - print('onnx test SUCCESS') - - def export_lm(self): - model = self.lm - hidden_states = torch.randn(1, self.hidden_size) - onnx_model = f'./{self.onnx_path}/lm.onnx' - torch.onnx.export(model, (hidden_states), - onnx_model, - verbose=self.export_verbose, - input_names=['hidden_states'], - output_names=['logits'], - do_constant_folding=True, - opset_version=15) - if not self.skip_slim: - slim(onnx_model, output_model=onnx_model) - # test lm - if self.export_test: - original_outs = model(hidden_states) - ort_session = ort.InferenceSession(onnx_model, providers=['CPUExecutionProvider']) - inputs = { - 'hidden_states' : hidden_states.numpy(), - } - onnx_outs = ort_session.run(None, inputs) - self.assert_equal(original_outs, onnx_outs) - if self.export_mnn: - onnx2mnn(onnx_model, self.mnn_path, self.quant_bit, self.asymmetric) - - def export_visual(self): - if self.visual is None: - return - input_images = torch.randn((1, 3, self.image_size, self.image_size)) - model = self.visual - onnx_model = f'./{self.onnx_path}/visual.onnx' - torch.onnx.export(model, (input_images), - onnx_model, - verbose=self.export_verbose, - input_names=['input_images'], - output_names=['image_embeds'], - dynamic_axes={"input_images": { - 0: "size" - }}, - do_constant_folding=True, - opset_version=15) - if not self.skip_slim: - slim(onnx_model, output_model=onnx_model) - # test - if self.export_test: - original_outs = model(input_images) - ort_session = ort.InferenceSession(onnx_model, providers=['CPUExecutionProvider']) - inputs = { - 'input_images' : input_images.numpy(), - } - onnx_outs = ort_session.run(None, inputs)[0] - self.assert_equal(original_outs, onnx_outs) - if self.export_mnn: - onnx2mnn(onnx_model, self.mnn_path) - - def export_embed(self): - model = self.embed - if self.embed_bin: - import ctypes - tensor_data = model.embed.weight.data - data_ptr = tensor_data.untyped_storage().data_ptr() - buffer = (ctypes.c_byte * (tensor_data.numel() * 2)).from_address(data_ptr) - with open(f'./{self.onnx_path}/embeddings_bf16.bin', 'wb') as f: - f.write(buffer) - return - input_ids = torch.arange(3, dtype=torch.long) - onnx_model = f'./{self.onnx_path}/embedding.onnx' - torch.onnx.export(model, (input_ids), - onnx_model, - verbose=self.export_verbose, - input_names=['input_ids'], - output_names=['inputs_embeds'], - dynamic_axes={"input_ids": { - 0: "length" - }}, - do_constant_folding=True, - opset_version=15) - if not self.skip_slim: - slim(onnx_model, output_model=onnx_model) - # test - if self.export_test: - original_outs = model(input_ids) - ort_session = ort.InferenceSession(onnx_model, providers=['CPUExecutionProvider']) - inputs = { - 'input_ids' : input_ids.numpy(), - } - onnx_outs = ort_session.run(None, inputs) - self.assert_equal(original_outs, onnx_outs) - if self.export_mnn: - onnx2mnn(onnx_model, self.mnn_path) - - def export_block(self, block_id: int): - self.seq_len = 3 - self.token_len = 0 - inputs_embeds = torch.randn((self.seq_len, 1, self.hidden_size)) - attention_mask = self.get_attention_mask() - position_ids = self.get_position_ids() - past_key_values = torch.zeros(self.past_kv_shape[1:]) - model = self.blocks[block_id] - onnx_model = f'./{self.onnx_path}/block_{block_id}.onnx' - torch.onnx.export( - model, (inputs_embeds, attention_mask, position_ids, past_key_values), - onnx_model, - verbose=self.export_verbose, - input_names=[ - 'inputs_embeds', 'attention_mask', 'position_ids', 'past_key_values' - ], - output_names=['hidden_states', 'presents'], - dynamic_axes=self.block_dynamic_axes, - do_constant_folding=True, - opset_version=15) - if not self.skip_slim: - slim(onnx_model, output_model=onnx_model) - if self.export_test: - original_outs = model(inputs_embeds, attention_mask, position_ids, past_key_values) - ort_session = ort.InferenceSession(onnx_model, providers=['CPUExecutionProvider']) - inputs = { - 'inputs_embeds' : inputs_embeds.detach().numpy(), - 'attention_mask' : attention_mask.numpy(), - 'position_ids' : position_ids.numpy(), - 'past_key_values' : past_key_values.numpy() - } - onnx_outs = ort_session.run(None, inputs) - self.assert_equal(original_outs, onnx_outs) - if self.export_mnn: - onnx2mnn(onnx_model, self.mnn_path, self.quant_bit, self.asymmetric) - - def export_blocks(self): - for i in range(self.block_nums): - self.export_block(i) - - def export_config(self, is_single = True): - self.llm_config['is_single'] = is_single - with open(f'./{self.onnx_path}/llm_config.json', 'w', encoding='utf-8') as f: - json.dump(self.llm_config, f, ensure_ascii=False, indent=4) - - def export(self): - model = self - self.seq_len = 3 - self.token_len = 0 - input_ids = torch.arange(3, dtype=torch.long) - attention_mask = self.get_attention_mask() - position_ids = self.get_position_ids() - past_key_values = torch.zeros(self.past_kv_shape) - onnx_model = f'./{self.onnx_path}/llm.onnx' - if self.embed_bin: - self.without_embed = True - input_ids = self.__embedding(input_ids) - print('export start ...') - torch.onnx.export( - model, (input_ids, attention_mask, position_ids, past_key_values), - onnx_model, - verbose=self.export_verbose, - input_names=[ - 'input_ids', 'attention_mask', 'position_ids', 'past_key_values' - ], - output_names=['logits', 'presents'], - dynamic_axes=self.model_dynamic_axes, - do_constant_folding=True, - opset_version=15) - print('export done!') - if not self.skip_slim: - slim(onnx_model, output_model=onnx_model) - for file_path in glob.glob(f'./{self.onnx_path}/onnx__*'): - try: - os.remove(file_path) - except FileNotFoundError: - pass - for file_path in glob.glob(f'./{self.onnx_path}/model.*'): - try: - os.remove(file_path) - except FileNotFoundError: - pass - if self.export_test: - # test - original_outs = model(input_ids, attention_mask, position_ids, past_key_values) - ort_session = ort.InferenceSession(onnx_model, providers=['CPUExecutionProvider']) - inputs = { - 'input_ids' : input_ids.detach().numpy(), - 'attention_mask' : attention_mask.numpy(), - 'position_ids' : position_ids.numpy(), - 'past_key_values' : past_key_values.numpy() - } - onnx_outs = ort_session.run(None, inputs) - self.assert_equal(original_outs, onnx_outs) - if self.export_mnn: - # single model is > 2G, using external_data - onnx2mnn(onnx_model, self.mnn_path, self.quant_bit, self.asymmetric, True) - if self.without_embed: - self.without_embed = False - - def export_tokenizer(self): - # TOKENIZER MAGIC NUMBER - MAGIC_NUMBER = 430 - # TOKENIZER TYPE - SENTENCEPIECE = 0; TIKTOIKEN = 1; BERT = 2; HUGGINGFACE = 3 - def write_line(fp, *args): - for arg in args: - for token in arg: - fp.write(str(token) + ' ') - fp.write('\n') - def write_header(fp, type, speicals, prefix = []): - fp.write(f'{MAGIC_NUMBER} {type}\n') - fp.write(f'{len(speicals)} {len(self.stop_ids)} {len(prefix)}\n') - write_line(fp, speicals, self.stop_ids, prefix) - - file_path = os.path.join(self.onnx_path, "tokenizer.txt") - special_list = list(self.tokenizer.added_tokens_decoder.keys()) - if hasattr(self.tokenizer, 'special_tokens'): - for k, v in self.tokenizer.special_tokens.items(): - special_list.append(v) - if hasattr(self.tokenizer, 'gmask_token_id'): - special_list.append(self.tokenizer.gmask_token_id) - vocab_list = [] - prefix_list = [] - if hasattr(self.tokenizer, 'get_prefix_tokens'): - prefix_list = self.tokenizer.get_prefix_tokens() - if self.sp_model is not None: - # senetencepiece - print('# senetencepiece tokenier') - NORMAL = 1; UNKNOWN = 2; CONTROL = 3 - USER_DEFINED = 4; UNUSED = 5; BYTE = 6 - for i in range(self.sp_model.GetPieceSize()): - token = self.sp_model.IdToPiece(i) - score = self.sp_model.GetScore(i) - type = NORMAL - if self.sp_model.IsUnknown(i): - type = UNKNOWN - elif self.sp_model.IsControl(i): - type = CONTROL - elif self.sp_model.IsUnused(i): - type = UNUSED - elif self.sp_model.IsByte(i): - type = BYTE - if self.model_name == 'Chatglm_6b': - if '' in token: token = '\n' - if '<|tab|>' in token: token = '\t' - if '<|blank_' in token: token = ' ' * int(token[8:token.find('|>')]) - if '▁' in token: token = token.replace('▁', ' ') - token_encode = base64.b64encode(token.encode("utf-8")).decode("utf8") - vocab_list.append(f'{token_encode} {score} {type}\n') - with open(file_path, "w", encoding="utf8") as fp: - write_header(fp, SENTENCEPIECE, special_list, prefix_list) - fp.write(f'{len(vocab_list)}\n') - for vocab in vocab_list: - fp.write(vocab) - elif hasattr(self.tokenizer, 'mergeable_ranks'): - print('# tiktoken tokenier') - # tikton - vocab_list = [] - for k, v in self.tokenizer.mergeable_ranks.items(): - line = base64.b64encode(k).decode("utf8") + "\n" - vocab_list.append(line) - if hasattr(self.tokenizer, 'special_tokens'): - for k, v in self.tokenizer.special_tokens.items(): - line = base64.b64encode(k.encode("utf-8")).decode("utf8") + "\n" - vocab_list.append(line) - if hasattr(self.tokenizer, 'added_tokens_decoder'): - for k, v in self.tokenizer.added_tokens_decoder.items(): - line = base64.b64encode(v.__str__().encode("utf-8")).decode("utf8") + "\n" - vocab_list.append(line) - with open(file_path, "w", encoding="utf8") as fp: - write_header(fp, TIKTOIKEN, special_list, prefix_list) - fp.write(f'{len(vocab_list)}\n') - for vocab in vocab_list: - fp.write(vocab) - elif self.merge_txt is not None: - # huggingface tokenizer - merge_list = [] - vocab = self.tokenizer.get_vocab() - special_list = list(self.tokenizer.added_tokens_decoder.keys()) - vocab_list = ['' for i in range(len(vocab))] - # load vocab - for k, v in vocab.items(): - vocab_list[int(v)] = k - # load merge - with open(self.merge_txt, 'rt') as merge: - for line in merge.readlines(): - merge_list.append(line) - # write to tokenizer.txt - with open(file_path, "w", encoding="utf8") as fp: - write_header(fp, HUGGINGFACE, special_list) - fp.write(f'{len(vocab_list)} {len(merge_list)}\n') - for v in vocab_list: - fp.write(v + '\n') - for m in merge_list: - fp.write(m) - else: - print('# other tiktoken tokenier') - # other tikton - def unicode_to_byte(u: int): - if u >= 256 and u <= 288: - return u - 256 - if u >= 289 and u <= 322: - return u - 162 - if u == 323: - return 173 - if u == 65372: # | - return 124 - if u == 9601: # _ - return 95 - return u - vocab = self.tokenizer.get_vocab() - vocab_list = ['' for i in range(len(vocab))] - for k, v in vocab.items(): - try: - vocab_list[int(v)] = bytes([unicode_to_byte(ord(c)) for c in k]).decode('utf-8', errors='ignore') - except: - vocab_list[int(v)] = k - special_list = list(self.tokenizer.added_tokens_decoder.keys()) - with open(file_path, "w", encoding="utf8") as fp: - write_header(fp, TIKTOIKEN, special_list) - fp.write(f'{len(vocab_list)}\n') - for v in vocab_list: - line = base64.b64encode(v.encode('utf-8')).decode("utf8") + "\n" - fp.write(line) - -# chatglm -class GLMBlock(torch.nn.Module): - def __init__(self, block, block_id, final_layernorm = None): - super().__init__() - self.block = block - self.block_id = block_id - self.hidden_size = 4096 - self.final_layernorm = final_layernorm - - def forward(self, hidden_states, attention_mask, position_ids, past_kv): - hidden_states, presents = self.block(hidden_states, - position_ids, - attention_mask, - self.block_id, - past_kv, - use_cache=True) - if self.final_layernorm is not None: - hidden_states = self.final_layernorm(hidden_states) - hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size) - if isinstance(presents, tuple): - presents = torch.stack(presents) - return hidden_states, presents - -class Chatglm_6b(LLM): - def __init__(self, args): - self.attention_mask_type = 'glm' - self.model_name = 'Chatglm_6b' - super().__init__(args) - - def load_model(self): - transformer = self.model.transformer - self.lm_ = self.model.lm_head - self.embed_ = transformer.word_embeddings - self.blocks_ = transformer.layers - self.final_layernorm_ = transformer.final_layernorm - # some wrapper - self.stop_ids.append(self.tokenizer._convert_token_to_id(self.tokenizer.eos_token)) - self.block_nums = len(self.blocks_) - self.lm = Lm(self.lm_) - # chatglm embedding and lm using same param, copy embedding when using bf16 - if self.embed_bf16: - import copy - embed_copy = copy.deepcopy(self.embed_) - self.embed = Embedding(embed_copy, self.embed_bf16) - else: - self.embed = Embedding(self.embed_, self.embed_bf16) - self.blocks = [GLMBlock(self.blocks_[i], i, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)] - # some config for export - self.past_kv_shape = [28, 2, 0, 1, 32, 128] - self.block_dynamic_axes = { - "inputs_embeds" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 2: "seq_len" }, - "past_key_values" : { 1: "history_len" } - } - self.model_dynamic_axes = { - "input_ids" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 2: "seq_len" }, - "past_key_values" : { 2: "history_len" } - } - - def get_attention_mask(self) -> torch.Tensor: - if self.token_len: - return torch.zeros([1]).bool().reshape([1, 1, 1, 1]) - attention_mask = torch.zeros([self.seq_len, self.seq_len], dtype=torch.bool) - for i in range(self.seq_len - 1): - attention_mask[i][-1] = True - attention_mask = attention_mask.reshape([1, 1, self.seq_len, self.seq_len]) - return attention_mask - - def get_position_ids(self) -> torch.Tensor: - if self.token_len: - return torch.tensor([self.context_len, self.token_len + 1]).reshape([1, 2, 1]) - position_ids_0 = torch.arange(self.seq_len, dtype=torch.long) - position_ids_1 = torch.zeros(self.seq_len, dtype=torch.long) - position_ids_0[-1] = position_ids_0[-2] - position_ids_1[-1] = 1 - position_ids = torch.stack([position_ids_0, position_ids_1]).view(1, 2, -1) - return position_ids - - def build_prompt(self, query): - return f'{query}[gMASK]' - -# chatglm2 -class GLM2Block(torch.nn.Module): - def __init__(self, block, block_id, config, final_layernorm = None): - super().__init__() - self.block = block - self.block_id = block_id - self.final_layernorm = final_layernorm - self.config = config - self.hidden_size = 4096 - - def forward(self, hidden_states, attention_mask, position_ids, past_kv): - rope_ratio = self.config.rope_ratio - base = 10000 * rope_ratio - theta = 1.0 / (base ** (torch.arange(0, 64, 2, dtype=torch.float32) / 64)) - position_ids = position_ids.float().reshape(-1, 1) - idx_theta = position_ids * theta - rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1).unsqueeze(0).contiguous() - hidden_states, presents = self.block(hidden_states, - attention_mask, - kv_cache=past_kv, - rotary_pos_emb=rotary_pos_emb) - if self.final_layernorm is not None: - hidden_states = self.final_layernorm(hidden_states) - hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size) - if isinstance(presents, tuple): - presents = torch.stack(presents) - return hidden_states, presents - -class Chatglm2_6b(LLM): - def __init__(self, args): - self.attention_mask_type = 'glm2' - super().__init__(args) - self.model_name = 'Chatglm2_6b' - if 'codegeex2-6b' in args.path: - self.model_name = 'Codegeex2_6b' - - def load_model(self): - transformer = self.model.transformer - self.lm_ = transformer.output_layer - self.embed_ = transformer.embedding.word_embeddings - self.blocks_ = transformer.encoder.layers - self.final_layernorm_ = transformer.encoder.final_layernorm - # some wrapper - if self.tokenizer.eos_token_id is None: - # codegeex2-6b - self.stop_ids.append(self.tokenizer.tokenizer.eos_id) - else: - self.stop_ids.append(self.tokenizer.eos_token_id) - if hasattr(self.config, 'eos_token_id'): - if type(self.config.eos_token_id) is list: - for eos_id in self.config.eos_token_id: - self.stop_ids.append(eos_id) - elif type(self.config.eos_token_id) is int: - self.stop_ids.append(self.config.eos_token_id) - self.block_nums = len(self.blocks_) - self.embed = Embedding(self.embed_, self.embed_bf16) - self.lm = Lm(self.lm_) - self.blocks = [GLM2Block(self.blocks_[i], i, self.config, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)] - # some config for export - self.past_kv_shape = [28, 2, 0, 1, 2, 128] - self.block_dynamic_axes = { - "inputs_embeds" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 0: "seq_len" }, - "past_key_values" : { 1: "history_len" } - } - self.model_dynamic_axes = { - "input_ids" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 0: "seq_len" }, - "past_key_values" : { 2: "history_len" } - } - num_layers = self.config.num_layers - if num_layers > 28: - self.past_kv_shape = [num_layers, 2, 1, 2, 0, 128] - self.block_dynamic_axes = { - "inputs_embeds" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 0: "seq_len" }, - "past_key_values" : { 3: "history_len" } - } - self.model_dynamic_axes = { - "input_ids" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 0: "seq_len" }, - "past_key_values" : { 4: "history_len" } - } - - def get_attention_mask(self) -> torch.Tensor: - if self.token_len: - return torch.zeros([1, 1, 1, 1]).bool() - attention_mask = ~torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]).bool()) - return attention_mask - - def get_position_ids(self) -> torch.Tensor: - if self.token_len: - return torch.tensor([self.token_len], dtype=torch.long) - return torch.arange(self.seq_len, dtype=torch.long) - -# chatglm3 -class Chatglm3_6b(Chatglm2_6b): - def __init__(self, args): - super().__init__(args) - self.model_name = 'Chatglm3_6b' - - def build_prompt(self, query): - return f'<|user|>\n{query}\n<|assistant|>\n' - -# qwen -class QWENBlock(torch.nn.Module): - def __init__(self, name, block, block_id, hidden_size, final_layernorm = None): - super().__init__() - self.name = name - self.block = block - self.block_id = block_id - self.final_layernorm = final_layernorm - self.hidden_size = hidden_size - - def forward(self, hidden_states, attention_mask, position_ids, past_kv): - theta = 1.0 / (10000.0 ** (torch.arange(0, 128, 2, dtype=torch.float32) / 128)) - position_ids = position_ids.float().reshape(-1, 1) - idx_theta = position_ids * theta - rotary_pos_emb = torch.cat((idx_theta, idx_theta), dim=-1) - rotary_pos_emb = rotary_pos_emb.unsqueeze(1).unsqueeze(0) - if self.name != 'Qwen-7B': - rotary_pos_emb = torch.stack([torch.cos(rotary_pos_emb), torch.sin(rotary_pos_emb)]) - hidden_states = hidden_states.view(1, -1, self.hidden_size) - hidden_states, presents = self.block(hidden_states=hidden_states, - layer_past=past_kv, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - use_cache=True) - if self.final_layernorm is not None: - hidden_states = self.final_layernorm(hidden_states) - hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size) - if isinstance(presents, tuple): - presents = torch.stack(presents) - return hidden_states, presents - -class QWEN18Block(torch.nn.Module): - def __init__(self, block, block_id, hidden_size, final_layernorm = None): - super().__init__() - self.block = block - self.block_id = block_id - self.final_layernorm = final_layernorm - self.hidden_size = hidden_size - - def forward(self, hidden_states, attention_mask, position_ids, past_kv): - theta = 1.0 / (10000.0 ** (torch.arange(0, 128, 2, dtype=torch.float32) / 128)) - position_ids = position_ids.float().reshape(-1, 1) - idx_theta = position_ids * theta - rotary_pos_emb = torch.cat((idx_theta, idx_theta), dim=-1).unsqueeze(1).unsqueeze(0) - rotary_pos_emb = torch.stack([torch.cos(rotary_pos_emb), torch.sin(rotary_pos_emb)]) - hidden_states = hidden_states.view(1, -1, self.hidden_size) - hidden_states, presents = self.block(hidden_states, - rotary_pos_emb, - past_kv, - attention_mask, - use_cache=True) - if self.final_layernorm is not None: - hidden_states = self.final_layernorm(hidden_states) - hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size) - if isinstance(presents, tuple): - presents = torch.stack(presents) - return hidden_states, presents - -class Qwen_Chat(LLM): - def __init__(self, args): - self.attention_mask_type = 'int' - super().__init__(args) - if 'VL' in self.model_name: - self.llm_config['is_visual'] = True - self.llm_config['attention_mask'] = 'float' - self.llm_config['img_size'] = 448 - self.llm_config['imgpad_len'] = 256 - self.llm_config['img_start'] = self.tokenizer.img_start_id - self.llm_config['img_end'] = self.tokenizer.img_end_id - self.llm_config['img_pad'] = self.tokenizer.img_pad_id - - - def load_model(self): - # Qwen models - self.model_name = 'Qwen-7B' - if '1_8' in model_path: - self.model_name = 'Qwen-1_8b' - if 'VL' in model_path: - self.model_name = 'Qwen-VL' - transformer = self.model.transformer - self.lm_ = self.model.lm_head - self.embed_ = transformer.wte - self.blocks_ = transformer.h - self.final_layernorm_ = transformer.ln_f - if hasattr(transformer, 'visual'): - self.visual = transformer.visual - self.image_start_id = transformer.config.visual['image_start_id'] - self.image_size = transformer.config.visual['image_size'] - # some wrapper - self.stop_ids.append(self.tokenizer.im_end_id) - self.block_nums = len(self.blocks_) - self.hidden_size = transformer.embed_dim - self.embed = Embedding(self.embed_, self.embed_bf16) - self.lm = Lm(self.lm_) - self.blocks = [QWENBlock(self.model_name, self.blocks_[i], i, self.hidden_size, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)] - if self.block_nums == 32: - # qwen-7b, qwen-vl - self.past_kv_shape = [32, 2, 1, 0, 32, 128] - elif self.block_nums == 24: - # qwen-1.8b - self.past_kv_shape = [24, 2, 1, 0, 16, 128] - # some config for export - self.block_dynamic_axes = { - "inputs_embeds" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 0: "seq_len" }, - "past_key_values" : { 2: "history_len" } - } - self.model_dynamic_axes = { - "input_ids" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 0: "seq_len" }, - "past_key_values" : { 3: "history_len" } - } - - def build_prompt(self, query): - return f'\n<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n' - - def get_attention_mask(self) -> torch.Tensor: - if self.model_name == 'Qwen-VL': - if self.token_len: - return torch.zeros([1, 1, 1, 1], dtype=torch.float32) - return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min - if self.token_len: - return torch.ones([1, 1, 1, 1]).bool() - return torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]).bool()) - - def get_position_ids(self) -> torch.Tensor: - if self.token_len: - return torch.tensor([self.seq_len - 1], dtype=torch.long) - return torch.arange(self.seq_len, dtype=torch.long) - - def visual_embed(self, input_ids): - if not torch.any(input_ids == self.image_start_id): - return self.embed(input_ids) - bos_pos = torch.where(input_ids == self.image_start_id) - eos_pos = torch.where(input_ids == self.image_start_id + 1) - img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1) - images = [] - for i, a, b in img_pos: - image = input_ids[i][a + 1 : b - 1].tolist() - image = image[ : image.index(self.image_start_id + 2)] - images.append(bytes(image).decode('utf-8')) - images = self.visual.encode(images) - hidden_states = self.embed(input_ids).view(1, -1, self.hidden_size) - for idx, (i, a, b) in enumerate(img_pos): - hidden_states[i][a + 1 : b] = images[idx] - return hidden_states.view(-1, 1, self.hidden_size) - -class QWEN2Block(torch.nn.Module): - def __init__(self, name, block, block_id, config, final_layernorm = None): - super().__init__() - self.name = name - self.block = block - self.block_id = block_id - self.final_layernorm = final_layernorm - self.hidden_size = config.hidden_size - self.head_dim = config.hidden_size // config.num_attention_heads - self.rope_theta = config.rope_theta - - def forward(self, hidden_states, attention_mask, position_ids, past_kv): - theta = 1.0 / (self.rope_theta ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32) / self.head_dim)) - position_ids = position_ids.float().reshape(-1, 1) - idx_theta = position_ids * theta - rotary_pos_emb = torch.cat((idx_theta, idx_theta), dim=-1) - rotary_pos_emb = rotary_pos_emb.unsqueeze(1).unsqueeze(0) - rotary_pos_emb = torch.stack([torch.cos(rotary_pos_emb), torch.sin(rotary_pos_emb)]) - hidden_states = hidden_states.view(1, -1, self.hidden_size) - hidden_states, presents = self.block(hidden_states=hidden_states, - attention_mask=attention_mask, - past_key_value=past_kv, - rotary_pos_emb=rotary_pos_emb, - use_cache=True) - - if self.final_layernorm is not None: - hidden_states = self.final_layernorm(hidden_states) - hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size) - if isinstance(presents, tuple): - presents = torch.stack(presents) - # print('###', presents.shape) - return hidden_states, presents - -class Qwen2_Chat(LLM): - def __init__(self, args): - self.attention_mask_type = 'float' - super().__init__(args) - - def load_model(self): - # Qwen2 models - self.model_name = 'Qwen2' - transformer = self.model.model - self.lm_ = self.model.lm_head - self.embed_ = transformer.embed_tokens - self.blocks_ = transformer.layers - self.final_layernorm_ = transformer.norm - # some wrapper - self.stop_ids.append(self.tokenizer.eos_token_id) - if hasattr(self.model, 'generation_config'): - for id in self.model.generation_config.eos_token_id: - self.stop_ids.append(id) - self.block_nums = self.config.num_hidden_layers - self.hidden_size = self.config.hidden_size - self.num_heads = self.config.num_attention_heads - self.kv_heads = self.config.num_key_value_heads - self.rope_theta = self.config.rope_theta - self.head_dim = self.hidden_size // self.num_heads - if self.embed_.weight is self.lm_.weight: - import copy - embed_copy = copy.deepcopy(self.embed_) - self.embed = Embedding(embed_copy, self.embed_bf16) - else: - self.embed = Embedding(self.embed_, self.embed_bf16) - self.lm = Lm(self.lm_) - self.past_kv_shape = [self.block_nums, 2, 1, 0, self.kv_heads, self.head_dim] - self.blocks = [QWEN2Block(self.model_name, self.blocks_[i], i, self.config, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)] - # some config for export - self.block_dynamic_axes = { - "inputs_embeds" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 0: "seq_len" }, - "past_key_values" : { 1: "history_len" } - } - self.model_dynamic_axes = { - "input_ids" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 0: "seq_len" }, - "past_key_values" : { 2: "history_len" } - } - - def build_prompt(self, query): - return f'<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n' - - def get_attention_mask(self) -> torch.Tensor: - if self.token_len: - return torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32) - return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min - - - def get_position_ids(self) -> torch.Tensor: - if self.token_len: - return torch.tensor([[self.seq_len - 1]], dtype=torch.long) - return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0) - - def visual_embed(self, input_ids): - if not torch.any(input_ids == self.image_start_id): - return self.embed(input_ids) - bos_pos = torch.where(input_ids == self.image_start_id) - eos_pos = torch.where(input_ids == self.image_start_id + 1) - img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1) - images = [] - for i, a, b in img_pos: - image = input_ids[i][a + 1 : b - 1].tolist() - image = image[ : image.index(self.image_start_id + 2)] - images.append(bytes(image).decode('utf-8')) - images = self.visual.encode(images) - hidden_states = self.embed(input_ids).view(1, -1, self.hidden_size) - for idx, (i, a, b) in enumerate(img_pos): - hidden_states[i][a + 1 : b] = images[idx] - return hidden_states.view(-1, 1, self.hidden_size) - -# llama2 -class LLAMA2Block(torch.nn.Module): - def __init__(self, block, block_id, hidden_size, head_dim, final_layernorm = None): - super().__init__() - self.block = block - self.block_id = block_id - self.head_dim = head_dim - self.final_layernorm = final_layernorm - self.hidden_size = hidden_size - - def forward(self, hidden_states, attention_mask, position_ids, past_kv): - theta = 1.0 / (10000.0 ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32) / self.head_dim)) - position_ids = position_ids.float().reshape(-1, 1) - idx_theta = position_ids * theta - rotary_pos_emb = torch.cat((idx_theta, idx_theta), dim=-1) - rotary_pos_emb = rotary_pos_emb.unsqueeze(1).unsqueeze(0) - rotary_pos_emb = torch.stack([torch.cos(rotary_pos_emb), torch.sin(rotary_pos_emb)]) - hidden_states = hidden_states.view(1, -1, self.hidden_size) - position_ids = position_ids.view(1, -1) - hidden_states, presents = self.block(hidden_states, - attention_mask, - position_ids, - past_kv, - rotary_pos_emb=rotary_pos_emb, - use_cache=True) - if self.final_layernorm is not None: - hidden_states = self.final_layernorm(hidden_states) - hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size) - if isinstance(presents, tuple): - presents = torch.stack(presents) - return hidden_states, presents - -class Llama2_7b_Chat(LLM): - def __init__(self, args): - self.attention_mask_type = 'float' - self.model_name = 'Llama2_7b' - if 'Baichuan2' in args.path: - self.model_name = 'Baichuan2_7B' - if 'internlm' in args.path: - self.model_name = 'Internlm_7b' - if 'TinyLlama' in args.path: - self.model_name = 'TinyLlama' - if 'Yi' in args.path: - self.model_name = 'Yi' - if 'deepseek' in args.path: - self.model_name = 'deepseek' - if 'Llama-3' in args.path: - self.model_name = 'Llama3_8B' - super().__init__(args) - - def load_model(self): - self.config = self.model.config - transformer = self.model.model - self.lm_ = self.model.lm_head - self.embed_ = transformer.embed_tokens - self.blocks_ = transformer.layers - self.final_layernorm_ = transformer.norm - # some wrapper - self.hidden_size = self.embed_.weight.shape[-1] - self.stop_ids.append(self.tokenizer.eos_token_id) - if hasattr(self.model, 'generation_config'): - self.stop_ids.append(self.model.generation_config.eos_token_id) - if self.model_name == 'Llama3_8B': - self.stop_ids.append(self.tokenizer.convert_tokens_to_ids("<|eot_id|>")) - self.block_nums = len(self.blocks_) - self.embed = Embedding(self.embed_, self.embed_bf16) - self.lm = Lm(self.lm_) - self.block_nums = self.config.num_hidden_layers - self.hidden_size = self.config.hidden_size - self.num_attention_heads = self.config.num_attention_heads - self.head_dim = self.hidden_size // self.num_attention_heads - if hasattr(self.config, 'num_key_value_heads'): - self.num_key_value_heads = self.config.num_key_value_heads - else: - self.num_key_value_heads = self.config.num_attention_heads - self.blocks = [LLAMA2Block(self.blocks_[i], i, self.hidden_size, self.head_dim, self.final_layernorm_ if i == len(self.blocks_) - 1 else None) for i in range(self.block_nums)] - self.past_kv_shape = [self.block_nums, 2, 1, 0, self.num_key_value_heads, self.head_dim] - self.block_dynamic_axes = { - "inputs_embeds" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 1: "seq_len" }, - "past_key_values" : { 2: "history_len" } - } - self.model_dynamic_axes = { - "input_ids" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 1: "seq_len" }, - "past_key_values" : { 3: "history_len" } - } - - def build_prompt(self, query): - if 'Baichuan2' in self.model_name: - return f'{query}' - if 'Internlm_7b' in self.model_name: - return f'<|User|>:{query}\n<|Bot|>:' - if 'TinyLlama' in self.model_name: - return f'<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate\n<|user|>\n{query}\n<|assistant|>\n' - if 'Yi' in self.model_name: - return f'<|im_start|> user\n{query}<|im_end|>\n<|im_start|> assistant\n' - if 'deepseek' in self.model_name: - return f'<|begin_of_sentence|>User: {query}\n\nAssistant:' - if 'Llama3' in self.model_name: - return f'<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{query}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' - return f'[INST]{query}[/INST]' - - def get_attention_mask(self) -> torch.Tensor: - if self.token_len: - return torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32) - return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min - - def get_position_ids(self) -> torch.Tensor: - if self.token_len: - return torch.tensor([[self.seq_len - 1]], dtype=torch.long) - return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0) - -# phi-2 -class PHI2Block(torch.nn.Module): - def __init__(self, block, block_id, hidden_size): - super().__init__() - self.block = block - self.block_id = block_id - self.hidden_size = hidden_size - - def forward(self, hidden_states, attention_mask, position_ids, past_kv): - theta = 1.0 / (10000 ** (torch.arange(0, 32, 2, dtype=torch.float32) / 32)) - position_ids = position_ids.float().reshape(-1, 1) - idx_theta = position_ids * theta - rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=0).contiguous() - hidden_states = hidden_states.view(1, -1, self.hidden_size) - hidden_states, presents = self.block(hidden_states, - past_kv, - rotary_pos_emb=rotary_pos_emb, - causal_mask=attention_mask - ) - if self.block_id == 31: - hidden_states = hidden_states[:, -1, :] - return hidden_states, presents - -class phi_2(LLM): - def __init__(self, args): - self.attention_mask_type = 'glm' - super().__init__(args) - self.model_name = 'phi-2' - self.asymmetric = False # TODO: some precision bug when using asymmetric - - def load_model(self): - transformer = self.model.transformer - self.lm_ = self.model.lm_head - self.embed_ = transformer.embd.wte - self.hidden_size = self.embed_.weight.shape[-1] - self.blocks_ = transformer.h - # self.final_layernorm_ = transformer.final_layernorm - # some wrapper - self.stop_ids.append(self.tokenizer.eos_token_id) - self.block_nums = len(self.blocks_) - self.embed = Embedding(self.embed_, self.embed_bf16) - self.lm = Lm(self.lm_) - self.blocks = [PHI2Block(self.blocks_[i], i, self.hidden_size) for i in range(self.block_nums)] - # some config for export - self.past_kv_shape = [len(self.blocks), 1, 0, 2, 32, 80] - self.block_dynamic_axes = { - "inputs_embeds" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 0: "seq_len" }, - "past_key_values" : { 1: "history_len" } - } - self.model_dynamic_axes = { - "input_ids" : { 0: "seq_len" }, - "attention_mask" : { 2: "seq_len", 3: "seq_len" }, - "position_ids" : { 0: "seq_len" }, - "past_key_values" : { 2: "history_len" } - } - - def build_prompt(self, query): - return f'Instruct: {query}\nOutput:' - - def get_attention_mask(self) -> torch.Tensor: - if self.token_len: - return torch.zeros([1, 1, 1, 1]).bool() - attention_mask = ~torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]).bool()) - return attention_mask - - def get_position_ids(self) -> torch.Tensor: - if self.token_len: - return torch.tensor([[self.seq_len - 1]], dtype=torch.long) - return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0) - -# BGE is Embedding Model based Bert -class BGEBlock(torch.nn.Module): - def __init__(self, block, block_id, hidden_size): - super().__init__() - self.block = block - self.block_id = block_id - self.hidden_size = hidden_size - - def forward(self, hidden_states, attention_mask): - hidden_states = self.block(hidden_states, attention_mask)[0] - return hidden_states - -class bge(LLM): - def __init__(self, args): - self.attention_mask_type = 'int' - self.past_kv_shape = [] - super().__init__(args) - self.model_name = 'bge-large-zh' - - def forward(self, input_ids, position_ids, attention_mask): - input_ids = input_ids.view(1, -1) - token_type_ids = (1 - attention_mask).view(1, -1) - hidden_states = self.embed(input_ids, token_type_ids, position_ids)[0].unsqueeze(0) - for i in range(self.block_nums): - hidden_states = self.blocks[i](hidden_states, attention_mask) - # hidden_states = self.lm(hidden_states) # sentence_embeddings not need - sentence_embeddings = hidden_states[:, 0] - sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) - return sentence_embeddings - - def response(self, query): - self.eval() - input_ids = self.tokenizer(query)['input_ids'] - self.seq_len = len(input_ids) - input_ids = torch.tensor(input_ids) - position_ids = self.get_position_ids() - attention_mask = self.get_attention_mask() - res = self.forward(input_ids, position_ids, attention_mask) - return res - - def load_model(self): - self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).float().eval() - transformer = self.model.encoder - self.lm_ = self.model.pooler - self.embed_ = self.model.embeddings - self.hidden_size = self.embed_.word_embeddings.weight.shape[-1] - self.blocks_ = transformer.layer - # some wrapper - self.stop_ids = [] - self.block_nums = len(self.blocks_) - self.embed = self.embed_ - self.lm = self.lm_ - self.blocks = [BGEBlock(self.blocks_[i], i, self.hidden_size) for i in range(self.block_nums)] - # some config for export - self.model_dynamic_axes = { - "input_ids" : { 0: "seq_len" }, - "position_ids" : { 1: "seq_len" }, - "attention_mask" : { 3: "seq_len" } - } - - def export(self): - model = self.eval() - self.seq_len = 3 - input_ids = torch.arange(3, dtype=torch.long) - position_ids = self.get_position_ids() - attention_mask = self.get_attention_mask() - onnx_model = f'./{self.onnx_path}/bge.onnx' - torch.onnx.export( - model, (input_ids, position_ids, attention_mask), - onnx_model, - verbose=self.export_verbose, - input_names=[ - 'input_ids', - 'position_ids', - 'attention_mask' - ], - output_names=['sentence_embeddings'], - dynamic_axes=self.model_dynamic_axes, - do_constant_folding=True, - opset_version=15) - if not self.skip_slim: - slim(onnx_model, output_model=onnx_model) - if self.export_test: - self.seq_len = 4 - position_ids = self.get_position_ids() - input_ids = torch.tensor([ 101, 872, 1962, 102 ], dtype=torch.long) - attention_mask = self.get_attention_mask() - # test - original_outs = model(input_ids, position_ids, attention_mask) - ort_session = ort.InferenceSession(onnx_model, providers=['CPUExecutionProvider']) - inputs = { - 'input_ids' : input_ids.detach().numpy(), - 'position_ids' : position_ids.detach().numpy(), - 'attention_mask' : attention_mask.detach().numpy() - } - onnx_outs = ort_session.run(None, inputs)[0] - self.assert_equal(original_outs, onnx_outs) - - token_str = None - if False: # save tokenizer in mnn - self.export_tokenizer() - token_path = os.path.join(self.onnx_path, "tokenizer.txt") - token_str = open(token_path, 'rt').read() - - if self.export_mnn: - onnx2mnn(onnx_model, self.mnn_path, 8, True, bizCode=token_str) - - def build_prompt(self, query): - return f'[CLS]{query}[SEP]' - - def get_position_ids(self) -> torch.Tensor: - return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0) - - def get_attention_mask(self) -> torch.Tensor: - return torch.ones([1, 1, 1, self.seq_len], dtype=torch.long) - -class LoraModule(torch.nn.Module): - def __init__(self, args): - super().__init__() - self.onnx_path = args.onnx_path - self.mnn_path = args.mnn_path - self.export_mnn = args.export_mnn - import peft - lora_weight = peft.load_peft_weights(args.path) - for k, v in lora_weight.items(): - k = k.replace('.', '/') - self.register_buffer(k, v.cpu()) - - def forward(self, dummpy): - return self._buffers - - def export(self): - onnx_model = f'./{self.onnx_path}/lora.onnx' - torch.onnx.export(self.eval(), torch.tensor([]), onnx_model) - if self.export_mnn: - onnx2mnn(onnx_model, self.mnn_path) - - -if __name__ == '__main__': - llm_models = { - 'chatglm-6b': Chatglm_6b, - 'chatglm2-6b': Chatglm2_6b, - 'codegeex2-6b': Chatglm2_6b, - 'chatglm3-6b': Chatglm3_6b, - 'glm-4-9b-chat': Chatglm3_6b, - 'Qwen-7B-Chat': Qwen_Chat, - 'Qwen-1_8B-Chat': Qwen_Chat, - 'Qwen-1_8B': Qwen_Chat, - 'Qwen-VL-Chat': Qwen_Chat, - 'Qwen1_5-0_5B-Chat': Qwen2_Chat, - 'Qwen1_5-1_8B-Chat': Qwen2_Chat, - 'Qwen1_5-4B-Chat': Qwen2_Chat, - 'Qwen1_5-7B-Chat': Qwen2_Chat, - 'Qwen2-0_5B-Instruct': Qwen2_Chat, - 'Qwen2-1_5B-Instruct': Qwen2_Chat, - 'Qwen2-7B-Instruct': Qwen2_Chat, - 'Baichuan2-7B-Chat': Llama2_7b_Chat, - 'Llama-2-7b-chat-ms': Llama2_7b_Chat, - 'Llama-3-8B-Instruct': Llama2_7b_Chat, - 'internlm-chat-7b': Llama2_7b_Chat, - 'TinyLlama-1_1B-Chat': Llama2_7b_Chat, - 'Yi-6B-Chat': Llama2_7b_Chat, - 'deepseek-llm-7b-chat': Llama2_7b_Chat, - 'MiniCPM-1.2b': Llama2_7b_Chat, - 'MiniCPM-2.4b': Llama2_7b_Chat, - 'phi-2': phi_2, - 'bge-large-zh': bge, - 'lora': LoraModule - } - parser = argparse.ArgumentParser(description='llm_exporter', formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument('--path', type=str, default='THUDM/chatglm-6b', required=True, - help='path(`str` or `os.PathLike`):\nCan be either:' - '\n\t- A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO]' - '\n\t- A path to a *directory* clone from repo like `../chatglm-6b`.') - parser.add_argument('--type', type=str, choices=llm_models.keys(), default=None, - help='type(`str`, *optional*):' - '\n\tThe pretrain llm model type.' - ) - parser.add_argument('--lora_path', type=str, default=None, help='lora path, defaut is `None` mean not apply lora.') - parser.add_argument('--onnx_path', type=str, default='./onnx', help='export onnx model path, defaut is `./onnx`.') - parser.add_argument('--mnn_path', type=str, default='./mnn', help='export mnn model path, defaut is `./mnn`.') - parser.add_argument('--export_mnn', action='store_true', default=False, help='Whether or not to export mnn model after onnx.') - parser.add_argument('--export_verbose', action='store_true', default=False, help='Whether or not to export onnx with verbose.') - parser.add_argument('--export_test', action='store_true', help='Whether or not to export onnx with test using onnxruntime.') - parser.add_argument('--test', type=str, help='test model inference with query `TEST`.') - parser.add_argument('--export', action='store_true', help='export model to an `onnx` model.') - parser.add_argument('--export_split', action='store_true', - help='export model split to some `onnx` models:' - '\n\t- embedding model.' - '\n\t- block models.' - '\n\t- lm_head model.' - ) - parser.add_argument('--export_visual', action='store_true', help='export llm visual model to an `onnx` model.') - parser.add_argument('--export_lm', action='store_true', help='export llm lm_head to an `onnx` model.') - parser.add_argument('--export_block', type=int, help='export llm block [id] to an `onnx` model.') - parser.add_argument('--export_blocks', action='store_true', help='export llm all blocks to `onnx` models.') - parser.add_argument('--skip_slim', action='store_true', help='Whether or not to skip onnx-slim.') - - # No use now, add invoid of call error - parser.add_argument('--export_token', action='store_true', help='export llm tokenizer to a txt file.') - parser.add_argument('--export_embed', action='store_true', help='export llm embedding to an `onnx` model.') - parser.add_argument('--embed_bf16', default=True, action='store_true', help='using `bfloat16` replace `float32` in embedding.') - parser.add_argument('--embed_bin', action='store_true', help='export embedding weight as bin file with dtype `bfloat16`') - - args = parser.parse_args() - model_path = args.path - model_type = args.type - # not sepcify model type, using path - if model_type is None: - for model in llm_models: - if model in model_path: - model_type = model - if model_type is None: - raise RuntimeError('Please specify model type.') - - # copy modeling py file to pretrain model for export - for file in glob.glob(f'./llm_models/{model_type}/*'): - shutil.copy2(file, model_path) - - llm_exporter = llm_models[model_type](args) - - # some actions - if args.test is not None: - llm_exporter.response(args.test) - - if args.export or args.export_split: - llm_exporter.export_config(args.export) - - if args.export: - llm_exporter.export() - - llm_exporter.export_tokenizer() - - llm_exporter.export_embed() - - if args.export_visual or args.export_split: - llm_exporter.export_visual() - - if args.export_lm or args.export_split: - llm_exporter.export_lm() - - if args.export_blocks or args.export_split: - llm_exporter.export_blocks() - - if args.export_block is not None: - llm_exporter.export_block(args.export_block) \ No newline at end of file diff --git a/transformers/llm/export/llm_models/Baichuan2-7B-Chat/modeling_baichuan.py b/transformers/llm/export/llm_models/Baichuan2-7B-Chat/modeling_baichuan.py deleted file mode 100755 index 5a0b69e83..000000000 --- a/transformers/llm/export/llm_models/Baichuan2-7B-Chat/modeling_baichuan.py +++ /dev/null @@ -1,825 +0,0 @@ -# Copyright 2023 Baichuan Inc. All Rights Reserved. - -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from .configuration_baichuan import BaichuanConfig -from .generation_utils import build_chat_input, TextIterStreamer - -import math -from typing import List, Optional, Tuple, Union -from threading import Thread - -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from torch.nn import functional as F -from transformers import PreTrainedModel, PretrainedConfig -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.generation.utils import GenerationConfig -from transformers.utils import logging, ContextManagers - -import os -from contextlib import contextmanager -logger = logging.get_logger(__name__) - -try: - from xformers import ops as xops -except ImportError: - xops = None - logger.warning( - "Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers\npip install xformers." - ) - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - if len(mask.size()) == 3: - bsz, src_len, _ = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = mask[:,None,:,:].expand(bsz, 1, tgt_len, src_len).to(dtype) - else: - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - - -class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32) - freqs = torch.outer(t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32) - self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32) - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32) - freqs = torch.outer(t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - self.cos_cached = emb.cos()[None, None, :, :].to(torch.float32).to(x.device) - self.sin_cached = emb.sin()[None, None, :, :].to(torch.float32).to(x.device) - elif self.cos_cached.device != x.device: - self.cos_cached = self.cos_cached.to(x.device) - self.sin_cached = self.sin_cached.to(x.device) - return ( - self.cos_cached[:, :, :seq_len, ...], - self.sin_cached[:, :, :seq_len, ...], - ) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids): - cos = torch.squeeze(cos_) # [seq_len, dim] - sin = torch.squeeze(sin_) # [seq_len, dim] - # print(f'### cos.shape = {cos.shape}, position_ids.shape = {position_ids.shape}, cos[position_ids].shape = {cos[position_ids].shape}') - # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - # print(f'### q.shape = {q.shape}, cos.shape = {cos.shape}') - # cos = cos[position_ids] - # sin = sin[position_ids] - q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin) - k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin) - return q_embed.to(q.dtype), k_embed.to(k.dtype) - - -class MLP(nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - ): - super().__init__() - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.act_fn = ACT2FN[hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: BaichuanConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = RotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def raw_atten(self, query_layer, key_layer, value_layer, attention_mask): - attn_weight = torch.softmax((query_layer @ key_layer.transpose(-2, -1) / math.sqrt(query_layer.size(-1))) + attention_mask, dim=-1) - return attn_weight @ value_layer - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - proj = self.W_pack(hidden_states) - proj = proj.reshape([1, -1, 3, 4096]).permute([2, 0, 1, 3]) - ''' - # proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) - query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - if rotary_pos_emb is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - else: - cos, sin = rotary_pos_emb - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - if xops is not None and self.training: - attn_weights = None - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - attn_output = xops.memory_efficient_attention( - query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask() - ) - else: - attn_output = self.raw_atten(query_states, key_states, value_states, attention_mask) - attn_output = attn_output.transpose(1, 2) - ''' - #--------------- - query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim) - key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim) - value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim) - kv_seq_len = key_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[1] - # rope - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - # kv cache - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=1) - value_states = torch.cat((past_value, value_states), dim=1) - past_key_value = torch.stack((key_states, value_states)) - query_states = query_states.transpose(1, 2) - key_states = key_states.permute([0, 2, 3, 1]) - value_states = value_states.transpose(1, 2) - attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim) - attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - #--------------- - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class DecoderLayer(nn.Module): - def __init__(self, config: BaichuanConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = Attention(config=config) - self.mlp = MLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - ) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - rotary_pos_emb=rotary_pos_emb, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class BaichuanPreTrainedModel(PreTrainedModel): - config_class = BaichuanConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["DecoderLayer"] - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, BaichuanModel): - module.gradient_checkpointing = value - - -class BaichuanModel(BaichuanPreTrainedModel): - def __init__(self, config: BaichuanConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class NormHead(nn.Module): - def __init__(self, hidden_size, vocab_size, bias=False): - super().__init__() - self.weight = nn.Parameter(torch.empty((vocab_size, hidden_size))) - nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - self.first_flag = True - - def forward(self, hidden_states): - if self.training: - norm_weight = nn.functional.normalize(self.weight) - elif self.first_flag: - self.first_flag = False - self.weight = nn.Parameter(nn.functional.normalize(self.weight)) - norm_weight = self.weight - else: - norm_weight = self.weight - return nn.functional.linear(hidden_states, norm_weight) - -_init_weights = True -@contextmanager -def no_init_weights(_enable=True): - global _init_weights - old_init_weights = _init_weights - if _enable: - _init_weights = False - try: - yield - finally: - _init_weights = old_init_weights - -class BaichuanForCausalLM(BaichuanPreTrainedModel): - def __init__(self, config, *model_args, **model_kwargs): - super().__init__(config, *model_args, **model_kwargs) - self.model = BaichuanModel(config) - - self.lm_head = NormHead(config.hidden_size, config.vocab_size, bias=False) - if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']: - try: - from .quantizer import quantize_offline, init_model_weight_int4 - except ImportError: - raise ImportError(f"Needs QLinear to run quantize.") - quantize_offline(self, 4) - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], - *model_args, - config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, - cache_dir: Optional[Union[str, os.PathLike]] = None, - ignore_mismatched_sizes: bool = False, - force_download: bool = False, - local_files_only: bool = False, - token: Optional[Union[str, bool]] = None, - revision: str = "main", - use_safetensors: bool = None, - **kwargs, - ): - # Load config if we don't provide a configuration - if not isinstance(config, PretrainedConfig): - config_path = config if config is not None else pretrained_model_name_or_path - config, model_kwargs = cls.config_class.from_pretrained( - config_path, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - resume_download=False, - proxies=None, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder="", - _from_auto=False, - _from_pipeline=None, - **kwargs, - ) - else: - model_kwargs = kwargs - - if hasattr(config, "quantization_config") and config.quantization_config['load_in_4bit']: - try: - from .quantizer import init_model_weight_int4 - from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map - from accelerate.utils import CustomDtype - from accelerate.utils import get_balanced_memory - except ImportError: - raise ImportError(f"Needs import model weight init func to run quantize.") - # Instantiate model. - init_contexts = [no_init_weights(_enable=True)] - init_contexts.append(init_empty_weights()) - with ContextManagers(init_contexts): - model = cls(config) - - model_file = os.path.join(pretrained_model_name_or_path, 'pytorch_model.bin') - state_dict = torch.load(model_file, map_location="cpu") - model.is_quantized = True - - device_map = kwargs.pop("device_map", None) - torch_dtype = kwargs.pop("torch_dtype", None) - - kwargs = {"no_split_module_classes": model._no_split_modules} - target_dtype = CustomDtype.INT4 - max_memory = get_balanced_memory( - model, - dtype=target_dtype, - low_zero=(device_map == "balanced_low_0"), - max_memory=None, - **kwargs, - ) - kwargs["max_memory"] = max_memory - - device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs) - model = init_model_weight_int4(config, model, state_dict) - - # Set model in evaluation mode to deactivate DropOut modules by default - model.eval() - # If it is a model with generation capabilities, attempt to load the generation config - if model.can_generate(): - try: - model.generation_config = GenerationConfig.from_pretrained( - pretrained_model_name_or_path, - cache_dir=cache_dir, - force_download=force_download, - resume_download=False, - proxies=None, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder="", - _from_auto=False, - _from_pipeline=None, - **kwargs, - ) - except (OSError, TypeError): - logger.info( - "Generation config file not found, using a generation config created from the model config." - ) - pass - - if device_map is not None: - dispatch_model(model, device_map=device_map) - - return model - return super(BaichuanForCausalLM, cls).from_pretrained(pretrained_model_name_or_path, *model_args, - config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, - force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, - use_safetensors=use_safetensors, **kwargs) - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - softmax_normalizer = shift_logits.max(-1).values ** 2 - z_loss = self.config.z_loss_weight * softmax_normalizer.mean() - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + z_loss - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) - return reordered_past - - def quantize(self, bits: int): - try: - from .quantizer import quantize_online - except ImportError: - raise ImportError(f"Needs QLinear to run quantize.") - return quantize_online(self, bits) - - def chat(self, tokenizer, messages: List[dict], stream=False, - generation_config: Optional[GenerationConfig]=None): - generation_config = generation_config or self.generation_config - input_ids = build_chat_input(self, tokenizer, messages, generation_config.max_new_tokens) - if stream: - streamer = TextIterStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) - Thread(target=self.generate, kwargs=dict( - inputs=input_ids, streamer=streamer, - generation_config=generation_config, - )).start() - return streamer - else: - outputs = self.generate(input_ids, generation_config=generation_config) - print(outputs[0]) - response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True) - return response diff --git a/transformers/llm/export/llm_models/Llama-2-7b-chat-ms/config.json b/transformers/llm/export/llm_models/Llama-2-7b-chat-ms/config.json deleted file mode 100755 index e2ba91313..000000000 --- a/transformers/llm/export/llm_models/Llama-2-7b-chat-ms/config.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "auto_map": { - "AutoModelForCausalLM": "modeling_llama.LlamaForCausalLM" - }, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 4096, - "initializer_range": 0.02, - "intermediate_size": 11008, - "max_position_embeddings": 4096, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 32, - "num_key_value_heads": 32, - "pad_token_id": 0, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "tie_word_embeddings": false, - "torch_dtype": "float16", - "transformers_version": "4.31.0.dev0", - "use_cache": true, - "vocab_size": 32000 -} diff --git a/transformers/llm/export/llm_models/Llama-2-7b-chat-ms/configuration_llama.py b/transformers/llm/export/llm_models/Llama-2-7b-chat-ms/configuration_llama.py deleted file mode 100644 index 1b0e9c357..000000000 --- a/transformers/llm/export/llm_models/Llama-2-7b-chat-ms/configuration_llama.py +++ /dev/null @@ -1,174 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" LLaMA model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - -class LlamaConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`LlamaModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - pretraining_tp (`int`, *optional*, defaults to `1`): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-12): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings(`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. See the following thread for more information on how - these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an - experimental feature, subject to breaking API changes in future versions. - - Example: - - ```python - >>> from transformers import LlamaModel, LlamaConfig - - >>> # Initializing a LLaMA llama-7b style configuration - >>> configuration = LlamaConfig() - - >>> # Initializing a model from the llama-7b style configuration - >>> model = LlamaModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - model_type = "llama" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_scaling=None, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_scaling = rope_scaling - self._rope_scaling_validation() - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - def _rope_scaling_validation(self): - """ - Validate the `rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " - f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/transformers/llm/export/llm_models/Llama-2-7b-chat-ms/modeling_llama.py b/transformers/llm/export/llm_models/Llama-2-7b-chat-ms/modeling_llama.py deleted file mode 100644 index 493b040b7..000000000 --- a/transformers/llm/export/llm_models/Llama-2-7b-chat-ms/modeling_llama.py +++ /dev/null @@ -1,1040 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch LLaMA model.""" -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_llama import LlamaConfig - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "LlamaConfig" - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = torch.squeeze(cos) # [seq_len, dim] - sin = torch.squeeze(sin) # [seq_len, dim] - # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LlamaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.pretraining_tp = config.pretraining_tp - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - if self.pretraining_tp > 1: - slice = self.intermediate_size // self.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.pretraining_tp = config.pretraining_tp - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - if self.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp - query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - ''' - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - if rotary_pos_emb is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - else: - cos, sin = rotary_pos_emb - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - ''' - #--------------- - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - kv_seq_len = key_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[1] - # rope - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - # kv cache - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=1) - value_states = torch.cat((past_value, value_states), dim=1) - past_key_value = torch.stack((key_states, value_states)) - query_states = query_states.transpose(1, 2) - key_states = key_states.permute([0, 2, 3, 1]) - value_states = value_states.transpose(1, 2) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - #--------------- - attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - rotary_pos_emb=rotary_pos_emb, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): - module.gradient_checkpointing = value - - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - self.pretraining_tp = config.pretraining_tp - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The LLaMa Model transformer with a sequence classification head on top (linear layer). - - [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - LLAMA_START_DOCSTRING, -) -class LlamaForSequenceClassification(LlamaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = LlamaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/Llama-3-8B-Instruct/config.json b/transformers/llm/export/llm_models/Llama-3-8B-Instruct/config.json deleted file mode 100755 index d9c36dfca..000000000 --- a/transformers/llm/export/llm_models/Llama-3-8B-Instruct/config.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "_name_or_path": "meta-llama/Meta-Llama-3-8B-Instruct", - "architectures": [ - "LlamaForCausalLM" - ], - "auto_map": { - "AutoModelForCausalLM": "modeling_llama.LlamaForCausalLM" - }, - "attention_bias": false, - "attention_dropout": 0.0, - "bos_token_id": 128000, - "eos_token_id": 128001, - "hidden_act": "silu", - "hidden_size": 4096, - "initializer_range": 0.02, - "intermediate_size": 14336, - "max_position_embeddings": 8192, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 32, - "num_key_value_heads": 8, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "rope_theta": 500000.0, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.38.2", - "use_cache": true, - "vocab_size": 128256 -} diff --git a/transformers/llm/export/llm_models/Llama-3-8B-Instruct/configuration_llama.py b/transformers/llm/export/llm_models/Llama-3-8B-Instruct/configuration_llama.py deleted file mode 100644 index 1b0e9c357..000000000 --- a/transformers/llm/export/llm_models/Llama-3-8B-Instruct/configuration_llama.py +++ /dev/null @@ -1,174 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" LLaMA model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - -class LlamaConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`LlamaModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - pretraining_tp (`int`, *optional*, defaults to `1`): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-12): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings(`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. See the following thread for more information on how - these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an - experimental feature, subject to breaking API changes in future versions. - - Example: - - ```python - >>> from transformers import LlamaModel, LlamaConfig - - >>> # Initializing a LLaMA llama-7b style configuration - >>> configuration = LlamaConfig() - - >>> # Initializing a model from the llama-7b style configuration - >>> model = LlamaModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - model_type = "llama" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_scaling=None, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_scaling = rope_scaling - self._rope_scaling_validation() - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - def _rope_scaling_validation(self): - """ - Validate the `rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " - f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/transformers/llm/export/llm_models/Llama-3-8B-Instruct/modeling_llama.py b/transformers/llm/export/llm_models/Llama-3-8B-Instruct/modeling_llama.py deleted file mode 100644 index 493b040b7..000000000 --- a/transformers/llm/export/llm_models/Llama-3-8B-Instruct/modeling_llama.py +++ /dev/null @@ -1,1040 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch LLaMA model.""" -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_llama import LlamaConfig - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "LlamaConfig" - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = torch.squeeze(cos) # [seq_len, dim] - sin = torch.squeeze(sin) # [seq_len, dim] - # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LlamaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.pretraining_tp = config.pretraining_tp - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - if self.pretraining_tp > 1: - slice = self.intermediate_size // self.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.pretraining_tp = config.pretraining_tp - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - if self.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp - query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - ''' - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - if rotary_pos_emb is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - else: - cos, sin = rotary_pos_emb - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - ''' - #--------------- - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - kv_seq_len = key_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[1] - # rope - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - # kv cache - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=1) - value_states = torch.cat((past_value, value_states), dim=1) - past_key_value = torch.stack((key_states, value_states)) - query_states = query_states.transpose(1, 2) - key_states = key_states.permute([0, 2, 3, 1]) - value_states = value_states.transpose(1, 2) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - #--------------- - attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - rotary_pos_emb=rotary_pos_emb, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): - module.gradient_checkpointing = value - - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - self.pretraining_tp = config.pretraining_tp - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The LLaMa Model transformer with a sequence classification head on top (linear layer). - - [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - LLAMA_START_DOCSTRING, -) -class LlamaForSequenceClassification(LlamaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = LlamaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/MiniCPM-1.2b/config.json b/transformers/llm/export/llm_models/MiniCPM-1.2b/config.json deleted file mode 100644 index 0bfa72faa..000000000 --- a/transformers/llm/export/llm_models/MiniCPM-1.2b/config.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "auto_map": { - "AutoModelForCausalLM": "modeling_llama.LlamaForCausalLM" - }, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 1536, - "initializer_range": 0.1, - "intermediate_size": 3840, - "max_position_embeddings": 4096, - "model_type": "llama", - "num_attention_heads": 24, - "num_hidden_layers": 52, - "num_key_value_heads": 8, - "pad_token_id": 0, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.31.0.dev0", - "use_cache": true, - "vocab_size": 73440 -} \ No newline at end of file diff --git a/transformers/llm/export/llm_models/MiniCPM-1.2b/configuration_llama.py b/transformers/llm/export/llm_models/MiniCPM-1.2b/configuration_llama.py deleted file mode 100644 index 1b0e9c357..000000000 --- a/transformers/llm/export/llm_models/MiniCPM-1.2b/configuration_llama.py +++ /dev/null @@ -1,174 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" LLaMA model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - -class LlamaConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`LlamaModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - pretraining_tp (`int`, *optional*, defaults to `1`): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-12): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings(`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. See the following thread for more information on how - these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an - experimental feature, subject to breaking API changes in future versions. - - Example: - - ```python - >>> from transformers import LlamaModel, LlamaConfig - - >>> # Initializing a LLaMA llama-7b style configuration - >>> configuration = LlamaConfig() - - >>> # Initializing a model from the llama-7b style configuration - >>> model = LlamaModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - model_type = "llama" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_scaling=None, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_scaling = rope_scaling - self._rope_scaling_validation() - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - def _rope_scaling_validation(self): - """ - Validate the `rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " - f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/transformers/llm/export/llm_models/MiniCPM-1.2b/convert_minicpm_to_llama.py b/transformers/llm/export/llm_models/MiniCPM-1.2b/convert_minicpm_to_llama.py deleted file mode 100644 index 7e6a56d5f..000000000 --- a/transformers/llm/export/llm_models/MiniCPM-1.2b/convert_minicpm_to_llama.py +++ /dev/null @@ -1,38 +0,0 @@ -from transformers import AutoModelForCausalLM, AutoTokenizer -import torch -import math -#torch.manual_seed(0) - -path = "path-to-MiniCPM-1B-sft-bf16" -tokenizer = AutoTokenizer.from_pretrained(path) -model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, trust_remote_code=True) - -responds, history = model.chat(tokenizer, "山东省最高的山是哪座山, 它比黄山高还是矮?差距多少?", temperature=0.3, top_p=0.5) -print(responds) - - -state_dict = model.state_dict() -print(state_dict.keys()) - -scale_emb = 12 -dim_model_base = 256 -scale_depth = 1.4 -num_layers = 52 -hidden_size = 1536 - -new_emb = state_dict["model.embed_tokens.weight"] * scale_emb -state_dict["model.embed_tokens.weight"] = new_emb - -new_emb = state_dict["lm_head.weight"] / (hidden_size / dim_model_base) -state_dict["lm_head.weight"] = new_emb - -for i in range(num_layers): - attn_out_name = f"model.layers.{i}.self_attn.o_proj.weight" - new_weight = state_dict[attn_out_name] * (scale_depth / math.sqrt(num_layers)) - state_dict[attn_out_name] = new_weight - - ffn_down_proj_name = f"model.layers.{i}.mlp.down_proj.weight" - new_weight = state_dict[ffn_down_proj_name] * (scale_depth / math.sqrt(num_layers)) - state_dict[ffn_down_proj_name] = new_weight - -torch.save(state_dict, "pytorch_model_llama.bin") diff --git a/transformers/llm/export/llm_models/MiniCPM-1.2b/modeling_llama.py b/transformers/llm/export/llm_models/MiniCPM-1.2b/modeling_llama.py deleted file mode 100644 index 8c562c604..000000000 --- a/transformers/llm/export/llm_models/MiniCPM-1.2b/modeling_llama.py +++ /dev/null @@ -1,1010 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch LLaMA model.""" -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_llama import LlamaConfig - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "LlamaConfig" - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = torch.squeeze(cos) # [seq_len, dim] - sin = torch.squeeze(sin) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LlamaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.pretraining_tp = config.pretraining_tp - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - if self.pretraining_tp > 1: - slice = self.intermediate_size // self.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.pretraining_tp = config.pretraining_tp - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - if self.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp - query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): - module.gradient_checkpointing = value - - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - self.pretraining_tp = config.pretraining_tp - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The LLaMa Model transformer with a sequence classification head on top (linear layer). - - [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - LLAMA_START_DOCSTRING, -) -class LlamaForSequenceClassification(LlamaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = LlamaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/MiniCPM-2.4b/config.json b/transformers/llm/export/llm_models/MiniCPM-2.4b/config.json deleted file mode 100644 index 541a5f8c4..000000000 --- a/transformers/llm/export/llm_models/MiniCPM-2.4b/config.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "auto_map": { - "AutoModelForCausalLM": "modeling_llama.LlamaForCausalLM" - }, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 2304, - "initializer_range": 0.1, - "intermediate_size": 5760, - "max_position_embeddings": 4096, - "model_type": "llama", - "num_attention_heads": 36, - "num_hidden_layers": 40, - "num_key_value_heads": 36, - "pad_token_id": 0, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.31.0.dev0", - "use_cache": true, - "vocab_size": 122756 -} \ No newline at end of file diff --git a/transformers/llm/export/llm_models/MiniCPM-2.4b/configuration_llama.py b/transformers/llm/export/llm_models/MiniCPM-2.4b/configuration_llama.py deleted file mode 100644 index 1b0e9c357..000000000 --- a/transformers/llm/export/llm_models/MiniCPM-2.4b/configuration_llama.py +++ /dev/null @@ -1,174 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" LLaMA model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - -class LlamaConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`LlamaModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - pretraining_tp (`int`, *optional*, defaults to `1`): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-12): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings(`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. See the following thread for more information on how - these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an - experimental feature, subject to breaking API changes in future versions. - - Example: - - ```python - >>> from transformers import LlamaModel, LlamaConfig - - >>> # Initializing a LLaMA llama-7b style configuration - >>> configuration = LlamaConfig() - - >>> # Initializing a model from the llama-7b style configuration - >>> model = LlamaModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - model_type = "llama" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_scaling=None, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_scaling = rope_scaling - self._rope_scaling_validation() - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - def _rope_scaling_validation(self): - """ - Validate the `rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " - f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/transformers/llm/export/llm_models/MiniCPM-2.4b/modeling_llama.py b/transformers/llm/export/llm_models/MiniCPM-2.4b/modeling_llama.py deleted file mode 100644 index 8c562c604..000000000 --- a/transformers/llm/export/llm_models/MiniCPM-2.4b/modeling_llama.py +++ /dev/null @@ -1,1010 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch LLaMA model.""" -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_llama import LlamaConfig - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "LlamaConfig" - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = torch.squeeze(cos) # [seq_len, dim] - sin = torch.squeeze(sin) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LlamaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.pretraining_tp = config.pretraining_tp - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - if self.pretraining_tp > 1: - slice = self.intermediate_size // self.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.pretraining_tp = config.pretraining_tp - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - if self.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp - query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): - module.gradient_checkpointing = value - - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - self.pretraining_tp = config.pretraining_tp - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The LLaMa Model transformer with a sequence classification head on top (linear layer). - - [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - LLAMA_START_DOCSTRING, -) -class LlamaForSequenceClassification(LlamaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = LlamaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/Qwen-1_8B-Chat/modeling_qwen.py b/transformers/llm/export/llm_models/Qwen-1_8B-Chat/modeling_qwen.py deleted file mode 100755 index 5138eea76..000000000 --- a/transformers/llm/export/llm_models/Qwen-1_8B-Chat/modeling_qwen.py +++ /dev/null @@ -1,1406 +0,0 @@ -# Copyright (c) Alibaba Cloud. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import copy -import importlib -import math -import pathlib -from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -import warnings - -from torch.nn import CrossEntropyLoss -from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList -from transformers.generation.logits_process import LogitsProcessorList - -if TYPE_CHECKING: - from transformers.generation.streamers import BaseStreamer -from transformers.generation.utils import GenerateOutput -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging - -try: - from einops import rearrange -except ImportError: - rearrange = None -from torch import nn - -SUPPORT_CUDA = torch.cuda.is_available() -SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() -SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 -SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2 - - -from .configuration_qwen import QWenConfig -from .qwen_generation_utils import ( - HistoryType, - make_context, - decode_tokens, - get_stop_words_ids, - StopWordsLogitsProcessor, -) - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "qwen" -_CONFIG_FOR_DOC = "QWenConfig" - -QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"] - -_ERROR_BAD_CHAT_FORMAT = """\ -We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml". -If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat(). -我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。 -如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。 -""" - -_SENTINEL = object() -_ERROR_STREAM_IN_CHAT = """\ -Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True). -向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。 -""" - -_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED = """\ -We detect you have activated flash attention support, but running model computation on CPU. Please make sure that your input data has been placed on GPU. If you actually want to run CPU computation, please following the readme and set device_map="cpu" to disable flash attention when loading the model (calling AutoModelForCausalLM.from_pretrained). -检测到您的模型已激活了flash attention支持,但正在执行CPU运算任务。如使用flash attention,请您确认模型输入已经传到GPU上。如果您确认要执行CPU运算,请您在载入模型(调用AutoModelForCausalLM.from_pretrained)时,按照readme说法,指定device_map="cpu"以禁用flash attention。 -""" - -apply_rotary_emb_func = None -rms_norm = None -flash_attn_unpadded_func = None -flash_attn_func = None - -def _import_flash_attn(): - global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func, flash_attn_func - try: - from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func - apply_rotary_emb_func = __apply_rotary_emb_func - except ImportError: - logger.warn( - "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency " - "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary" - ) - - try: - from flash_attn.ops.rms_norm import rms_norm as __rms_norm - rms_norm = __rms_norm - except ImportError: - logger.warn( - "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency " - "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm" - ) - - try: - import flash_attn - _flash_attn_func = None - if not hasattr(flash_attn, '__version__'): - from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func - else: - if int(flash_attn.__version__.split(".")[0]) >= 2: - if int(flash_attn.__version__.split(".")[1]) >= 1: - from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func - from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func - else: - from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func - flash_attn_unpadded_func = __flash_attn_unpadded_func - flash_attn_func = _flash_attn_func - except ImportError: - logger.warn( - "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency " - "https://github.com/Dao-AILab/flash-attention" - ) - -def quantize_cache_v(fdata, bits, qmax, qmin): - # b, s, head, h-dim->b, head, s, h-dim - qtype = torch.uint8 - device = fdata.device - shape = fdata.shape - - fdata_cal = torch.flatten(fdata, 2) - fmax = torch.amax(fdata_cal, dim=-1, keepdim=True) - fmin = torch.amin(fdata_cal, dim=-1, keepdim=True) - # Compute params - if qmax.device != fmax.device: - qmax = qmax.to(device) - qmin = qmin.to(device) - scale = (fmax - fmin) / (qmax - qmin) - zero = qmin - fmin / scale - scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() - zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() - # Quantize - res_data = fdata / scale + zero - qdata = torch.clamp(res_data, qmin, qmax).to(qtype) - return qdata.contiguous(), scale, zero - -def dequantize_cache_torch(qdata, scale, zero): - data = scale * (qdata - zero) - return data - -class FlashSelfAttention(torch.nn.Module): - def __init__( - self, - causal=False, - softmax_scale=None, - attention_dropout=0.0, - ): - super().__init__() - assert flash_attn_unpadded_func is not None, ( - "Please install FlashAttention first, " "e.g., with pip install flash-attn" - ) - assert ( - rearrange is not None - ), "Please install einops first, e.g., with pip install einops" - self.causal = causal - self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout - - def unpad_input(self, hidden_states, attention_mask): - valid_mask = attention_mask.squeeze(1).squeeze(1).eq(0) - seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - hidden_states = hidden_states[indices] - return hidden_states, indices, cu_seqlens, max_seqlen_in_batch - - def pad_input(self, hidden_states, indices, batch, seqlen): - output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device, - dtype=hidden_states.dtype) - output[indices] = hidden_states - return rearrange(output, '(b s) ... -> b s ...', b=batch) - - def forward(self, q, k, v, attention_mask=None): - assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v))) - assert all((i.is_cuda for i in (q, k, v))) - batch_size, seqlen_q = q.shape[0], q.shape[1] - seqlen_k = k.shape[1] - seqlen_out = seqlen_q - - if flash_attn_func is not None and batch_size == 1: - dropout_p = self.dropout_p if self.training else 0 - output = flash_attn_func(q, k, v, dropout_p, softmax_scale=self.softmax_scale, causal=self.causal) - return output - - q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * seqlen_q, - step=seqlen_q, - dtype=torch.int32, - device=q.device, - ) - - if batch_size > 1 and attention_mask is not None: - k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask) - if q.size(0) == v.size(0): - q = q[indices_k] - cu_seqlens_q = cu_seqlens_k - seqlen_q = seqlen_k - v = v[indices_k] - else: - cu_seqlens_k = torch.arange( - 0, - (batch_size + 1) * seqlen_k, - step=seqlen_k, - dtype=torch.int32, - device=q.device, - ) - - if self.training: - assert seqlen_k == seqlen_q - is_causal = self.causal - dropout_p = self.dropout_p - else: - is_causal = seqlen_q == seqlen_k - dropout_p = 0 - - output = flash_attn_unpadded_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - seqlen_q, - seqlen_k, - dropout_p, - softmax_scale=self.softmax_scale, - causal=is_causal, - ) - if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k: - output = self.pad_input(output, indices_k, batch_size, seqlen_out) - else: - new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:] - output = output.view(new_shape) - return output - - -class QWenAttention(nn.Module): - def __init__(self, config): - super().__init__() - - self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) - self.seq_length = config.seq_length - - self.hidden_size = config.hidden_size - self.split_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - - self.use_flash_attn = config.use_flash_attn - self.scale_attn_weights = True - - self.projection_size = config.kv_channels * config.num_attention_heads - - assert self.projection_size % config.num_attention_heads == 0 - self.hidden_size_per_attention_head = ( - self.projection_size // config.num_attention_heads - ) - - self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size) - - self.c_proj = nn.Linear( - config.hidden_size, self.projection_size, bias=not config.no_bias - ) - - self.is_fp32 = not (config.bf16 or config.fp16) - if ( - self.use_flash_attn - and flash_attn_unpadded_func is not None - and not self.is_fp32 - ): - self.core_attention_flash = FlashSelfAttention( - causal=True, attention_dropout=config.attn_dropout_prob - ) - self.bf16 = config.bf16 - - self.use_dynamic_ntk = config.use_dynamic_ntk - self.use_logn_attn = config.use_logn_attn - - logn_list = [ - math.log(i, self.seq_length) if i > self.seq_length else 1 - for i in range(1, 32768) - ] - logn_tensor = torch.tensor(logn_list)[None, :, None, None] - self.register_buffer("logn_tensor", logn_tensor, persistent=False) - - self.attn_dropout = nn.Dropout(config.attn_dropout_prob) - self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False - self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False - self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False - cache_dtype = torch.float - if self.bf16: - cache_dtype=torch.bfloat16 - elif config.fp16: - cache_dtype = torch.float16 - self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) - self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype) - - if config.use_cache_quantization and config.use_cache_kernel: - # pre check if the support files existing - module_root = pathlib.Path(__file__).parent - src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu") - if any(not (module_root/src).is_file() for src in src_files): - warnings.warn("KV cache kernel source files (.cpp and .cu) not found.") - self.cache_kernels = None - else: - try: - from .cpp_kernels import cache_autogptq_cuda_256 - self.cache_kernels = cache_autogptq_cuda_256 - except ImportError: - warnings.warn("Failed to import KV cache kernels.") - self.cache_kernels = None - - def _attn(self, query, key, value, no_use_mask, attention_mask=None, head_mask=None): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - if self.scale_attn_weights: - attn_weights = attn_weights / torch.full( - [], - value.size(-1) ** 0.5, - dtype=attn_weights.dtype, - device=attn_weights.device, - ) - - query_length, key_length = query.size(-2), key.size(-2) - if attention_mask is None: - causal_mask = self.bias[ - :, :, key_length - query_length : key_length, :key_length - ] - else: - causal_mask = attention_mask - mask_value = torch.finfo(attn_weights.dtype).min - mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to( - attn_weights.device - ) - attn_weights = torch.where( - causal_mask, attn_weights.to(attn_weights.dtype), mask_value - ) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - attn_output = attn_output.transpose(1, 2) - - return attn_output, attn_weights - - def __attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None): - device = query.device - if self.use_cache_quantization: - qk, qk_scale, qk_zero = key - if self.use_cache_kernel and self.cache_kernels is not None: - shape = query.shape[:-1] + (qk.shape[-2],) - attn_weights = torch.zeros(shape, dtype=torch.float16, device=device) - self.cache_kernels.vecquant8matmul_batched_faster_old( - query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(), - qk.transpose(-1, -2).contiguous(), - attn_weights, - qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(), - qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous()) - # attn_weights = attn_weights.to(query.dtype).contiguous() - else: - key = dequantize_cache_torch(qk, qk_scale, qk_zero) - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - else: - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - if self.scale_attn_weights: - if self.use_cache_quantization: - size_temp = value[0].size(-1) - else: - size_temp = value.size(-1) - attn_weights = attn_weights / (size_temp ** 0.5) - - mask_value = torch.finfo(attn_weights.dtype).min - if causal_mask is not None: - attn_weights = torch.where( - causal_mask, attn_weights.to(attn_weights.dtype), mask_value - ) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - if self.softmax_in_fp32: - attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1) - else: - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - attn_weights = attn_weights.type(query.dtype) - attn_weights = self.attn_dropout(attn_weights) - - if head_mask is not None: - attn_weights = attn_weights * head_mask - - if self.use_cache_quantization: - qv, qv_scale, qv_zero = value - if self.use_cache_kernel and self.cache_kernels is not None: - shape = attn_weights.shape[:-1] + (query.shape[-1],) - attn_output = torch.zeros(shape, dtype=torch.float16, device=device) - self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old( - attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(), - qv.contiguous(), # dtype: int32 - attn_output, - qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(), - qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous()) - if attn_output.dtype != query.dtype: - attn_output = attn_output.to(query.dtype) - attn_weights = attn_weights.to(query.dtype) - else: - value = dequantize_cache_torch(qv, qv_scale, qv_zero) - attn_output = torch.matmul(attn_weights, value) - else: - attn_output = torch.matmul(attn_weights, value) - - attn_output = attn_output.transpose(1, 2) - - return attn_output, attn_weights - - def _split_heads(self, tensor, num_heads, attn_head_size): - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - return tensor - - def _merge_heads(self, tensor, num_heads, attn_head_size): - tensor = tensor.contiguous() - new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) - return tensor.view(new_shape) - - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - rotary_pos_emb_list: Optional[List[torch.Tensor]] = None, - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ): - mixed_x_layer = self.c_attn(hidden_states) - - query, key, value = mixed_x_layer.split(self.split_size, dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) - - if rotary_pos_emb_list is not None: - cur_len = query.shape[1] - if True: - rotary_pos_emb = rotary_pos_emb_list - rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] - rotary_pos_emb = (rotary_pos_emb,) * 2 - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - query = apply_rotary_pos_emb(query, q_pos_emb) - key = apply_rotary_pos_emb(key, k_pos_emb) - else: - query_list = [] - key_list = [] - for i, rotary_pos_emb in enumerate(rotary_pos_emb_list): - rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] - rotary_pos_emb = (rotary_pos_emb,) * 2 - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)] - key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)] - query = torch.cat(query_list, dim=0) - key = torch.cat(key_list, dim=0) - - if self.use_cache_quantization: - key = quantize_cache_v(key.permute(0, 2, 1, 3), - bits=8, - qmin=self.cache_qmin, - qmax=self.cache_qmax) - value = quantize_cache_v(value.permute(0, 2, 1, 3), - bits=8, - qmin=self.cache_qmin, - qmax=self.cache_qmax) - - - if layer_past is not None: - past_key, past_value = layer_past[0], layer_past[1] - if self.use_cache_quantization: - # use_cache_quantization: - # present=((q_key,key_scale,key_zero_point), - # (q_value,value_scale,value_zero_point)) - key = (torch.cat((past_key[0], key[0]), dim=2), - torch.cat((past_key[1], key[1]), dim=2), - torch.cat((past_key[2], key[2]), dim=2)) - value = (torch.cat((past_value[0], value[0]), dim=2), - torch.cat((past_value[1], value[1]), dim=2), - torch.cat((past_value[2], value[2]), dim=2)) - else: - # not use_cache_quantization: - # present=(key,value) - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=1) - - if use_cache: - present = (key, value) - else: - present = None - - key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) - if key_size > self.seq_length and self.use_logn_attn and not self.training: - if self.use_cache_quantization: - seq_start = key[0].size(2) - query.size(1) - seq_end = key[0].size(2) - else: - seq_start = key.size(1) - query.size(1) - seq_end = key.size(1) - logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) - query = query * logn_tensor.expand_as(query) - - if ( - self.use_flash_attn - and flash_attn_unpadded_func is not None - and not self.is_fp32 - and query.is_cuda - ): - q, k, v = query, key, value - attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask) - else: - key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) - if query.size(1) == key_size: - causal_mask = torch.tril( - torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) - ).view(1, 1, key_size, key_size) - else: - causal_mask = None - query = query.permute(0, 2, 1, 3) - if not self.use_cache_quantization: - key = key.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) - if ( - causal_mask is None - and self.use_flash_attn - and flash_attn_unpadded_func is not None - and not self.is_fp32 - and not query.is_cuda - ): - raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED) - - if not self.use_cache_quantization and SUPPORT_TORCH2 and False: - if attention_mask is not None: - # attention_mask = attention_mask.expand( - # -1, -1, causal_mask.size(2), -1 - # ) - # if causal_mask is not None: - # attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min) - causal_mask = attention_mask - else: - attention_mask = causal_mask - attn_output = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask - ).transpose(1, 2) - attn_weight = None - else: - attn_output, attn_weight = self._attn( - # query, key, value, causal_mask, attention_mask, head_mask - query, key, value, attention_mask, attention_mask, head_mask - ) - context_layer = self._merge_heads( - attn_output, self.num_heads, self.head_dim - ) - - attn_output = self.c_proj(context_layer) - - outputs = (attn_output, present) - if output_attentions: - if ( - self.use_flash_attn - and flash_attn_unpadded_func is not None - and not self.is_fp32 - ): - raise ValueError("Cannot output attentions while using flash-attn") - elif not self.use_cache_quantization and SUPPORT_TORCH2: - raise ValueError("Cannot output attentions while using scaled_dot_product_attention") - else: - outputs += (attn_weight,) - - return outputs - - -class QWenMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.w1 = nn.Linear( - config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias - ) - self.w2 = nn.Linear( - config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias - ) - ff_dim_in = config.intermediate_size // 2 - self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias) - - def forward(self, hidden_states): - a1 = self.w1(hidden_states) - a2 = self.w2(hidden_states) - intermediate_parallel = a1 * F.silu(a2) - output = self.c_proj(intermediate_parallel) - return output - - -class QWenBlock(nn.Module): - def __init__(self, config): - super().__init__() - hidden_size = config.hidden_size - self.bf16 = config.bf16 - - self.ln_1 = RMSNorm( - hidden_size, - eps=config.layer_norm_epsilon, - ) - self.attn = QWenAttention(config) - self.ln_2 = RMSNorm( - hidden_size, - eps=config.layer_norm_epsilon, - ) - - self.mlp = QWenMLP(config) - - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - rotary_pos_emb: Optional[List[List[torch.Tensor]]] = None, - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ): - layernorm_output = self.ln_1(hidden_states) - - attn_outputs = self.attn( - layernorm_output, - rotary_pos_emb, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - attn_output = attn_outputs[0] - - outputs = attn_outputs[1:] - - residual = hidden_states - layernorm_input = attn_output + residual - - layernorm_output = self.ln_2(layernorm_input) - - residual = layernorm_input - mlp_output = self.mlp(layernorm_output) - hidden_states = residual + mlp_output - - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] - - return outputs - - -class QWenPreTrainedModel(PreTrainedModel): - config_class = QWenConfig - base_model_prefix = "transformer" - is_parallelizable = False - supports_gradient_checkpointing = True - _no_split_modules = ["QWenBlock"] - _skip_keys_device_placement = "past_key_values" - - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, RMSNorm): - module.weight.data.fill_(1.0) - - for name, p in module.named_parameters(): - if name == "c_proj.weight": - p.data.normal_( - mean=0.0, - std=( - self.config.initializer_range - / math.sqrt(2 * self.config.num_hidden_layers) - ), - ) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, QWenModel): - module.gradient_checkpointing = value - - -class QWenModel(QWenPreTrainedModel): - _keys_to_ignore_on_load_missing = ["attn.masked_bias"] - - def __init__(self, config): - super().__init__(config) - self.vocab_size = config.vocab_size - self.num_hidden_layers = config.num_hidden_layers - self.embed_dim = config.hidden_size - self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False - - self.gradient_checkpointing = False - self.use_dynamic_ntk = config.use_dynamic_ntk - self.seq_length = config.seq_length - - self.wte = nn.Embedding(self.vocab_size, self.embed_dim) - - self.drop = nn.Dropout(config.emb_dropout_prob) - - if config.rotary_pct == 1.0: - self.rotary_ndims = None - else: - assert config.rotary_pct < 1 - self.rotary_ndims = int( - config.kv_channels * config.rotary_pct - ) - dim = ( - self.rotary_ndims - if self.rotary_ndims is not None - else config.kv_channels - ) - self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) - - self.use_flash_attn = config.use_flash_attn - self.is_fp32 = not (config.bf16 or config.fp16) - - self.h = nn.ModuleList( - [ - QWenBlock( - config - ) - for i in range(config.num_hidden_layers) - ] - ) - self.ln_f = RMSNorm( - self.embed_dim, - eps=config.layer_norm_epsilon, - ) - - self.post_init() - - def get_input_embeddings(self): - return self.wte - - def set_input_embeddings(self, new_embeddings): - self.wte = new_embeddings - - def get_ntk_alpha(self, true_seq_len): - context_value = math.log(true_seq_len / self.seq_length, 2) + 1 - ntk_alpha = 2 ** math.ceil(context_value) - 1 - ntk_alpha = max(ntk_alpha, 1) - return ntk_alpha - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - if self.use_cache_quantization: - past_length = past_key_values[0][0][0].size(2) - else: - past_length = past_key_values[0][0].size(-2) - if position_ids is None: - position_ids = torch.arange( - past_length, - input_shape[-1] + past_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=self.dtype) - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - encoder_attention_mask = None - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - hidden_states = inputs_embeds - - kv_seq_len = hidden_states.size()[1] - if past_key_values[0] is not None: - # past key values[0][0] shape: bs * seq_len * head_num * dim - if self.use_cache_quantization: - kv_seq_len += past_key_values[0][0][0].shape[2] - else: - kv_seq_len += past_key_values[0][0].shape[1] - - if self.training or not self.use_dynamic_ntk: - ntk_alpha_list = [1.0] - elif kv_seq_len != hidden_states.size()[1]: - ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list - else: - ntk_alpha_list = [] - if attention_mask is not None and kv_seq_len > self.seq_length: - true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32) - for i in range(hidden_states.size()[0]): - true_seq_len = true_seq_lens[i].item() - ntk_alpha = self.get_ntk_alpha(true_seq_len) - ntk_alpha_list.append(ntk_alpha) - else: - ntk_alpha = self.get_ntk_alpha(kv_seq_len) - ntk_alpha_list.append(ntk_alpha) - self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list - rotary_pos_emb_list = [ - self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list - ] - - hidden_states = self.drop(hidden_states) - output_shape = input_shape + (hidden_states.size(-1),) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - rotary_pos_emb_list, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - rotary_pos_emb=rotary_pos_emb_list, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - - hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, presents, all_hidden_states] if v is not None - ) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -class QWenLMHeadModel(QWenPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"] - _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"] - - def __init__(self, config): - super().__init__(config) - assert ( - config.bf16 + config.fp16 + config.fp32 <= 1 - ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true" - - autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 - - if autoset_precision: - if SUPPORT_BF16: - logger.warn( - "The model is automatically converting to bf16 for faster inference. " - "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." - ) - config.bf16 = True - elif SUPPORT_FP16: - logger.warn( - "The model is automatically converting to fp16 for faster inference. " - "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." - ) - config.fp16 = True - else: - config.fp32 = True - - if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16: - logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".") - if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16: - logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster") - if config.fp32: - if SUPPORT_BF16: - logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".") - elif SUPPORT_FP16: - logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".") - - if config.use_flash_attn == "auto": - if config.bf16 or config.fp16: - logger.warn("Try importing flash-attention for faster inference...") - config.use_flash_attn = True - else: - config.use_flash_attn = False - if config.use_flash_attn and config.fp32: - logger.warn("Flash attention will be disabled because it does NOT support fp32.") - - if config.use_flash_attn: - _import_flash_attn() - - self.transformer = QWenModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - if config.bf16: - self.transformer.bfloat16() - self.lm_head.bfloat16() - if config.fp16: - self.transformer.half() - self.lm_head.half() - self.post_init() - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) - - if input_ids.size(0) == 1: - attention_mask = None - else: - attention_mask = kwargs.get("attention_mask", None) - - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - labels = labels.to(lm_logits.device) - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) - ) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - - return tuple( - tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past - ) - for layer_past in past_key_values - ) - - def chat( - self, - tokenizer: PreTrainedTokenizer, - query: str, - history: Optional[HistoryType], - system: str = "You are a helpful assistant.", - stream: Optional[bool] = _SENTINEL, - stop_words_ids: Optional[List[List[int]]] = None, - generation_config: Optional[GenerationConfig] = None, - **kwargs, - ) -> Tuple[str, HistoryType]: - generation_config = generation_config if generation_config is not None else self.generation_config - - assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT - assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT - if history is None: - history = [] - else: - # make a copy of the user's input such that is is left untouched - history = copy.deepcopy(history) - - if stop_words_ids is None: - stop_words_ids = [] - - max_window_size = kwargs.get('max_window_size', None) - if max_window_size is None: - max_window_size = generation_config.max_window_size - raw_text, context_tokens = make_context( - tokenizer, - query, - history=history, - system=system, - max_window_size=max_window_size, - chat_format=generation_config.chat_format, - ) - - stop_words_ids.extend(get_stop_words_ids( - generation_config.chat_format, tokenizer - )) - input_ids = torch.tensor([context_tokens]).to(self.device) - outputs = self.generate( - input_ids, - stop_words_ids=stop_words_ids, - return_dict_in_generate=False, - generation_config=generation_config, - **kwargs, - ) - - response = decode_tokens( - outputs[0], - tokenizer, - raw_text_len=len(raw_text), - context_length=len(context_tokens), - chat_format=generation_config.chat_format, - verbose=False, - errors='replace' - ) - - # as history is a copy of the user inputs, - # we can always return the new turn to the user. - # separating input history and output history also enables the user - # to implement more complex history management - history.append((query, response)) - - return response, history - - def chat_stream( - self, - tokenizer: PreTrainedTokenizer, - query: str, - history: Optional[HistoryType], - system: str = "You are a helpful assistant.", - stop_words_ids: Optional[List[List[int]]] = None, - logits_processor: Optional[LogitsProcessorList] = None, - generation_config: Optional[GenerationConfig] = None, - **kwargs, - ) -> Generator[str, Any, None]: - generation_config = generation_config if generation_config is not None else self.generation_config - assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT - if history is None: - history = [] - if stop_words_ids is None: - stop_words_ids = [] - - max_window_size = kwargs.get('max_window_size', None) - if max_window_size is None: - max_window_size = generation_config.max_window_size - raw_text, context_tokens = make_context( - tokenizer, - query, - history=history, - system=system, - max_window_size=max_window_size, - chat_format=generation_config.chat_format, - ) - - stop_words_ids.extend(get_stop_words_ids( - generation_config.chat_format, tokenizer - )) - if stop_words_ids is not None: - stop_words_logits_processor = StopWordsLogitsProcessor( - stop_words_ids=stop_words_ids, - eos_token_id=generation_config.eos_token_id, - ) - if logits_processor is None: - logits_processor = LogitsProcessorList([stop_words_logits_processor]) - else: - logits_processor.append(stop_words_logits_processor) - input_ids = torch.tensor([context_tokens]).to(self.device) - - from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig - self.__class__.generate_stream = NewGenerationMixin.generate - self.__class__.sample_stream = NewGenerationMixin.sample_stream - stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) - - def stream_generator(): - outputs = [] - for token in self.generate_stream( - input_ids, - return_dict_in_generate=False, - generation_config=stream_config, - logits_processor=logits_processor, - seed=-1, - **kwargs): - outputs.append(token.item()) - yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore') - - return stream_generator() - - def generate( - self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[ - Callable[[int, torch.Tensor], List[int]] - ] = None, - synced_gpus: Optional[bool] = None, - assistant_model: Optional["PreTrainedModel"] = None, - streamer: Optional["BaseStreamer"] = None, - **kwargs, - ) -> Union[GenerateOutput, torch.LongTensor]: - generation_config = generation_config if generation_config is not None else self.generation_config - - # Process stop_words_ids. - stop_words_ids = kwargs.pop("stop_words_ids", None) - if stop_words_ids is None and generation_config is not None: - stop_words_ids = getattr(generation_config, "stop_words_ids", None) - if stop_words_ids is None: - stop_words_ids = getattr(generation_config, "stop_words_ids", None) - - if stop_words_ids is not None: - stop_words_logits_processor = StopWordsLogitsProcessor( - stop_words_ids=stop_words_ids, - eos_token_id=generation_config.eos_token_id, - ) - if logits_processor is None: - logits_processor = LogitsProcessorList([stop_words_logits_processor]) - else: - logits_processor.append(stop_words_logits_processor) - - return super().generate( - inputs, - generation_config=generation_config, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - synced_gpus=synced_gpus, - assistant_model=assistant_model, - streamer=streamer, - **kwargs, - ) - - -class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, base=10000): - super().__init__() - self.dim = dim - self.base = base - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - if importlib.util.find_spec("einops") is None: - raise RuntimeError("einops is required for Rotary Embedding") - - self._rotary_pos_emb_cache = None - self._seq_len_cached = 0 - self._ntk_alpha_cached = 1.0 - self._ntk_alpha_cached_list = [1.0] - - def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0): - if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: - base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) - self.inv_freq = 1.0 / ( - base - ** ( - torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() - / self.dim - ) - ) - self._seq_len_cached = max(2 * seqlen, 16) - self._ntk_alpha_cached = ntk_alpha - seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device) - freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) - - emb = torch.cat((freqs, freqs), dim=-1) - from einops import rearrange - - emb = rearrange(emb, "n d -> 1 n 1 d") - - cos, sin = emb.cos(), emb.sin() - self._rotary_pos_emb_cache = [cos, sin] - - def forward(self, max_seq_len, ntk_alpha=1.0): - self.update_rotary_pos_emb_cache(max_seq_len, ntk_alpha) - cos, sin = self._rotary_pos_emb_cache - return [cos[:, :max_seq_len], sin[:, :max_seq_len]] - - -def _rotate_half(x): - from einops import rearrange - - x = rearrange(x, "... (j d) -> ... j d", j=2) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(t, freqs): - """ Apply rotary embedding to the first rotary_dim of the iput - - Arguments: - t (tensor(batch_size, seq_len, n_head, head_dim)): - the input embedding/hidden states - freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]): - the cached cos/sin position embeddings - """ - rot_dim = freqs[0].shape[-1] - cos, sin = freqs - t_float = t.float() - if apply_rotary_emb_func is not None and t.is_cuda: - # apply_rotary_emb in flash_attn requires cos/sin to be of - # shape (seqlen, rotary_dim / 2) and apply rotary embedding - # to the first rotary_dim of the input - cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2] - sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2] - return apply_rotary_emb_func(t_float, cos, sin).type_as(t) - else: - t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:] - t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin) - return torch.cat((t_rot, t_pass), dim=-1).type_as(t) - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - if rms_norm is not None and x.is_cuda: - return rms_norm(x, self.weight, self.eps) - else: - output = self._norm(x.float()).type_as(x) - return output * self.weight diff --git a/transformers/llm/export/llm_models/Qwen-1_8B/modeling_qwen.py b/transformers/llm/export/llm_models/Qwen-1_8B/modeling_qwen.py deleted file mode 100755 index 5138eea76..000000000 --- a/transformers/llm/export/llm_models/Qwen-1_8B/modeling_qwen.py +++ /dev/null @@ -1,1406 +0,0 @@ -# Copyright (c) Alibaba Cloud. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import copy -import importlib -import math -import pathlib -from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -import warnings - -from torch.nn import CrossEntropyLoss -from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList -from transformers.generation.logits_process import LogitsProcessorList - -if TYPE_CHECKING: - from transformers.generation.streamers import BaseStreamer -from transformers.generation.utils import GenerateOutput -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging - -try: - from einops import rearrange -except ImportError: - rearrange = None -from torch import nn - -SUPPORT_CUDA = torch.cuda.is_available() -SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() -SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 -SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2 - - -from .configuration_qwen import QWenConfig -from .qwen_generation_utils import ( - HistoryType, - make_context, - decode_tokens, - get_stop_words_ids, - StopWordsLogitsProcessor, -) - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "qwen" -_CONFIG_FOR_DOC = "QWenConfig" - -QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"] - -_ERROR_BAD_CHAT_FORMAT = """\ -We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml". -If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat(). -我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。 -如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。 -""" - -_SENTINEL = object() -_ERROR_STREAM_IN_CHAT = """\ -Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True). -向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。 -""" - -_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED = """\ -We detect you have activated flash attention support, but running model computation on CPU. Please make sure that your input data has been placed on GPU. If you actually want to run CPU computation, please following the readme and set device_map="cpu" to disable flash attention when loading the model (calling AutoModelForCausalLM.from_pretrained). -检测到您的模型已激活了flash attention支持,但正在执行CPU运算任务。如使用flash attention,请您确认模型输入已经传到GPU上。如果您确认要执行CPU运算,请您在载入模型(调用AutoModelForCausalLM.from_pretrained)时,按照readme说法,指定device_map="cpu"以禁用flash attention。 -""" - -apply_rotary_emb_func = None -rms_norm = None -flash_attn_unpadded_func = None -flash_attn_func = None - -def _import_flash_attn(): - global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func, flash_attn_func - try: - from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func - apply_rotary_emb_func = __apply_rotary_emb_func - except ImportError: - logger.warn( - "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency " - "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary" - ) - - try: - from flash_attn.ops.rms_norm import rms_norm as __rms_norm - rms_norm = __rms_norm - except ImportError: - logger.warn( - "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency " - "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm" - ) - - try: - import flash_attn - _flash_attn_func = None - if not hasattr(flash_attn, '__version__'): - from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func - else: - if int(flash_attn.__version__.split(".")[0]) >= 2: - if int(flash_attn.__version__.split(".")[1]) >= 1: - from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func - from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func - else: - from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func - flash_attn_unpadded_func = __flash_attn_unpadded_func - flash_attn_func = _flash_attn_func - except ImportError: - logger.warn( - "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency " - "https://github.com/Dao-AILab/flash-attention" - ) - -def quantize_cache_v(fdata, bits, qmax, qmin): - # b, s, head, h-dim->b, head, s, h-dim - qtype = torch.uint8 - device = fdata.device - shape = fdata.shape - - fdata_cal = torch.flatten(fdata, 2) - fmax = torch.amax(fdata_cal, dim=-1, keepdim=True) - fmin = torch.amin(fdata_cal, dim=-1, keepdim=True) - # Compute params - if qmax.device != fmax.device: - qmax = qmax.to(device) - qmin = qmin.to(device) - scale = (fmax - fmin) / (qmax - qmin) - zero = qmin - fmin / scale - scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() - zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() - # Quantize - res_data = fdata / scale + zero - qdata = torch.clamp(res_data, qmin, qmax).to(qtype) - return qdata.contiguous(), scale, zero - -def dequantize_cache_torch(qdata, scale, zero): - data = scale * (qdata - zero) - return data - -class FlashSelfAttention(torch.nn.Module): - def __init__( - self, - causal=False, - softmax_scale=None, - attention_dropout=0.0, - ): - super().__init__() - assert flash_attn_unpadded_func is not None, ( - "Please install FlashAttention first, " "e.g., with pip install flash-attn" - ) - assert ( - rearrange is not None - ), "Please install einops first, e.g., with pip install einops" - self.causal = causal - self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout - - def unpad_input(self, hidden_states, attention_mask): - valid_mask = attention_mask.squeeze(1).squeeze(1).eq(0) - seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - hidden_states = hidden_states[indices] - return hidden_states, indices, cu_seqlens, max_seqlen_in_batch - - def pad_input(self, hidden_states, indices, batch, seqlen): - output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device, - dtype=hidden_states.dtype) - output[indices] = hidden_states - return rearrange(output, '(b s) ... -> b s ...', b=batch) - - def forward(self, q, k, v, attention_mask=None): - assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v))) - assert all((i.is_cuda for i in (q, k, v))) - batch_size, seqlen_q = q.shape[0], q.shape[1] - seqlen_k = k.shape[1] - seqlen_out = seqlen_q - - if flash_attn_func is not None and batch_size == 1: - dropout_p = self.dropout_p if self.training else 0 - output = flash_attn_func(q, k, v, dropout_p, softmax_scale=self.softmax_scale, causal=self.causal) - return output - - q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * seqlen_q, - step=seqlen_q, - dtype=torch.int32, - device=q.device, - ) - - if batch_size > 1 and attention_mask is not None: - k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask) - if q.size(0) == v.size(0): - q = q[indices_k] - cu_seqlens_q = cu_seqlens_k - seqlen_q = seqlen_k - v = v[indices_k] - else: - cu_seqlens_k = torch.arange( - 0, - (batch_size + 1) * seqlen_k, - step=seqlen_k, - dtype=torch.int32, - device=q.device, - ) - - if self.training: - assert seqlen_k == seqlen_q - is_causal = self.causal - dropout_p = self.dropout_p - else: - is_causal = seqlen_q == seqlen_k - dropout_p = 0 - - output = flash_attn_unpadded_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - seqlen_q, - seqlen_k, - dropout_p, - softmax_scale=self.softmax_scale, - causal=is_causal, - ) - if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k: - output = self.pad_input(output, indices_k, batch_size, seqlen_out) - else: - new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:] - output = output.view(new_shape) - return output - - -class QWenAttention(nn.Module): - def __init__(self, config): - super().__init__() - - self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) - self.seq_length = config.seq_length - - self.hidden_size = config.hidden_size - self.split_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - - self.use_flash_attn = config.use_flash_attn - self.scale_attn_weights = True - - self.projection_size = config.kv_channels * config.num_attention_heads - - assert self.projection_size % config.num_attention_heads == 0 - self.hidden_size_per_attention_head = ( - self.projection_size // config.num_attention_heads - ) - - self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size) - - self.c_proj = nn.Linear( - config.hidden_size, self.projection_size, bias=not config.no_bias - ) - - self.is_fp32 = not (config.bf16 or config.fp16) - if ( - self.use_flash_attn - and flash_attn_unpadded_func is not None - and not self.is_fp32 - ): - self.core_attention_flash = FlashSelfAttention( - causal=True, attention_dropout=config.attn_dropout_prob - ) - self.bf16 = config.bf16 - - self.use_dynamic_ntk = config.use_dynamic_ntk - self.use_logn_attn = config.use_logn_attn - - logn_list = [ - math.log(i, self.seq_length) if i > self.seq_length else 1 - for i in range(1, 32768) - ] - logn_tensor = torch.tensor(logn_list)[None, :, None, None] - self.register_buffer("logn_tensor", logn_tensor, persistent=False) - - self.attn_dropout = nn.Dropout(config.attn_dropout_prob) - self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False - self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False - self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False - cache_dtype = torch.float - if self.bf16: - cache_dtype=torch.bfloat16 - elif config.fp16: - cache_dtype = torch.float16 - self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) - self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype) - - if config.use_cache_quantization and config.use_cache_kernel: - # pre check if the support files existing - module_root = pathlib.Path(__file__).parent - src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu") - if any(not (module_root/src).is_file() for src in src_files): - warnings.warn("KV cache kernel source files (.cpp and .cu) not found.") - self.cache_kernels = None - else: - try: - from .cpp_kernels import cache_autogptq_cuda_256 - self.cache_kernels = cache_autogptq_cuda_256 - except ImportError: - warnings.warn("Failed to import KV cache kernels.") - self.cache_kernels = None - - def _attn(self, query, key, value, no_use_mask, attention_mask=None, head_mask=None): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - if self.scale_attn_weights: - attn_weights = attn_weights / torch.full( - [], - value.size(-1) ** 0.5, - dtype=attn_weights.dtype, - device=attn_weights.device, - ) - - query_length, key_length = query.size(-2), key.size(-2) - if attention_mask is None: - causal_mask = self.bias[ - :, :, key_length - query_length : key_length, :key_length - ] - else: - causal_mask = attention_mask - mask_value = torch.finfo(attn_weights.dtype).min - mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to( - attn_weights.device - ) - attn_weights = torch.where( - causal_mask, attn_weights.to(attn_weights.dtype), mask_value - ) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - attn_output = attn_output.transpose(1, 2) - - return attn_output, attn_weights - - def __attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None): - device = query.device - if self.use_cache_quantization: - qk, qk_scale, qk_zero = key - if self.use_cache_kernel and self.cache_kernels is not None: - shape = query.shape[:-1] + (qk.shape[-2],) - attn_weights = torch.zeros(shape, dtype=torch.float16, device=device) - self.cache_kernels.vecquant8matmul_batched_faster_old( - query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(), - qk.transpose(-1, -2).contiguous(), - attn_weights, - qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(), - qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous()) - # attn_weights = attn_weights.to(query.dtype).contiguous() - else: - key = dequantize_cache_torch(qk, qk_scale, qk_zero) - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - else: - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - if self.scale_attn_weights: - if self.use_cache_quantization: - size_temp = value[0].size(-1) - else: - size_temp = value.size(-1) - attn_weights = attn_weights / (size_temp ** 0.5) - - mask_value = torch.finfo(attn_weights.dtype).min - if causal_mask is not None: - attn_weights = torch.where( - causal_mask, attn_weights.to(attn_weights.dtype), mask_value - ) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - if self.softmax_in_fp32: - attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1) - else: - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - attn_weights = attn_weights.type(query.dtype) - attn_weights = self.attn_dropout(attn_weights) - - if head_mask is not None: - attn_weights = attn_weights * head_mask - - if self.use_cache_quantization: - qv, qv_scale, qv_zero = value - if self.use_cache_kernel and self.cache_kernels is not None: - shape = attn_weights.shape[:-1] + (query.shape[-1],) - attn_output = torch.zeros(shape, dtype=torch.float16, device=device) - self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old( - attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(), - qv.contiguous(), # dtype: int32 - attn_output, - qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(), - qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous()) - if attn_output.dtype != query.dtype: - attn_output = attn_output.to(query.dtype) - attn_weights = attn_weights.to(query.dtype) - else: - value = dequantize_cache_torch(qv, qv_scale, qv_zero) - attn_output = torch.matmul(attn_weights, value) - else: - attn_output = torch.matmul(attn_weights, value) - - attn_output = attn_output.transpose(1, 2) - - return attn_output, attn_weights - - def _split_heads(self, tensor, num_heads, attn_head_size): - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - return tensor - - def _merge_heads(self, tensor, num_heads, attn_head_size): - tensor = tensor.contiguous() - new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) - return tensor.view(new_shape) - - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - rotary_pos_emb_list: Optional[List[torch.Tensor]] = None, - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ): - mixed_x_layer = self.c_attn(hidden_states) - - query, key, value = mixed_x_layer.split(self.split_size, dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) - - if rotary_pos_emb_list is not None: - cur_len = query.shape[1] - if True: - rotary_pos_emb = rotary_pos_emb_list - rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] - rotary_pos_emb = (rotary_pos_emb,) * 2 - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - query = apply_rotary_pos_emb(query, q_pos_emb) - key = apply_rotary_pos_emb(key, k_pos_emb) - else: - query_list = [] - key_list = [] - for i, rotary_pos_emb in enumerate(rotary_pos_emb_list): - rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] - rotary_pos_emb = (rotary_pos_emb,) * 2 - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)] - key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)] - query = torch.cat(query_list, dim=0) - key = torch.cat(key_list, dim=0) - - if self.use_cache_quantization: - key = quantize_cache_v(key.permute(0, 2, 1, 3), - bits=8, - qmin=self.cache_qmin, - qmax=self.cache_qmax) - value = quantize_cache_v(value.permute(0, 2, 1, 3), - bits=8, - qmin=self.cache_qmin, - qmax=self.cache_qmax) - - - if layer_past is not None: - past_key, past_value = layer_past[0], layer_past[1] - if self.use_cache_quantization: - # use_cache_quantization: - # present=((q_key,key_scale,key_zero_point), - # (q_value,value_scale,value_zero_point)) - key = (torch.cat((past_key[0], key[0]), dim=2), - torch.cat((past_key[1], key[1]), dim=2), - torch.cat((past_key[2], key[2]), dim=2)) - value = (torch.cat((past_value[0], value[0]), dim=2), - torch.cat((past_value[1], value[1]), dim=2), - torch.cat((past_value[2], value[2]), dim=2)) - else: - # not use_cache_quantization: - # present=(key,value) - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=1) - - if use_cache: - present = (key, value) - else: - present = None - - key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) - if key_size > self.seq_length and self.use_logn_attn and not self.training: - if self.use_cache_quantization: - seq_start = key[0].size(2) - query.size(1) - seq_end = key[0].size(2) - else: - seq_start = key.size(1) - query.size(1) - seq_end = key.size(1) - logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) - query = query * logn_tensor.expand_as(query) - - if ( - self.use_flash_attn - and flash_attn_unpadded_func is not None - and not self.is_fp32 - and query.is_cuda - ): - q, k, v = query, key, value - attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask) - else: - key_size = key[0].size(2) if self.use_cache_quantization else key.size(1) - if query.size(1) == key_size: - causal_mask = torch.tril( - torch.ones((key_size, key_size), dtype=torch.bool, device=query.device) - ).view(1, 1, key_size, key_size) - else: - causal_mask = None - query = query.permute(0, 2, 1, 3) - if not self.use_cache_quantization: - key = key.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) - if ( - causal_mask is None - and self.use_flash_attn - and flash_attn_unpadded_func is not None - and not self.is_fp32 - and not query.is_cuda - ): - raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED) - - if not self.use_cache_quantization and SUPPORT_TORCH2 and False: - if attention_mask is not None: - # attention_mask = attention_mask.expand( - # -1, -1, causal_mask.size(2), -1 - # ) - # if causal_mask is not None: - # attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min) - causal_mask = attention_mask - else: - attention_mask = causal_mask - attn_output = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask - ).transpose(1, 2) - attn_weight = None - else: - attn_output, attn_weight = self._attn( - # query, key, value, causal_mask, attention_mask, head_mask - query, key, value, attention_mask, attention_mask, head_mask - ) - context_layer = self._merge_heads( - attn_output, self.num_heads, self.head_dim - ) - - attn_output = self.c_proj(context_layer) - - outputs = (attn_output, present) - if output_attentions: - if ( - self.use_flash_attn - and flash_attn_unpadded_func is not None - and not self.is_fp32 - ): - raise ValueError("Cannot output attentions while using flash-attn") - elif not self.use_cache_quantization and SUPPORT_TORCH2: - raise ValueError("Cannot output attentions while using scaled_dot_product_attention") - else: - outputs += (attn_weight,) - - return outputs - - -class QWenMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.w1 = nn.Linear( - config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias - ) - self.w2 = nn.Linear( - config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias - ) - ff_dim_in = config.intermediate_size // 2 - self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias) - - def forward(self, hidden_states): - a1 = self.w1(hidden_states) - a2 = self.w2(hidden_states) - intermediate_parallel = a1 * F.silu(a2) - output = self.c_proj(intermediate_parallel) - return output - - -class QWenBlock(nn.Module): - def __init__(self, config): - super().__init__() - hidden_size = config.hidden_size - self.bf16 = config.bf16 - - self.ln_1 = RMSNorm( - hidden_size, - eps=config.layer_norm_epsilon, - ) - self.attn = QWenAttention(config) - self.ln_2 = RMSNorm( - hidden_size, - eps=config.layer_norm_epsilon, - ) - - self.mlp = QWenMLP(config) - - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - rotary_pos_emb: Optional[List[List[torch.Tensor]]] = None, - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ): - layernorm_output = self.ln_1(hidden_states) - - attn_outputs = self.attn( - layernorm_output, - rotary_pos_emb, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - attn_output = attn_outputs[0] - - outputs = attn_outputs[1:] - - residual = hidden_states - layernorm_input = attn_output + residual - - layernorm_output = self.ln_2(layernorm_input) - - residual = layernorm_input - mlp_output = self.mlp(layernorm_output) - hidden_states = residual + mlp_output - - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] - - return outputs - - -class QWenPreTrainedModel(PreTrainedModel): - config_class = QWenConfig - base_model_prefix = "transformer" - is_parallelizable = False - supports_gradient_checkpointing = True - _no_split_modules = ["QWenBlock"] - _skip_keys_device_placement = "past_key_values" - - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, RMSNorm): - module.weight.data.fill_(1.0) - - for name, p in module.named_parameters(): - if name == "c_proj.weight": - p.data.normal_( - mean=0.0, - std=( - self.config.initializer_range - / math.sqrt(2 * self.config.num_hidden_layers) - ), - ) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, QWenModel): - module.gradient_checkpointing = value - - -class QWenModel(QWenPreTrainedModel): - _keys_to_ignore_on_load_missing = ["attn.masked_bias"] - - def __init__(self, config): - super().__init__(config) - self.vocab_size = config.vocab_size - self.num_hidden_layers = config.num_hidden_layers - self.embed_dim = config.hidden_size - self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False - - self.gradient_checkpointing = False - self.use_dynamic_ntk = config.use_dynamic_ntk - self.seq_length = config.seq_length - - self.wte = nn.Embedding(self.vocab_size, self.embed_dim) - - self.drop = nn.Dropout(config.emb_dropout_prob) - - if config.rotary_pct == 1.0: - self.rotary_ndims = None - else: - assert config.rotary_pct < 1 - self.rotary_ndims = int( - config.kv_channels * config.rotary_pct - ) - dim = ( - self.rotary_ndims - if self.rotary_ndims is not None - else config.kv_channels - ) - self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) - - self.use_flash_attn = config.use_flash_attn - self.is_fp32 = not (config.bf16 or config.fp16) - - self.h = nn.ModuleList( - [ - QWenBlock( - config - ) - for i in range(config.num_hidden_layers) - ] - ) - self.ln_f = RMSNorm( - self.embed_dim, - eps=config.layer_norm_epsilon, - ) - - self.post_init() - - def get_input_embeddings(self): - return self.wte - - def set_input_embeddings(self, new_embeddings): - self.wte = new_embeddings - - def get_ntk_alpha(self, true_seq_len): - context_value = math.log(true_seq_len / self.seq_length, 2) + 1 - ntk_alpha = 2 ** math.ceil(context_value) - 1 - ntk_alpha = max(ntk_alpha, 1) - return ntk_alpha - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - if self.use_cache_quantization: - past_length = past_key_values[0][0][0].size(2) - else: - past_length = past_key_values[0][0].size(-2) - if position_ids is None: - position_ids = torch.arange( - past_length, - input_shape[-1] + past_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=self.dtype) - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - encoder_attention_mask = None - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - hidden_states = inputs_embeds - - kv_seq_len = hidden_states.size()[1] - if past_key_values[0] is not None: - # past key values[0][0] shape: bs * seq_len * head_num * dim - if self.use_cache_quantization: - kv_seq_len += past_key_values[0][0][0].shape[2] - else: - kv_seq_len += past_key_values[0][0].shape[1] - - if self.training or not self.use_dynamic_ntk: - ntk_alpha_list = [1.0] - elif kv_seq_len != hidden_states.size()[1]: - ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list - else: - ntk_alpha_list = [] - if attention_mask is not None and kv_seq_len > self.seq_length: - true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32) - for i in range(hidden_states.size()[0]): - true_seq_len = true_seq_lens[i].item() - ntk_alpha = self.get_ntk_alpha(true_seq_len) - ntk_alpha_list.append(ntk_alpha) - else: - ntk_alpha = self.get_ntk_alpha(kv_seq_len) - ntk_alpha_list.append(ntk_alpha) - self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list - rotary_pos_emb_list = [ - self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list - ] - - hidden_states = self.drop(hidden_states) - output_shape = input_shape + (hidden_states.size(-1),) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - rotary_pos_emb_list, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - rotary_pos_emb=rotary_pos_emb_list, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - - hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, presents, all_hidden_states] if v is not None - ) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -class QWenLMHeadModel(QWenPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"] - _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"] - - def __init__(self, config): - super().__init__(config) - assert ( - config.bf16 + config.fp16 + config.fp32 <= 1 - ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true" - - autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 - - if autoset_precision: - if SUPPORT_BF16: - logger.warn( - "The model is automatically converting to bf16 for faster inference. " - "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." - ) - config.bf16 = True - elif SUPPORT_FP16: - logger.warn( - "The model is automatically converting to fp16 for faster inference. " - "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." - ) - config.fp16 = True - else: - config.fp32 = True - - if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16: - logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".") - if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16: - logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster") - if config.fp32: - if SUPPORT_BF16: - logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".") - elif SUPPORT_FP16: - logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".") - - if config.use_flash_attn == "auto": - if config.bf16 or config.fp16: - logger.warn("Try importing flash-attention for faster inference...") - config.use_flash_attn = True - else: - config.use_flash_attn = False - if config.use_flash_attn and config.fp32: - logger.warn("Flash attention will be disabled because it does NOT support fp32.") - - if config.use_flash_attn: - _import_flash_attn() - - self.transformer = QWenModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - if config.bf16: - self.transformer.bfloat16() - self.lm_head.bfloat16() - if config.fp16: - self.transformer.half() - self.lm_head.half() - self.post_init() - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) - - if input_ids.size(0) == 1: - attention_mask = None - else: - attention_mask = kwargs.get("attention_mask", None) - - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - labels = labels.to(lm_logits.device) - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) - ) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - - return tuple( - tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past - ) - for layer_past in past_key_values - ) - - def chat( - self, - tokenizer: PreTrainedTokenizer, - query: str, - history: Optional[HistoryType], - system: str = "You are a helpful assistant.", - stream: Optional[bool] = _SENTINEL, - stop_words_ids: Optional[List[List[int]]] = None, - generation_config: Optional[GenerationConfig] = None, - **kwargs, - ) -> Tuple[str, HistoryType]: - generation_config = generation_config if generation_config is not None else self.generation_config - - assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT - assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT - if history is None: - history = [] - else: - # make a copy of the user's input such that is is left untouched - history = copy.deepcopy(history) - - if stop_words_ids is None: - stop_words_ids = [] - - max_window_size = kwargs.get('max_window_size', None) - if max_window_size is None: - max_window_size = generation_config.max_window_size - raw_text, context_tokens = make_context( - tokenizer, - query, - history=history, - system=system, - max_window_size=max_window_size, - chat_format=generation_config.chat_format, - ) - - stop_words_ids.extend(get_stop_words_ids( - generation_config.chat_format, tokenizer - )) - input_ids = torch.tensor([context_tokens]).to(self.device) - outputs = self.generate( - input_ids, - stop_words_ids=stop_words_ids, - return_dict_in_generate=False, - generation_config=generation_config, - **kwargs, - ) - - response = decode_tokens( - outputs[0], - tokenizer, - raw_text_len=len(raw_text), - context_length=len(context_tokens), - chat_format=generation_config.chat_format, - verbose=False, - errors='replace' - ) - - # as history is a copy of the user inputs, - # we can always return the new turn to the user. - # separating input history and output history also enables the user - # to implement more complex history management - history.append((query, response)) - - return response, history - - def chat_stream( - self, - tokenizer: PreTrainedTokenizer, - query: str, - history: Optional[HistoryType], - system: str = "You are a helpful assistant.", - stop_words_ids: Optional[List[List[int]]] = None, - logits_processor: Optional[LogitsProcessorList] = None, - generation_config: Optional[GenerationConfig] = None, - **kwargs, - ) -> Generator[str, Any, None]: - generation_config = generation_config if generation_config is not None else self.generation_config - assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT - if history is None: - history = [] - if stop_words_ids is None: - stop_words_ids = [] - - max_window_size = kwargs.get('max_window_size', None) - if max_window_size is None: - max_window_size = generation_config.max_window_size - raw_text, context_tokens = make_context( - tokenizer, - query, - history=history, - system=system, - max_window_size=max_window_size, - chat_format=generation_config.chat_format, - ) - - stop_words_ids.extend(get_stop_words_ids( - generation_config.chat_format, tokenizer - )) - if stop_words_ids is not None: - stop_words_logits_processor = StopWordsLogitsProcessor( - stop_words_ids=stop_words_ids, - eos_token_id=generation_config.eos_token_id, - ) - if logits_processor is None: - logits_processor = LogitsProcessorList([stop_words_logits_processor]) - else: - logits_processor.append(stop_words_logits_processor) - input_ids = torch.tensor([context_tokens]).to(self.device) - - from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig - self.__class__.generate_stream = NewGenerationMixin.generate - self.__class__.sample_stream = NewGenerationMixin.sample_stream - stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) - - def stream_generator(): - outputs = [] - for token in self.generate_stream( - input_ids, - return_dict_in_generate=False, - generation_config=stream_config, - logits_processor=logits_processor, - seed=-1, - **kwargs): - outputs.append(token.item()) - yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore') - - return stream_generator() - - def generate( - self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[ - Callable[[int, torch.Tensor], List[int]] - ] = None, - synced_gpus: Optional[bool] = None, - assistant_model: Optional["PreTrainedModel"] = None, - streamer: Optional["BaseStreamer"] = None, - **kwargs, - ) -> Union[GenerateOutput, torch.LongTensor]: - generation_config = generation_config if generation_config is not None else self.generation_config - - # Process stop_words_ids. - stop_words_ids = kwargs.pop("stop_words_ids", None) - if stop_words_ids is None and generation_config is not None: - stop_words_ids = getattr(generation_config, "stop_words_ids", None) - if stop_words_ids is None: - stop_words_ids = getattr(generation_config, "stop_words_ids", None) - - if stop_words_ids is not None: - stop_words_logits_processor = StopWordsLogitsProcessor( - stop_words_ids=stop_words_ids, - eos_token_id=generation_config.eos_token_id, - ) - if logits_processor is None: - logits_processor = LogitsProcessorList([stop_words_logits_processor]) - else: - logits_processor.append(stop_words_logits_processor) - - return super().generate( - inputs, - generation_config=generation_config, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - synced_gpus=synced_gpus, - assistant_model=assistant_model, - streamer=streamer, - **kwargs, - ) - - -class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, base=10000): - super().__init__() - self.dim = dim - self.base = base - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - if importlib.util.find_spec("einops") is None: - raise RuntimeError("einops is required for Rotary Embedding") - - self._rotary_pos_emb_cache = None - self._seq_len_cached = 0 - self._ntk_alpha_cached = 1.0 - self._ntk_alpha_cached_list = [1.0] - - def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0): - if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: - base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) - self.inv_freq = 1.0 / ( - base - ** ( - torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() - / self.dim - ) - ) - self._seq_len_cached = max(2 * seqlen, 16) - self._ntk_alpha_cached = ntk_alpha - seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device) - freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) - - emb = torch.cat((freqs, freqs), dim=-1) - from einops import rearrange - - emb = rearrange(emb, "n d -> 1 n 1 d") - - cos, sin = emb.cos(), emb.sin() - self._rotary_pos_emb_cache = [cos, sin] - - def forward(self, max_seq_len, ntk_alpha=1.0): - self.update_rotary_pos_emb_cache(max_seq_len, ntk_alpha) - cos, sin = self._rotary_pos_emb_cache - return [cos[:, :max_seq_len], sin[:, :max_seq_len]] - - -def _rotate_half(x): - from einops import rearrange - - x = rearrange(x, "... (j d) -> ... j d", j=2) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(t, freqs): - """ Apply rotary embedding to the first rotary_dim of the iput - - Arguments: - t (tensor(batch_size, seq_len, n_head, head_dim)): - the input embedding/hidden states - freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]): - the cached cos/sin position embeddings - """ - rot_dim = freqs[0].shape[-1] - cos, sin = freqs - t_float = t.float() - if apply_rotary_emb_func is not None and t.is_cuda: - # apply_rotary_emb in flash_attn requires cos/sin to be of - # shape (seqlen, rotary_dim / 2) and apply rotary embedding - # to the first rotary_dim of the input - cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2] - sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2] - return apply_rotary_emb_func(t_float, cos, sin).type_as(t) - else: - t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:] - t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin) - return torch.cat((t_rot, t_pass), dim=-1).type_as(t) - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - if rms_norm is not None and x.is_cuda: - return rms_norm(x, self.weight, self.eps) - else: - output = self._norm(x.float()).type_as(x) - return output * self.weight diff --git a/transformers/llm/export/llm_models/Qwen-7B-Chat/config.json b/transformers/llm/export/llm_models/Qwen-7B-Chat/config.json deleted file mode 100644 index 2a794d958..000000000 --- a/transformers/llm/export/llm_models/Qwen-7B-Chat/config.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "architectures": [ - "QWenLMHeadModel" - ], - "auto_map": { - "AutoConfig": "configuration_qwen.QWenConfig", - "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel" - }, - "attn_dropout_prob": 0.0, - "bf16": false, - "fp16": false, - "fp32": false, - "emb_dropout_prob": 0.0, - "intermediate_size": 22016, - "initializer_range": 0.02, - "kv_channels": 128, - "layer_norm_epsilon": 1e-06, - "model_type": "qwen", - "hidden_size": 4096, - "num_attention_heads": 32, - "num_hidden_layers": 32, - "max_position_embeddings": 8192, - "no_bias": true, - "onnx_safe": null, - "rotary_emb_base": 10000, - "rotary_pct": 1.0, - "scale_attn_weights": true, - "seq_length": 2048, - "tie_word_embeddings": false, - "tokenizer_type": "QWenTokenizer", - "transformers_version": "4.31.0", - "use_cache": true, - "use_flash_attn": "auto", - "vocab_size": 151936, - "use_dynamic_ntk": true, - "use_logn_attn": false -} diff --git a/transformers/llm/export/llm_models/Qwen-7B-Chat/modeling_qwen.py b/transformers/llm/export/llm_models/Qwen-7B-Chat/modeling_qwen.py deleted file mode 100644 index 698486f6f..000000000 --- a/transformers/llm/export/llm_models/Qwen-7B-Chat/modeling_qwen.py +++ /dev/null @@ -1,1199 +0,0 @@ -# Copyright (c) Alibaba Cloud. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import importlib -import math -from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch.cuda.amp import autocast - -from torch.nn import CrossEntropyLoss -from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList -from transformers.generation.logits_process import LogitsProcessorList - -if TYPE_CHECKING: - from transformers.generation.streamers import BaseStreamer -from transformers.generation.utils import GenerateOutput -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging - -try: - from einops import rearrange -except ImportError: - rearrange = None -from torch import nn - -SUPPORT_CUDA = torch.cuda.is_available() -SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() -SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 - -from .configuration_qwen import QWenConfig -from .qwen_generation_utils import ( - HistoryType, - make_context, - decode_tokens, - get_stop_words_ids, - StopWordsLogitsProcessor, -) - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "qwen" -_CONFIG_FOR_DOC = "QWenConfig" - -QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"] - -_ERROR_BAD_CHAT_FORMAT = """\ -We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml". -If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat(). -我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。 -如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。 -""" - -_SENTINEL = object() -_ERROR_STREAM_IN_CHAT = """\ -Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True). -向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。 -""" - -apply_rotary_emb_func = None -rms_norm = None -flash_attn_unpadded_func = None - - -def _import_flash_attn(): - global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func - try: - from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func - apply_rotary_emb_func = __apply_rotary_emb_func - except ImportError: - logger.warn( - "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency " - "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary" - ) - - try: - from flash_attn.ops.rms_norm import rms_norm as __rms_norm - rms_norm = __rms_norm - except ImportError: - logger.warn( - "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency " - "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm" - ) - - try: - import flash_attn - if not hasattr(flash_attn, '__version__'): - from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func - else: - if int(flash_attn.__version__.split(".")[0]) >= 2: - from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func - else: - from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func - flash_attn_unpadded_func = __flash_attn_unpadded_func - except ImportError: - logger.warn( - "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency " - "https://github.com/Dao-AILab/flash-attention" - ) - - -class FlashSelfAttention(torch.nn.Module): - def __init__( - self, - causal=False, - softmax_scale=None, - attention_dropout=0.0, - ): - super().__init__() - assert flash_attn_unpadded_func is not None, ( - "Please install FlashAttention first, " "e.g., with pip install flash-attn" - ) - assert ( - rearrange is not None - ), "Please install einops first, e.g., with pip install einops" - self.causal = causal - self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout - - def forward(self, q, k, v): - assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v))) - assert all((i.is_cuda for i in (q, k, v))) - batch_size, seqlen_q = q.shape[0], q.shape[1] - seqlen_k = k.shape[1] - q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * seqlen_q, - step=seqlen_q, - dtype=torch.int32, - device=q.device, - ) - - if self.training: - assert seqlen_k == seqlen_q - - is_causal = self.causal - cu_seqlens_k = cu_seqlens_q - else: - is_causal = seqlen_q == seqlen_k - cu_seqlens_k = torch.arange( - 0, - (batch_size + 1) * seqlen_k, - step=seqlen_k, - dtype=torch.int32, - device=q.device, - ) - self.dropout_p = 0 - output = flash_attn_unpadded_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - seqlen_q, - seqlen_k, - self.dropout_p, - softmax_scale=self.softmax_scale, - causal=is_causal, - ) - - output = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - return output - - -class QWenAttention(nn.Module): - def __init__(self, config): - super().__init__() - - max_positions = config.max_position_embeddings - self.register_buffer( - "bias", - torch.tril( - torch.ones((max_positions, max_positions), dtype=torch.bool) - ).view(1, 1, max_positions, max_positions), - persistent=False, - ) - self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) - self.seq_length = config.seq_length - - self.hidden_size = config.hidden_size - self.split_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - - self.use_flash_attn = config.use_flash_attn - self.scale_attn_weights = True - - self.projection_size = config.kv_channels * config.num_attention_heads - - assert self.projection_size % config.num_attention_heads == 0 - self.hidden_size_per_attention_head = ( - self.projection_size // config.num_attention_heads - ) - - self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size) - - self.c_proj = nn.Linear( - config.hidden_size, self.projection_size, bias=not config.no_bias - ) - - self.is_fp32 = not (config.bf16 or config.fp16) - if ( - self.use_flash_attn - and flash_attn_unpadded_func is not None - and not self.is_fp32 - ): - self.core_attention_flash = FlashSelfAttention( - causal=True, attention_dropout=config.attn_dropout_prob - ) - - self.bf16 = config.bf16 - - if config.rotary_pct == 1.0: - self.rotary_ndims = None - else: - assert config.rotary_pct < 1 - self.rotary_ndims = int( - self.hidden_size_per_attention_head * config.rotary_pct - ) - dim = ( - self.rotary_ndims - if self.rotary_ndims is not None - else self.hidden_size_per_attention_head - ) - self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) - - self.use_dynamic_ntk = config.use_dynamic_ntk - self.use_logn_attn = config.use_logn_attn - - logn_list = [ - math.log(i, self.seq_length) if i > self.seq_length else 1 - for i in range(1, 32768) - ] - self.logn_tensor = torch.tensor(logn_list)[None, :, None, None] - self._ntk_cached = 1.0 - - self.attn_dropout = nn.Dropout(config.attn_dropout_prob) - - def _attn(self, query, key, value, attention_mask=None, head_mask=None): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - if self.scale_attn_weights: - attn_weights = attn_weights / math.sqrt(self.head_dim) - - query_length, key_length = query.size(-2), key.size(-2) - if attention_mask is None: - causal_mask = self.bias[ - :, :, key_length - query_length : key_length, :key_length - ] - else: - causal_mask = attention_mask - mask_value = torch.finfo(attn_weights.dtype).min - mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to( - attn_weights.device - ) - attn_weights = torch.where( - causal_mask, attn_weights.to(attn_weights.dtype), mask_value - ) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - attn_output = attn_output.transpose(1, 2) - - return attn_output, attn_weights - - def _upcast_and_reordered_attn( - self, query, key, value, attention_mask=None, head_mask=None - ): - bsz, num_heads, q_seq_len, dk = query.size() - _, _, k_seq_len, _ = key.size() - - attn_weights = torch.empty( - bsz * num_heads, - q_seq_len, - k_seq_len, - dtype=torch.float32, - device=query.device, - ) - - scale_factor = 1.0 - if self.scale_attn_weights: - scale_factor /= float(value.size(-1)) ** 0.5 - - with autocast(enabled=False): - q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape( - -1, dk, k_seq_len - ) - attn_weights = torch.baddbmm( - attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor - ) - attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) - - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[ - :, :, key_length - query_length : key_length, :key_length - ] - mask_value = torch.finfo(attn_weights.dtype).min - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to( - attn_weights.device - ) - attn_weights = torch.where(causal_mask, attn_weights, mask_value) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if attn_weights.dtype != torch.float32: - raise RuntimeError( - "Error with upcasting, attn_weights does not have dtype torch.float32" - ) - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - - return attn_output, attn_weights - - def _split_heads(self, tensor, num_heads, attn_head_size): - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - return tensor - - def _merge_heads(self, tensor, num_heads, attn_head_size): - tensor = tensor.contiguous() - new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) - return tensor.view(new_shape) - - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ): - mixed_x_layer = self.c_attn(hidden_states) - query, key, value = mixed_x_layer.split(self.split_size, dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) - - kv_seq_len = hidden_states.size()[1] - if layer_past is not None: - # layer past[0] shape: bs * seq_len * head_num * dim - kv_seq_len += layer_past[0].shape[1] - if ( - self.use_dynamic_ntk - and kv_seq_len == hidden_states.size()[1] - and not self.training - ): - context_value = math.log(kv_seq_len / self.seq_length, 2) + 1 - ntk_alpha = 2 ** math.ceil(context_value) - 1 - ntk_alpha = max(ntk_alpha, 1) - self._ntk_cached = ntk_alpha - else: - ntk_alpha = self._ntk_cached - if rotary_pos_emb is None: - rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to( - hidden_states.device - ) - cur_len = query.shape[1] - rotary_pos_emb = rotary_pos_emb[:, -cur_len:, :, :] - - if rotary_pos_emb is not None and False: - if isinstance(rotary_pos_emb, tuple): - rotary_pos_emb = rotary_pos_emb - else: - rotary_pos_emb = (rotary_pos_emb,) * 2 - - if rotary_pos_emb is not None: - ''' - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - cur_len = query.shape[1] - q_pos_emb = q_pos_emb[:, -cur_len:, :, :] - k_pos_emb = k_pos_emb[:, -cur_len:, :, :] - ''' - query = apply_rotary_pos_emb(query, rotary_pos_emb) - key = apply_rotary_pos_emb(key, rotary_pos_emb) - - if layer_past is not None: - past_key, past_value = layer_past[0], layer_past[1] - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=1) - - if use_cache: - present = torch.stack((key, value)) - else: - present = None - - if self.use_logn_attn and not self.training: - if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype: - self.logn_tensor = self.logn_tensor.to(query.device).type_as(query) - seq_start = key.size(1) - query.size(1) - seq_end = key.size(1) - logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :] - query = query * logn_tensor.expand_as(query) - - if ( - self.use_flash_attn - and flash_attn_unpadded_func is not None - and not self.is_fp32 - and query.is_cuda - ): - q, k, v = query, key, value - context_layer = self.core_attention_flash(q, k, v) - - context_layer = rearrange( - context_layer, "b s h d -> b s (h d)" - ).contiguous() - else: - query = query.permute(0, 2, 1, 3) - key = key.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) - attn_output, attn_weight = self._attn( - query, key, value, attention_mask, head_mask - ) - context_layer = self._merge_heads( - attn_output, self.num_heads, self.head_dim - ) - - attn_output = self.c_proj(context_layer) - outputs = (attn_output, present) - if output_attentions: - if ( - self.use_flash_attn - and flash_attn_unpadded_func is not None - and not self.is_fp32 - ): - raise ValueError("Cannot output attentions while using flash-attn") - else: - outputs += (attn_weight,) - - return outputs - - -class QWenMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.w1 = nn.Linear( - config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias - ) - self.w2 = nn.Linear( - config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias - ) - ff_dim_in = config.intermediate_size // 2 - self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias) - - def forward(self, hidden_states): - a1 = self.w1(hidden_states) - a2 = self.w2(hidden_states) - intermediate_parallel = a1 * F.silu(a2) - output = self.c_proj(intermediate_parallel) - return output - - -class QWenBlock(nn.Module): - def __init__(self, config): - super().__init__() - hidden_size = config.hidden_size - self.bf16 = config.bf16 - - self.ln_1 = RMSNorm( - hidden_size, - eps=config.layer_norm_epsilon, - ) - self.attn = QWenAttention(config) - self.ln_2 = RMSNorm( - hidden_size, - eps=config.layer_norm_epsilon, - ) - - self.mlp = QWenMLP(config) - - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ): - layernorm_output = self.ln_1(hidden_states) - - attn_outputs = self.attn( - layernorm_output, - layer_past=layer_past, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - attn_output = attn_outputs[0] - - outputs = attn_outputs[1:] - - residual = hidden_states - layernorm_input = attn_output + residual - - layernorm_output = self.ln_2(layernorm_input) - - residual = layernorm_input - mlp_output = self.mlp(layernorm_output) - hidden_states = residual + mlp_output - - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] - - return outputs - - -class QWenPreTrainedModel(PreTrainedModel): - config_class = QWenConfig - base_model_prefix = "transformer" - is_parallelizable = False - supports_gradient_checkpointing = True - _no_split_modules = ["QWenBlock"] - - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, RMSNorm): - module.weight.data.fill_(1.0) - - for name, p in module.named_parameters(): - if name == "c_proj.weight": - p.data.normal_( - mean=0.0, - std=( - self.config.initializer_range - / math.sqrt(2 * self.config.num_hidden_layers) - ), - ) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, QWenModel): - module.gradient_checkpointing = value - - -class QWenModel(QWenPreTrainedModel): - _keys_to_ignore_on_load_missing = ["attn.masked_bias"] - - def __init__(self, config): - super().__init__(config) - self.vocab_size = config.vocab_size - self.num_hidden_layers = config.num_hidden_layers - self.embed_dim = config.hidden_size - - self.gradient_checkpointing = False - - self.wte = nn.Embedding(self.vocab_size, self.embed_dim) - - self.drop = nn.Dropout(config.emb_dropout_prob) - self.h = nn.ModuleList( - [ - QWenBlock( - config, - ) - for i in range(config.num_hidden_layers) - ] - ) - self.ln_f = RMSNorm( - self.embed_dim, - eps=config.layer_norm_epsilon, - ) - - self.post_init() - - def get_input_embeddings(self): - return self.wte - - def set_input_embeddings(self, new_embeddings): - self.wte = new_embeddings - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - - if position_ids is None: - position_ids = torch.arange( - past_length, - input_shape[-1] + past_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - if attention_mask is not None: - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = attention_mask.view(batch_size, -1) - attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=self.dtype) - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min - - encoder_attention_mask = None - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - hidden_states = inputs_embeds - - hidden_states = self.drop(hidden_states) - output_shape = input_shape + (hidden_states.size(-1),) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[1],) - - hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, presents, all_hidden_states] if v is not None - ) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -class QWenLMHeadModel(QWenPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"] - _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"] - - def __init__(self, config): - super().__init__(config) - assert ( - config.bf16 + config.fp16 + config.fp32 <= 1 - ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true" - - autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 - - if autoset_precision: - if SUPPORT_BF16: - logger.warn( - "The model is automatically converting to bf16 for faster inference. " - "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." - ) - config.bf16 = True - elif SUPPORT_FP16: - logger.warn( - "The model is automatically converting to fp16 for faster inference. " - "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." - ) - config.fp16 = True - else: - config.fp32 = True - - if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16: - logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".") - if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16: - logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster") - if config.fp32: - if SUPPORT_BF16: - logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".") - elif SUPPORT_FP16: - logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".") - - if config.use_flash_attn == "auto": - if config.bf16 or config.fp16: - logger.warn("Try importing flash-attention for faster inference...") - config.use_flash_attn = True - else: - config.use_flash_attn = False - if config.use_flash_attn and config.fp32: - logger.warn("Flash attention will be disabled because it does NOT support fp32.") - - if config.use_flash_attn: - _import_flash_attn() - - self.transformer = QWenModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - if config.bf16: - self.transformer.bfloat16() - self.lm_head.bfloat16() - if config.fp16: - self.transformer.half() - self.lm_head.half() - self.post_init() - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs - ): - token_type_ids = kwargs.get("token_type_ids", None) - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - else: - position_ids = None - - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } - ) - return model_inputs - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - labels = labels.to(lm_logits.device) - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) - ) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - - return tuple( - tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past - ) - for layer_past in past_key_values - ) - - def chat( - self, - tokenizer: PreTrainedTokenizer, - query: str, - history: Optional[HistoryType], - system: str = "You are a helpful assistant.", - append_history: bool = True, - stream: Optional[bool] = _SENTINEL, - stop_words_ids: Optional[List[List[int]]] = None, - generation_config: Optional[GenerationConfig] = None, - **kwargs, - ) -> Tuple[str, HistoryType]: - generation_config = generation_config if generation_config is not None else self.generation_config - - assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT - assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT - if history is None: - history = [] - if stop_words_ids is None: - stop_words_ids = [] - - max_window_size = kwargs.get('max_window_size', None) - if max_window_size is None: - max_window_size = generation_config.max_window_size - raw_text, context_tokens = make_context( - tokenizer, - query, - history=history, - system=system, - max_window_size=max_window_size, - chat_format=generation_config.chat_format, - ) - - stop_words_ids.extend(get_stop_words_ids( - generation_config.chat_format, tokenizer - )) - input_ids = torch.tensor([context_tokens]).to(self.device) - outputs = self.generate( - input_ids, - stop_words_ids=stop_words_ids, - return_dict_in_generate=False, - generation_config=generation_config, - **kwargs, - ) - - response = decode_tokens( - outputs[0], - tokenizer, - raw_text_len=len(raw_text), - context_length=len(context_tokens), - chat_format=generation_config.chat_format, - verbose=False, - errors='replace' - ) - - if append_history: - history.append((query, response)) - - return response, history - - def chat_stream( - self, - tokenizer: PreTrainedTokenizer, - query: str, - history: Optional[HistoryType], - system: str = "You are a helpful assistant.", - stop_words_ids: Optional[List[List[int]]] = None, - logits_processor: Optional[LogitsProcessorList] = None, - generation_config: Optional[GenerationConfig] = None, - **kwargs, - ) -> Generator[str, Any, None]: - generation_config = generation_config if generation_config is not None else self.generation_config - assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT - if history is None: - history = [] - if stop_words_ids is None: - stop_words_ids = [] - - max_window_size = kwargs.get('max_window_size', None) - if max_window_size is None: - max_window_size = generation_config.max_window_size - raw_text, context_tokens = make_context( - tokenizer, - query, - history=history, - system=system, - max_window_size=max_window_size, - chat_format=generation_config.chat_format, - ) - - stop_words_ids.extend(get_stop_words_ids( - generation_config.chat_format, tokenizer - )) - if stop_words_ids is not None: - stop_words_logits_processor = StopWordsLogitsProcessor( - stop_words_ids=stop_words_ids, - eos_token_id=generation_config.eos_token_id, - ) - if logits_processor is None: - logits_processor = LogitsProcessorList([stop_words_logits_processor]) - else: - logits_processor.append(stop_words_logits_processor) - input_ids = torch.tensor([context_tokens]).to(self.device) - - from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig - self.__class__.generate_stream = NewGenerationMixin.generate - self.__class__.sample_stream = NewGenerationMixin.sample_stream - stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) - - def stream_generator(): - outputs = [] - for token in self.generate_stream( - input_ids, - return_dict_in_generate=False, - generation_config=stream_config, - logits_processor=logits_processor, - seed=-1, - **kwargs): - outputs.append(token.item()) - yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore') - - return stream_generator() - - def generate( - self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[ - Callable[[int, torch.Tensor], List[int]] - ] = None, - synced_gpus: Optional[bool] = None, - assistant_model: Optional["PreTrainedModel"] = None, - streamer: Optional["BaseStreamer"] = None, - **kwargs, - ) -> Union[GenerateOutput, torch.LongTensor]: - generation_config = generation_config if generation_config is not None else self.generation_config - - # Process stop_words_ids. - stop_words_ids = kwargs.pop("stop_words_ids", None) - if stop_words_ids is None and generation_config is not None: - stop_words_ids = getattr(generation_config, "stop_words_ids", None) - if stop_words_ids is None: - stop_words_ids = getattr(generation_config, "stop_words_ids", None) - - if stop_words_ids is not None: - stop_words_logits_processor = StopWordsLogitsProcessor( - stop_words_ids=stop_words_ids, - eos_token_id=generation_config.eos_token_id, - ) - if logits_processor is None: - logits_processor = LogitsProcessorList([stop_words_logits_processor]) - else: - logits_processor.append(stop_words_logits_processor) - - return super().generate( - inputs, - generation_config=generation_config, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - synced_gpus=synced_gpus, - assistant_model=assistant_model, - streamer=streamer, - **kwargs, - ) - - -class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, base=10000): - super().__init__() - self.dim = dim - self.base = base - self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) - if importlib.util.find_spec("einops") is None: - raise RuntimeError("einops is required for Rotary Embedding") - - self._rotary_pos_emb_cache = None - self._seq_len_cached = 0 - self._ntk_alpha_cached = 1.0 - - def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0): - seqlen = max_seq_len + offset - if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: - base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) - self.inv_freq = 1.0 / ( - base - ** ( - torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() - / self.dim - ) - ) - self._seq_len_cached = max(2 * seqlen, 16) - self._ntk_alpha_cached = ntk_alpha - seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device) - freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - from einops import rearrange - - self._rotary_pos_emb_cache = rearrange(emb, "n d -> 1 n 1 d") - - def forward(self, max_seq_len, offset=0, ntk_alpha=1.0): - self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha) - return self._rotary_pos_emb_cache[:, offset : offset + max_seq_len] - - -def _rotate_half(x): - from einops import rearrange - - x = rearrange(x, "... (j d) -> ... j d", j=2) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(t, freqs): - if apply_rotary_emb_func is not None and t.is_cuda: - t_ = t.float() - freqs = freqs.squeeze(0).squeeze(1) - cos = freqs[:, : freqs.shape[-1] // 2].cos() - sin = freqs[:, : freqs.shape[-1] // 2].sin() - output = apply_rotary_emb_func(t_, cos, sin).type_as(t) - return output - else: - rot_dim = freqs.shape[-1] - t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] - t_ = t_.float() - t_pass_ = t_pass_.float() - t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin()) - return torch.cat((t_, t_pass_), dim=-1).type_as(t) - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - if rms_norm is not None and x.is_cuda: - return rms_norm(x, self.weight, self.eps) - else: - output = self._norm(x.float()).type_as(x) - return output * self.weight diff --git a/transformers/llm/export/llm_models/Qwen-VL-Chat/modeling_qwen.py b/transformers/llm/export/llm_models/Qwen-VL-Chat/modeling_qwen.py deleted file mode 100755 index d7b3c4798..000000000 --- a/transformers/llm/export/llm_models/Qwen-VL-Chat/modeling_qwen.py +++ /dev/null @@ -1,1162 +0,0 @@ -# Copyright (c) Alibaba Cloud. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import importlib -import math -from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch.cuda.amp import autocast - -from torch.nn import CrossEntropyLoss -from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList -from transformers.generation.logits_process import LogitsProcessorList - -if TYPE_CHECKING: - from transformers.generation.streamers import BaseStreamer -from transformers.generation.utils import GenerateOutput -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging - -try: - from einops import rearrange -except ImportError: - rearrange = None -from torch import nn - -SUPPORT_CUDA = torch.cuda.is_available() -SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() -SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 - -from .configuration_qwen import QWenConfig -from .qwen_generation_utils import ( - HistoryType, - make_context, - decode_tokens, - get_stop_words_ids, - StopWordsLogitsProcessor, -) -from .visual import VisionTransformer - - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "qwen" -_CONFIG_FOR_DOC = "QWenConfig" - -QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"] - -_ERROR_BAD_CHAT_FORMAT = """\ -We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml". -If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat(). -我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。 -如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。 -""" - -_SENTINEL = object() -_ERROR_STREAM_IN_CHAT = """\ -Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True). -向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。 -""" - -apply_rotary_emb_func = None -rms_norm = None - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class QWenAttention(nn.Module): - def __init__(self, config): - super().__init__() - - self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) - self.seq_length = config.seq_length - - self.hidden_size = config.hidden_size - self.split_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - - self.scale_attn_weights = True - - self.projection_size = config.kv_channels * config.num_attention_heads - - assert self.projection_size % config.num_attention_heads == 0 - self.hidden_size_per_attention_head = ( - self.projection_size // config.num_attention_heads - ) - - self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size) - - self.c_proj = nn.Linear( - config.hidden_size, self.projection_size, bias=not config.no_bias - ) - - self.is_fp32 = not (config.bf16 or config.fp16) - self.bf16 = config.bf16 - - self.use_dynamic_ntk = config.use_dynamic_ntk - self.use_logn_attn = config.use_logn_attn - - logn_list = [ - math.log(i, self.seq_length) if i > self.seq_length else 1 - for i in range(1, 32768) - ] - self.logn_tensor = torch.tensor(logn_list)[None, :, None, None] - - self.attn_dropout = nn.Dropout(config.attn_dropout_prob) - - def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None): - attn_weights = torch.matmul(query, key.transpose(-1, -2)) - - if self.scale_attn_weights: - attn_weights = attn_weights / math.sqrt(self.head_dim) - - # causal_mask = self.bias[ - # :, :, key_length - query_length : key_length, :key_length - # ] - # mask_value = torch.finfo(attn_weights.dtype).min - # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to( - # attn_weights.device - # ) - # attn_weights = torch.where( - # causal_mask, attn_weights.to(attn_weights.dtype), mask_value - # ) - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - attn_output = attn_output.transpose(1, 2) - - return attn_output, attn_weights - - def _upcast_and_reordered_attn( - self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None - ): - bsz, num_heads, q_seq_len, dk = query.size() - _, _, k_seq_len, _ = key.size() - - attn_weights = torch.empty( - bsz * num_heads, - q_seq_len, - k_seq_len, - dtype=torch.float32, - device=query.device, - ) - - scale_factor = 1.0 - if self.scale_attn_weights: - scale_factor /= float(value.size(-1)) ** 0.5 - - with autocast(enabled=False): - q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape( - -1, dk, k_seq_len - ) - attn_weights = torch.baddbmm( - attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor - ) - attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) - - query_length, key_length = query.size(-2), key.size(-2) - causal_mask = registered_causal_mask[ - :, :, key_length - query_length : key_length, :key_length - ] - mask_value = torch.finfo(attn_weights.dtype).min - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to( - attn_weights.device - ) - attn_weights = torch.where(causal_mask, attn_weights, mask_value) - - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if attn_weights.dtype != torch.float32: - raise RuntimeError( - "Error with upcasting, attn_weights does not have dtype torch.float32" - ) - attn_weights = attn_weights.type(value.dtype) - attn_weights = self.attn_dropout(attn_weights) - - if head_mask is not None: - attn_weights = attn_weights * head_mask - - attn_output = torch.matmul(attn_weights, value) - - return attn_output, attn_weights - - def _split_heads(self, tensor, num_heads, attn_head_size): - new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) - tensor = tensor.view(new_shape) - return tensor - - def _merge_heads(self, tensor, num_heads, attn_head_size): - tensor = tensor.contiguous() - new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) - return tensor.view(new_shape) - - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - rotary_pos_emb: Optional[List[torch.Tensor]] = None, - registered_causal_mask: Optional[torch.Tensor] = None, - layer_past: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ): - - mixed_x_layer = self.c_attn(hidden_states) - - query, key, value = mixed_x_layer.split(self.split_size, dim=2) - - query = self._split_heads(query, self.num_heads, self.head_dim) - key = self._split_heads(key, self.num_heads, self.head_dim) - value = self._split_heads(value, self.num_heads, self.head_dim) - - if rotary_pos_emb is not None: - ''' - cur_len = query.shape[1] - rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] - rotary_pos_emb = (rotary_pos_emb,) * 2 - q_pos_emb, k_pos_emb = rotary_pos_emb - # Slice the pos emb for current inference - print('len(q_pos_emb) = ', len(q_pos_emb)) # 2 - print('q_pos_emb[0].shape = ', q_pos_emb[0].shape) # 1, 20, 1, 128 - query = apply_rotary_pos_emb(query, q_pos_emb) - key = apply_rotary_pos_emb(key, k_pos_emb) - ''' - query = apply_rotary_pos_emb(query, rotary_pos_emb) - key = apply_rotary_pos_emb(key, rotary_pos_emb) - - if layer_past is not None: - past_key, past_value = layer_past[0], layer_past[1] - key = torch.cat((past_key, key), dim=1) - value = torch.cat((past_value, value), dim=1) - - if use_cache: - present = torch.stack([key, value]) - else: - present = None - - if self.use_logn_attn and not self.training and False: - if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype: - self.logn_tensor = self.logn_tensor.to(query.device).type_as(query) - seq_start = key.size(1) - query.size(1) - seq_end = key.size(1) - logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :] - query = query * logn_tensor.expand_as(query) - - query = query.permute(0, 2, 1, 3) - key = key.permute(0, 2, 1, 3) - value = value.permute(0, 2, 1, 3) - attn_output, attn_weight = self._attn( - query, key, value, registered_causal_mask, attention_mask, head_mask - ) - context_layer = self._merge_heads( - attn_output, self.num_heads, self.head_dim - ) - - attn_output = self.c_proj(context_layer) - - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weight,) - - return outputs - - -class QWenMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.w1 = nn.Linear( - config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias - ) - self.w2 = nn.Linear( - config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias - ) - ff_dim_in = config.intermediate_size // 2 - self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias) - - def forward(self, hidden_states): - a1 = self.w1(hidden_states) - a2 = self.w2(hidden_states) - intermediate_parallel = a1 * F.silu(a2) - output = self.c_proj(intermediate_parallel) - return output - -class QWenBlock(nn.Module): - def __init__(self, config): - super().__init__() - hidden_size = config.hidden_size - self.bf16 = config.bf16 - - self.ln_1 = RMSNorm( - hidden_size, - eps=config.layer_norm_epsilon, - ) - self.attn = QWenAttention(config) - self.ln_2 = RMSNorm( - hidden_size, - eps=config.layer_norm_epsilon, - ) - - self.mlp = QWenMLP(config) - - def forward( - self, - hidden_states: Optional[Tuple[torch.FloatTensor]], - rotary_pos_emb: Optional[List[torch.Tensor]] = None, - registered_causal_mask: Optional[torch.Tensor] = None, - layer_past: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, - output_attentions: Optional[bool] = False, - ): - layernorm_output = self.ln_1(hidden_states) - - attn_outputs = self.attn( - layernorm_output, - rotary_pos_emb, - registered_causal_mask=registered_causal_mask, - layer_past=layer_past, - attention_mask=attention_mask, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - attn_output = attn_outputs[0] - - outputs = attn_outputs[1:] - - residual = hidden_states - layernorm_input = attn_output + residual - - layernorm_output = self.ln_2(layernorm_input) - - residual = layernorm_input - mlp_output = self.mlp(layernorm_output) - hidden_states = residual + mlp_output - - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] - - return outputs - - -class QWenPreTrainedModel(PreTrainedModel): - config_class = QWenConfig - base_model_prefix = "transformer" - is_parallelizable = False - supports_gradient_checkpointing = True - _no_split_modules = ["QWenBlock"] - - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - def _init_weights(self, module): - """Initialize the weights.""" - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, RMSNorm): - module.weight.data.fill_(1.0) - - for name, p in module.named_parameters(): - if name == "c_proj.weight": - p.data.normal_( - mean=0.0, - std=( - self.config.initializer_range - / math.sqrt(2 * self.config.num_hidden_layers) - ), - ) - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, QWenModel): - module.gradient_checkpointing = value - - -class QWenModel(QWenPreTrainedModel): - _keys_to_ignore_on_load_missing = ["attn.masked_bias"] - - def __init__(self, config): - super().__init__(config) - self.vocab_size = config.vocab_size - self.num_hidden_layers = config.num_hidden_layers - self.embed_dim = config.hidden_size - - self.gradient_checkpointing = False - self.use_dynamic_ntk = config.use_dynamic_ntk - self.seq_length = config.seq_length - - self.wte = nn.Embedding(self.vocab_size, self.embed_dim) - - self.drop = nn.Dropout(config.emb_dropout_prob) - - if config.rotary_pct == 1.0: - self.rotary_ndims = None - else: - assert config.rotary_pct < 1 - self.rotary_ndims = int( - config.kv_channels * config.rotary_pct - ) - dim = ( - self.rotary_ndims - if self.rotary_ndims is not None - else config.kv_channels - ) - self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) - - self.use_flash_attn = config.use_flash_attn - self.is_fp32 = not (config.bf16 or config.fp16) - self.registered_causal_mask = None - # if ( - # self.use_flash_attn - # and flash_attn_unpadded_func is not None - # and not self.is_fp32 - # ): - # self.registered_causal_mask = None - # else: - # max_positions = config.max_position_embeddings - # self.register_buffer( - # "registered_causal_mask", - # torch.tril( - # torch.ones((max_positions, max_positions), dtype=torch.bool) - # ).view(1, 1, max_positions, max_positions), - # persistent=False, - # ) - - self.h = nn.ModuleList( - [ - QWenBlock( - config - ) - for i in range(config.num_hidden_layers) - ] - ) - self.ln_f = RMSNorm( - self.embed_dim, - eps=config.layer_norm_epsilon, - ) - - self.visual = VisionTransformer(**config.visual) - - self.post_init() - - def get_input_embeddings(self): - return self.wte - - def set_input_embeddings(self, new_embeddings): - self.wte = new_embeddings - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']): - bos_pos = torch.where(input_ids == self.config.visual['image_start_id']) - eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1) - assert (bos_pos[0] == eos_pos[0]).all() - img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1) - images = [] - for i, a, b in img_pos: - image = input_ids[i][a + 1 : b - 1].tolist() - image = image[ : image.index(self.config.visual['image_start_id'] + 2)] - images.append(bytes(image).decode('utf-8')) - - images = self.visual.encode(images) - assert images.shape[0] == len(images) - fake_images = None - elif self.training: - fake_images=torch.zeros(1,3,224,224).to( - dtype=self.visual.conv1.weight.dtype, device=self.visual.conv1.weight.device) - images = self.visual(fake_images) - else: - fake_images = None - images = None - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time" - ) - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if position_ids is not None: - position_ids = position_ids.view(-1, input_shape[-1]) - - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - - if position_ids is None: - position_ids = torch.arange( - past_length, - input_shape[-1] + past_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) - - encoder_attention_mask = None - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - - if batch_size <= 0: - raise ValueError("batch_size has to be defined and > 0") - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, input_shape, inputs_embeds, past_length - ) - - hidden_states = inputs_embeds - - kv_seq_len = hidden_states.size()[1] - if past_key_values[0] is not None: - # past key values[0][0] shape: bs * seq_len * head_num * dim - kv_seq_len += past_key_values[0][0].shape[1] - if ( - self.use_dynamic_ntk - and kv_seq_len == hidden_states.size()[1] - and not self.training - ): - context_value = math.log(kv_seq_len / self.seq_length, 2) + 1 - ntk_alpha = 2 ** math.ceil(context_value) - 1 - ntk_alpha = max(ntk_alpha, 1) - else: - ntk_alpha = self.rotary_emb._ntk_alpha_cached - - rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) - for idx in range(len(rotary_pos_emb)): - rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device) - - hidden_states = self.drop(hidden_states).clone() - if fake_images is not None: - hidden_states = hidden_states + images.mean()*0 - elif images is not None: - for idx, (i, a, b) in enumerate(img_pos): - hidden_states[i][a + 1 : b] = images[idx] - output_shape = input_shape + (hidden_states.size(-1),) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, use_cache, output_attentions) - - return custom_forward - - outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - rotary_pos_emb, - self.registered_causal_mask, - None, - attention_mask, - head_mask[i], - encoder_hidden_states, - encoder_attention_mask, - ) - else: - outputs = block( - hidden_states, - layer_past=layer_past, - rotary_pos_emb=rotary_pos_emb, - registered_causal_mask=self.registered_causal_mask, - attention_mask=attention_mask, - head_mask=head_mask[i], - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - - hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - - hidden_states = self.ln_f(hidden_states) - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v for v in [hidden_states, presents, all_hidden_states] if v is not None - ) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -class QWenLMHeadModel(QWenPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"] - _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"] - - def __init__(self, config): - super().__init__(config) - assert ( - config.bf16 + config.fp16 + config.fp32 <= 1 - ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true" - - autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 - - if autoset_precision: - if SUPPORT_BF16: - logger.warn( - "The model is automatically converting to bf16 for faster inference. " - "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." - ) - config.bf16 = True - elif SUPPORT_FP16: - logger.warn( - "The model is automatically converting to fp16 for faster inference. " - "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." - ) - config.fp16 = True - else: - config.fp32 = True - - if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16: - logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".") - if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16: - logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster") - if config.fp32: - if SUPPORT_BF16: - logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".") - elif SUPPORT_FP16: - logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".") - - self.transformer = QWenModel(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - if config.bf16: - self.transformer.bfloat16() - self.lm_head.bfloat16() - if config.fp16: - self.transformer.half() - self.lm_head.half() - self.post_init() - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs - ): - token_type_ids = kwargs.get("token_type_ids", None) - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - else: - position_ids = None - - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - } - ) - return model_inputs - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - transformer_outputs = self.transformer( - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - labels = labels.to(lm_logits.device) - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - loss_fct = CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) - ) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - - return tuple( - tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past - ) - for layer_past in past_key_values - ) - - def chat( - self, - tokenizer: PreTrainedTokenizer, - query: str, - history: Optional[HistoryType], - system: str = "You are a helpful assistant.", - append_history: bool = True, - stream: Optional[bool] = _SENTINEL, - stop_words_ids: Optional[List[List[int]]] = None, - generation_config: Optional[GenerationConfig] = None, - **kwargs, - ) -> Tuple[str, HistoryType]: - generation_config = generation_config if generation_config is not None else self.generation_config - - assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT - assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT - if history is None: - history = [] - if stop_words_ids is None: - stop_words_ids = [] - - max_window_size = kwargs.get('max_window_size', None) - if max_window_size is None: - max_window_size = generation_config.max_window_size - raw_text, context_tokens = make_context( - tokenizer, - query, - history=history, - system=system, - max_window_size=max_window_size, - chat_format=generation_config.chat_format, - ) - - stop_words_ids.extend(get_stop_words_ids( - generation_config.chat_format, tokenizer - )) - input_ids = torch.tensor([context_tokens]).to(self.device) - outputs = self.generate( - input_ids, - stop_words_ids=stop_words_ids, - return_dict_in_generate=False, - generation_config=generation_config, - **kwargs, - ) - - response = decode_tokens( - outputs[0], - tokenizer, - raw_text_len=len(raw_text), - context_length=len(context_tokens), - chat_format=generation_config.chat_format, - verbose=False, - errors='replace' - ) - - if append_history: - history.append((query, response)) - - return response, history - - def chat_stream( - self, - tokenizer: PreTrainedTokenizer, - query: str, - history: Optional[HistoryType], - system: str = "You are a helpful assistant.", - stop_words_ids: Optional[List[List[int]]] = None, - logits_processor: Optional[LogitsProcessorList] = None, - generation_config: Optional[GenerationConfig] = None, - **kwargs, - ) -> Generator[str, Any, None]: - generation_config = generation_config if generation_config is not None else self.generation_config - assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT - if history is None: - history = [] - if stop_words_ids is None: - stop_words_ids = [] - - max_window_size = kwargs.get('max_window_size', None) - if max_window_size is None: - max_window_size = generation_config.max_window_size - raw_text, context_tokens = make_context( - tokenizer, - query, - history=history, - system=system, - max_window_size=max_window_size, - chat_format=generation_config.chat_format, - ) - - stop_words_ids.extend(get_stop_words_ids( - generation_config.chat_format, tokenizer - )) - if stop_words_ids is not None: - stop_words_logits_processor = StopWordsLogitsProcessor( - stop_words_ids=stop_words_ids, - eos_token_id=generation_config.eos_token_id, - ) - if logits_processor is None: - logits_processor = LogitsProcessorList([stop_words_logits_processor]) - else: - logits_processor.append(stop_words_logits_processor) - input_ids = torch.tensor([context_tokens]).to(self.device) - - from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig - self.__class__.generate_stream = NewGenerationMixin.generate - self.__class__.sample_stream = NewGenerationMixin.sample_stream - stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) - - def stream_generator(): - outputs = [] - for token in self.generate_stream( - input_ids, - return_dict_in_generate=False, - generation_config=stream_config, - logits_processor=logits_processor, - seed=-1, - **kwargs): - outputs.append(token.item()) - yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore', keep_image_special=True) - - return stream_generator() - - def generate( - self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[ - Callable[[int, torch.Tensor], List[int]] - ] = None, - synced_gpus: Optional[bool] = None, - assistant_model: Optional["PreTrainedModel"] = None, - streamer: Optional["BaseStreamer"] = None, - **kwargs, - ) -> Union[GenerateOutput, torch.LongTensor]: - generation_config = generation_config if generation_config is not None else self.generation_config - - # Process stop_words_ids. - stop_words_ids = kwargs.pop("stop_words_ids", None) - if stop_words_ids is None and generation_config is not None: - stop_words_ids = getattr(generation_config, "stop_words_ids", None) - if stop_words_ids is None: - stop_words_ids = getattr(generation_config, "stop_words_ids", None) - - if stop_words_ids is not None: - stop_words_logits_processor = StopWordsLogitsProcessor( - stop_words_ids=stop_words_ids, - eos_token_id=generation_config.eos_token_id, - ) - if logits_processor is None: - logits_processor = LogitsProcessorList([stop_words_logits_processor]) - else: - logits_processor.append(stop_words_logits_processor) - - return super().generate( - inputs, - generation_config=generation_config, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - synced_gpus=synced_gpus, - assistant_model=assistant_model, - streamer=streamer, - **kwargs, - ) - - -class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, base=10000): - super().__init__() - self.dim = dim - self.base = base - self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) - if importlib.util.find_spec("einops") is None: - raise RuntimeError("einops is required for Rotary Embedding") - - self._rotary_pos_emb_cache = None - self._seq_len_cached = 0 - self._ntk_alpha_cached = 1.0 - - def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0): - seqlen = max_seq_len + offset - if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: - base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) - self.inv_freq = 1.0 / ( - base - ** ( - torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() - / self.dim - ) - ) - self._seq_len_cached = max(2 * seqlen, 16) - self._ntk_alpha_cached = ntk_alpha - seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device) - freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) - - emb = torch.cat((freqs, freqs), dim=-1) - from einops import rearrange - - emb = rearrange(emb, "n d -> 1 n 1 d") - - cos, sin = emb.cos(), emb.sin() - self._rotary_pos_emb_cache = [cos, sin] - - def forward(self, max_seq_len, offset=0, ntk_alpha=1.0): - self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha) - cos, sin = self._rotary_pos_emb_cache - return [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]] - - -def _rotate_half(x): - from einops import rearrange - - x = rearrange(x, "... (j d) -> ... j d", j=2) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(t, freqs): - cos, sin = freqs - if apply_rotary_emb_func is not None and t.is_cuda: - t_ = t.float() - cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2] - sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2] - output = apply_rotary_emb_func(t_, cos, sin).type_as(t) - return output - else: - rot_dim = freqs[0].shape[-1] - cos, sin = freqs - t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] - t_ = t_.float() - t_pass_ = t_pass_.float() - t_ = (t_ * cos) + (_rotate_half(t_) * sin) - return torch.cat((t_, t_pass_), dim=-1).type_as(t) - - -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - if rms_norm is not None and x.is_cuda: - return rms_norm(x, self.weight, self.eps) - else: - output = self._norm(x.float()).type_as(x) - return output * self.weight diff --git a/transformers/llm/export/llm_models/Qwen1_5-0_5B-Chat/config.json b/transformers/llm/export/llm_models/Qwen1_5-0_5B-Chat/config.json deleted file mode 100755 index ea93bc66a..000000000 --- a/transformers/llm/export/llm_models/Qwen1_5-0_5B-Chat/config.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "architectures": [ - "Qwen2ForCausalLM" - ], - "auto_map": { - "AutoConfig": "configuration_qwen2.Qwen2Config", - "AutoModelForCausalLM": "modeling_qwen2.Qwen2ForCausalLM" - }, - "attention_dropout": 0.0, - "bos_token_id": 151643, - "eos_token_id": 151645, - "hidden_act": "silu", - "hidden_size": 1024, - "initializer_range": 0.02, - "intermediate_size": 2816, - "max_position_embeddings": 32768, - "max_window_layers": 21, - "model_type": "qwen2", - "num_attention_heads": 16, - "num_hidden_layers": 24, - "num_key_value_heads": 16, - "rms_norm_eps": 1e-06, - "rope_theta": 1000000.0, - "sliding_window": 32768, - "tie_word_embeddings": true, - "torch_dtype": "bfloat16", - "transformers_version": "4.37.0", - "use_cache": true, - "use_sliding_window": false, - "vocab_size": 151936 -} - diff --git a/transformers/llm/export/llm_models/Qwen1_5-0_5B-Chat/configuration_qwen2.py b/transformers/llm/export/llm_models/Qwen1_5-0_5B-Chat/configuration_qwen2.py deleted file mode 100644 index b6ca1ed43..000000000 --- a/transformers/llm/export/llm_models/Qwen1_5-0_5B-Chat/configuration_qwen2.py +++ /dev/null @@ -1,144 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Qwen2 model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json", -} - - -class Qwen2Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a - Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of - Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 151936): - Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Qwen2Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 22016): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 32): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 32768): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - use_sliding_window (`bool`, *optional*, defaults to `False`): - Whether to use sliding window attention. - sliding_window (`int`, *optional*, defaults to 4096): - Sliding window attention (SWA) window size. If not specified, will default to `4096`. - max_window_layers (`int`, *optional*, defaults to 28): - The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - - ```python - >>> from transformers import Qwen2Model, Qwen2Config - - >>> # Initializing a Qwen2 style configuration - >>> configuration = Qwen2Config() - - >>> # Initializing a model from the Qwen2-7B style configuration - >>> model = Qwen2Model(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "qwen2" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=151936, - hidden_size=4096, - intermediate_size=22016, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=32, - hidden_act="silu", - max_position_embeddings=32768, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - tie_word_embeddings=False, - rope_theta=10000.0, - use_sliding_window=False, - sliding_window=4096, - max_window_layers=28, - attention_dropout=0.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.max_window_layers = max_window_layers - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - - super().__init__( - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/transformers/llm/export/llm_models/Qwen1_5-0_5B-Chat/modeling_qwen2.py b/transformers/llm/export/llm_models/Qwen1_5-0_5B-Chat/modeling_qwen2.py deleted file mode 100644 index 595a3e91c..000000000 --- a/transformers/llm/export/llm_models/Qwen1_5-0_5B-Chat/modeling_qwen2.py +++ /dev/null @@ -1,1436 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch Qwen2 model.""" -import inspect -import math -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) -from .configuration_qwen2 import Qwen2Config - - -# if is_flash_attn_2_available(): - #from flash_attn import flash_attn_func, flash_attn_varlen_func - #from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - #_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) - - -logger = logging.get_logger(__name__) - - -_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta" -_CONFIG_FOR_DOC = "Qwen2Config" - -QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "Qwen/Qwen2-7B-beta", - # See all Qwen2 models at https://huggingface.co/models?filter=qwen2 -] - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 -class Qwen2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class Qwen2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - ''' - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - kv_seq_len += past_key_value[0].shape[2] - if rotary_pos_emb is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - else: - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=2) - value_states = torch.cat((past_value, value_states), dim=2) - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - past_key_value = torch.stack((key_states, value_states)) - # repeat k/v heads if n_kv_heads < n_heads - # key_states = repeat_kv(key_states, self.num_key_value_groups) - # value_states = repeat_kv(value_states, self.num_key_value_groups) - ''' - #--------------- - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - kv_seq_len = key_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[1] - # rope - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - # kv cache - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=1) - value_states = torch.cat((past_value, value_states), dim=1) - past_key_value = torch.stack((key_states, value_states)) - query_states = query_states.transpose(1, 2) - key_states = key_states.permute([0, 2, 3, 1]) - value_states = value_states.transpose(1, 2) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - #--------------- - attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2FlashAttention2(Qwen2Attention): - """ - Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and self.config.use_sliding_window - ) - - if not _flash_supports_window_size: - logger.warning_once( - "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" - " make sure to upgrade flash-attn library." - ) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def _flash_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - use_sliding_windows=False, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - use_sliding_windows (`bool`, *optional*): - Whether to activate sliding window attention. - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Decide whether to use SWA or not by layer index. - if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: - use_sliding_windows = False - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - if not use_sliding_windows: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - if not use_sliding_windows: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - return attn_output - - # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape - - # On the first iteration we need to properly re-create the padding mask - # by slicing it on the proper place - if kv_seq_len != attention_mask.shape[-1]: - attention_mask_num_tokens = attention_mask.shape[-1] - attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] - - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2 -class Qwen2SdpaAttention(Qwen2Attention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2_ATTENTION_CLASSES = { - "eager": Qwen2Attention, - "flash_attention_2": Qwen2FlashAttention2, - "sdpa": Qwen2SdpaAttention, -} - - -class Qwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) - # self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - self.self_attn = Qwen2Attention(config, layer_idx) - - self.mlp = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. " - "Please make sure use `attention_mask` instead.`" - ) - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -QWEN2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Qwen2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2PreTrainedModel(PreTrainedModel): - config_class = Qwen2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2DecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -QWEN2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance; - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2Model(Qwen2PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class Qwen2ForCausalLM(Qwen2PreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = Qwen2Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen2ForCausalLM - - >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - # Omit tokens covered by past_key_values - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The Qwen2 Model transformer with a sequence classification head on top (linear layer). - - [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - QWEN2_START_DOCSTRING, -) -class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Qwen2Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/Qwen1_5-1_8B-Chat/config.json b/transformers/llm/export/llm_models/Qwen1_5-1_8B-Chat/config.json deleted file mode 100755 index 26ce493f6..000000000 --- a/transformers/llm/export/llm_models/Qwen1_5-1_8B-Chat/config.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "architectures": [ - "Qwen2ForCausalLM" - ], - "auto_map": { - "AutoConfig": "configuration_qwen2.Qwen2Config", - "AutoModelForCausalLM": "modeling_qwen2.Qwen2ForCausalLM" - }, - "attention_dropout": 0.0, - "bos_token_id": 151643, - "eos_token_id": 151645, - "hidden_act": "silu", - "hidden_size": 2048, - "initializer_range": 0.02, - "intermediate_size": 5504, - "max_position_embeddings": 32768, - "max_window_layers": 21, - "model_type": "qwen2", - "num_attention_heads": 16, - "num_hidden_layers": 24, - "num_key_value_heads": 16, - "rms_norm_eps": 1e-06, - "rope_theta": 1000000.0, - "sliding_window": 32768, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.37.0", - "use_cache": true, - "use_sliding_window": false, - "vocab_size": 151936 -} diff --git a/transformers/llm/export/llm_models/Qwen1_5-1_8B-Chat/configuration_qwen2.py b/transformers/llm/export/llm_models/Qwen1_5-1_8B-Chat/configuration_qwen2.py deleted file mode 100644 index b6ca1ed43..000000000 --- a/transformers/llm/export/llm_models/Qwen1_5-1_8B-Chat/configuration_qwen2.py +++ /dev/null @@ -1,144 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Qwen2 model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json", -} - - -class Qwen2Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a - Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of - Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 151936): - Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Qwen2Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 22016): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 32): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 32768): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - use_sliding_window (`bool`, *optional*, defaults to `False`): - Whether to use sliding window attention. - sliding_window (`int`, *optional*, defaults to 4096): - Sliding window attention (SWA) window size. If not specified, will default to `4096`. - max_window_layers (`int`, *optional*, defaults to 28): - The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - - ```python - >>> from transformers import Qwen2Model, Qwen2Config - - >>> # Initializing a Qwen2 style configuration - >>> configuration = Qwen2Config() - - >>> # Initializing a model from the Qwen2-7B style configuration - >>> model = Qwen2Model(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "qwen2" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=151936, - hidden_size=4096, - intermediate_size=22016, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=32, - hidden_act="silu", - max_position_embeddings=32768, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - tie_word_embeddings=False, - rope_theta=10000.0, - use_sliding_window=False, - sliding_window=4096, - max_window_layers=28, - attention_dropout=0.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.max_window_layers = max_window_layers - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - - super().__init__( - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/transformers/llm/export/llm_models/Qwen1_5-1_8B-Chat/modeling_qwen2.py b/transformers/llm/export/llm_models/Qwen1_5-1_8B-Chat/modeling_qwen2.py deleted file mode 100644 index 595a3e91c..000000000 --- a/transformers/llm/export/llm_models/Qwen1_5-1_8B-Chat/modeling_qwen2.py +++ /dev/null @@ -1,1436 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch Qwen2 model.""" -import inspect -import math -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) -from .configuration_qwen2 import Qwen2Config - - -# if is_flash_attn_2_available(): - #from flash_attn import flash_attn_func, flash_attn_varlen_func - #from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - #_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) - - -logger = logging.get_logger(__name__) - - -_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta" -_CONFIG_FOR_DOC = "Qwen2Config" - -QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "Qwen/Qwen2-7B-beta", - # See all Qwen2 models at https://huggingface.co/models?filter=qwen2 -] - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 -class Qwen2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class Qwen2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - ''' - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - kv_seq_len += past_key_value[0].shape[2] - if rotary_pos_emb is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - else: - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=2) - value_states = torch.cat((past_value, value_states), dim=2) - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - past_key_value = torch.stack((key_states, value_states)) - # repeat k/v heads if n_kv_heads < n_heads - # key_states = repeat_kv(key_states, self.num_key_value_groups) - # value_states = repeat_kv(value_states, self.num_key_value_groups) - ''' - #--------------- - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - kv_seq_len = key_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[1] - # rope - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - # kv cache - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=1) - value_states = torch.cat((past_value, value_states), dim=1) - past_key_value = torch.stack((key_states, value_states)) - query_states = query_states.transpose(1, 2) - key_states = key_states.permute([0, 2, 3, 1]) - value_states = value_states.transpose(1, 2) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - #--------------- - attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2FlashAttention2(Qwen2Attention): - """ - Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and self.config.use_sliding_window - ) - - if not _flash_supports_window_size: - logger.warning_once( - "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" - " make sure to upgrade flash-attn library." - ) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def _flash_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - use_sliding_windows=False, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - use_sliding_windows (`bool`, *optional*): - Whether to activate sliding window attention. - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Decide whether to use SWA or not by layer index. - if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: - use_sliding_windows = False - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - if not use_sliding_windows: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - if not use_sliding_windows: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - return attn_output - - # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape - - # On the first iteration we need to properly re-create the padding mask - # by slicing it on the proper place - if kv_seq_len != attention_mask.shape[-1]: - attention_mask_num_tokens = attention_mask.shape[-1] - attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] - - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2 -class Qwen2SdpaAttention(Qwen2Attention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2_ATTENTION_CLASSES = { - "eager": Qwen2Attention, - "flash_attention_2": Qwen2FlashAttention2, - "sdpa": Qwen2SdpaAttention, -} - - -class Qwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) - # self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - self.self_attn = Qwen2Attention(config, layer_idx) - - self.mlp = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. " - "Please make sure use `attention_mask` instead.`" - ) - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -QWEN2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Qwen2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2PreTrainedModel(PreTrainedModel): - config_class = Qwen2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2DecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -QWEN2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance; - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2Model(Qwen2PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class Qwen2ForCausalLM(Qwen2PreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = Qwen2Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen2ForCausalLM - - >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - # Omit tokens covered by past_key_values - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The Qwen2 Model transformer with a sequence classification head on top (linear layer). - - [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - QWEN2_START_DOCSTRING, -) -class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Qwen2Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/Qwen1_5-4B-Chat/config.json b/transformers/llm/export/llm_models/Qwen1_5-4B-Chat/config.json deleted file mode 100755 index 9f2be4f60..000000000 --- a/transformers/llm/export/llm_models/Qwen1_5-4B-Chat/config.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "architectures": [ - "Qwen2ForCausalLM" - ], - "auto_map": { - "AutoConfig": "configuration_qwen2.Qwen2Config", - "AutoModelForCausalLM": "modeling_qwen2.Qwen2ForCausalLM" - }, - "attention_dropout": 0.0, - "bos_token_id": 151643, - "eos_token_id": 151645, - "hidden_act": "silu", - "hidden_size": 2560, - "initializer_range": 0.02, - "intermediate_size": 6912, - "max_position_embeddings": 32768, - "max_window_layers": 21, - "model_type": "qwen2", - "num_attention_heads": 20, - "num_hidden_layers": 40, - "num_key_value_heads": 20, - "rms_norm_eps": 1e-06, - "rope_theta": 5000000.0, - "sliding_window": 32768, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.37.0", - "use_cache": true, - "use_sliding_window": false, - "vocab_size": 151936 -} diff --git a/transformers/llm/export/llm_models/Qwen1_5-4B-Chat/configuration_qwen2.py b/transformers/llm/export/llm_models/Qwen1_5-4B-Chat/configuration_qwen2.py deleted file mode 100644 index b6ca1ed43..000000000 --- a/transformers/llm/export/llm_models/Qwen1_5-4B-Chat/configuration_qwen2.py +++ /dev/null @@ -1,144 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Qwen2 model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json", -} - - -class Qwen2Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a - Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of - Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 151936): - Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Qwen2Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 22016): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 32): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 32768): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - use_sliding_window (`bool`, *optional*, defaults to `False`): - Whether to use sliding window attention. - sliding_window (`int`, *optional*, defaults to 4096): - Sliding window attention (SWA) window size. If not specified, will default to `4096`. - max_window_layers (`int`, *optional*, defaults to 28): - The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - - ```python - >>> from transformers import Qwen2Model, Qwen2Config - - >>> # Initializing a Qwen2 style configuration - >>> configuration = Qwen2Config() - - >>> # Initializing a model from the Qwen2-7B style configuration - >>> model = Qwen2Model(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "qwen2" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=151936, - hidden_size=4096, - intermediate_size=22016, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=32, - hidden_act="silu", - max_position_embeddings=32768, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - tie_word_embeddings=False, - rope_theta=10000.0, - use_sliding_window=False, - sliding_window=4096, - max_window_layers=28, - attention_dropout=0.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.max_window_layers = max_window_layers - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - - super().__init__( - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/transformers/llm/export/llm_models/Qwen1_5-4B-Chat/modeling_qwen2.py b/transformers/llm/export/llm_models/Qwen1_5-4B-Chat/modeling_qwen2.py deleted file mode 100644 index 595a3e91c..000000000 --- a/transformers/llm/export/llm_models/Qwen1_5-4B-Chat/modeling_qwen2.py +++ /dev/null @@ -1,1436 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch Qwen2 model.""" -import inspect -import math -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) -from .configuration_qwen2 import Qwen2Config - - -# if is_flash_attn_2_available(): - #from flash_attn import flash_attn_func, flash_attn_varlen_func - #from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - #_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) - - -logger = logging.get_logger(__name__) - - -_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta" -_CONFIG_FOR_DOC = "Qwen2Config" - -QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "Qwen/Qwen2-7B-beta", - # See all Qwen2 models at https://huggingface.co/models?filter=qwen2 -] - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 -class Qwen2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class Qwen2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - ''' - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - kv_seq_len += past_key_value[0].shape[2] - if rotary_pos_emb is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - else: - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=2) - value_states = torch.cat((past_value, value_states), dim=2) - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - past_key_value = torch.stack((key_states, value_states)) - # repeat k/v heads if n_kv_heads < n_heads - # key_states = repeat_kv(key_states, self.num_key_value_groups) - # value_states = repeat_kv(value_states, self.num_key_value_groups) - ''' - #--------------- - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - kv_seq_len = key_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[1] - # rope - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - # kv cache - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=1) - value_states = torch.cat((past_value, value_states), dim=1) - past_key_value = torch.stack((key_states, value_states)) - query_states = query_states.transpose(1, 2) - key_states = key_states.permute([0, 2, 3, 1]) - value_states = value_states.transpose(1, 2) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - #--------------- - attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2FlashAttention2(Qwen2Attention): - """ - Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and self.config.use_sliding_window - ) - - if not _flash_supports_window_size: - logger.warning_once( - "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" - " make sure to upgrade flash-attn library." - ) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def _flash_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - use_sliding_windows=False, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - use_sliding_windows (`bool`, *optional*): - Whether to activate sliding window attention. - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Decide whether to use SWA or not by layer index. - if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: - use_sliding_windows = False - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - if not use_sliding_windows: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - if not use_sliding_windows: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - return attn_output - - # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape - - # On the first iteration we need to properly re-create the padding mask - # by slicing it on the proper place - if kv_seq_len != attention_mask.shape[-1]: - attention_mask_num_tokens = attention_mask.shape[-1] - attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] - - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2 -class Qwen2SdpaAttention(Qwen2Attention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2_ATTENTION_CLASSES = { - "eager": Qwen2Attention, - "flash_attention_2": Qwen2FlashAttention2, - "sdpa": Qwen2SdpaAttention, -} - - -class Qwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) - # self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - self.self_attn = Qwen2Attention(config, layer_idx) - - self.mlp = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. " - "Please make sure use `attention_mask` instead.`" - ) - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -QWEN2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Qwen2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2PreTrainedModel(PreTrainedModel): - config_class = Qwen2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2DecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -QWEN2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance; - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2Model(Qwen2PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class Qwen2ForCausalLM(Qwen2PreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = Qwen2Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen2ForCausalLM - - >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - # Omit tokens covered by past_key_values - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The Qwen2 Model transformer with a sequence classification head on top (linear layer). - - [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - QWEN2_START_DOCSTRING, -) -class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Qwen2Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/Qwen1_5-7B-Chat/config.json b/transformers/llm/export/llm_models/Qwen1_5-7B-Chat/config.json deleted file mode 100755 index 6b0cfb9b4..000000000 --- a/transformers/llm/export/llm_models/Qwen1_5-7B-Chat/config.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "architectures": [ - "Qwen2ForCausalLM" - ], - "auto_map": { - "AutoConfig": "configuration_qwen2.Qwen2Config", - "AutoModelForCausalLM": "modeling_qwen2.Qwen2ForCausalLM" - }, - "attention_dropout": 0.0, - "bos_token_id": 151643, - "eos_token_id": 151645, - "hidden_act": "silu", - "hidden_size": 4096, - "initializer_range": 0.02, - "intermediate_size": 11008, - "max_position_embeddings": 32768, - "max_window_layers": 28, - "model_type": "qwen2", - "num_attention_heads": 32, - "num_hidden_layers": 32, - "num_key_value_heads": 32, - "rms_norm_eps": 1e-06, - "rope_theta": 1000000.0, - "sliding_window": 32768, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.37.0", - "use_cache": true, - "use_sliding_window": false, - "vocab_size": 151936 -} diff --git a/transformers/llm/export/llm_models/Qwen1_5-7B-Chat/configuration_qwen2.py b/transformers/llm/export/llm_models/Qwen1_5-7B-Chat/configuration_qwen2.py deleted file mode 100644 index b6ca1ed43..000000000 --- a/transformers/llm/export/llm_models/Qwen1_5-7B-Chat/configuration_qwen2.py +++ /dev/null @@ -1,144 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Qwen2 model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json", -} - - -class Qwen2Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a - Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of - Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 151936): - Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Qwen2Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 22016): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 32): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 32768): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - use_sliding_window (`bool`, *optional*, defaults to `False`): - Whether to use sliding window attention. - sliding_window (`int`, *optional*, defaults to 4096): - Sliding window attention (SWA) window size. If not specified, will default to `4096`. - max_window_layers (`int`, *optional*, defaults to 28): - The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - - ```python - >>> from transformers import Qwen2Model, Qwen2Config - - >>> # Initializing a Qwen2 style configuration - >>> configuration = Qwen2Config() - - >>> # Initializing a model from the Qwen2-7B style configuration - >>> model = Qwen2Model(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "qwen2" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=151936, - hidden_size=4096, - intermediate_size=22016, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=32, - hidden_act="silu", - max_position_embeddings=32768, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - tie_word_embeddings=False, - rope_theta=10000.0, - use_sliding_window=False, - sliding_window=4096, - max_window_layers=28, - attention_dropout=0.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.max_window_layers = max_window_layers - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - - super().__init__( - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/transformers/llm/export/llm_models/Qwen1_5-7B-Chat/modeling_qwen2.py b/transformers/llm/export/llm_models/Qwen1_5-7B-Chat/modeling_qwen2.py deleted file mode 100644 index 595a3e91c..000000000 --- a/transformers/llm/export/llm_models/Qwen1_5-7B-Chat/modeling_qwen2.py +++ /dev/null @@ -1,1436 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch Qwen2 model.""" -import inspect -import math -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) -from .configuration_qwen2 import Qwen2Config - - -# if is_flash_attn_2_available(): - #from flash_attn import flash_attn_func, flash_attn_varlen_func - #from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - #_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) - - -logger = logging.get_logger(__name__) - - -_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta" -_CONFIG_FOR_DOC = "Qwen2Config" - -QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "Qwen/Qwen2-7B-beta", - # See all Qwen2 models at https://huggingface.co/models?filter=qwen2 -] - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 -class Qwen2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class Qwen2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - ''' - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - kv_seq_len += past_key_value[0].shape[2] - if rotary_pos_emb is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - else: - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=2) - value_states = torch.cat((past_value, value_states), dim=2) - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - past_key_value = torch.stack((key_states, value_states)) - # repeat k/v heads if n_kv_heads < n_heads - # key_states = repeat_kv(key_states, self.num_key_value_groups) - # value_states = repeat_kv(value_states, self.num_key_value_groups) - ''' - #--------------- - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - kv_seq_len = key_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[1] - # rope - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - # kv cache - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=1) - value_states = torch.cat((past_value, value_states), dim=1) - past_key_value = torch.stack((key_states, value_states)) - query_states = query_states.transpose(1, 2) - key_states = key_states.permute([0, 2, 3, 1]) - value_states = value_states.transpose(1, 2) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - #--------------- - attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2FlashAttention2(Qwen2Attention): - """ - Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and self.config.use_sliding_window - ) - - if not _flash_supports_window_size: - logger.warning_once( - "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" - " make sure to upgrade flash-attn library." - ) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def _flash_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - use_sliding_windows=False, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - use_sliding_windows (`bool`, *optional*): - Whether to activate sliding window attention. - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Decide whether to use SWA or not by layer index. - if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: - use_sliding_windows = False - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - if not use_sliding_windows: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - if not use_sliding_windows: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - return attn_output - - # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape - - # On the first iteration we need to properly re-create the padding mask - # by slicing it on the proper place - if kv_seq_len != attention_mask.shape[-1]: - attention_mask_num_tokens = attention_mask.shape[-1] - attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] - - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2 -class Qwen2SdpaAttention(Qwen2Attention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2_ATTENTION_CLASSES = { - "eager": Qwen2Attention, - "flash_attention_2": Qwen2FlashAttention2, - "sdpa": Qwen2SdpaAttention, -} - - -class Qwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) - # self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - self.self_attn = Qwen2Attention(config, layer_idx) - - self.mlp = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. " - "Please make sure use `attention_mask` instead.`" - ) - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -QWEN2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Qwen2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2PreTrainedModel(PreTrainedModel): - config_class = Qwen2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2DecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -QWEN2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance; - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2Model(Qwen2PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class Qwen2ForCausalLM(Qwen2PreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = Qwen2Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen2ForCausalLM - - >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - # Omit tokens covered by past_key_values - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The Qwen2 Model transformer with a sequence classification head on top (linear layer). - - [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - QWEN2_START_DOCSTRING, -) -class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Qwen2Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/Qwen2-0_5B-Instruct/config.json b/transformers/llm/export/llm_models/Qwen2-0_5B-Instruct/config.json deleted file mode 100755 index 8f9ea8a58..000000000 --- a/transformers/llm/export/llm_models/Qwen2-0_5B-Instruct/config.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "architectures": [ - "Qwen2ForCausalLM" - ], - "auto_map": { - "AutoConfig": "configuration_qwen2.Qwen2Config", - "AutoModelForCausalLM": "modeling_qwen2.Qwen2ForCausalLM" - }, - "attention_dropout": 0.0, - "bos_token_id": 151643, - "eos_token_id": 151645, - "hidden_act": "silu", - "hidden_size": 896, - "initializer_range": 0.02, - "intermediate_size": 4864, - "max_position_embeddings": 32768, - "max_window_layers": 21, - "model_type": "qwen2", - "num_attention_heads": 14, - "num_hidden_layers": 24, - "num_key_value_heads": 2, - "rms_norm_eps": 1e-06, - "rope_theta": 1000000.0, - "sliding_window": 32768, - "tie_word_embeddings": true, - "torch_dtype": "bfloat16", - "transformers_version": "4.40.1", - "use_cache": true, - "use_sliding_window": false, - "vocab_size": 151936 -} diff --git a/transformers/llm/export/llm_models/Qwen2-0_5B-Instruct/configuration_qwen2.py b/transformers/llm/export/llm_models/Qwen2-0_5B-Instruct/configuration_qwen2.py deleted file mode 100644 index b6ca1ed43..000000000 --- a/transformers/llm/export/llm_models/Qwen2-0_5B-Instruct/configuration_qwen2.py +++ /dev/null @@ -1,144 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Qwen2 model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json", -} - - -class Qwen2Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a - Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of - Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 151936): - Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Qwen2Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 22016): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 32): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 32768): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - use_sliding_window (`bool`, *optional*, defaults to `False`): - Whether to use sliding window attention. - sliding_window (`int`, *optional*, defaults to 4096): - Sliding window attention (SWA) window size. If not specified, will default to `4096`. - max_window_layers (`int`, *optional*, defaults to 28): - The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - - ```python - >>> from transformers import Qwen2Model, Qwen2Config - - >>> # Initializing a Qwen2 style configuration - >>> configuration = Qwen2Config() - - >>> # Initializing a model from the Qwen2-7B style configuration - >>> model = Qwen2Model(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "qwen2" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=151936, - hidden_size=4096, - intermediate_size=22016, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=32, - hidden_act="silu", - max_position_embeddings=32768, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - tie_word_embeddings=False, - rope_theta=10000.0, - use_sliding_window=False, - sliding_window=4096, - max_window_layers=28, - attention_dropout=0.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.max_window_layers = max_window_layers - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - - super().__init__( - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/transformers/llm/export/llm_models/Qwen2-0_5B-Instruct/modeling_qwen2.py b/transformers/llm/export/llm_models/Qwen2-0_5B-Instruct/modeling_qwen2.py deleted file mode 100644 index 595a3e91c..000000000 --- a/transformers/llm/export/llm_models/Qwen2-0_5B-Instruct/modeling_qwen2.py +++ /dev/null @@ -1,1436 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch Qwen2 model.""" -import inspect -import math -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) -from .configuration_qwen2 import Qwen2Config - - -# if is_flash_attn_2_available(): - #from flash_attn import flash_attn_func, flash_attn_varlen_func - #from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - #_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) - - -logger = logging.get_logger(__name__) - - -_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta" -_CONFIG_FOR_DOC = "Qwen2Config" - -QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "Qwen/Qwen2-7B-beta", - # See all Qwen2 models at https://huggingface.co/models?filter=qwen2 -] - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 -class Qwen2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class Qwen2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - ''' - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - kv_seq_len += past_key_value[0].shape[2] - if rotary_pos_emb is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - else: - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=2) - value_states = torch.cat((past_value, value_states), dim=2) - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - past_key_value = torch.stack((key_states, value_states)) - # repeat k/v heads if n_kv_heads < n_heads - # key_states = repeat_kv(key_states, self.num_key_value_groups) - # value_states = repeat_kv(value_states, self.num_key_value_groups) - ''' - #--------------- - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - kv_seq_len = key_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[1] - # rope - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - # kv cache - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=1) - value_states = torch.cat((past_value, value_states), dim=1) - past_key_value = torch.stack((key_states, value_states)) - query_states = query_states.transpose(1, 2) - key_states = key_states.permute([0, 2, 3, 1]) - value_states = value_states.transpose(1, 2) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - #--------------- - attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2FlashAttention2(Qwen2Attention): - """ - Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and self.config.use_sliding_window - ) - - if not _flash_supports_window_size: - logger.warning_once( - "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" - " make sure to upgrade flash-attn library." - ) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def _flash_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - use_sliding_windows=False, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - use_sliding_windows (`bool`, *optional*): - Whether to activate sliding window attention. - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Decide whether to use SWA or not by layer index. - if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: - use_sliding_windows = False - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - if not use_sliding_windows: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - if not use_sliding_windows: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - return attn_output - - # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape - - # On the first iteration we need to properly re-create the padding mask - # by slicing it on the proper place - if kv_seq_len != attention_mask.shape[-1]: - attention_mask_num_tokens = attention_mask.shape[-1] - attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] - - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2 -class Qwen2SdpaAttention(Qwen2Attention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2_ATTENTION_CLASSES = { - "eager": Qwen2Attention, - "flash_attention_2": Qwen2FlashAttention2, - "sdpa": Qwen2SdpaAttention, -} - - -class Qwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) - # self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - self.self_attn = Qwen2Attention(config, layer_idx) - - self.mlp = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. " - "Please make sure use `attention_mask` instead.`" - ) - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -QWEN2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Qwen2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2PreTrainedModel(PreTrainedModel): - config_class = Qwen2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2DecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -QWEN2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance; - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2Model(Qwen2PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class Qwen2ForCausalLM(Qwen2PreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = Qwen2Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen2ForCausalLM - - >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - # Omit tokens covered by past_key_values - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The Qwen2 Model transformer with a sequence classification head on top (linear layer). - - [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - QWEN2_START_DOCSTRING, -) -class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Qwen2Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/Qwen2-1_5B-Instruct/config.json b/transformers/llm/export/llm_models/Qwen2-1_5B-Instruct/config.json deleted file mode 100755 index bdc572b07..000000000 --- a/transformers/llm/export/llm_models/Qwen2-1_5B-Instruct/config.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "architectures": [ - "Qwen2ForCausalLM" - ], - "auto_map": { - "AutoConfig": "configuration_qwen2.Qwen2Config", - "AutoModelForCausalLM": "modeling_qwen2.Qwen2ForCausalLM" - }, - "attention_dropout": 0.0, - "bos_token_id": 151643, - "eos_token_id": 151645, - "hidden_act": "silu", - "hidden_size": 1536, - "initializer_range": 0.02, - "intermediate_size": 8960, - "max_position_embeddings": 32768, - "max_window_layers": 21, - "model_type": "qwen2", - "num_attention_heads": 12, - "num_hidden_layers": 28, - "num_key_value_heads": 2, - "rms_norm_eps": 1e-06, - "rope_theta": 1000000.0, - "sliding_window": 32768, - "tie_word_embeddings": true, - "torch_dtype": "bfloat16", - "transformers_version": "4.40.1", - "use_cache": true, - "use_sliding_window": false, - "vocab_size": 151936 -} diff --git a/transformers/llm/export/llm_models/Qwen2-1_5B-Instruct/configuration_qwen2.py b/transformers/llm/export/llm_models/Qwen2-1_5B-Instruct/configuration_qwen2.py deleted file mode 100644 index b6ca1ed43..000000000 --- a/transformers/llm/export/llm_models/Qwen2-1_5B-Instruct/configuration_qwen2.py +++ /dev/null @@ -1,144 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Qwen2 model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json", -} - - -class Qwen2Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a - Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of - Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 151936): - Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Qwen2Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 22016): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 32): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 32768): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - use_sliding_window (`bool`, *optional*, defaults to `False`): - Whether to use sliding window attention. - sliding_window (`int`, *optional*, defaults to 4096): - Sliding window attention (SWA) window size. If not specified, will default to `4096`. - max_window_layers (`int`, *optional*, defaults to 28): - The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - - ```python - >>> from transformers import Qwen2Model, Qwen2Config - - >>> # Initializing a Qwen2 style configuration - >>> configuration = Qwen2Config() - - >>> # Initializing a model from the Qwen2-7B style configuration - >>> model = Qwen2Model(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "qwen2" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=151936, - hidden_size=4096, - intermediate_size=22016, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=32, - hidden_act="silu", - max_position_embeddings=32768, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - tie_word_embeddings=False, - rope_theta=10000.0, - use_sliding_window=False, - sliding_window=4096, - max_window_layers=28, - attention_dropout=0.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.max_window_layers = max_window_layers - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - - super().__init__( - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/transformers/llm/export/llm_models/Qwen2-1_5B-Instruct/modeling_qwen2.py b/transformers/llm/export/llm_models/Qwen2-1_5B-Instruct/modeling_qwen2.py deleted file mode 100644 index 595a3e91c..000000000 --- a/transformers/llm/export/llm_models/Qwen2-1_5B-Instruct/modeling_qwen2.py +++ /dev/null @@ -1,1436 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch Qwen2 model.""" -import inspect -import math -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) -from .configuration_qwen2 import Qwen2Config - - -# if is_flash_attn_2_available(): - #from flash_attn import flash_attn_func, flash_attn_varlen_func - #from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - #_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) - - -logger = logging.get_logger(__name__) - - -_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta" -_CONFIG_FOR_DOC = "Qwen2Config" - -QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "Qwen/Qwen2-7B-beta", - # See all Qwen2 models at https://huggingface.co/models?filter=qwen2 -] - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 -class Qwen2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class Qwen2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - ''' - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - kv_seq_len += past_key_value[0].shape[2] - if rotary_pos_emb is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - else: - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=2) - value_states = torch.cat((past_value, value_states), dim=2) - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - past_key_value = torch.stack((key_states, value_states)) - # repeat k/v heads if n_kv_heads < n_heads - # key_states = repeat_kv(key_states, self.num_key_value_groups) - # value_states = repeat_kv(value_states, self.num_key_value_groups) - ''' - #--------------- - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - kv_seq_len = key_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[1] - # rope - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - # kv cache - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=1) - value_states = torch.cat((past_value, value_states), dim=1) - past_key_value = torch.stack((key_states, value_states)) - query_states = query_states.transpose(1, 2) - key_states = key_states.permute([0, 2, 3, 1]) - value_states = value_states.transpose(1, 2) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - #--------------- - attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2FlashAttention2(Qwen2Attention): - """ - Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and self.config.use_sliding_window - ) - - if not _flash_supports_window_size: - logger.warning_once( - "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" - " make sure to upgrade flash-attn library." - ) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def _flash_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - use_sliding_windows=False, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - use_sliding_windows (`bool`, *optional*): - Whether to activate sliding window attention. - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Decide whether to use SWA or not by layer index. - if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: - use_sliding_windows = False - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - if not use_sliding_windows: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - if not use_sliding_windows: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - return attn_output - - # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape - - # On the first iteration we need to properly re-create the padding mask - # by slicing it on the proper place - if kv_seq_len != attention_mask.shape[-1]: - attention_mask_num_tokens = attention_mask.shape[-1] - attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] - - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2 -class Qwen2SdpaAttention(Qwen2Attention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2_ATTENTION_CLASSES = { - "eager": Qwen2Attention, - "flash_attention_2": Qwen2FlashAttention2, - "sdpa": Qwen2SdpaAttention, -} - - -class Qwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) - # self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - self.self_attn = Qwen2Attention(config, layer_idx) - - self.mlp = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. " - "Please make sure use `attention_mask` instead.`" - ) - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -QWEN2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Qwen2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2PreTrainedModel(PreTrainedModel): - config_class = Qwen2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2DecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -QWEN2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance; - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2Model(Qwen2PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class Qwen2ForCausalLM(Qwen2PreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = Qwen2Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen2ForCausalLM - - >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - # Omit tokens covered by past_key_values - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The Qwen2 Model transformer with a sequence classification head on top (linear layer). - - [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - QWEN2_START_DOCSTRING, -) -class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Qwen2Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/Qwen2-1_5B/config.json b/transformers/llm/export/llm_models/Qwen2-1_5B/config.json deleted file mode 100755 index 08a0ac476..000000000 --- a/transformers/llm/export/llm_models/Qwen2-1_5B/config.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "architectures": [ - "Qwen2ForCausalLM" - ], - "auto_map": { - "AutoConfig": "configuration_qwen2.Qwen2Config", - "AutoModelForCausalLM": "modeling_qwen2.Qwen2ForCausalLM" - }, - "attention_dropout": 0.0, - "bos_token_id": 151643, - "eos_token_id": 151643, - "hidden_act": "silu", - "hidden_size": 1536, - "initializer_range": 0.02, - "intermediate_size": 8960, - "max_position_embeddings": 131072, - "max_window_layers": 21, - "model_type": "qwen2", - "num_attention_heads": 12, - "num_hidden_layers": 28, - "num_key_value_heads": 2, - "rms_norm_eps": 1e-06, - "rope_theta": 1000000.0, - "sliding_window": 131072, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.38.2", - "use_cache": true, - "use_sliding_window": false, - "vocab_size": 151936 -} diff --git a/transformers/llm/export/llm_models/Qwen2-1_5B/configuration_qwen2.py b/transformers/llm/export/llm_models/Qwen2-1_5B/configuration_qwen2.py deleted file mode 100644 index b6ca1ed43..000000000 --- a/transformers/llm/export/llm_models/Qwen2-1_5B/configuration_qwen2.py +++ /dev/null @@ -1,144 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Qwen2 model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json", -} - - -class Qwen2Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a - Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of - Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 151936): - Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Qwen2Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 22016): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 32): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 32768): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - use_sliding_window (`bool`, *optional*, defaults to `False`): - Whether to use sliding window attention. - sliding_window (`int`, *optional*, defaults to 4096): - Sliding window attention (SWA) window size. If not specified, will default to `4096`. - max_window_layers (`int`, *optional*, defaults to 28): - The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - - ```python - >>> from transformers import Qwen2Model, Qwen2Config - - >>> # Initializing a Qwen2 style configuration - >>> configuration = Qwen2Config() - - >>> # Initializing a model from the Qwen2-7B style configuration - >>> model = Qwen2Model(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "qwen2" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=151936, - hidden_size=4096, - intermediate_size=22016, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=32, - hidden_act="silu", - max_position_embeddings=32768, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - tie_word_embeddings=False, - rope_theta=10000.0, - use_sliding_window=False, - sliding_window=4096, - max_window_layers=28, - attention_dropout=0.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.max_window_layers = max_window_layers - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - - super().__init__( - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/transformers/llm/export/llm_models/Qwen2-1_5B/modeling_qwen2.py b/transformers/llm/export/llm_models/Qwen2-1_5B/modeling_qwen2.py deleted file mode 100644 index f8d5b5345..000000000 --- a/transformers/llm/export/llm_models/Qwen2-1_5B/modeling_qwen2.py +++ /dev/null @@ -1,1434 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch Qwen2 model.""" -import inspect -import math -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) -from .configuration_qwen2 import Qwen2Config - - -# if is_flash_attn_2_available(): - #from flash_attn import flash_attn_func, flash_attn_varlen_func - #from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - #_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) - - -logger = logging.get_logger(__name__) - - -_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta" -_CONFIG_FOR_DOC = "Qwen2Config" - -QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "Qwen/Qwen2-7B-beta", - # See all Qwen2 models at https://huggingface.co/models?filter=qwen2 -] - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 -class Qwen2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class Qwen2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - ''' - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - kv_seq_len += past_key_value[0].shape[2] - if rotary_pos_emb is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - else: - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=2) - value_states = torch.cat((past_value, value_states), dim=2) - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - past_key_value = torch.stack((key_states, value_states)) - ''' - #--------------- - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - kv_seq_len = key_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[1] - # rope - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - # kv cache - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=1) - value_states = torch.cat((past_value, value_states), dim=1) - past_key_value = torch.stack((key_states, value_states)) - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - #--------------- - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2FlashAttention2(Qwen2Attention): - """ - Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and self.config.use_sliding_window - ) - - if not _flash_supports_window_size: - logger.warning_once( - "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" - " make sure to upgrade flash-attn library." - ) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def _flash_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - use_sliding_windows=False, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - use_sliding_windows (`bool`, *optional*): - Whether to activate sliding window attention. - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Decide whether to use SWA or not by layer index. - if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: - use_sliding_windows = False - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - if not use_sliding_windows: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - if not use_sliding_windows: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - return attn_output - - # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape - - # On the first iteration we need to properly re-create the padding mask - # by slicing it on the proper place - if kv_seq_len != attention_mask.shape[-1]: - attention_mask_num_tokens = attention_mask.shape[-1] - attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] - - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2 -class Qwen2SdpaAttention(Qwen2Attention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2_ATTENTION_CLASSES = { - "eager": Qwen2Attention, - "flash_attention_2": Qwen2FlashAttention2, - "sdpa": Qwen2SdpaAttention, -} - - -class Qwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) - # self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - self.self_attn = Qwen2Attention(config, layer_idx) - - self.mlp = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. " - "Please make sure use `attention_mask` instead.`" - ) - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -QWEN2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Qwen2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2PreTrainedModel(PreTrainedModel): - config_class = Qwen2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2DecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -QWEN2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance; - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2Model(Qwen2PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class Qwen2ForCausalLM(Qwen2PreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = Qwen2Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen2ForCausalLM - - >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - # Omit tokens covered by past_key_values - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The Qwen2 Model transformer with a sequence classification head on top (linear layer). - - [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - QWEN2_START_DOCSTRING, -) -class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Qwen2Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/Qwen2-7B-Instruct/config.json b/transformers/llm/export/llm_models/Qwen2-7B-Instruct/config.json deleted file mode 100755 index eac7cd285..000000000 --- a/transformers/llm/export/llm_models/Qwen2-7B-Instruct/config.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "architectures": [ - "Qwen2ForCausalLM" - ], - "auto_map": { - "AutoConfig": "configuration_qwen2.Qwen2Config", - "AutoModelForCausalLM": "modeling_qwen2.Qwen2ForCausalLM" - }, - "attention_dropout": 0.0, - "bos_token_id": 151643, - "eos_token_id": 151645, - "hidden_act": "silu", - "hidden_size": 3584, - "initializer_range": 0.02, - "intermediate_size": 18944, - "max_position_embeddings": 32768, - "max_window_layers": 28, - "model_type": "qwen2", - "num_attention_heads": 28, - "num_hidden_layers": 28, - "num_key_value_heads": 4, - "rms_norm_eps": 1e-06, - "rope_theta": 1000000.0, - "sliding_window": 32768, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.41.2", - "use_cache": true, - "use_sliding_window": false, - "vocab_size": 152064 -} diff --git a/transformers/llm/export/llm_models/Qwen2-7B-Instruct/configuration_qwen2.py b/transformers/llm/export/llm_models/Qwen2-7B-Instruct/configuration_qwen2.py deleted file mode 100644 index b6ca1ed43..000000000 --- a/transformers/llm/export/llm_models/Qwen2-7B-Instruct/configuration_qwen2.py +++ /dev/null @@ -1,144 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Qwen2 model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -QWEN2_PRETRAINED_CONFIG_ARCHIVE_MAP = { - "Qwen/Qwen2-7B-beta": "https://huggingface.co/Qwen/Qwen2-7B-beta/resolve/main/config.json", -} - - -class Qwen2Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a - Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of - Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta). - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 151936): - Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Qwen2Model`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 22016): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*, defaults to 32): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 32768): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether the model's input and output word embeddings should be tied. - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - use_sliding_window (`bool`, *optional*, defaults to `False`): - Whether to use sliding window attention. - sliding_window (`int`, *optional*, defaults to 4096): - Sliding window attention (SWA) window size. If not specified, will default to `4096`. - max_window_layers (`int`, *optional*, defaults to 28): - The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - - ```python - >>> from transformers import Qwen2Model, Qwen2Config - - >>> # Initializing a Qwen2 style configuration - >>> configuration = Qwen2Config() - - >>> # Initializing a model from the Qwen2-7B style configuration - >>> model = Qwen2Model(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "qwen2" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=151936, - hidden_size=4096, - intermediate_size=22016, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=32, - hidden_act="silu", - max_position_embeddings=32768, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - tie_word_embeddings=False, - rope_theta=10000.0, - use_sliding_window=False, - sliding_window=4096, - max_window_layers=28, - attention_dropout=0.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.use_sliding_window = use_sliding_window - self.sliding_window = sliding_window - self.max_window_layers = max_window_layers - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - - super().__init__( - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/transformers/llm/export/llm_models/Qwen2-7B-Instruct/modeling_qwen2.py b/transformers/llm/export/llm_models/Qwen2-7B-Instruct/modeling_qwen2.py deleted file mode 100644 index 595a3e91c..000000000 --- a/transformers/llm/export/llm_models/Qwen2-7B-Instruct/modeling_qwen2.py +++ /dev/null @@ -1,1436 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch Qwen2 model.""" -import inspect -import math -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) -from .configuration_qwen2 import Qwen2Config - - -# if is_flash_attn_2_available(): - #from flash_attn import flash_attn_func, flash_attn_varlen_func - #from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - #_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) - - -logger = logging.get_logger(__name__) - - -_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta" -_CONFIG_FOR_DOC = "Qwen2Config" - -QWEN2_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "Qwen/Qwen2-7B-beta", - # See all Qwen2 models at https://huggingface.co/models?filter=qwen2 -] - - -# Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2 -class Qwen2RMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - Qwen2RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 -class Qwen2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2 -class Qwen2MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class Qwen2Attention(nn.Module): - """ - Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer - and "Generating Long Sequences with Sparse Transformers". - """ - - def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " - "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.attention_dropout = config.attention_dropout - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - self.rotary_emb = Qwen2RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - ''' - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - kv_seq_len += past_key_value[0].shape[2] - if rotary_pos_emb is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - else: - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=2) - value_states = torch.cat((past_value, value_states), dim=2) - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - past_key_value = torch.stack((key_states, value_states)) - # repeat k/v heads if n_kv_heads < n_heads - # key_states = repeat_kv(key_states, self.num_key_value_groups) - # value_states = repeat_kv(value_states, self.num_key_value_groups) - ''' - #--------------- - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - kv_seq_len = key_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[1] - # rope - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - # kv cache - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=1) - value_states = torch.cat((past_value, value_states), dim=1) - past_key_value = torch.stack((key_states, value_states)) - query_states = query_states.transpose(1, 2) - key_states = key_states.permute([0, 2, 3, 1]) - value_states = value_states.transpose(1, 2) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - #--------------- - attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class Qwen2FlashAttention2(Qwen2Attention): - """ - Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` - as the weights of the module stays untouched. The only required change would be on the forward pass - where it needs to correctly call the public API of flash attention and deal with padding tokens - in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom - config.max_window_layers layers. - """ - - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and self.config.use_sliding_window - ) - - if not _flash_supports_window_size: - logger.warning_once( - "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" - " make sure to upgrade flash-attn library." - ) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if ( - getattr(self.config, "sliding_window", None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents - ): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" - f" {past_key.shape}" - ) - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def _flash_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - use_sliding_windows=False, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - use_sliding_windows (`bool`, *optional*): - Whether to activate sliding window attention. - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Decide whether to use SWA or not by layer index. - if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: - use_sliding_windows = False - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - if not use_sliding_windows: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - if not use_sliding_windows: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), - ) - - return attn_output - - # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape - - # On the first iteration we need to properly re-create the padding mask - # by slicing it on the proper place - if kv_seq_len != attention_mask.shape[-1]: - attention_mask_num_tokens = attention_mask.shape[-1] - attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] - - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2 -class Qwen2SdpaAttention(Qwen2Attention): - """ - Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Qwen2Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and attention_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal=self.is_causal and attention_mask is None and q_len > 1, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -QWEN2_ATTENTION_CLASSES = { - "eager": Qwen2Attention, - "flash_attention_2": Qwen2FlashAttention2, - "sdpa": Qwen2SdpaAttention, -} - - -class Qwen2DecoderLayer(nn.Module): - def __init__(self, config: Qwen2Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": - logger.warning_once( - f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " - "unexpected results may be encountered." - ) - # self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - self.self_attn = Qwen2Attention(config, layer_idx) - - self.mlp = Qwen2MLP(config) - self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.FloatTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. " - "Please make sure use `attention_mask` instead.`" - ) - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, sequence_length)` where padding elements are indicated by 0. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -QWEN2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Qwen2Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2PreTrainedModel(PreTrainedModel): - config_class = Qwen2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2DecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -QWEN2_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance; - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.", - QWEN2_START_DOCSTRING, -) -class Qwen2Model(Qwen2PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] - - Args: - config: Qwen2Config - """ - - def __init__(self, config: Qwen2Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self._attn_implementation = config._attn_implementation - self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class Qwen2ForCausalLM(Qwen2PreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = Qwen2Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen2ForCausalLM - - >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - # Omit tokens covered by past_key_values - if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The Qwen2 Model transformer with a sequence classification head on top (linear layer). - - [`Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - QWEN2_START_DOCSTRING, -) -class Qwen2ForSequenceClassification(Qwen2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = Qwen2Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/TinyLlama-1_1B-Chat/config.json b/transformers/llm/export/llm_models/TinyLlama-1_1B-Chat/config.json deleted file mode 100755 index 117c9e5d6..000000000 --- a/transformers/llm/export/llm_models/TinyLlama-1_1B-Chat/config.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "_name_or_path": "/mnt/petrelfs/libo1.p/alignment-handbook/data/tinyllama-2T-sft-full", - "architectures": [ - "LlamaForCausalLM" - ], - "auto_map": { - "AutoModelForCausalLM": "modeling_llama.LlamaForCausalLM" - }, - "attention_bias": false, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 2048, - "initializer_range": 0.02, - "intermediate_size": 5632, - "max_position_embeddings": 2048, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 22, - "num_key_value_heads": 4, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "rope_theta": 10000.0, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.35.0", - "use_cache": false, - "vocab_size": 32000 -} diff --git a/transformers/llm/export/llm_models/TinyLlama-1_1B-Chat/configuration_llama.py b/transformers/llm/export/llm_models/TinyLlama-1_1B-Chat/configuration_llama.py deleted file mode 100644 index 1b0e9c357..000000000 --- a/transformers/llm/export/llm_models/TinyLlama-1_1B-Chat/configuration_llama.py +++ /dev/null @@ -1,174 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" LLaMA model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - -class LlamaConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`LlamaModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - pretraining_tp (`int`, *optional*, defaults to `1`): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-12): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings(`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. See the following thread for more information on how - these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an - experimental feature, subject to breaking API changes in future versions. - - Example: - - ```python - >>> from transformers import LlamaModel, LlamaConfig - - >>> # Initializing a LLaMA llama-7b style configuration - >>> configuration = LlamaConfig() - - >>> # Initializing a model from the llama-7b style configuration - >>> model = LlamaModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - model_type = "llama" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_scaling=None, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_scaling = rope_scaling - self._rope_scaling_validation() - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - def _rope_scaling_validation(self): - """ - Validate the `rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " - f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/transformers/llm/export/llm_models/TinyLlama-1_1B-Chat/modeling_llama.py b/transformers/llm/export/llm_models/TinyLlama-1_1B-Chat/modeling_llama.py deleted file mode 100644 index 493b040b7..000000000 --- a/transformers/llm/export/llm_models/TinyLlama-1_1B-Chat/modeling_llama.py +++ /dev/null @@ -1,1040 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch LLaMA model.""" -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_llama import LlamaConfig - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "LlamaConfig" - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = torch.squeeze(cos) # [seq_len, dim] - sin = torch.squeeze(sin) # [seq_len, dim] - # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LlamaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.pretraining_tp = config.pretraining_tp - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - if self.pretraining_tp > 1: - slice = self.intermediate_size // self.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.pretraining_tp = config.pretraining_tp - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - if self.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp - query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - ''' - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - if rotary_pos_emb is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - else: - cos, sin = rotary_pos_emb - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - ''' - #--------------- - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - kv_seq_len = key_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[1] - # rope - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - # kv cache - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=1) - value_states = torch.cat((past_value, value_states), dim=1) - past_key_value = torch.stack((key_states, value_states)) - query_states = query_states.transpose(1, 2) - key_states = key_states.permute([0, 2, 3, 1]) - value_states = value_states.transpose(1, 2) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - #--------------- - attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - rotary_pos_emb=rotary_pos_emb, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): - module.gradient_checkpointing = value - - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - self.pretraining_tp = config.pretraining_tp - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The LLaMa Model transformer with a sequence classification head on top (linear layer). - - [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - LLAMA_START_DOCSTRING, -) -class LlamaForSequenceClassification(LlamaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = LlamaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/Yi-6B-Chat/config.json b/transformers/llm/export/llm_models/Yi-6B-Chat/config.json deleted file mode 100755 index aad6b1d39..000000000 --- a/transformers/llm/export/llm_models/Yi-6B-Chat/config.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "auto_map": { - "AutoModelForCausalLM": "modeling_llama.LlamaForCausalLM" - }, - "attention_bias": false, - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 4096, - "initializer_range": 0.02, - "intermediate_size": 11008, - "max_position_embeddings": 4096, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 32, - "num_key_value_heads": 4, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": null, - "rope_theta": 5000000.0, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.35.0", - "use_cache": true, - "vocab_size": 64000 -} diff --git a/transformers/llm/export/llm_models/Yi-6B-Chat/configuration_llama.py b/transformers/llm/export/llm_models/Yi-6B-Chat/configuration_llama.py deleted file mode 100644 index 1b0e9c357..000000000 --- a/transformers/llm/export/llm_models/Yi-6B-Chat/configuration_llama.py +++ /dev/null @@ -1,174 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" LLaMA model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - -class LlamaConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`LlamaModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - pretraining_tp (`int`, *optional*, defaults to `1`): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-12): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings(`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. See the following thread for more information on how - these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an - experimental feature, subject to breaking API changes in future versions. - - Example: - - ```python - >>> from transformers import LlamaModel, LlamaConfig - - >>> # Initializing a LLaMA llama-7b style configuration - >>> configuration = LlamaConfig() - - >>> # Initializing a model from the llama-7b style configuration - >>> model = LlamaModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - model_type = "llama" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_scaling=None, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_scaling = rope_scaling - self._rope_scaling_validation() - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - def _rope_scaling_validation(self): - """ - Validate the `rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " - f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/transformers/llm/export/llm_models/Yi-6B-Chat/modeling_llama.py b/transformers/llm/export/llm_models/Yi-6B-Chat/modeling_llama.py deleted file mode 100644 index 493b040b7..000000000 --- a/transformers/llm/export/llm_models/Yi-6B-Chat/modeling_llama.py +++ /dev/null @@ -1,1040 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch LLaMA model.""" -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_llama import LlamaConfig - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "LlamaConfig" - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = torch.squeeze(cos) # [seq_len, dim] - sin = torch.squeeze(sin) # [seq_len, dim] - # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LlamaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.pretraining_tp = config.pretraining_tp - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - if self.pretraining_tp > 1: - slice = self.intermediate_size // self.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.pretraining_tp = config.pretraining_tp - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - if self.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp - query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - ''' - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - if rotary_pos_emb is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - else: - cos, sin = rotary_pos_emb - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - ''' - #--------------- - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - kv_seq_len = key_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[1] - # rope - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - # kv cache - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=1) - value_states = torch.cat((past_value, value_states), dim=1) - past_key_value = torch.stack((key_states, value_states)) - query_states = query_states.transpose(1, 2) - key_states = key_states.permute([0, 2, 3, 1]) - value_states = value_states.transpose(1, 2) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - #--------------- - attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - rotary_pos_emb=rotary_pos_emb, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): - module.gradient_checkpointing = value - - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - self.pretraining_tp = config.pretraining_tp - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The LLaMa Model transformer with a sequence classification head on top (linear layer). - - [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - LLAMA_START_DOCSTRING, -) -class LlamaForSequenceClassification(LlamaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = LlamaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/chatglm-6b/modeling_chatglm.py b/transformers/llm/export/llm_models/chatglm-6b/modeling_chatglm.py deleted file mode 100644 index 82effe877..000000000 --- a/transformers/llm/export/llm_models/chatglm-6b/modeling_chatglm.py +++ /dev/null @@ -1,1441 +0,0 @@ -""" PyTorch ChatGLM model. """ - -import math -import copy -import os -import warnings -import re -import sys - -import torch -import torch.utils.checkpoint -import torch.nn.functional as F -from torch import nn -from torch.nn import CrossEntropyLoss, LayerNorm -from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable, Dict, Any - -from transformers.utils import ( - add_code_sample_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, -) -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput - -from .configuration_chatglm import ChatGLMConfig - -# flags required to enable jit fusion kernels - -if sys.platform != 'darwin': - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B" -_CONFIG_FOR_DOC = "ChatGLM6BConfig" - -CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "THUDM/chatglm-6b", - # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm -] - - -class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 5] = 5e4 - return scores - - -def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): - """Load tf checkpoints in a pytorch model.""" - try: - import re - - import numpy as np - import tensorflow as tf - except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) - raise - tf_path = os.path.abspath(tf_checkpoint_path) - logger.info(f"Converting TensorFlow checkpoint from {tf_path}") - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - logger.info(f"Loading TF weight {name} with shape {shape}") - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split("/") - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): - logger.info(f"Skipping {'/'.join(name)}") - continue - pointer = model - for m_name in name: - if re.fullmatch(r"[A-Za-z]+_\d+", m_name): - scope_names = re.split(r"_(\d+)", m_name) - else: - scope_names = [m_name] - if scope_names[0] == "kernel" or scope_names[0] == "gamma": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "output_bias" or scope_names[0] == "beta": - pointer = getattr(pointer, "bias") - elif scope_names[0] == "output_weights": - pointer = getattr(pointer, "weight") - elif scope_names[0] == "squad": - pointer = getattr(pointer, "classifier") - else: - try: - pointer = getattr(pointer, scope_names[0]) - except AttributeError: - logger.info(f"Skipping {'/'.join(name)}") - continue - if len(scope_names) >= 2: - num = int(scope_names[1]) - pointer = pointer[num] - if m_name[-11:] == "_embeddings": - pointer = getattr(pointer, "weight") - elif m_name == "kernel": - array = np.transpose(array) - try: - assert ( - pointer.shape == array.shape - ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - logger.info(f"Initialize PyTorch weight {name}") - pointer.data = torch.from_numpy(array) - return model - - -class PrefixEncoder(torch.nn.Module): - """ - The torch.nn model to encode the prefix - Input shape: (batch-size, prefix-length) - Output shape: (batch-size, prefix-length, 2*layers*hidden) - """ - - def __init__(self, config): - super().__init__() - self.prefix_projection = config.prefix_projection - if self.prefix_projection: - # Use a two-layer MLP to encode the prefix - self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) - self.trans = torch.nn.Sequential( - torch.nn.Linear(config.hidden_size, config.hidden_size), - torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) - ) - else: - self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) - - def forward(self, prefix: torch.Tensor): - if self.prefix_projection: - prefix_tokens = self.embedding(prefix) - past_key_values = self.trans(prefix_tokens) - else: - past_key_values = self.embedding(prefix) - return past_key_values - - -@torch.jit.script -def gelu_impl(x): - """OpenAI's gelu implementation.""" - return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * - (1.0 + 0.044715 * x * x))) - - -def gelu(x): - return gelu_impl(x) - - -class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, base=10000, precision=torch.half, learnable=False): - super().__init__() - inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) - inv_freq = inv_freq.half() - self.learnable = learnable - if learnable: - self.inv_freq = torch.nn.Parameter(inv_freq) - self.max_seq_len_cached = None - else: - self.register_buffer('inv_freq', inv_freq) - self.max_seq_len_cached = None - self.cos_cached = None - self.sin_cached = None - self.precision = precision - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): - pass - - def forward(self, x, seq_dim=1, seq_len=None): - if seq_len is None: - seq_len = x.shape[seq_dim] - if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): - self.max_seq_len_cached = None if self.learnable else seq_len - t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - if self.precision == torch.bfloat16: - emb = emb.float() - - # [sx, 1 (b * np), hn] - cos_cached = emb.cos()[:, None, :] - sin_cached = emb.sin()[:, None, :] - if self.precision == torch.bfloat16: - cos_cached = cos_cached.bfloat16() - sin_cached = sin_cached.bfloat16() - if self.learnable: - return cos_cached, sin_cached - self.cos_cached, self.sin_cached = cos_cached, sin_cached - return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] - - def _apply(self, fn): - if self.cos_cached is not None: - self.cos_cached = fn(self.cos_cached) - if self.sin_cached is not None: - self.sin_cached = fn(self.sin_cached) - return super()._apply(fn) - - -def rotate_half(x): - x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions - - -@torch.jit.script -def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): - # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] - cos, sin = F.embedding(position_id, torch.squeeze(cos)).unsqueeze(2), \ - F.embedding(position_id, torch.squeeze(sin)).unsqueeze(2) - q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) - return q, k - - -def attention_fn( - self, - query_layer, - key_layer, - value_layer, - attention_mask, - hidden_size_per_partition, - layer_id, - layer_past=None, - scaling_attention_score=True, - use_cache=False, -): - if layer_past is not None: - past_key, past_value = layer_past[0], layer_past[1] - key_layer = torch.cat((past_key, key_layer), dim=0) - value_layer = torch.cat((past_value, value_layer), dim=0) - - # seqlen, batch, num_attention_heads, hidden_size_per_attention_head - seq_len, b, nh, hidden_size = key_layer.shape - - if use_cache: - present = (key_layer, value_layer) - else: - present = None - - query_key_layer_scaling_coeff = float(layer_id + 1) - if scaling_attention_score: - query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff) - - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - - # [b, np, sq, sk] - output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) - - # [sq, b, np, hn] -> [sq, b * np, hn] - #query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - #key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - query_layer = query_layer.squeeze(1) - key_layer = key_layer.squeeze(1) - - matmul_result = torch.zeros( - 1, 1, 1, - dtype=query_layer.dtype, - device=query_layer.device, - ) - - matmul_result = torch.baddbmm( - matmul_result, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=1.0, - ) - - # change view to [b, np, sq, sk] - # attention_scores = matmul_result.view(*output_size) - attention_scores = matmul_result.unsqueeze(0) - - if self.scale_mask_softmax: - self.scale_mask_softmax.scale = query_key_layer_scaling_coeff - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous()) - else: - if not (attention_mask == 0).all(): - # if auto-regressive, skip - attention_scores.masked_fill_(attention_mask, -10000.0) - dtype = attention_scores.dtype - attention_scores = attention_scores.float() - attention_scores = attention_scores * query_key_layer_scaling_coeff - - attention_probs = F.softmax(attention_scores, dim=-1) - - attention_probs = attention_probs.type(dtype) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - - # change view [sk, b * np, hn] - # value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - value_layer = value_layer.squeeze(1) - - # change view [b * np, sq, sk] - # attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - attention_probs = attention_probs.squeeze(0) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - - # change view [b, np, sq, hn] - # context_layer = context_layer.view(*output_size) - context_layer = context_layer.unsqueeze(0) - - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,) - # context_layer = context_layer.view(*new_context_layer_shape) - context_layer = context_layer.view([-1, 1, hidden_size_per_partition]) - outputs = (context_layer, present, attention_probs) - - return outputs - - -def default_init(cls, *args, **kwargs): - return cls(*args, **kwargs) - - -class SelfAttention(torch.nn.Module): - def __init__(self, hidden_size, num_attention_heads, - layer_id, hidden_size_per_attention_head=None, bias=True, - params_dtype=torch.float, position_encoding_2d=True, empty_init=True): - if empty_init: - init_method = skip_init - else: - init_method = default_init - super(SelfAttention, self).__init__() - - self.layer_id = layer_id - self.hidden_size = hidden_size - self.hidden_size_per_partition = hidden_size - self.num_attention_heads = num_attention_heads - self.num_attention_heads_per_partition = num_attention_heads - self.position_encoding_2d = position_encoding_2d - self.rotary_emb = RotaryEmbedding( - self.hidden_size // (self.num_attention_heads * 2) - if position_encoding_2d - else self.hidden_size // self.num_attention_heads, - base=10000, - precision=torch.half, - learnable=False, - ) - - self.scale_mask_softmax = None - - if hidden_size_per_attention_head is None: - self.hidden_size_per_attention_head = hidden_size // num_attention_heads - else: - self.hidden_size_per_attention_head = hidden_size_per_attention_head - - self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head - - # Strided linear layer. - self.query_key_value = init_method( - torch.nn.Linear, - hidden_size, - 3 * self.inner_hidden_size, - bias=bias, - dtype=params_dtype, - ) - - self.dense = init_method( - torch.nn.Linear, - self.inner_hidden_size, - hidden_size, - bias=bias, - dtype=params_dtype, - ) - - @staticmethod - def attention_mask_func(attention_scores, attention_mask): - attention_scores.masked_fill_(attention_mask, -10000.0) - return attention_scores - - def split_tensor_along_last_dim(self, tensor, num_partitions, - contiguous_split_chunks=False): - """Split a tensor along its last dimension. - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - def forward( - self, - hidden_states: torch.Tensor, - position_ids, - attention_mask: torch.Tensor, - layer_id, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - """ - hidden_states: [seq_len, batch, hidden_size] - attention_mask: [(1, 1), seq_len, seq_len] - """ - - # [seq_len, batch, 3 * hidden_size] - mixed_raw_layer = self.query_key_value(hidden_states) - - # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head] - new_tensor_shape = mixed_raw_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape) - - # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] - (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3) - - if self.position_encoding_2d: - q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) - k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1)) - cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) - position_ids, block_position_ids = position_ids[:, 0, :].transpose(0, 1).contiguous(), \ - position_ids[:, 1, :].transpose(0, 1).contiguous() - q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids) - q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids) - query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1)) - key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1)) - else: - position_ids = position_ids.transpose(0, 1) - cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1) - # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head] - query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids) - - # [seq_len, batch, hidden_size] - context_layer, present, attention_probs = attention_fn( - self=self, - query_layer=query_layer, - key_layer=key_layer, - value_layer=value_layer, - attention_mask=attention_mask, - hidden_size_per_partition=self.hidden_size_per_partition, - layer_id=layer_id, - layer_past=layer_past, - use_cache=use_cache - ) - - output = self.dense(context_layer) - - outputs = (output, present) - - if output_attentions: - outputs += (attention_probs,) - - return outputs # output, present, attention_probs - - -class GEGLU(torch.nn.Module): - def __init__(self): - super().__init__() - self.activation_fn = F.gelu - - def forward(self, x): - # dim=-1 breaks in jit for pt<1.10 - x1, x2 = x.chunk(2, dim=(x.ndim - 1)) - return x1 * self.activation_fn(x2) - - -class GLU(torch.nn.Module): - def __init__(self, hidden_size, inner_hidden_size=None, - layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True): - super(GLU, self).__init__() - if empty_init: - init_method = skip_init - else: - init_method = default_init - self.layer_id = layer_id - self.activation_func = activation_func - - # Project to 4h. - self.hidden_size = hidden_size - if inner_hidden_size is None: - inner_hidden_size = 4 * hidden_size - self.inner_hidden_size = inner_hidden_size - self.dense_h_to_4h = init_method( - torch.nn.Linear, - self.hidden_size, - self.inner_hidden_size, - bias=bias, - dtype=params_dtype, - ) - # Project back to h. - self.dense_4h_to_h = init_method( - torch.nn.Linear, - self.inner_hidden_size, - self.hidden_size, - bias=bias, - dtype=params_dtype, - ) - - def forward(self, hidden_states): - """ - hidden_states: [seq_len, batch, hidden_size] - """ - - # [seq_len, batch, inner_hidden_size] - intermediate_parallel = self.dense_h_to_4h(hidden_states) - - intermediate_parallel = self.activation_func(intermediate_parallel) - - output = self.dense_4h_to_h(intermediate_parallel) - - return output - - -class GLMBlock(torch.nn.Module): - def __init__( - self, - hidden_size, - num_attention_heads, - layernorm_epsilon, - layer_id, - inner_hidden_size=None, - hidden_size_per_attention_head=None, - layernorm=LayerNorm, - use_bias=True, - params_dtype=torch.float, - num_layers=28, - position_encoding_2d=True, - empty_init=True - ): - super(GLMBlock, self).__init__() - # Set output layer initialization if not provided. - - self.layer_id = layer_id - - # Layernorm on the input data. - self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) - - self.position_encoding_2d = position_encoding_2d - - # Self attention. - self.attention = SelfAttention( - hidden_size, - num_attention_heads, - layer_id, - hidden_size_per_attention_head=hidden_size_per_attention_head, - bias=use_bias, - params_dtype=params_dtype, - position_encoding_2d=self.position_encoding_2d, - empty_init=empty_init - ) - - # Layernorm on the input data. - self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) - - self.num_layers = num_layers - - # GLU - self.mlp = GLU( - hidden_size, - inner_hidden_size=inner_hidden_size, - bias=use_bias, - layer_id=layer_id, - params_dtype=params_dtype, - empty_init=empty_init - ) - - def forward( - self, - hidden_states: torch.Tensor, - position_ids, - attention_mask: torch.Tensor, - layer_id, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - """ - hidden_states: [seq_len, batch, hidden_size] - attention_mask: [(1, 1), seq_len, seq_len] - """ - - # Layer norm at the begining of the transformer layer. - # [seq_len, batch, hidden_size] - attention_input = self.input_layernorm(hidden_states) - - # Self attention. - attention_outputs = self.attention( - attention_input, - position_ids, - attention_mask=attention_mask, - layer_id=layer_id, - layer_past=layer_past, - use_cache=use_cache, - output_attentions=output_attentions - ) - - attention_output = attention_outputs[0] - - outputs = attention_outputs[1:] - - # Residual connection. - alpha = (2 * self.num_layers) ** 0.5 - hidden_states = attention_input * alpha + attention_output - - mlp_input = self.post_attention_layernorm(hidden_states) - - # MLP. - mlp_output = self.mlp(mlp_input) - - # Second residual connection. - output = mlp_input * alpha + mlp_output - - if use_cache: - outputs = (output,) + outputs - else: - outputs = (output,) + outputs[1:] - - return outputs # hidden_states, present, attentions - - -class ChatGLMPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - is_parallelizable = False - supports_gradient_checkpointing = True - config_class = ChatGLMConfig - base_model_prefix = "transformer" - _no_split_modules = ["GLMBlock"] - - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) - - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - return - - def get_masks(self, input_ids, device): - batch_size, seq_length = input_ids.shape - context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] - attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) - attention_mask.tril_() - for i, context_length in enumerate(context_lengths): - attention_mask[i, :, :context_length] = 1 - attention_mask.unsqueeze_(1) - attention_mask = (attention_mask < 0.5).bool() - - return attention_mask - - def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): - batch_size, seq_length = input_ids.shape - if use_gmasks is None: - use_gmasks = [False] * batch_size - context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] - if self.position_encoding_2d: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - for i, context_length in enumerate(context_lengths): - position_ids[i, context_length:] = mask_positions[i] - block_position_ids = [torch.cat(( - torch.zeros(context_length, dtype=torch.long, device=device), - torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 - )) for context_length in context_lengths] - block_position_ids = torch.stack(block_position_ids, dim=0) - position_ids = torch.stack((position_ids, block_position_ids), dim=1) - else: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - for i, context_length in enumerate(context_lengths): - if not use_gmasks[i]: - position_ids[i, context_length:] = mask_positions[i] - - return position_ids - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, ChatGLMModel): - module.gradient_checkpointing = value - - -CHATGLM_6B_START_DOCSTRING = r""" - This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general - usage and behavior. - - Parameters: - config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the configuration. - Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - -CHATGLM_6B_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. - - Indices can be obtained using [`ChatGLM6BTokenizer`]. - See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: - - - 0 corresponds to a *sentence A* token, - - 1 corresponds to a *sentence B* token. - - [What are token type IDs?](../glossary#token-type-ids) - position_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. - Selected in the range `[0, config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert *input_ids* indices into associated vectors - than the model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.", - CHATGLM_6B_START_DOCSTRING, -) -class ChatGLMModel(ChatGLMPreTrainedModel): - """ - - The model can behave as an encoder (with only self-attention) as well - as a decoder, in which case a layer of cross-attention is added between - the self-attention layers, following the architecture described in [Attention is - all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, - Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. - - To behave as an decoder the model needs to be initialized with the - `is_decoder` argument of the configuration set to `True`. - To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` - argument and `add_cross_attention` set to `True`; an - `encoder_hidden_states` is then expected as an input to the forward pass. - """ - - def __init__(self, config: ChatGLMConfig, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - # recording parameters - self.max_sequence_length = config.max_sequence_length - self.hidden_size = config.hidden_size - self.params_dtype = torch.half - self.num_attention_heads = config.num_attention_heads - self.vocab_size = config.vocab_size - self.num_layers = config.num_layers - self.layernorm_epsilon = config.layernorm_epsilon - self.inner_hidden_size = config.inner_hidden_size - self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads - self.position_encoding_2d = config.position_encoding_2d - self.pre_seq_len = config.pre_seq_len - self.prefix_projection = config.prefix_projection - - self.word_embeddings = init_method( - torch.nn.Embedding, - num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, - dtype=self.params_dtype - ) - self.gradient_checkpointing = False - - def get_layer(layer_id): - return GLMBlock( - self.hidden_size, - self.num_attention_heads, - self.layernorm_epsilon, - layer_id, - inner_hidden_size=self.inner_hidden_size, - hidden_size_per_attention_head=self.hidden_size_per_attention_head, - layernorm=LayerNorm, - use_bias=True, - params_dtype=self.params_dtype, - position_encoding_2d=self.position_encoding_2d, - empty_init=empty_init - ) - - self.layers = torch.nn.ModuleList( - [get_layer(layer_id) for layer_id in range(self.num_layers)] - ) - - # Final layer norm before output. - self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) - - if self.pre_seq_len is not None: - for param in self.parameters(): - param.requires_grad = False - self.prefix_tokens = torch.arange(self.pre_seq_len).long() - self.prefix_encoder = PrefixEncoder(config) - self.dropout = torch.nn.Dropout(0.1) - - # total_params = sum(p.numel() for p in self.parameters()) - # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) - # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) - - def get_input_embeddings(self): - return self.word_embeddings - - def set_input_embeddings(self, new_embeddings: torch.Tensor): - self.word_embeddings = new_embeddings - - def get_prompt(self, batch_size, device, dtype=torch.half): - prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) - past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) - past_key_values = past_key_values.view( - batch_size, - self.pre_seq_len, - self.num_layers * 2, - self.num_attention_heads, - self.hidden_size // self.num_attention_heads - ) - # seq_len, b, nh, hidden_size - past_key_values = self.dropout(past_key_values) - past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) - # past_key_values = [(v[0], v[1]) for v in past_key_values] - return past_key_values - - @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @add_code_sample_docstrings( - checkpoint=_CHECKPOINT_FOR_DOC, - output_type=BaseModelOutputWithPastAndCrossAttentions, - config_class=_CONFIG_FOR_DOC, - ) - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]: - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - if past_key_values is None: - if self.pre_seq_len is not None: - past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device, - dtype=inputs_embeds.dtype) - else: - past_key_values = tuple([None] * len(self.layers)) - - if attention_mask is None: - attention_mask = self.get_masks( - input_ids, - device=input_ids.device - ) - - - if position_ids is None: - MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id - seqs = input_ids.tolist() - - mask_positions, use_gmasks = [], [] - for seq in seqs: - mask_token = gMASK if gMASK in seq else MASK - use_gmask = mask_token == gMASK - mask_positions.append(seq.index(mask_token)) - use_gmasks.append(use_gmask) - - position_ids = self.get_position_ids( - input_ids, - mask_positions=mask_positions, - device=input_ids.device, - use_gmasks=use_gmasks - ) - - if self.pre_seq_len is not None and attention_mask is not None: - prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( - attention_mask.device) - prefix_attention_mask = (prefix_attention_mask < 0.5).bool() - attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) - - # [seq_len, batch, hidden_size] - hidden_states = inputs_embeds.transpose(0, 1) - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - if attention_mask is None: - attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() - else: - attention_mask = attention_mask.to(hidden_states.device) - - for i, layer in enumerate(self.layers): - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - layer_past = past_key_values[i] - - if self.gradient_checkpointing and self.training: - layer_ret = torch.utils.checkpoint.checkpoint( - layer, - hidden_states, - position_ids, - attention_mask, - torch.tensor(i), - layer_past, - use_cache, - output_attentions - ) - else: - layer_ret = layer( - hidden_states, - position_ids=position_ids, - attention_mask=attention_mask, - layer_id=torch.tensor(i), - layer_past=layer_past, - use_cache=use_cache, - output_attentions=output_attentions - ) - - hidden_states = layer_ret[0] - - if use_cache: - presents = presents + (layer_ret[1],) - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],) - - # Final layer norm. - hidden_states = self.final_layernorm(hidden_states) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - - # self.hidden_size = config.hidden_size - # self.params_dtype = torch.half - # self.vocab_size = config.vocab_size - self.max_sequence_length = config.max_sequence_length - - self.position_encoding_2d = config.position_encoding_2d - - self.transformer = ChatGLMModel(config, empty_init=empty_init) - - self.lm_head = init_method( - nn.Linear, - config.hidden_size, - config.vocab_size, - bias=False, - dtype=torch.half - ) - - self.config = config - - self.quantized = False - - if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, empty_init=True) - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - if attention_mask is not None and attention_mask.dtype == torch.bool: - attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3) - new_attention_mask = attention_mask[:, :, -1:].clone() - new_attention_mask[..., -1] = False - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, new_attention_mask], dim=2 - ) - - # update position ids - if "position_ids" in model_kwargs: - position_ids = model_kwargs["position_ids"] - new_position_id = position_ids[..., -1:].clone() - new_position_id[:, 1, :] += 1 - model_kwargs["position_ids"] = torch.cat( - [position_ids, new_position_id], dim=-1 - ) - - return model_kwargs - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past: Optional[torch.Tensor] = None, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - **kwargs - ) -> dict: - batch_size, seq_length = input_ids.shape - MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id - seqs = input_ids.tolist() - mask_positions, use_gmasks = [], [] - for seq in seqs: - mask_token = gMASK if gMASK in seq else MASK - use_gmask = mask_token == gMASK - mask_positions.append(seq.index(mask_token)) - use_gmasks.append(use_gmask) - - # only last token for input_ids if past is not None - if past is not None or past_key_values is not None: - last_token = input_ids[:, -1].unsqueeze(-1) - if attention_mask is not None and attention_mask.dtype == torch.bool: - attention_mask = attention_mask[:, :, -1:] - else: - attention_mask = None - if position_ids is not None: - position_ids = position_ids[..., -1:] - else: - context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] - if self.position_encoding_2d: - position_ids = torch.tensor( - [[mask_position, seq_length - context_length] for mask_position, context_length in - zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) - else: - position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, - device=input_ids.device).unsqueeze(-1) - - if past is None: - past = past_key_values - return { - "input_ids": last_token, - "past_key_values": past, - "position_ids": position_ids, - "attention_mask": attention_mask - } - else: - if attention_mask is not None and attention_mask.dtype != torch.bool: - logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") - attention_mask = None - if attention_mask is None: - attention_mask = self.get_masks( - input_ids, - device=input_ids.device - ) - if position_ids is None: - position_ids = self.get_position_ids( - input_ids, - device=input_ids.device, - mask_positions=mask_positions, - use_gmasks=use_gmasks - ) - - return { - "input_ids": input_ids, - "past_key_values": past, - "position_ids": position_ids, - "attention_mask": attention_mask - } - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous() - - loss = None - if labels is not None: - lm_logits = lm_logits.to(torch.float32) - - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - return tuple( - ( - layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), - ) - for layer_past in past - ) - - def process_response(self, response): - response = response.strip() - response = response.replace("[[训练时间]]", "2023年") - punkts = [ - [",", ","], - ["!", "!"], - [":", ":"], - [";", ";"], - ["\?", "?"], - ] - for item in punkts: - response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) - response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) - return response - - @torch.no_grad() - def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, - do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - if not history: - prompt = query - else: - prompt = "" - for i, (old_query, response) in enumerate(history): - prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) - prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - outputs = self.generate(**inputs, **gen_kwargs) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - response = self.process_response(response) - history = history + [(query, response)] - return response, history - - @torch.no_grad() - def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, - do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - if not history: - prompt = query - else: - prompt = "" - for i, (old_query, response) in enumerate(history): - prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) - prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - for outputs in self.stream_generate(**inputs, **gen_kwargs): - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - response = self.process_response(response) - new_history = history + [(query, response)] - yield response, new_history - - @torch.no_grad() - def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - **kwargs, - ): - batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - - if generation_config is None: - generation_config = self.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - warnings.warn( - f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " - "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" - " recommend using `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - if not has_default_max_length: - logger.warn( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - UserWarning, - ) - - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) - - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - logits_warper = self._get_logits_warper(generation_config) - - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None - while True: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - if generation_config.do_sample: - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) - - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - break - yield input_ids - - def quantize(self, bits: int, empty_init=False, **kwargs): - if bits == 0: - return - - from .quantization import quantize - - if self.quantized: - logger.info("Already quantized.") - return self - - self.quantized = True - - self.config.quantization_bit = bits - - self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs) - return self diff --git a/transformers/llm/export/llm_models/chatglm2-6b/modeling_chatglm.py b/transformers/llm/export/llm_models/chatglm2-6b/modeling_chatglm.py deleted file mode 100644 index e9b5ca258..000000000 --- a/transformers/llm/export/llm_models/chatglm2-6b/modeling_chatglm.py +++ /dev/null @@ -1,1193 +0,0 @@ -""" PyTorch ChatGLM model. """ - -import math -import copy -import warnings -import re -import sys - -import torch -import torch.utils.checkpoint -import torch.nn.functional as F -from torch import nn -from torch.nn import CrossEntropyLoss, LayerNorm -from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable, Dict, Any - -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput - -from .configuration_chatglm import ChatGLMConfig - -# flags required to enable jit fusion kernels - -if sys.platform != 'darwin': - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B" -_CONFIG_FOR_DOC = "ChatGLM6BConfig" - -CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "THUDM/chatglm2-6b", - # See all ChatGLM models at https://huggingface.co/models?filter=chatglm -] - - -def default_init(cls, *args, **kwargs): - return cls(*args, **kwargs) - - -class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 5] = 5e4 - return scores - - -class PrefixEncoder(torch.nn.Module): - """ - The torch.nn model to encode the prefix - Input shape: (batch-size, prefix-length) - Output shape: (batch-size, prefix-length, 2*layers*hidden) - """ - - def __init__(self, config: ChatGLMConfig): - super().__init__() - self.prefix_projection = config.prefix_projection - if self.prefix_projection: - # Use a two-layer MLP to encode the prefix - kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 - self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) - self.trans = torch.nn.Sequential( - torch.nn.Linear(kv_size, config.hidden_size), - torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, kv_size) - ) - else: - self.embedding = torch.nn.Embedding(config.pre_seq_len, - config.num_layers * config.kv_channels * config.multi_query_group_num * 2) - - def forward(self, prefix: torch.Tensor): - if self.prefix_projection: - prefix_tokens = self.embedding(prefix) - past_key_values = self.trans(prefix_tokens) - else: - past_key_values = self.embedding(prefix) - return past_key_values - - -def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """Split a tensor along its last dimension. - - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - - Returns: - A list of Tensors - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -class RotaryEmbedding(nn.Module): - def __init__(self, dim, original_impl=False, device=None, dtype=None): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.dim = dim - self.original_impl = original_impl - - def forward_impl( - self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 - ): - """Enhanced Transformer with Rotary Position Embedding. - - Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ - transformers/rope/__init__.py. MIT License: - https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. - """ - # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) - - # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, dtype=dtype, device=device) - - # Calculate the product of position index and $\theta_i$ - idx_theta = torch.outer(seq_idx, theta).float() - - cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) - - # this is to mimic the behaviour of complex32, else we will get different results - if dtype in (torch.float16, torch.bfloat16, torch.int8): - cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() - return cache - - def forward(self, max_seq_len, offset=0): - return self.forward_impl( - max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device - ) - - -@torch.jit.script -def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: - # x: [sq, b, np, hn] - sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) - rot_dim = rope_cache.shape[-2] * 2 - x, x_pass = x[..., :rot_dim], x[..., rot_dim:] - # truncate to support variable sizes - rope_cache = rope_cache[:sq] - xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) - rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], - ], - -1, - ) - x_out2 = x_out2.flatten(3) - return torch.cat((x_out2, x_pass), dim=-1) - - -class RMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): - super().__init__() - self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) - self.eps = eps - - def forward(self, hidden_states: torch.Tensor): - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - - return (self.weight * hidden_states).to(input_dtype) - - -class CoreAttention(torch.nn.Module): - def __init__(self, config: ChatGLMConfig, layer_number): - super(CoreAttention, self).__init__() - - self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - self.layer_number = max(1, layer_number) - - projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_partition = projection_size - self.hidden_size_per_attention_head = projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.apply_query_key_layer_scaling: - coeff = self.layer_number - self.norm_factor *= coeff - self.coeff = coeff - - self.attention_dropout = torch.nn.Dropout(config.attention_dropout) - - def forward(self, query_layer, key_layer, value_layer, attention_mask): - pytorch_major_version = int(torch.__version__.split('.')[0]) - if pytorch_major_version >= 2 and False: - query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - is_causal=True) - else: - if attention_mask is not None: - attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attention_mask) - context_layer = context_layer.permute(2, 0, 1, 3) - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) - else: - # Raw attention scores - - # [b, np, sq, sk] - output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) - - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, - device=query_layer.device - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - if self.attention_softmax_in_fp32: - attention_scores = attention_scores.float() - if self.coeff is not None: - attention_scores = attention_scores * self.coeff - if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: - attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], - device=attention_scores.device, dtype=torch.bool) - attention_mask.tril_() - attention_mask = ~attention_mask - if attention_mask is not None: - attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_probs.type_as(value_layer) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.attention_dropout(attention_probs) - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer - - -class SelfAttention(torch.nn.Module): - """Parallel self-attention layer abstract class. - - Self-attention layer takes input with size [s, b, h] - and returns output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(SelfAttention, self).__init__() - self.layer_number = max(1, layer_number) - - self.projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - self.multi_query_attention = config.multi_query_attention - self.qkv_hidden_size = 3 * self.projection_size - if self.multi_query_attention: - self.num_multi_query_groups_per_partition = config.multi_query_group_num - self.qkv_hidden_size = ( - self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num - ) - self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, - bias=config.add_bias_linear or config.add_qkv_bias, - device=device, **_config_to_kwargs(config) - ) - - self.core_attention = CoreAttention(config, self.layer_number) - - # Output. - self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, - device=device, **_config_to_kwargs(config) - ) - - def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): - if self.multi_query_attention: - num_attention_heads = self.num_multi_query_groups_per_partition - else: - num_attention_heads = self.num_attention_heads_per_partition - return torch.empty( - inference_max_sequence_len, - batch_size, - num_attention_heads, - self.hidden_size_per_attention_head, - dtype=dtype, - device=device, - ) - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True - ): - # hidden_states: [sq, b, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - # ===================== - # Query, Key, and Value - # ===================== - - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer = self.query_key_value(hidden_states) - - if self.multi_query_attention: - (query_layer, key_layer, value_layer) = mixed_x_layer.split( - [ - self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - ], - dim=-1, - ) - query_layer = query_layer.view( - query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - key_layer = key_layer.view( - key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - value_layer = value_layer.view( - value_layer.size()[:-1] - + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - else: - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - - # apply relative positional encoding (rotary embedding) - if rotary_pos_emb is not None: - query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - - # adjust key and value for inference - if kv_cache is not None: - cache_k, cache_v = kv_cache - key_layer = torch.cat((cache_k, key_layer), dim=0) - value_layer = torch.cat((cache_v, value_layer), dim=0) - if use_cache: - kv_cache = (key_layer, value_layer) - else: - kv_cache = None - - if self.multi_query_attention: - key_layer = key_layer.unsqueeze(-2) - key_layer = key_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 - ) - key_layer = key_layer.contiguous().view( - key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - value_layer = value_layer.unsqueeze(-2) - value_layer = value_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 - ) - value_layer = value_layer.contiguous().view( - value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - - # ================================== - # core attention computation - # ================================== - - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - - # ================= - # Output. [sq, b, h] - # ================= - - output = self.dense(context_layer) - - return output, kv_cache - - -def _config_to_kwargs(args): - common_kwargs = { - "dtype": args.torch_dtype, - } - return common_kwargs - - -class MLP(torch.nn.Module): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. - """ - - def __init__(self, config: ChatGLMConfig, device=None): - super(MLP, self).__init__() - - self.add_bias = config.add_bias_linear - - # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - self.dense_h_to_4h = nn.Linear( - config.hidden_size, - config.ffn_hidden_size * 2, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def swiglu(x): - x = torch.chunk(x, 2, dim=-1) - return F.silu(x[0]) * x[1] - - self.activation_func = swiglu - - # Project back to h. - self.dense_4h_to_h = nn.Linear( - config.ffn_hidden_size, - config.hidden_size, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def forward(self, hidden_states): - # [s, b, 4hp] - intermediate_parallel = self.dense_h_to_4h(hidden_states) - intermediate_parallel = self.activation_func(intermediate_parallel) - # [s, b, h] - output = self.dense_4h_to_h(intermediate_parallel) - return output - - -class GLMBlock(torch.nn.Module): - """A single transformer layer. - - Transformer layer takes input with size [s, b, h] and returns an - output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(GLMBlock, self).__init__() - self.layer_number = layer_number - - self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm - - self.fp32_residual_connection = config.fp32_residual_connection - - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Layernorm on the input data. - self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # Self attention. - self.self_attention = SelfAttention(config, layer_number, device=device) - self.hidden_dropout = config.hidden_dropout - - # Layernorm on the attention output - self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # MLP - self.mlp = MLP(config, device=device) - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, - ): - # hidden_states: [s, b, h] - - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - # Self attention. - attention_output, kv_cache = self.self_attention( - layernorm_output, - attention_mask, - rotary_pos_emb, - kv_cache=kv_cache, - use_cache=use_cache - ) - - # Residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) - layernorm_input = residual + layernorm_input - - # Layer norm post the self attention. - layernorm_output = self.post_attention_layernorm(layernorm_input) - - # MLP. - mlp_output = self.mlp(layernorm_output) - - # Second residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) - output = residual + output - - return output, kv_cache - - -class GLMTransformer(torch.nn.Module): - """Transformer class.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(GLMTransformer, self).__init__() - - self.fp32_residual_connection = config.fp32_residual_connection - self.post_layer_norm = config.post_layer_norm - - # Number of layers. - self.num_layers = config.num_layers - - # Transformer layers. - def build_layer(layer_number): - return GLMBlock(config, layer_number, device=device) - - self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) - - if self.post_layer_norm: - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - self.gradient_checkpointing = False - - def _get_layer(self, layer_number): - return self.layers[layer_number] - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, - use_cache: Optional[bool] = True, - output_hidden_states: Optional[bool] = False, - ): - if not kv_caches: - kv_caches = [None for _ in range(self.num_layers)] - presents = () if use_cache else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - all_self_attentions = None - all_hidden_states = () if output_hidden_states else None - for index in range(self.num_layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer = self._get_layer(index) - if self.gradient_checkpointing and self.training: - layer_ret = torch.utils.checkpoint.checkpoint( - layer, - hidden_states, - attention_mask, - rotary_pos_emb, - kv_caches[index], - use_cache - ) - else: - layer_ret = layer( - hidden_states, - attention_mask, - rotary_pos_emb, - kv_cache=kv_caches[index], - use_cache=use_cache - ) - hidden_states, kv_cache = layer_ret - if use_cache: - presents = presents + (kv_cache,) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # Final layer norm. - if self.post_layer_norm: - hidden_states = self.final_layernorm(hidden_states) - - return hidden_states, presents, all_hidden_states, all_self_attentions - - -class ChatGLMPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - is_parallelizable = False - supports_gradient_checkpointing = True - config_class = ChatGLMConfig - base_model_prefix = "transformer" - _no_split_modules = ["GLMBlock"] - - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - return - - def get_masks(self, input_ids, past_key_values, padding_mask=None): - batch_size, seq_length = input_ids.shape - full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) - full_attention_mask.tril_() - past_length = 0 - if past_key_values: - past_length = past_key_values[0][0].shape[0] - if past_length: - full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, - device=input_ids.device), full_attention_mask), dim=-1) - if padding_mask is not None: - full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) - if not past_length and padding_mask is not None: - full_attention_mask -= padding_mask.unsqueeze(-1) - 1 - full_attention_mask = (full_attention_mask < 0.5).bool() - full_attention_mask.unsqueeze_(1) - return full_attention_mask - - def get_position_ids(self, input_ids, device): - batch_size, seq_length = input_ids.shape - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - return position_ids - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, GLMTransformer): - module.gradient_checkpointing = value - - -class Embedding(torch.nn.Module): - """Language model embeddings.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(Embedding, self).__init__() - - self.hidden_size = config.hidden_size - # Word embeddings (parallel). - self.word_embeddings = nn.Embedding( - config.padded_vocab_size, - self.hidden_size, - dtype=config.torch_dtype, - device=device - ) - self.fp32_residual_connection = config.fp32_residual_connection - - def forward(self, input_ids): - # Embeddings. - words_embeddings = self.word_embeddings(input_ids) - embeddings = words_embeddings - # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. - embeddings = embeddings.transpose(0, 1).contiguous() - # If the input flag for fp32 residual connection is set, convert for float. - if self.fp32_residual_connection: - embeddings = embeddings.float() - return embeddings - - -class ChatGLMModel(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - init_kwargs = {} - if device is not None: - init_kwargs["device"] = device - self.embedding = init_method(Embedding, config, **init_kwargs) - self.num_layers = config.num_layers - self.multi_query_group_num = config.multi_query_group_num - self.kv_channels = config.kv_channels - - # Rotary positional embeddings - self.seq_length = config.seq_length - rotary_dim = ( - config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels - ) - - self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device, - dtype=config.torch_dtype) - self.encoder = init_method(GLMTransformer, config, **init_kwargs) - self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, - dtype=config.torch_dtype, **init_kwargs) - self.pre_seq_len = config.pre_seq_len - self.prefix_projection = config.prefix_projection - if self.pre_seq_len is not None: - for param in self.parameters(): - param.requires_grad = False - self.prefix_tokens = torch.arange(self.pre_seq_len).long() - self.prefix_encoder = PrefixEncoder(config) - self.dropout = torch.nn.Dropout(0.1) - - def get_input_embeddings(self): - return self.embedding.word_embeddings - - def get_prompt(self, batch_size, device, dtype=torch.half): - prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) - past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) - past_key_values = past_key_values.view( - batch_size, - self.pre_seq_len, - self.num_layers * 2, - self.multi_query_group_num, - self.kv_channels - ) - # seq_len, b, nh, hidden_size - past_key_values = self.dropout(past_key_values) - past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) - return past_key_values - - def forward( - self, - input_ids, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - batch_size, seq_length = input_ids.shape - - if inputs_embeds is None: - inputs_embeds = self.embedding(input_ids) - - if self.pre_seq_len is not None: - if past_key_values is None: - past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, - dtype=inputs_embeds.dtype) - if attention_mask is not None: - attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), - attention_mask], dim=-1) - - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) - - # Rotary positional embeddings - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - if position_ids is not None: - rotary_pos_emb = rotary_pos_emb[position_ids] - else: - rotary_pos_emb = rotary_pos_emb[None, :seq_length] - rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() - - # Run encoder. - hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( - inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, - kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states - ) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - def quantize(self, weight_bit_width: int): - from .quantization import quantize - quantize(self.encoder, weight_bit_width) - return self - - -class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): - super().__init__(config) - - self.max_sequence_length = config.max_length - self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) - self.config = config - self.quantized = False - - if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, empty_init=True) - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - - # update position ids - if "position_ids" in model_kwargs: - position_ids = model_kwargs["position_ids"] - new_position_id = position_ids[..., -1:].clone() - new_position_id += 1 - model_kwargs["position_ids"] = torch.cat( - [position_ids, new_position_id], dim=-1 - ) - - model_kwargs["is_first_forward"] = False - return model_kwargs - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - is_first_forward: bool = True, - **kwargs - ) -> dict: - # only last token for input_ids if past is not None - if position_ids is None: - position_ids = self.get_position_ids(input_ids, device=input_ids.device) - if not is_first_forward: - position_ids = position_ids[..., -1:] - input_ids = input_ids[:, -1:] - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "position_ids": position_ids, - "attention_mask": attention_mask, - "return_last_logit": True - } - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - ): - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - if return_last_logit: - hidden_states = hidden_states[-1:] - lm_logits = self.transformer.output_layer(hidden_states) - lm_logits = lm_logits.transpose(0, 1).contiguous() - - loss = None - if labels is not None: - lm_logits = lm_logits.to(torch.float32) - - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - return tuple( - ( - layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), - ) - for layer_past in past - ) - - def process_response(self, response): - response = response.strip() - response = response.replace("[[训练时间]]", "2023年") - return response - - def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): - prompt = tokenizer.build_prompt(query, history=history) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - return inputs - - def build_stream_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): - if history: - prompt = "\n\n[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) - input_ids = tokenizer.encode(prompt, add_special_tokens=False) - input_ids = input_ids[1:] - inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False) - else: - prompt = "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - return inputs - - @torch.inference_mode() - def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1, - do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - inputs = self.build_inputs(tokenizer, query, history=history) - outputs = self.generate(**inputs, **gen_kwargs) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - response = self.process_response(response) - history = history + [(query, response)] - return response, history - - @torch.inference_mode() - def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values=None, - max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, - return_past_key_values=False, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - if past_key_values is None and not return_past_key_values: - inputs = self.build_inputs(tokenizer, query, history=history) - else: - inputs = self.build_stream_inputs(tokenizer, query, history=history) - if past_key_values is not None: - past_length = past_key_values[0][0].shape[0] - if self.transformer.pre_seq_len is not None: - past_length -= self.transformer.pre_seq_len - inputs.position_ids += past_length - attention_mask = inputs.attention_mask - attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) - inputs['attention_mask'] = attention_mask - for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, - return_past_key_values=return_past_key_values, **gen_kwargs): - if return_past_key_values: - outputs, past_key_values = outputs - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - if response and response[-1] != "�": - response = self.process_response(response) - new_history = history + [(query, response)] - if return_past_key_values: - yield response, new_history, past_key_values - else: - yield response, new_history - - @torch.inference_mode() - def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - return_past_key_values=False, - **kwargs, - ): - batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - - if generation_config is None: - generation_config = self.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - warnings.warn( - f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " - "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" - " recommend using `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - if not has_default_max_length: - logger.warn( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - UserWarning, - ) - - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) - - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - logits_warper = self._get_logits_warper(generation_config) - - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None - while True: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - if generation_config.do_sample: - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) - if return_past_key_values: - yield input_ids, outputs.past_key_values - else: - yield input_ids - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - break - - def quantize(self, bits: int, empty_init=False, device=None, **kwargs): - if bits == 0: - return - - from .quantization import quantize - - if self.quantized: - logger.info("Already quantized.") - return self - - self.quantized = True - - self.config.quantization_bit = bits - - self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device, - **kwargs) - return self diff --git a/transformers/llm/export/llm_models/chatglm3-6b/modeling_chatglm.py b/transformers/llm/export/llm_models/chatglm3-6b/modeling_chatglm.py deleted file mode 100755 index f887c44ce..000000000 --- a/transformers/llm/export/llm_models/chatglm3-6b/modeling_chatglm.py +++ /dev/null @@ -1,1293 +0,0 @@ -""" PyTorch ChatGLM model. """ - -import math -import copy -import warnings -import re -import sys - -import torch -import torch.utils.checkpoint -import torch.nn.functional as F -from torch import nn -from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss -from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable, Dict, Any -from copy import deepcopy - -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput - -from .configuration_chatglm import ChatGLMConfig - -# flags required to enable jit fusion kernels - -if sys.platform != 'darwin': - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM" -_CONFIG_FOR_DOC = "ChatGLMConfig" - -CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "THUDM/chatglm3-6b", - # See all ChatGLM models at https://huggingface.co/models?filter=chatglm -] - - -def default_init(cls, *args, **kwargs): - return cls(*args, **kwargs) - - -class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 5] = 5e4 - return scores - - -class PrefixEncoder(torch.nn.Module): - """ - The torch.nn model to encode the prefix - Input shape: (batch-size, prefix-length) - Output shape: (batch-size, prefix-length, 2*layers*hidden) - """ - - def __init__(self, config: ChatGLMConfig): - super().__init__() - self.prefix_projection = config.prefix_projection - if self.prefix_projection: - # Use a two-layer MLP to encode the prefix - kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 - self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) - self.trans = torch.nn.Sequential( - torch.nn.Linear(kv_size, config.hidden_size), - torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, kv_size) - ) - else: - self.embedding = torch.nn.Embedding(config.pre_seq_len, - config.num_layers * config.kv_channels * config.multi_query_group_num * 2) - - def forward(self, prefix: torch.Tensor): - if self.prefix_projection: - prefix_tokens = self.embedding(prefix) - past_key_values = self.trans(prefix_tokens) - else: - past_key_values = self.embedding(prefix) - return past_key_values - - -def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """Split a tensor along its last dimension. - - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - - Returns: - A list of Tensors - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -class RotaryEmbedding(nn.Module): - def __init__(self, dim, original_impl=False, device=None, dtype=None): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.dim = dim - self.original_impl = original_impl - - def forward_impl( - self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 - ): - """Enhanced Transformer with Rotary Position Embedding. - - Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ - transformers/rope/__init__.py. MIT License: - https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. - """ - # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)) - - # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) - - # Calculate the product of position index and $\theta_i$ - idx_theta = torch.outer(seq_idx, theta).float() - - cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) - - # this is to mimic the behaviour of complex32, else we will get different results - if dtype in (torch.float16, torch.bfloat16, torch.int8): - cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() - return cache - - def forward(self, max_seq_len, offset=0): - return self.forward_impl( - max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device - ) - - -@torch.jit.script -def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: - # x: [sq, b, np, hn] - sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) - rot_dim = rope_cache.shape[-2] * 2 - x, x_pass = x[..., :rot_dim], x[..., rot_dim:] - # truncate to support variable sizes - rope_cache = rope_cache[:sq] - xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) - rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], - ], - -1, - ) - x_out2 = x_out2.flatten(3) - return torch.cat((x_out2, x_pass), dim=-1) - - -class RMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): - super().__init__() - self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) - self.eps = eps - - def forward(self, hidden_states: torch.Tensor): - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - - return (self.weight * hidden_states).to(input_dtype) - - -class CoreAttention(torch.nn.Module): - def __init__(self, config: ChatGLMConfig, layer_number): - super(CoreAttention, self).__init__() - - self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - self.layer_number = max(1, layer_number) - - projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_partition = projection_size - self.hidden_size_per_attention_head = projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.apply_query_key_layer_scaling: - coeff = self.layer_number - self.norm_factor *= coeff - self.coeff = coeff - - self.attention_dropout = torch.nn.Dropout(config.attention_dropout) - - def forward(self, query_layer, key_layer, value_layer, attention_mask): - pytorch_major_version = int(torch.__version__.split('.')[0]) - if pytorch_major_version >= 2 and False: - query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - is_causal=True) - else: - if attention_mask is not None: - attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attention_mask) - context_layer = context_layer.permute(2, 0, 1, 3) - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) - else: - # Raw attention scores - - # [b, np, sq, sk] - output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) - - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, - device=query_layer.device - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - if self.attention_softmax_in_fp32: - attention_scores = attention_scores.float() - if self.coeff is not None: - attention_scores = attention_scores * self.coeff - if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: - attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], - device=attention_scores.device, dtype=torch.bool) - attention_mask.tril_() - attention_mask = ~attention_mask - if attention_mask is not None: - attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_probs.type_as(value_layer) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.attention_dropout(attention_probs) - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer - - -class SelfAttention(torch.nn.Module): - """Parallel self-attention layer abstract class. - - Self-attention layer takes input with size [s, b, h] - and returns output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(SelfAttention, self).__init__() - self.layer_number = max(1, layer_number) - - self.projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - self.multi_query_attention = config.multi_query_attention - self.qkv_hidden_size = 3 * self.projection_size - if self.multi_query_attention: - self.num_multi_query_groups_per_partition = config.multi_query_group_num - self.qkv_hidden_size = ( - self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num - ) - self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, - bias=config.add_bias_linear or config.add_qkv_bias, - device=device, **_config_to_kwargs(config) - ) - - self.core_attention = CoreAttention(config, self.layer_number) - - # Output. - self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, - device=device, **_config_to_kwargs(config) - ) - - def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): - if self.multi_query_attention: - num_attention_heads = self.num_multi_query_groups_per_partition - else: - num_attention_heads = self.num_attention_heads_per_partition - return torch.empty( - inference_max_sequence_len, - batch_size, - num_attention_heads, - self.hidden_size_per_attention_head, - dtype=dtype, - device=device, - ) - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True - ): - # hidden_states: [sq, b, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - # ===================== - # Query, Key, and Value - # ===================== - - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer = self.query_key_value(hidden_states) - - if self.multi_query_attention: - (query_layer, key_layer, value_layer) = mixed_x_layer.split( - [ - self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - ], - dim=-1, - ) - query_layer = query_layer.view( - query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - key_layer = key_layer.view( - key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - value_layer = value_layer.view( - value_layer.size()[:-1] - + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - else: - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - - # apply relative positional encoding (rotary embedding) - if rotary_pos_emb is not None: - query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - - # adjust key and value for inference - if kv_cache is not None: - cache_k, cache_v = kv_cache - key_layer = torch.cat((cache_k, key_layer), dim=0) - value_layer = torch.cat((cache_v, value_layer), dim=0) - if use_cache: - kv_cache = (key_layer, value_layer) - else: - kv_cache = None - - if self.multi_query_attention: - key_layer = key_layer.unsqueeze(-2) - key_layer = key_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 - ) - key_layer = key_layer.contiguous().view( - key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - value_layer = value_layer.unsqueeze(-2) - value_layer = value_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 - ) - value_layer = value_layer.contiguous().view( - value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - - # ================================== - # core attention computation - # ================================== - - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - - # ================= - # Output. [sq, b, h] - # ================= - - output = self.dense(context_layer) - - return output, kv_cache - - -def _config_to_kwargs(args): - common_kwargs = { - "dtype": args.torch_dtype, - } - return common_kwargs - - -class MLP(torch.nn.Module): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. - """ - - def __init__(self, config: ChatGLMConfig, device=None): - super(MLP, self).__init__() - - self.add_bias = config.add_bias_linear - - # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - self.dense_h_to_4h = nn.Linear( - config.hidden_size, - config.ffn_hidden_size * 2, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def swiglu(x): - x = torch.chunk(x, 2, dim=-1) - return F.silu(x[0]) * x[1] - - self.activation_func = swiglu - - # Project back to h. - self.dense_4h_to_h = nn.Linear( - config.ffn_hidden_size, - config.hidden_size, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def forward(self, hidden_states): - # [s, b, 4hp] - intermediate_parallel = self.dense_h_to_4h(hidden_states) - intermediate_parallel = self.activation_func(intermediate_parallel) - # [s, b, h] - output = self.dense_4h_to_h(intermediate_parallel) - return output - - -class GLMBlock(torch.nn.Module): - """A single transformer layer. - - Transformer layer takes input with size [s, b, h] and returns an - output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(GLMBlock, self).__init__() - self.layer_number = layer_number - - self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm - - self.fp32_residual_connection = config.fp32_residual_connection - - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Layernorm on the input data. - self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # Self attention. - self.self_attention = SelfAttention(config, layer_number, device=device) - self.hidden_dropout = config.hidden_dropout - - # Layernorm on the attention output - self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # MLP - self.mlp = MLP(config, device=device) - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, - ): - # hidden_states: [s, b, h] - - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - # Self attention. - attention_output, kv_cache = self.self_attention( - layernorm_output, - attention_mask, - rotary_pos_emb, - kv_cache=kv_cache, - use_cache=use_cache - ) - - # Residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) - layernorm_input = residual + layernorm_input - - # Layer norm post the self attention. - layernorm_output = self.post_attention_layernorm(layernorm_input) - - # MLP. - mlp_output = self.mlp(layernorm_output) - - # Second residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) - output = residual + output - - return output, kv_cache - - -class GLMTransformer(torch.nn.Module): - """Transformer class.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(GLMTransformer, self).__init__() - - self.fp32_residual_connection = config.fp32_residual_connection - self.post_layer_norm = config.post_layer_norm - - # Number of layers. - self.num_layers = config.num_layers - - # Transformer layers. - def build_layer(layer_number): - return GLMBlock(config, layer_number, device=device) - - self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) - - if self.post_layer_norm: - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - self.gradient_checkpointing = False - - def _get_layer(self, layer_number): - return self.layers[layer_number] - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, - use_cache: Optional[bool] = True, - output_hidden_states: Optional[bool] = False, - ): - if not kv_caches: - kv_caches = [None for _ in range(self.num_layers)] - presents = () if use_cache else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - all_self_attentions = None - all_hidden_states = () if output_hidden_states else None - for index in range(self.num_layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer = self._get_layer(index) - if self.gradient_checkpointing and self.training: - layer_ret = torch.utils.checkpoint.checkpoint( - layer, - hidden_states, - attention_mask, - rotary_pos_emb, - kv_caches[index], - use_cache - ) - else: - layer_ret = layer( - hidden_states, - attention_mask, - rotary_pos_emb, - kv_cache=kv_caches[index], - use_cache=use_cache - ) - hidden_states, kv_cache = layer_ret - if use_cache: - presents = presents + (kv_cache,) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # Final layer norm. - if self.post_layer_norm: - hidden_states = self.final_layernorm(hidden_states) - - return hidden_states, presents, all_hidden_states, all_self_attentions - - -class ChatGLMPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - is_parallelizable = False - supports_gradient_checkpointing = True - config_class = ChatGLMConfig - base_model_prefix = "transformer" - _no_split_modules = ["GLMBlock"] - - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - return - - def get_masks(self, input_ids, past_key_values, padding_mask=None): - batch_size, seq_length = input_ids.shape - full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) - full_attention_mask.tril_() - past_length = 0 - if past_key_values: - past_length = past_key_values[0][0].shape[0] - if past_length: - full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, - device=input_ids.device), full_attention_mask), dim=-1) - if padding_mask is not None: - full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) - if not past_length and padding_mask is not None: - full_attention_mask -= padding_mask.unsqueeze(-1) - 1 - full_attention_mask = (full_attention_mask < 0.5).bool() - full_attention_mask.unsqueeze_(1) - return full_attention_mask - - def get_position_ids(self, input_ids, device): - batch_size, seq_length = input_ids.shape - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - return position_ids - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, GLMTransformer): - module.gradient_checkpointing = value - - -class Embedding(torch.nn.Module): - """Language model embeddings.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(Embedding, self).__init__() - - self.hidden_size = config.hidden_size - # Word embeddings (parallel). - self.word_embeddings = nn.Embedding( - config.padded_vocab_size, - self.hidden_size, - dtype=config.torch_dtype, - device=device - ) - self.fp32_residual_connection = config.fp32_residual_connection - - def forward(self, input_ids): - # Embeddings. - words_embeddings = self.word_embeddings(input_ids) - embeddings = words_embeddings - # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. - embeddings = embeddings.transpose(0, 1).contiguous() - # If the input flag for fp32 residual connection is set, convert for float. - if self.fp32_residual_connection: - embeddings = embeddings.float() - return embeddings - - -class ChatGLMModel(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - init_kwargs = {} - if device is not None: - init_kwargs["device"] = device - self.embedding = init_method(Embedding, config, **init_kwargs) - self.num_layers = config.num_layers - self.multi_query_group_num = config.multi_query_group_num - self.kv_channels = config.kv_channels - - # Rotary positional embeddings - self.seq_length = config.seq_length - rotary_dim = ( - config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels - ) - - self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device, - dtype=config.torch_dtype) - self.encoder = init_method(GLMTransformer, config, **init_kwargs) - self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, - dtype=config.torch_dtype, **init_kwargs) - self.pre_seq_len = config.pre_seq_len - self.prefix_projection = config.prefix_projection - if self.pre_seq_len is not None: - for param in self.parameters(): - param.requires_grad = False - self.prefix_tokens = torch.arange(self.pre_seq_len).long() - self.prefix_encoder = PrefixEncoder(config) - self.dropout = torch.nn.Dropout(0.1) - - def get_input_embeddings(self): - return self.embedding.word_embeddings - - def get_prompt(self, batch_size, device, dtype=torch.half): - prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) - past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) - past_key_values = past_key_values.view( - batch_size, - self.pre_seq_len, - self.num_layers * 2, - self.multi_query_group_num, - self.kv_channels - ) - # seq_len, b, nh, hidden_size - past_key_values = self.dropout(past_key_values) - past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) - return past_key_values - - def forward( - self, - input_ids, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - batch_size, seq_length = input_ids.shape - - if inputs_embeds is None: - inputs_embeds = self.embedding(input_ids) - - if self.pre_seq_len is not None: - if past_key_values is None: - past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, - dtype=inputs_embeds.dtype) - if attention_mask is not None: - attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), - attention_mask], dim=-1) - - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) - - # Rotary positional embeddings - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - if position_ids is not None: - rotary_pos_emb = rotary_pos_emb[position_ids] - else: - rotary_pos_emb = rotary_pos_emb[None, :seq_length] - rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() - - # Run encoder. - hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( - inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, - kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states - ) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - def quantize(self, weight_bit_width: int): - from .quantization import quantize - quantize(self.encoder, weight_bit_width) - return self - - -class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): - super().__init__(config) - - self.max_sequence_length = config.max_length - self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) - self.config = config - self.quantized = False - - if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, empty_init=True) - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - - # update position ids - if "position_ids" in model_kwargs: - position_ids = model_kwargs["position_ids"] - new_position_id = position_ids[..., -1:].clone() - new_position_id += 1 - model_kwargs["position_ids"] = torch.cat( - [position_ids, new_position_id], dim=-1 - ) - - model_kwargs["is_first_forward"] = False - return model_kwargs - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - is_first_forward: bool = True, - **kwargs - ) -> dict: - # only last token for input_ids if past is not None - if position_ids is None: - position_ids = self.get_position_ids(input_ids, device=input_ids.device) - if not is_first_forward: - if past_key_values is not None: - position_ids = position_ids[..., -1:] - input_ids = input_ids[:, -1:] - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "position_ids": position_ids, - "attention_mask": attention_mask, - "return_last_logit": True, - "use_cache": use_cache - } - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - ): - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - if return_last_logit: - hidden_states = hidden_states[-1:] - lm_logits = self.transformer.output_layer(hidden_states) - lm_logits = lm_logits.transpose(0, 1).contiguous() - - loss = None - if labels is not None: - lm_logits = lm_logits.to(torch.float32) - - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - return tuple( - ( - layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), - ) - for layer_past in past - ) - - def process_response(self, output, history): - content = "" - history = deepcopy(history) - for response in output.split("<|assistant|>"): - metadata, content = response.split("\n", maxsplit=1) - if not metadata.strip(): - content = content.strip() - history.append({"role": "assistant", "metadata": metadata, "content": content}) - content = content.replace("[[训练时间]]", "2023年") - else: - history.append({"role": "assistant", "metadata": metadata, "content": content}) - if history[0]["role"] == "system" and "tools" in history[0]: - content = "\n".join(content.split("\n")[1:-1]) - def tool_call(**kwargs): - return kwargs - parameters = eval(content) - content = {"name": metadata.strip(), "parameters": parameters} - else: - content = {"name": metadata.strip(), "content": content} - return content, history - - @torch.inference_mode() - def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user", - max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, - **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - inputs = tokenizer.build_chat_input(query, history=history, role=role) - inputs = inputs.to(self.device) - eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), - tokenizer.get_command("<|observation|>")] - outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] - response = tokenizer.decode(outputs) - history.append({"role": role, "content": query}) - response, history = self.process_response(response, history) - return response, history - - @torch.inference_mode() - def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user", - past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, - logits_processor=None, return_past_key_values=False, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), - tokenizer.get_command("<|observation|>")] - gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - if past_key_values is None: - inputs = tokenizer.build_chat_input(query, history=history, role=role) - else: - inputs = tokenizer.build_chat_input(query, role=role) - inputs = inputs.to(self.device) - if past_key_values is not None: - past_length = past_key_values[0][0].shape[0] - if self.transformer.pre_seq_len is not None: - past_length -= self.transformer.pre_seq_len - inputs.position_ids += past_length - attention_mask = inputs.attention_mask - attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) - inputs['attention_mask'] = attention_mask - history.append({"role": role, "content": query}) - for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, - eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, - **gen_kwargs): - if return_past_key_values: - outputs, past_key_values = outputs - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] - response = tokenizer.decode(outputs) - if response and response[-1] != "�": - response, new_history = self.process_response(response, history) - if return_past_key_values: - yield response, new_history, past_key_values - else: - yield response, new_history - - @torch.inference_mode() - def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - return_past_key_values=False, - **kwargs, - ): - batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - - if generation_config is None: - generation_config = self.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - model_kwargs["use_cache"] = generation_config.use_cache - bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - warnings.warn( - f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " - "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" - " recommend using `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - if not has_default_max_length: - logger.warn( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - UserWarning, - ) - - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) - - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - logits_warper = self._get_logits_warper(generation_config) - - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None - while True: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - if generation_config.do_sample: - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) - if return_past_key_values: - yield input_ids, outputs.past_key_values - else: - yield input_ids - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - break - - def quantize(self, bits: int, empty_init=False, device=None, **kwargs): - if bits == 0: - return - - from .quantization import quantize - - if self.quantized: - logger.info("Already quantized.") - return self - - self.quantized = True - - self.config.quantization_bit = bits - - self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device, - **kwargs) - return self - - -class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): - super().__init__(config) - - self.num_labels = config.num_labels - self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) - - self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half) - if config.classifier_dropout is not None: - self.dropout = nn.Dropout(config.classifier_dropout) - else: - self.dropout = None - self.config = config - - if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, empty_init=True) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - full_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - full_attention_mask=full_attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - pooled_hidden_states = hidden_states[-1] - if self.dropout is not None: - pooled_hidden_states = self.dropout(pooled_hidden_states) - logits = self.classifier_head(pooled_hidden_states) - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze().float(), labels.squeeze()) - else: - loss = loss_fct(logits.float(), labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits.float(), labels.view(-1, self.num_labels)) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/codegeex2-6b/modeling_chatglm.py b/transformers/llm/export/llm_models/codegeex2-6b/modeling_chatglm.py deleted file mode 100755 index fdc619f81..000000000 --- a/transformers/llm/export/llm_models/codegeex2-6b/modeling_chatglm.py +++ /dev/null @@ -1,1092 +0,0 @@ -""" PyTorch ChatGLM model. """ - -import math -import copy -import warnings -import re -import sys - -import torch -import torch.utils.checkpoint -import torch.nn.functional as F -from torch import nn -from torch.nn import CrossEntropyLoss, LayerNorm -from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable, Dict, Any - -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput - -from .configuration_chatglm import ChatGLMConfig - -# flags required to enable jit fusion kernels - -if sys.platform != 'darwin': - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B" -_CONFIG_FOR_DOC = "ChatGLM6BConfig" - -CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "THUDM/chatglm-6b", - # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm -] - - -def default_init(cls, *args, **kwargs): - return cls(*args, **kwargs) - - -class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 5] = 5e4 - return scores - - -def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """Split a tensor along its last dimension. - - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - - Returns: - A list of Tensors - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -class RotaryEmbedding(nn.Module): - def __init__(self, dim, original_impl=False, device=None, dtype=None): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device, dtype=dtype) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.dim = dim - self.original_impl = original_impl - - def forward_original_impl( - self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 - ): - """Enhanced Transformer with Rotary Position Embedding. - - Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ - transformers/rope/__init__.py. MIT License: - https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. - """ - # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) - - # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, dtype=dtype, device=device) - - # Calculate the product of position index and $\theta_i$ - idx_theta = torch.outer(seq_idx, theta).float() - - cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) - - # this is to mimic the behaviour of complex32, else we will get different results - if dtype in (torch.float16, torch.bfloat16, torch.int8): - cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() - return cache - - def forward(self, max_seq_len, offset=0): - if self.original_impl: - return self.forward_original_impl( - max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device - ) - - -@torch.jit.script -def apply_rotary_pos_emb_original(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: - # x: [sq, b, np, hn] - sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) - rot_dim = rope_cache.shape[-2] * 2 - x, x_pass = x[..., :rot_dim], x[..., rot_dim:] - # truncate to support variable sizes - rope_cache = rope_cache[:sq] - xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) - rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], - ], - -1, - ) - x_out2 = x_out2.flatten(3) - return torch.cat((x_out2, x_pass), dim=-1) - - -class RMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): - super().__init__() - self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) - self.eps = eps - - def forward(self, input: torch.Tensor): - norm_x = torch.mean(input * input, dim=-1, keepdim=True) - x_normed = input * torch.rsqrt(norm_x + self.eps) - return self.weight * x_normed - - -class CoreAttention(torch.nn.Module): - def __init__(self, config: ChatGLMConfig, layer_number): - super(CoreAttention, self).__init__() - - self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - self.layer_number = max(1, layer_number) - - projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_partition = projection_size - self.hidden_size_per_attention_head = projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.apply_query_key_layer_scaling: - coeff = self.layer_number - self.norm_factor *= coeff - self.coeff = coeff - - self.attention_dropout = torch.nn.Dropout(config.attention_dropout) - - def forward(self, query_layer, key_layer, value_layer, attention_mask): - pytorch_major_version = int(torch.__version__.split('.')[0]) - if pytorch_major_version >= 2 and False: - query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - is_causal=True) - else: - if attention_mask is not None: - attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attention_mask) - context_layer = context_layer.permute(2, 0, 1, 3) - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) - else: - # Raw attention scores - - # [b, np, sq, sk] - output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) - - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, - device=query_layer.device - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - if self.attention_softmax_in_fp32: - attention_scores = attention_scores.float() - if self.coeff is not None: - attention_scores = attention_scores * self.coeff - if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: - attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], - device=attention_scores.device, dtype=torch.bool) - attention_mask.tril_() - attention_mask = ~attention_mask - if attention_mask is not None: - attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_probs.type_as(value_layer) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.attention_dropout(attention_probs) - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer - - -class SelfAttention(torch.nn.Module): - """Parallel self-attention layer abstract class. - - Self-attention layer takes input with size [s, b, h] - and returns output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(SelfAttention, self).__init__() - self.layer_number = max(1, layer_number) - - self.projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - self.multi_query_attention = config.multi_query_attention - self.qkv_hidden_size = 3 * self.projection_size - if self.multi_query_attention: - self.num_multi_query_groups_per_partition = config.multi_query_group_num - self.qkv_hidden_size = ( - self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num - ) - self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, - bias=config.add_bias_linear or config.add_qkv_bias, - device=device, **_config_to_kwargs(config) - ) - - self.core_attention = CoreAttention(config, self.layer_number) - - # Output. - self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, - device=device, **_config_to_kwargs(config) - ) - - self.interleaved_qkv = config.interleaved_qkv - - def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): - if self.multi_query_attention: - num_attention_heads = self.num_multi_query_groups_per_partition - else: - num_attention_heads = self.num_attention_heads_per_partition - return torch.empty( - inference_max_sequence_len, - batch_size, - num_attention_heads, - self.hidden_size_per_attention_head, - dtype=dtype, - device=device, - ) - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True - ): - # hidden_states: [sq, b, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - # ===================== - # Query, Key, and Value - # ===================== - - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer = self.query_key_value(hidden_states) - - if self.multi_query_attention: - (query_layer, key_layer, value_layer) = mixed_x_layer.split( - [ - self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - ], - dim=-1, - ) - query_layer = query_layer.view( - query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - key_layer = key_layer.view( - key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - value_layer = value_layer.view( - value_layer.size()[:-1] - + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - else: - if self.interleaved_qkv: - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - - if not self.interleaved_qkv: - query_layer = query_layer.view( - query_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ).contiguous() - key_layer = key_layer.view( - key_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ).contiguous() - value_layer = value_layer.view( - value_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ).contiguous() - - # apply relative positional encoding (rotary embedding) - if rotary_pos_emb is not None: - query_layer = apply_rotary_pos_emb_original(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb_original(key_layer, rotary_pos_emb) - - # adjust key and value for inference - if use_cache: - if kv_cache is not None: - cache_k, cache_v = kv_cache - key_layer = torch.cat((cache_k, key_layer), dim=0) - value_layer = torch.cat((cache_v, value_layer), dim=0) - kv_cache = (key_layer, value_layer) - else: - kv_cache = None - - if self.multi_query_attention: - key_layer = key_layer.unsqueeze(-2) - key_layer = key_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 - ) - key_layer = key_layer.contiguous().view( - key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - value_layer = value_layer.unsqueeze(-2) - value_layer = value_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 - ) - value_layer = value_layer.contiguous().view( - value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - - # ================================== - # core attention computation - # ================================== - - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - - # ================= - # Output. [sq, b, h] - # ================= - - output = self.dense(context_layer) - - return output, kv_cache - - -def _config_to_kwargs(args): - common_kwargs = { - "dtype": args.torch_dtype, - } - return common_kwargs - - -class MLP(torch.nn.Module): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. - """ - - def __init__(self, config: ChatGLMConfig, device=None): - super(MLP, self).__init__() - - self.add_bias = config.add_bias_linear - - # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - self.dense_h_to_4h = nn.Linear( - config.hidden_size, - config.ffn_hidden_size * 2, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def swiglu(x): - x = torch.chunk(x, 2, dim=-1) - return F.silu(x[0]) * x[1] - - self.activation_func = swiglu - - # Project back to h. - self.dense_4h_to_h = nn.Linear( - config.ffn_hidden_size, - config.hidden_size, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def forward(self, hidden_states): - # [s, b, 4hp] - intermediate_parallel = self.dense_h_to_4h(hidden_states) - intermediate_parallel = self.activation_func(intermediate_parallel) - # [s, b, h] - output = self.dense_4h_to_h(intermediate_parallel) - return output - - -class GLMBlock(torch.nn.Module): - """A single transformer layer. - - Transformer layer takes input with size [s, b, h] and returns an - output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(GLMBlock, self).__init__() - self.layer_number = layer_number - - self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm - - self.fp32_residual_connection = config.fp32_residual_connection - - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Layernorm on the input data. - self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # Self attention. - self.self_attention = SelfAttention(config, layer_number, device=device) - self.hidden_dropout = config.hidden_dropout - - # Layernorm on the attention output - self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # MLP - self.mlp = MLP(config, device=device) - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, - ): - # hidden_states: [s, b, h] - - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - # Self attention. - attention_output, kv_cache = self.self_attention( - layernorm_output, - attention_mask, - rotary_pos_emb, - kv_cache=kv_cache, - use_cache=use_cache - ) - - # Residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) - layernorm_input = residual + layernorm_input - - # Layer norm post the self attention. - layernorm_output = self.post_attention_layernorm(layernorm_input) - - # MLP. - mlp_output = self.mlp(layernorm_output) - - # Second residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) - output = residual + output - - return output, kv_cache - - -class GLMTransformer(torch.nn.Module): - """Transformer class.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(GLMTransformer, self).__init__() - - self.fp32_residual_connection = config.fp32_residual_connection - self.post_layer_norm = config.post_layer_norm - - # Number of layers. - self.num_layers = config.num_layers - - # Transformer layers. - def build_layer(layer_number): - return GLMBlock(config, layer_number, device=device) - - self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) - - if self.post_layer_norm: - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - def _get_layer(self, layer_number): - return self.layers[layer_number] - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, - use_cache: Optional[bool] = True, - output_hidden_states: Optional[bool] = False, - ): - if not kv_caches: - kv_caches = [None for _ in range(self.num_layers)] - presents = () if use_cache else None - all_self_attentions = None - all_hidden_states = () if output_hidden_states else None - for index in range(self.num_layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer = self._get_layer(index) - - hidden_states, kv_cache = layer( - hidden_states, - attention_mask, - rotary_pos_emb, - kv_cache=kv_caches[index], - use_cache=use_cache - ) - if use_cache: - presents = presents + (kv_cache,) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # Final layer norm. - if self.post_layer_norm: - hidden_states = self.final_layernorm(hidden_states) - - return hidden_states, presents, all_hidden_states, all_self_attentions - - -class ChatGLMPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - is_parallelizable = False - supports_gradient_checkpointing = True - config_class = ChatGLMConfig - base_model_prefix = "transformer" - _no_split_modules = ["GLMBlock"] - - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - return - - def get_masks(self, input_ids, past_key_values, padding_mask=None): - batch_size, seq_length = input_ids.shape - full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) - full_attention_mask.tril_() - past_length = 0 - if past_key_values: - past_length = past_key_values[0][0].shape[0] - if past_length: - full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, - device=input_ids.device), full_attention_mask), dim=-1) - if padding_mask is not None: - full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) - if not past_length and padding_mask is not None: - full_attention_mask -= padding_mask.unsqueeze(-1) - 1 - full_attention_mask = (full_attention_mask < 0.5).bool() - full_attention_mask.unsqueeze_(1) - return full_attention_mask - - def get_position_ids(self, input_ids, device): - batch_size, seq_length = input_ids.shape - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - return position_ids - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, ChatGLMModel): - module.gradient_checkpointing = value - - -class Embedding(torch.nn.Module): - """Language model embeddings.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(Embedding, self).__init__() - - self.hidden_size = config.hidden_size - # Word embeddings (parallel). - self.word_embeddings = nn.Embedding( - config.padded_vocab_size, - self.hidden_size, - dtype=config.torch_dtype, - device=device - ) - self.fp32_residual_connection = config.fp32_residual_connection - - def forward(self, input_ids): - # Embeddings. - words_embeddings = self.word_embeddings(input_ids) - embeddings = words_embeddings - # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. - embeddings = embeddings.transpose(0, 1).contiguous() - # If the input flag for fp32 residual connection is set, convert for float. - if self.fp32_residual_connection: - embeddings = embeddings.float() - return embeddings - - -class ChatGLMModel(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - init_kwargs = {} - if device is not None: - init_kwargs["device"] = device - self.embedding = init_method(Embedding, config, **init_kwargs) - - # Rotary positional embeddings - self.seq_length = config.seq_length - rotary_dim = ( - config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels - ) - - if config.rotary_percent < 1.0: - rotary_dim = int(rotary_dim * config.rotary_percent) - - # partial rotary embeddings, which is better than full rotary - # Wang and Komatsuzaki et al - # https://github.com/kingoflolz/mesh-transformer-jax/ - self.rotary_pos_emb = RotaryEmbedding(rotary_dim, original_impl=config.original_rope, device=device, - dtype=config.torch_dtype) - self.encoder = init_method(GLMTransformer, config, **init_kwargs) - self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, - dtype=config.torch_dtype, **init_kwargs) - self.gradient_checkpointing = False - - def forward( - self, - input_ids, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - batch_size, seq_length = input_ids.shape - - if inputs_embeds is None: - inputs_embeds = self.embedding(input_ids) - - if full_attention_mask is None and attention_mask is not None and not attention_mask.all(): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) - - # Rotary positional embeddings - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - if position_ids is not None: - rotary_pos_emb = rotary_pos_emb[position_ids] - else: - rotary_pos_emb = rotary_pos_emb[None, :seq_length] - rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() - - # Run encoder. - hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( - inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, - kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states - ) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - def quantize(self, weight_bit_width: int): - from .quantization import quantize - quantize(self.encoder, weight_bit_width) - return self - - -class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): - super().__init__(config) - - self.max_sequence_length = config.max_length - self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) - self.config = config - self.quantized = False - - if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, empty_init=True) - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - - # update position ids - if "position_ids" in model_kwargs: - position_ids = model_kwargs["position_ids"] - new_position_id = position_ids[..., -1:].clone() - new_position_id += 1 - model_kwargs["position_ids"] = torch.cat( - [position_ids, new_position_id], dim=-1 - ) - - return model_kwargs - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - input_pos: int = None, - **kwargs - ) -> dict: - # only last token for input_ids if past is not None - if past_key_values is not None: - if position_ids is None: - position_ids = self.get_position_ids(input_ids, device=input_ids.device) - position_ids = position_ids[..., -1:] - input_ids = input_ids[:, -1:] - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "position_ids": position_ids, - "attention_mask": attention_mask - } - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - - lm_logits = self.transformer.output_layer(hidden_states) - lm_logits = lm_logits.transpose(0, 1).contiguous() - - loss = None - if labels is not None: - lm_logits = lm_logits.to(torch.float32) - - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - return tuple( - ( - layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), - ) - for layer_past in past - ) - - def process_response(self, response): - response = response.strip() - response = response.replace("[[训练时间]]", "2023年") - return response - - def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = None): - prompt = "" - for i, (old_query, response) in enumerate(history): - prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response) - prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query) - inputs = tokenizer([prompt], return_tensors="pt") - inputs = inputs.to(self.device) - return inputs - - @torch.no_grad() - def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1, - do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - inputs = self.build_inputs(tokenizer, query, history=history) - outputs = self.generate(**inputs, **gen_kwargs) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - response = self.process_response(response) - history = history + [(query, response)] - return response, history - - @torch.no_grad() - def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, - do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - inputs = self.build_inputs(tokenizer, query, history=history) - for outputs in self.stream_generate(**inputs, **gen_kwargs): - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs) - response = self.process_response(response) - new_history = history + [(query, response)] - yield response, new_history - - @torch.no_grad() - def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - **kwargs, - ): - batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - - if generation_config is None: - generation_config = self.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - warnings.warn( - f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " - "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" - " recommend using `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - if not has_default_max_length: - logger.warn( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - UserWarning, - ) - - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) - - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - logits_warper = self._get_logits_warper(generation_config) - - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None - while True: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - if generation_config.do_sample: - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long()) - - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - break - yield input_ids - - def quantize(self, bits: int, empty_init=False, device=None, **kwargs): - if bits == 0: - return - - from .quantization import quantize - - if self.quantized: - logger.info("Already quantized.") - return self - - self.quantized = True - - self.config.quantization_bit = bits - - self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device, - **kwargs) - return self diff --git a/transformers/llm/export/llm_models/deepseek-llm-7b-chat/config.json b/transformers/llm/export/llm_models/deepseek-llm-7b-chat/config.json deleted file mode 100755 index 67a803b6a..000000000 --- a/transformers/llm/export/llm_models/deepseek-llm-7b-chat/config.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "architectures": [ - "LlamaForCausalLM" - ], - "auto_map": { - "AutoModelForCausalLM": "modeling_llama.LlamaForCausalLM" - }, - "bos_token_id": 100000, - "eos_token_id": 100001, - "hidden_act": "silu", - "hidden_size": 4096, - "initializer_range": 0.02, - "intermediate_size": 11008, - "max_position_embeddings": 4096, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 30, - "num_key_value_heads": 32, - "pretraining_tp": 1, - "rms_norm_eps": 1e-06, - "rope_scaling": null, - "rope_theta": 10000.0, - "tie_word_embeddings": false, - "torch_dtype": "bfloat16", - "transformers_version": "4.33.1", - "use_cache": true, - "vocab_size": 102400 -} diff --git a/transformers/llm/export/llm_models/deepseek-llm-7b-chat/configuration_llama.py b/transformers/llm/export/llm_models/deepseek-llm-7b-chat/configuration_llama.py deleted file mode 100644 index 1b0e9c357..000000000 --- a/transformers/llm/export/llm_models/deepseek-llm-7b-chat/configuration_llama.py +++ /dev/null @@ -1,174 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" LLaMA model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - -class LlamaConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`LlamaModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - pretraining_tp (`int`, *optional*, defaults to `1`): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-12): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings(`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. See the following thread for more information on how - these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an - experimental feature, subject to breaking API changes in future versions. - - Example: - - ```python - >>> from transformers import LlamaModel, LlamaConfig - - >>> # Initializing a LLaMA llama-7b style configuration - >>> configuration = LlamaConfig() - - >>> # Initializing a model from the llama-7b style configuration - >>> model = LlamaModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - model_type = "llama" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_scaling=None, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_scaling = rope_scaling - self._rope_scaling_validation() - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - def _rope_scaling_validation(self): - """ - Validate the `rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " - f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/transformers/llm/export/llm_models/deepseek-llm-7b-chat/modeling_llama.py b/transformers/llm/export/llm_models/deepseek-llm-7b-chat/modeling_llama.py deleted file mode 100644 index 493b040b7..000000000 --- a/transformers/llm/export/llm_models/deepseek-llm-7b-chat/modeling_llama.py +++ /dev/null @@ -1,1040 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch LLaMA model.""" -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings -from .configuration_llama import LlamaConfig - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "LlamaConfig" - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = torch.squeeze(cos) # [seq_len, dim] - sin = torch.squeeze(sin) # [seq_len, dim] - # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LlamaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.pretraining_tp = config.pretraining_tp - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - if self.pretraining_tp > 1: - slice = self.intermediate_size // self.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.pretraining_tp = config.pretraining_tp - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - if self.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp - query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - ''' - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - if rotary_pos_emb is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - else: - cos, sin = rotary_pos_emb - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - ''' - #--------------- - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) - kv_seq_len = key_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[1] - # rope - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - # kv cache - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=1) - value_states = torch.cat((past_value, value_states), dim=1) - past_key_value = torch.stack((key_states, value_states)) - query_states = query_states.transpose(1, 2) - key_states = key_states.permute([0, 2, 3, 1]) - value_states = value_states.transpose(1, 2) - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - #--------------- - attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - rotary_pos_emb=rotary_pos_emb, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): - module.gradient_checkpointing = value - - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - self.pretraining_tp = config.pretraining_tp - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The LLaMa Model transformer with a sequence classification head on top (linear layer). - - [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - LLAMA_START_DOCSTRING, -) -class LlamaForSequenceClassification(LlamaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = LlamaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/glm-4-9b-chat/modeling_chatglm.py b/transformers/llm/export/llm_models/glm-4-9b-chat/modeling_chatglm.py deleted file mode 100755 index e86f5a2f4..000000000 --- a/transformers/llm/export/llm_models/glm-4-9b-chat/modeling_chatglm.py +++ /dev/null @@ -1,1238 +0,0 @@ -""" PyTorch ChatGLM model. """ -import json -import math -import copy -import warnings -import re -import sys - -import torch -import torch.utils.checkpoint -import torch.nn.functional as F -from torch import nn -from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss -from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable, Dict, Any -from copy import deepcopy - -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput - -from .configuration_chatglm import ChatGLMConfig - -# flags required to enable jit fusion kernels - -if sys.platform != 'darwin': - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM" -_CONFIG_FOR_DOC = "ChatGLMConfig" - -def default_init(cls, *args, **kwargs): - return cls(*args, **kwargs) - - -class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 198] = 5e4 - return scores - - -def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """Split a tensor along its last dimension. - - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - - Returns: - A list of Tensors - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -class RotaryEmbedding(nn.Module): - def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.dim = dim - self.original_impl = original_impl - self.rope_ratio = rope_ratio - - def forward_impl( - self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 - ): - """Enhanced Transformer with Rotary Position Embedding. - - Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ - transformers/rope/__init__.py. MIT License: - https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. - """ - # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - base = base * self.rope_ratio - theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)) - - # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) - - # Calculate the product of position index and $\theta_i$ - idx_theta = torch.outer(seq_idx, theta).float() - - cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) - - # this is to mimic the behaviour of complex32, else we will get different results - if dtype in (torch.float16, torch.bfloat16, torch.int8): - cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() - return cache - - def forward(self, max_seq_len, offset=0): - return self.forward_impl( - max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device - ) - - -@torch.jit.script -def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: - # x: [b, np, sq, hn] - b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3) - rot_dim = rope_cache.shape[-2] * 2 - x, x_pass = x[..., :rot_dim], x[..., rot_dim:] - # truncate to support variable sizes - rope_cache = rope_cache[:, :sq] - xshaped = x.reshape(b, np, sq, rot_dim // 2, 2) - rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], - ], - -1, - ) - x_out2 = x_out2.flatten(3) - return torch.cat((x_out2, x_pass), dim=-1) - - -class RMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): - super().__init__() - self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) - self.eps = eps - - def forward(self, hidden_states: torch.Tensor): - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - - return (self.weight * hidden_states).to(input_dtype) - - -class CoreAttention(torch.nn.Module): - def __init__(self, config: ChatGLMConfig, layer_number): - super(CoreAttention, self).__init__() - - self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - self.layer_number = max(1, layer_number) - - projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_partition = projection_size - self.hidden_size_per_attention_head = projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.apply_query_key_layer_scaling: - coeff = self.layer_number - self.norm_factor *= coeff - self.coeff = coeff - - self.attention_dropout = torch.nn.Dropout(config.attention_dropout) - - def raw_atten(self, query_layer, key_layer, value_layer, attention_mask): - attn_weights = torch.matmul(query_layer, key_layer.transpose(-1, -2)) / self.norm_factor - if attention_mask is None: - seq_len = query_layer.shape[2] - attention_mask = ~torch.tril(torch.ones([1, 1, seq_len, seq_len], device=attn_weights.device).bool()) - attn_weights = attn_weights.masked_fill(attention_mask, float("-inf")) - #mask_value = torch.finfo(attn_weights.dtype).min - #attn_weights = torch.where(attention_mask, attn_weights.to(attn_weights.dtype), mask_value) - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - context_layer = torch.matmul(attn_weights, value_layer) - return context_layer - context_layer = context_layer.transpose(1, 2).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) - return context_layer - - def forward(self, query_layer, key_layer, value_layer, attention_mask): - pytorch_major_version = int(torch.__version__.split('.')[0]) - if pytorch_major_version >= 2 and False: - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - is_causal=True) - else: - if attention_mask is not None: - attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attention_mask) - context_layer = context_layer.transpose(1, 2).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) - else: - # Raw attention scores - - # [b, np, sq, sk] - output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2)) - - # [b, np, sq, hn] -> [b * np, sq, hn] - query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) - # [b, np, sk, hn] -> [b * np, sk, hn] - key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) - - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, - device=query_layer.device - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer, # [b * np, sq, hn] - key_layer.transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - if self.attention_softmax_in_fp32: - attention_scores = attention_scores.float() - if self.coeff is not None: - attention_scores = attention_scores * self.coeff - if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: - attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], - device=attention_scores.device, dtype=torch.bool) - attention_mask.tril_() - attention_mask = ~attention_mask - - if attention_mask is not None: - attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_probs.type_as(value_layer) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.attention_dropout(attention_probs) - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - # change view [b * np, sk, hn] - #value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1) - # change view [b * np, sq, sk] - #attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - # matmul: [b * np, sq, hn] - # context_layer = torch.bmm(attention_probs, value_layer) - context_layer = torch.matmul(attention_probs, value_layer) - # change view [b, np, sq, hn] - # context_layer = context_layer.view(*output_size) - # [b, np, sq, hn] --> [b, sq, np, hn] - context_layer = context_layer.transpose(1, 2).contiguous() - # [b, sq, np, hn] --> [b, sq, hp] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) - - return context_layer - - -class SelfAttention(torch.nn.Module): - """Parallel self-attention layer abstract class. - - Self-attention layer takes input with size [s, b, h] - and returns output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(SelfAttention, self).__init__() - self.layer_number = max(1, layer_number) - - self.projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - self.multi_query_attention = config.multi_query_attention - self.qkv_hidden_size = 3 * self.projection_size - if self.multi_query_attention: - self.num_multi_query_groups_per_partition = config.multi_query_group_num - self.qkv_hidden_size = ( - self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num - ) - self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, - bias=config.add_bias_linear or config.add_qkv_bias, - device=device, **_config_to_kwargs(config) - ) - - self.core_attention = CoreAttention(config, self.layer_number) - - # Output. - self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, - device=device, **_config_to_kwargs(config) - ) - - def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): - if self.multi_query_attention: - num_attention_heads = self.num_multi_query_groups_per_partition - else: - num_attention_heads = self.num_attention_heads_per_partition - return torch.empty( - inference_max_sequence_len, - batch_size, - num_attention_heads, - self.hidden_size_per_attention_head, - dtype=dtype, - device=device, - ) - - def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True): - # hidden_states: [b, sq, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - # ===================== - # Query, Key, and Value - # ===================== - - # Attention heads [b, sq, h] --> [b, sq, (np * 3 * hn)] - mixed_x_layer = self.query_key_value(hidden_states) - - if self.multi_query_attention: - (query_layer, key_layer, value_layer) = mixed_x_layer.split( - [ - self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - ], - dim=-1, - ) - query_layer = query_layer.view( - query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - key_layer = key_layer.view( - key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - value_layer = value_layer.view( - value_layer.size()[:-1] - + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - else: - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - - # [b, sq, np, hn] -> [b, np, sq, hn] - query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]] - - # apply relative positional encoding (rotary embedding) - if rotary_pos_emb is not None: - query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - - # adjust key and value for inference - if kv_cache is not None: - cache_k, cache_v = kv_cache - key_layer = torch.cat((cache_k, key_layer), dim=2) - value_layer = torch.cat((cache_v, value_layer), dim=2) - if use_cache: - ''' - if kv_cache is None: - kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1) - else: - kv_cache = (key_layer, value_layer) - ''' - kv_cache = torch.stack([key_layer, value_layer], axis=0) - # ''' - else: - kv_cache = None - - if self.multi_query_attention: - key_layer = key_layer.unsqueeze(2) - key_layer = key_layer.expand( - -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1 - ) - key_layer = key_layer.contiguous().view( - key_layer.size()[:1] + (self.num_attention_heads_per_partition,) + key_layer.size()[3:] - ) - value_layer = value_layer.unsqueeze(2) - value_layer = value_layer.expand( - -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1 - ) - value_layer = value_layer.contiguous().view( - value_layer.size()[:1] + (self.num_attention_heads_per_partition,) + value_layer.size()[3:] - ) - - # ================================== - # core attention computation - # ================================== - - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - - # ================= - # Output. [sq, b, h] - # ================= - - output = self.dense(context_layer) - - return output, kv_cache - - -def _config_to_kwargs(args): - common_kwargs = { - "dtype": args.torch_dtype, - } - return common_kwargs - - -class MLP(torch.nn.Module): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. - """ - - def __init__(self, config: ChatGLMConfig, device=None): - super(MLP, self).__init__() - - self.add_bias = config.add_bias_linear - - # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - self.dense_h_to_4h = nn.Linear( - config.hidden_size, - config.ffn_hidden_size * 2, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def swiglu(x): - x = torch.chunk(x, 2, dim=-1) - return F.silu(x[0]) * x[1] - - self.activation_func = swiglu - - # Project back to h. - self.dense_4h_to_h = nn.Linear( - config.ffn_hidden_size, - config.hidden_size, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def forward(self, hidden_states): - # [s, b, 4hp] - intermediate_parallel = self.dense_h_to_4h(hidden_states) - intermediate_parallel = self.activation_func(intermediate_parallel) - # [s, b, h] - output = self.dense_4h_to_h(intermediate_parallel) - return output - - -class GLMBlock(torch.nn.Module): - """A single transformer layer. - - Transformer layer takes input with size [s, b, h] and returns an - output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(GLMBlock, self).__init__() - self.layer_number = layer_number - - self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm - - self.fp32_residual_connection = config.fp32_residual_connection - - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Layernorm on the input data. - self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # Self attention. - self.self_attention = SelfAttention(config, layer_number, device=device) - self.hidden_dropout = config.hidden_dropout - - # Layernorm on the attention output - self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # MLP - self.mlp = MLP(config, device=device) - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, - ): - # hidden_states: [s, b, h] - hidden_states = hidden_states.view(1, -1, 4096) - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - # Self attention. - attention_output, kv_cache = self.self_attention( - layernorm_output, - attention_mask, - rotary_pos_emb, - kv_cache=kv_cache, - use_cache=use_cache - ) - - # Residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) - layernorm_input = residual + layernorm_input - - # Layer norm post the self attention. - layernorm_output = self.post_attention_layernorm(layernorm_input) - - # MLP. - mlp_output = self.mlp(layernorm_output) - - # Second residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) - output = residual + output - - return output, kv_cache - - -class GLMTransformer(torch.nn.Module): - """Transformer class.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(GLMTransformer, self).__init__() - - self.fp32_residual_connection = config.fp32_residual_connection - self.post_layer_norm = config.post_layer_norm - - # Number of layers. - self.num_layers = config.num_layers - - # Transformer layers. - def build_layer(layer_number): - return GLMBlock(config, layer_number, device=device) - - self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) - - if self.post_layer_norm: - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - self.gradient_checkpointing = False - - def _get_layer(self, layer_number): - return self.layers[layer_number] - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, - use_cache: Optional[bool] = True, - output_hidden_states: Optional[bool] = False, - ): - if not kv_caches: - kv_caches = [None for _ in range(self.num_layers)] - presents = () if use_cache else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - all_self_attentions = None - all_hidden_states = () if output_hidden_states else None - for index in range(self.num_layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer = self._get_layer(index) - if self.gradient_checkpointing and self.training: - layer_ret = torch.utils.checkpoint.checkpoint( - layer, - hidden_states, - attention_mask, - rotary_pos_emb, - kv_caches[index], - use_cache, - use_reentrant=False - ) - else: - layer_ret = layer( - hidden_states, - attention_mask, - rotary_pos_emb, - kv_cache=kv_caches[index], - use_cache=use_cache - ) - hidden_states, kv_cache = layer_ret - if use_cache: - # token by token decoding, use tuple format - if kv_caches[0] is not None: - presents = presents + (kv_cache,) - # prefilling in decoding, use tensor format to save cuda memory - else: - if len(presents) == 0: - presents = kv_cache - else: - presents = torch.cat((presents, kv_cache), dim=0) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # Final layer norm. - if self.post_layer_norm: - hidden_states = self.final_layernorm(hidden_states) - - return hidden_states, presents, all_hidden_states, all_self_attentions - - -class ChatGLMPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - is_parallelizable = False - supports_gradient_checkpointing = True - config_class = ChatGLMConfig - base_model_prefix = "transformer" - _no_split_modules = ["GLMBlock"] - - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - return - - def get_masks(self, input_ids, past_key_values, padding_mask=None): - batch_size, seq_length = input_ids.shape - full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) - full_attention_mask.tril_() - past_length = 0 - if past_key_values: - past_length = past_key_values[0][0].shape[2] - if past_length: - full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, - device=input_ids.device), full_attention_mask), dim=-1) - if padding_mask is not None: - full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) - if not past_length and padding_mask is not None: - full_attention_mask -= padding_mask.unsqueeze(-1) - 1 - full_attention_mask = (full_attention_mask < 0.5).bool() - full_attention_mask.unsqueeze_(1) - return full_attention_mask - - def get_position_ids(self, input_ids, device): - batch_size, seq_length = input_ids.shape - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - return position_ids - - def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): - if not self.supports_gradient_checkpointing: - raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") - - -class Embedding(torch.nn.Module): - """Language model embeddings.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(Embedding, self).__init__() - - self.hidden_size = config.hidden_size - # Word embeddings (parallel). - self.word_embeddings = nn.Embedding( - config.padded_vocab_size, - self.hidden_size, - dtype=config.torch_dtype, - device=device - ) - self.fp32_residual_connection = config.fp32_residual_connection - - def forward(self, input_ids): - # Embeddings. - words_embeddings = self.word_embeddings(input_ids) - embeddings = words_embeddings - # If the input flag for fp32 residual connection is set, convert for float. - if self.fp32_residual_connection: - embeddings = embeddings.float() - return embeddings - - -class ChatGLMModel(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - init_kwargs = {} - if device is not None: - init_kwargs["device"] = device - self.embedding = init_method(Embedding, config, **init_kwargs) - self.num_layers = config.num_layers - self.multi_query_group_num = config.multi_query_group_num - self.kv_channels = config.kv_channels - - # Rotary positional embeddings - self.seq_length = config.seq_length - rotary_dim = ( - config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels - ) - - self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope, - device=device, dtype=config.torch_dtype) - self.encoder = init_method(GLMTransformer, config, **init_kwargs) - self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, - dtype=config.torch_dtype, **init_kwargs) - - def get_input_embeddings(self): - return self.embedding.word_embeddings - - def set_input_embeddings(self, value): - self.embedding.word_embeddings = value - - def forward( - self, - input_ids, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - batch_size, seq_length = input_ids.shape - - if inputs_embeds is None: - inputs_embeds = self.embedding(input_ids) - - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) - - # Rotary positional embeddings - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - if position_ids is not None: - rotary_pos_emb = rotary_pos_emb[position_ids] - else: - rotary_pos_emb = rotary_pos_emb[None, :seq_length] - - # Run encoder. - hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( - inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, - kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states - ) - if presents is not None and type(presents) is torch.Tensor: - presents = presents.split(1, dim=0) - presents = list(presents) - presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents] - presents = [tuple([x.squeeze(0) for x in y]) for y in presents] - presents = tuple(presents) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): - super().__init__(config) - - self.max_sequence_length = config.max_length - self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) - self.config = config - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - - # update position ids - if "position_ids" in model_kwargs: - position_ids = model_kwargs["position_ids"] - new_position_id = position_ids[..., -1:].clone() - new_position_id += 1 - model_kwargs["position_ids"] = torch.cat( - [position_ids, new_position_id], dim=-1 - ) - - model_kwargs["is_first_forward"] = False - return model_kwargs - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - is_first_forward: bool = True, - **kwargs - ) -> dict: - # only last token for input_ids if past is not None - if position_ids is None: - position_ids = self.get_position_ids(input_ids, device=input_ids.device) - if not is_first_forward: - if past_key_values is not None: - position_ids = position_ids[..., -1:] - input_ids = input_ids[:, -1:] - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "position_ids": position_ids, - "attention_mask": attention_mask, - "return_last_logit": True, - "use_cache": use_cache - } - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - ): - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - if return_last_logit: - hidden_states = hidden_states[:, -1:] - lm_logits = self.transformer.output_layer(hidden_states) - - loss = None - if labels is not None: - lm_logits = lm_logits.to(torch.float32) - - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - return tuple( - ( - layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)), - ) - for layer_past in past - ) - - def process_response(self, output, history): - content = "" - history = deepcopy(history) - for response in output.split("<|assistant|>"): - if "\n" in response: - metadata, content = response.split("\n", maxsplit=1) - else: - metadata, content = "", response - if not metadata.strip(): - content = content.strip() - history.append({"role": "assistant", "metadata": metadata, "content": content}) - content = content.replace("[[训练时间]]", "2023年") - else: - history.append({"role": "assistant", "metadata": metadata, "content": content}) - if history[0]["role"] == "system" and "tools" in history[0]: - parameters = json.loads(content) - content = {"name": metadata.strip(), "parameters": parameters} - else: - content = {"name": metadata.strip(), "content": content} - return content, history - - @torch.inference_mode() - def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", - max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, - **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - history.append({"role": role, "content": query}) - inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=True, - return_tensors="pt", return_dict=True) - inputs = inputs.to(self.device) - eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|user|>"), - tokenizer.convert_tokens_to_ids("<|observation|>")] - outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] - response = tokenizer.decode(outputs) - response, history = self.process_response(response, history) - return response, history - - @torch.inference_mode() - def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", - past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, - logits_processor=None, return_past_key_values=False, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|user|>"), - tokenizer.convert_tokens_to_ids("<|observation|>")] - gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - if past_key_values is None: - inputs = tokenizer.apply_chat_template(history + [{"role": role, "content": query}], - add_generation_prompt=True, tokenize=True, return_tensors="pt", - return_dict=True) - else: - inputs = tokenizer.apply_chat_template([{"role": role, "content": query}], add_special_tokens=False, - add_generation_prompt=True, tokenize=True, return_tensors="pt", - return_dict=True) - inputs = inputs.to(self.device) - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - inputs.position_ids += past_length - attention_mask = inputs.attention_mask - attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) - inputs['attention_mask'] = attention_mask - history.append({"role": role, "content": query}) - for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, - eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, - **gen_kwargs): - if return_past_key_values: - outputs, past_key_values = outputs - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] - response = tokenizer.decode(outputs) - if response and response[-1] != "�": - response, new_history = self.process_response(response, history) - if return_past_key_values: - yield response, new_history, past_key_values - else: - yield response, new_history - - @torch.inference_mode() - def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - return_past_key_values=False, - **kwargs, - ): - batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - - if generation_config is None: - generation_config = self.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - model_kwargs["use_cache"] = generation_config.use_cache - bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - warnings.warn( - f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " - "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" - " recommend using `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - if not has_default_max_length: - logger.warn( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - UserWarning, - ) - - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) - - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - logits_warper = self._get_logits_warper(generation_config) - - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None - while True: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - if generation_config.do_sample: - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) - if return_past_key_values: - yield input_ids, outputs.past_key_values - else: - yield input_ids - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - break - - -class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): - super().__init__(config) - - self.num_labels = config.num_labels - self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) - - self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half) - if config.classifier_dropout is not None: - self.dropout = nn.Dropout(config.classifier_dropout) - else: - self.dropout = None - self.config = config - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - full_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - full_attention_mask=full_attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - pooled_hidden_states = hidden_states[-1] - if self.dropout is not None: - pooled_hidden_states = self.dropout(pooled_hidden_states) - logits = self.classifier_head(pooled_hidden_states) - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze().float(), labels.squeeze()) - else: - loss = loss_fct(logits.float(), labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits.float(), labels.view(-1, self.num_labels)) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/internlm-chat-7b/modeling_internlm.py b/transformers/llm/export/llm_models/internlm-chat-7b/modeling_internlm.py deleted file mode 100755 index b636e8716..000000000 --- a/transformers/llm/export/llm_models/internlm-chat-7b/modeling_internlm.py +++ /dev/null @@ -1,1046 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch InternLM model.""" -import math -from typing import List, Optional, Tuple, Union -import threading, queue - -import torch -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.generation.streamers import BaseStreamer -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from .configuration_internlm import InternLMConfig - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "InternLMConfig" - - -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class InternLMRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - InternLMRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - - -class InternLMRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = torch.squeeze(cos) # [seq_len, dim] - sin = torch.squeeze(sin) # [seq_len, dim] - # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class InternLMMLP(nn.Module): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - ): - super().__init__() - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.act_fn = ACT2FN[hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class InternLMAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: InternLMConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) - self.rotary_emb = InternLMRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - ''' - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - if rotary_pos_emb is None: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - else: - cos, sin = rotary_pos_emb - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - ''' - #--------------- - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim) - - kv_seq_len = key_states.shape[1] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[1] - # rope - cos, sin = rotary_pos_emb - query_states = (query_states * cos) + (rotate_half(query_states) * sin) - key_states = (key_states * cos) + (rotate_half(key_states) * sin) - # kv cache - if past_key_value is not None: - past_key, past_value = past_key_value[0], past_key_value[1] - key_states = torch.cat((past_key, key_states), dim=1) - value_states = torch.cat((past_value, value_states), dim=1) - past_key_value = torch.stack((key_states, value_states)) - query_states = query_states.transpose(1, 2) - key_states = key_states.permute([0, 2, 3, 1]) - value_states = value_states.transpose(1, 2) - attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim) - #--------------- - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - # attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class InternLMDecoderLayer(nn.Module): - def __init__(self, config: InternLMConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = InternLMAttention(config=config) - self.mlp = InternLMMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - ) - self.input_layernorm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - rotary_pos_emb=rotary_pos_emb, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -INTERNLM_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`InternLMConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare InternLM Model outputting raw hidden-states without any specific head on top.", - INTERNLM_START_DOCSTRING, -) -class InternLMPreTrainedModel(PreTrainedModel): - config_class = InternLMConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["InternLMDecoderLayer"] - _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, InternLMModel): - module.gradient_checkpointing = value - - -INTERNLM_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that - don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all - `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare InternLM Model outputting raw hidden-states without any specific head on top.", - INTERNLM_START_DOCSTRING, -) -class InternLMModel(InternLMPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLMDecoderLayer`] - - Args: - config: InternLMConfig - """ - - _auto_class = "AutoModel" - - def __init__(self, config: InternLMConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([InternLMDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - @add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class InternLMForCausalLM(InternLMPreTrainedModel): - _auto_class = "AutoModelForCausalLM" - - def __init__(self, config): - super().__init__(config) - self.model = InternLMModel(config) - - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, InternLMForCausalLM - - >>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) - return reordered_past - - def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []): - prompt = "" - for record in history: - prompt += f"""<|User|>:{record[0]}\n<|Bot|>:{record[1]}\n""" - if len(prompt) == 0: - prompt += "" - prompt += f"""<|User|>:{query}\n<|Bot|>:""" - return tokenizer([prompt], return_tensors="pt") - - @torch.no_grad() - def chat( - self, - tokenizer, - query: str, - history: List[Tuple[str, str]] = [], - streamer: Optional[BaseStreamer] = None, - max_new_tokens: int = 1024, - do_sample: bool = True, - temperature: float = 0.8, - top_p: float = 0.8, - **kwargs, - ): - inputs = self.build_inputs(tokenizer, query, history) - inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)} - outputs = self.generate( - **inputs, - streamer=streamer, - max_new_tokens=max_new_tokens, - do_sample=do_sample, - temperature=temperature, - top_p=top_p, - **kwargs, - ) - outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :] - response = tokenizer.decode(outputs, skip_special_tokens=True) - response = response.split("")[0] - history = history + [(query, response)] - return response, history - - @torch.no_grad() - def stream_chat( - self, - tokenizer, - query: str, - history: List[Tuple[str, str]] = [], - max_new_tokens: int = 1024, - do_sample: bool = True, - temperature: float = 0.8, - top_p: float = 0.8, - **kwargs, - ): - """ - Return a generator in format: (response, history) - Eg. - ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) - ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')]) - """ - - response_queue = queue.Queue(maxsize=20) - - class ChatStreamer(BaseStreamer): - def __init__(self, tokenizer) -> None: - super().__init__() - self.tokenizer = tokenizer - self.queue = response_queue - self.query = query - self.history = history - self.response = "" - self.received_inputs = False - self.queue.put((self.response, history + [(self.query, self.response)])) - - def put(self, value): - if len(value.shape) > 1 and value.shape[0] > 1: - raise ValueError("ChatStreamer only supports batch size 1") - elif len(value.shape) > 1: - value = value[0] - - if not self.received_inputs: - # The first received value is input_ids, ignore here - self.received_inputs = True - return - - token = self.tokenizer.decode([value[-1]], skip_special_tokens=True) - if token.strip() != "": - self.response = self.response + token - history = self.history + [(self.query, self.response)] - self.queue.put((self.response, history)) - - def end(self): - self.queue.put(None) - - def stream_producer(): - return self.chat( - tokenizer=tokenizer, - query=query, - streamer=ChatStreamer(tokenizer=tokenizer), - history=history, - max_new_tokens=max_new_tokens, - do_sample=do_sample, - temperature=temperature, - top_p=top_p, - **kwargs, - ) - - def consumer(): - producer = threading.Thread(target=stream_producer) - producer.start() - while True: - res = response_queue.get() - if res is not None: - return - yield res - - return consumer() - - -@add_start_docstrings( - """ - The InternLM Model transformer with a sequence classification head on top (linear layer). - - [`InternLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - INTERNLM_START_DOCSTRING, -) -class InternLMForSequenceClassification(InternLMPreTrainedModel): - _keys_to_ignore_on_load_missing = [r"lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = InternLMModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/transformers/llm/export/llm_models/phi-2/modeling_phi.py b/transformers/llm/export/llm_models/phi-2/modeling_phi.py deleted file mode 100644 index 30b7fc8fd..000000000 --- a/transformers/llm/export/llm_models/phi-2/modeling_phi.py +++ /dev/null @@ -1,989 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. -# -# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. -# Licensed under the BSD 3-Clause License. - -from __future__ import annotations - -import math -from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Tuple, Union - -import torch -import torch.nn as nn -from einops import rearrange, repeat -from transformers import PretrainedConfig, PreTrainedModel -from transformers.activations import ACT2FN -from transformers.modeling_outputs import CausalLMOutputWithPast - -from .configuration_phi import PhiConfig - -try: - from flash_attn.bert_padding import pad_input, unpad_input - from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding - from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention - from flash_attn.ops.fused_dense import FusedDense -except: - pad_input, unpad_input = None, None - FlashRotaryEmbedding = None - FlashSelfAttention, FlashCrossAttention = None, None - FusedDense = None - - -@dataclass -class InferenceParams: - """Inference parameters passed to model to efficiently calculate - and store context during inference. - - Reference: - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py. - - Args: - max_seqlen: Maximum sequence length. - max_batch_size: Maximum batch size. - seqlen_offset: Sequence length offset. - batch_size_offset: Batch size offset. - key_value_memory_dict: Key value memory dictionary. - lengths_per_sample: Lengths per sample. - - """ - - max_seqlen: int = field(metadata={"help": "Maximum sequence length."}) - - max_batch_size: int = field(metadata={"help": "Maximum batch size."}) - - seqlen_offset: int = field(default=0, metadata={"help": "Sequence length offset."}) - - batch_size_offset: int = field(default=0, metadata={"help": "Batch size offset."}) - - key_value_memory_dict: Dict[str, Any] = field( - default_factory=dict, metadata={"help": "Key value memory dictionary."} - ) - - lengths_per_sample: torch.Tensor = field(default=None, metadata={"help": "Lengths per sample."}) - - -class Embedding(nn.Module): - """Token embedding with dropout.""" - - def __init__(self, config: PretrainedConfig) -> None: - super().__init__() - - self.wte = nn.Embedding(config.vocab_size, config.n_embd) - self.drop = nn.Dropout(config.embd_pdrop) - - def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - - hidden_states = self.wte(input_ids) - hidden_states = self.drop(hidden_states) - - return hidden_states - - -def _apply_rotary_emb( - x: torch.FloatTensor, - cos: torch.FloatTensor, - sin: torch.FloatTensor, -) -> torch.FloatTensor: - _, seqlen, _, _ = x.shape - _, rotary_dim = cos.shape - rotary_dim *= 2 - - x_rot = x[:, :, :, :rotary_dim] - x_pass = x[:, :, :, rotary_dim:] - - x1, x2 = x_rot.chunk(2, dim=-1) - c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d") - x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] - - x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype) - - return torch.cat([x_rot, x_pass], axis=-1) - - -def _apply_rotary_emb_kv( - kv: torch.FloatTensor, - cos: torch.FloatTensor, - sin: torch.FloatTensor, - cos_k: Optional[torch.FloatTensor] = None, - sin_k: Optional[torch.FloatTensor] = None, -) -> torch.FloatTensor: - _, seqlen, _, _, _ = kv.shape - _, rotary_dim = cos.shape - rotary_dim *= 2 - - k_rot = kv[:, :, 0, :, :rotary_dim] - k_pass = kv[:, :, 0, :, rotary_dim:] - - k1, k2 = k_rot.chunk(2, dim=-1) - c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d") - k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]] - - k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype) - - return torch.cat( - [ - torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2), - kv[:, :, 1:2, :, :], - ], - axis=2, - ) - - -def _apply_rotary_emb_qkv( - qkv: torch.FloatTensor, - cos: torch.FloatTensor, - sin: torch.FloatTensor, - cos_k: Optional[torch.FloatTensor] = None, - sin_k: Optional[torch.FloatTensor] = None, -) -> torch.FloatTensor: - _, seqlen, _, _, _ = qkv.shape - _, rotary_dim = cos.shape - rotary_dim *= 2 - - q_rot = qkv[:, :, 0, :, :rotary_dim] - q_pass = qkv[:, :, 0, :, rotary_dim:] - - k_rot = qkv[:, :, 1, :, :rotary_dim] - k_pass = qkv[:, :, 1, :, rotary_dim:] - - q1, q2 = q_rot.chunk(2, dim=-1) - k1, k2 = k_rot.chunk(2, dim=-1) - c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d") - q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]] - - q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype) - k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype) - - return torch.cat( - [ - torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2), - torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2), - qkv[:, :, 2:3, :, :], - ], - axis=2, - ) - - -class RotaryEmbedding(nn.Module): - """Rotary positional embedding (RoPE). - - Reference: - RoFormer: Enhanced Transformer with Rotary Position Embedding. - https://arxiv.org/pdf/2104.09864.pdf. - - """ - - def __init__( - self, - dim: int, - base: int = 10000, - scale_base: Optional[float] = None, - pos_idx_in_fp32: bool = True, - max_position_embeddings: int = 2048, - device: Optional[str] = None, - **kwargs, - ) -> None: - super().__init__() - - if scale_base is not None: - raise NotImplementedError - - self.dim = dim - self.base = float(base) - self.scale_base = scale_base - self.pos_idx_in_fp32 = pos_idx_in_fp32 - self.max_position_embeddings = max_position_embeddings - self.device = device - - # Generate and save the inverse frequency buffer (non-trainable) - inv_freq = self._compute_inv_freq(device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Generate and save the scale buffer (non-trainable) - scale = ( - (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) - if scale_base is not None - else None - ) - self.register_buffer("scale", scale, persistent=False) - - # Initialize cached attributes since ONNX can't rely on dynamic initialization - self._update_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.float32) - - def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor: - return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) - - def _update_cos_sin_cache( - self, - seqlen: int, - device: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - ) -> None: - self._seq_len_cached = seqlen - - # fp32 is preferred since the output of `torch.arange` can be quite large - # and bf16 would lose a lot of precision - if self.pos_idx_in_fp32: - t = torch.arange(seqlen, device=device, dtype=torch.float32) - if self.inv_freq.dtype != torch.float32: - inv_freq = self._compute_inv_freq(device=device) - else: - inv_freq = self.inv_freq - else: - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - inv_freq = self.inv_freq - - # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP - freqs = torch.outer(t, inv_freq) - if self.scale is None: - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - else: - power = ( - torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 - ) / self.scale_base - scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") - - # Force the scale multiplication to happen in fp32 - self._cos_cached = (torch.cos(freqs) * scale).to(dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) - - def forward( - self, - qkv: torch.Tensor, - kv: Optional[torch.Tensor] = None, - seqlen_offset: int = 0, - **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor]: - if ( - self._seq_len_cached < qkv.shape[1] + seqlen_offset - or self._cos_cached.device != qkv.device - or self._cos_cached.dtype != qkv.dtype - or (self.training and self._cos_cached.is_inference()) - ): - self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype) - - if kv is None: - return _apply_rotary_emb_qkv( - qkv, - self._cos_cached[seqlen_offset:], - self._sin_cached[seqlen_offset:], - ) - else: - q = _apply_rotary_emb( - qkv, - self._cos_cached[seqlen_offset:], - self._sin_cached[seqlen_offset:], - ) - kv = _apply_rotary_emb_kv( - kv, - self._cos_cached[seqlen_offset:], - self._sin_cached[seqlen_offset:], - ) - - return q, kv - - -class MLP(nn.Module): - """Multi-Layer Perceptron. - - Reference: - Attention Is All You Need. - https://arxiv.org/pdf/1706.03762.pdf. - - """ - - def __init__( - self, - config: PretrainedConfig, - n_inner: Optional[int] = None, - act_fn: Optional[str] = None, - ) -> None: - super().__init__() - - act_fn = config.activation_function if act_fn is None else act_fn - - n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner - n_inner = n_inner if n_inner is not None else 4 * config.n_embd - - self.fc1 = nn.Linear(config.n_embd, n_inner) - self.fc2 = nn.Linear(n_inner, config.n_embd) - self.act = ACT2FN[act_fn] - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - hidden_states = self.fc1(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.fc2(hidden_states) - - return hidden_states - - -class SelfAttention(nn.Module): - """Self-attention layer (compatible with PyTorch). - - Reference: - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py. - - """ - - def __init__( - self, - causal: bool = True, - softmax_scale: Optional[float] = None, - attention_dropout: float = 0.0, - ) -> None: - super().__init__() - - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - @torch.autocast("cpu", enabled=False) - @torch.autocast("cuda", enabled=False) - def forward( - self, - qkv: torch.FloatTensor, - causal: bool = None, - key_padding_mask: Optional[torch.BoolTensor] = None, - **kwargs, - ) -> torch.FloatTensor: - batch_size, seqlen = qkv.shape[0], qkv.shape[1] - q, k, v = qkv.unbind(dim=2) - - q = q.to(torch.float32) - k = k.to(torch.float32) - - causal = self.causal if causal is None else causal - softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - - # Autocast is manually disabled to avoid `torch.einsum` performing the operation - # using float16, which might lead to overflow - scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) - - if key_padding_mask is not None: - padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device) - padding_mask.masked_fill_(key_padding_mask, 0.0) - - scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") - - if causal: - causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) - scores = scores + causal_mask.to(dtype=scores.dtype) - - attention = torch.softmax(scores, dim=-1).to(v.dtype) - attention = self.drop(attention) - - output = torch.einsum("bhts,bshd->bthd", attention, v) - - return output - - -class CrossAttention(nn.Module): - """Cross-attention layer (compatible with PyTorch). - - Reference: - https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py. - - """ - - def __init__( - self, - causal: bool = True, - softmax_scale: Optional[float] = None, - attention_dropout: float = 0.0, - ) -> None: - super().__init__() - - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - @torch.autocast("cpu", enabled=False) - @torch.autocast("cuda", enabled=False) - def forward( - self, - q: torch.FloatTensor, - kv: torch.FloatTensor, - causal: bool = None, - key_padding_mask: Optional[torch.BoolTensor] = None, - causal_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.FloatTensor: - batch_size, seqlen_q = q.shape[0], q.shape[1] - seqlen_k = kv.shape[1] - - if kv.shape[3] != q.shape[2]: - kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) - k, v = kv.unbind(dim=2) - - q = q.to(torch.float32) - k = k.to(torch.float32) - - causal = self.causal if causal is None else causal - softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - - # Autocast is manually disabled to avoid `torch.einsum` performing the operation - # using float16, which might lead to overflow - # scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) - q = q.permute(0, 2, 1, 3) - k_ = (k * softmax_scale).permute(0, 2, 3, 1) - scores = torch.matmul(q, k_) - - if key_padding_mask is not None: - padding_mask = torch.full( - (batch_size, seqlen_k), - -10000.0, - dtype=scores.dtype, - device=scores.device, - ) - padding_mask.masked_fill_(key_padding_mask, 0.0) - - scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") - - if causal_mask is not None: - scores = scores.masked_fill(causal_mask, -10000.0) - elif causal: - rows = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1") - cols = torch.arange(seqlen_k, device=k.device, dtype=torch.long) - causal_mask = cols > rows + seqlen_k - seqlen_q - - scores = scores.masked_fill(causal_mask, -10000.0) - - attention = torch.softmax(scores, dim=-1).to(v.dtype) - attention = self.drop(attention) - - # output = torch.einsum("bhts,bshd->bthd", attention, v) - v = v.permute(0, 2, 1, 3) - output = torch.matmul(attention, v).permute(0, 2, 1, 3) - return output - - -def _find_mha_dims( - config: PretrainedConfig, - n_head: Optional[int] = None, - n_head_kv: Optional[int] = None, - head_dim: Optional[int] = None, -) -> Tuple[int, int]: - if n_head is None and head_dim is None: - head_dim = config.n_embd // config.n_head - n_head = config.n_head - elif n_head is None or head_dim is None: - raise ValueError("`n_head` and `head_dim` must be both specified or `None`.") - - if n_head_kv is None: - n_head_kv = getattr(config, "n_head_kv", None) or n_head - - return n_head, n_head_kv, head_dim - - -def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, layer_idx: int) -> torch.FloatTensor: - num_heads, head_dim = kv.shape[-2:] - - if layer_idx not in inference_params.key_value_memory_dict: - inference_params.key_value_memory_dict[layer_idx] = torch.empty( - inference_params.max_batch_size, - inference_params.max_seqlen, - 2, - num_heads, - head_dim, - dtype=kv.dtype, - device=kv.device, - ) - - batch_start = inference_params.batch_size_offset - batch_end = batch_start + kv.shape[0] - - sequence_start = inference_params.seqlen_offset - sequence_end = sequence_start + kv.shape[1] - - # When the current sequence length is equal to or larger than the maximum sequence length, - # we need to concatenate the current `kv` with the cached `kv` to expand its length - if sequence_end >= inference_params.max_seqlen: - inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1) - - inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv - kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...] - - return kv - - -class MHA(nn.Module): - """Multi-head attention layer.""" - - def __init__( - self, - config: PretrainedConfig, - dtype: Optional[torch.dtype] = None, - device: Optional[str] = None, - rotary_dim: Optional[int] = None, - rotary_base: float = 10000.0, - rotary_scale_base: Optional[float] = None, - n_head: Optional[int] = None, - n_head_kv: Optional[int] = None, - head_dim: Optional[int] = None, - bias: bool = True, - causal: bool = True, - softmax_scale: Optional[float] = None, - layer_idx: Optional[int] = None, - return_residual: bool = False, - checkpointing: bool = False, - ) -> None: - super().__init__() - - # Rotary embedding - self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0) - if self.rotary_dim > 0: - rotary_cls = FlashRotaryEmbedding if config.flash_rotary else RotaryEmbedding - if rotary_cls is None: - rotary_cls = RotaryEmbedding - - rotary_kwargs = {} - if rotary_cls is RotaryEmbedding: - rotary_kwargs["max_position_embeddings"] = config.n_positions - - self.rotary_emb = rotary_cls( - self.rotary_dim, - base=rotary_base, - scale_base=rotary_scale_base, - device=device, - **rotary_kwargs, - ) - - # MLP - self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims( - config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim - ) - op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv) - hidden_size = config.n_embd - - linear_cls = FusedDense if config.fused_dense else nn.Linear - if linear_cls is None: - linear_cls = nn.Linear - - self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype) - self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype) - - # Attention - attn_cls = FlashSelfAttention if config.flash_attn else SelfAttention - if attn_cls is None: - attn_cls = SelfAttention - - cross_attn_cls = FlashCrossAttention if config.flash_attn else CrossAttention - if cross_attn_cls is None: - cross_attn_cls = CrossAttention - - self.inner_attn = attn_cls( - causal=causal, - softmax_scale=softmax_scale, - attention_dropout=config.attn_pdrop, - ) - self.inner_cross_attn = cross_attn_cls( - causal=causal, - softmax_scale=softmax_scale, - attention_dropout=config.attn_pdrop, - ) - - self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention - self.layer_idx = layer_idx - self.return_residual = return_residual - self.checkpointing = checkpointing - - def _forward_self_attn( - self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor] - ) -> torch.FloatTensor: - qkv = self.Wqkv(x) - qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) - - if self.rotary_dim > 0: - qkv = self.rotary_emb(qkv) - - if self.flash_attn: - batch_size, seqlen = qkv.shape[0], qkv.shape[1] - - cu_seqlens, max_seqlen = None, None - if key_padding_mask is not None: - # If `key_padding_mask` is supplied, we need to unpad the input and retrieve - # the `cu_seqlens` and `max_seqlen` to be used by `flash-attn` - qkv, indices, cu_seqlens, max_seqlen = unpad_input(qkv, key_padding_mask) - - if self.checkpointing: - attn_output = torch.utils.checkpoint.checkpoint( - self.inner_attn, qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen - ) - else: - attn_output = self.inner_attn(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).to(qkv.device) - - # If `key_padding_mask` is supplied, we need to pad the output back to the original shape - return pad_input(attn_output, indices, batch_size, seqlen) if key_padding_mask is not None else attn_output - - if self.checkpointing: - return torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, key_padding_mask=key_padding_mask) - - return self.inner_attn(qkv, key_padding_mask=key_padding_mask) - - def _forward_cross_attn( - self, - x: torch.FloatTensor, - past_key_values: Optional[Union[torch.Tensor, InferenceParams]], - key_padding_mask: Optional[torch.BoolTensor], - rotary_pos_emb: Optional[torch.Tensor] = None, - causal_mask: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size = x.shape[0] - - qkv = self.Wqkv(x) - - q = qkv[..., : self.n_head * self.head_dim] - q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) - - kv = qkv[..., self.n_head * self.head_dim :] - kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) - - if rotary_pos_emb is None: - seqlen_offset = past_key_values.seqlen_offset if past_key_values is not None else 0 - causal = None if seqlen_offset == 0 else False - if self.rotary_dim > 0: - q, kv = self.rotary_emb(q, kv=kv, seqlen_offset=seqlen_offset) - else: - causal = False - cos_pos, sin_pos = rotary_pos_emb - q = _apply_rotary_emb(q, cos_pos, sin_pos) - kv = _apply_rotary_emb_kv(kv, cos_pos, sin_pos) - - if past_key_values is not None: - if type(past_key_values) is InferenceParams: - kv = _update_kv_cache(kv, past_key_values, self.layer_idx) - else: - # kv.shape is [1, past + 1, 2, 32, 80] - #print('kv.shape', kv.shape) - #print('past_key_values.shape', past_key_values.shape) - kv = torch.cat((past_key_values, kv), dim=1) - - if self.flash_attn: - batch_size, seqlen_q = q.shape[0], q.shape[1] - seqlen_k = kv.shape[1] - - cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = ( - None, - None, - None, - None, - ) - if key_padding_mask is not None: - kv, _, cu_seqlens_k, max_seqlen_k = unpad_input(kv, key_padding_mask) - - if seqlen_q == 1: - key_padding_mask = torch.ones(batch_size, 1, device=q.device) - elif seqlen_q != seqlen_k: - key_padding_mask = key_padding_mask[:, -seqlen_q:] - - q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, key_padding_mask) - - if self.checkpointing: - attn_output = torch.utils.checkpoint.checkpoint( - self.inner_cross_attn, - q, - kv, - causal=causal, - cu_seqlens=cu_seqlens_q, - max_seqlen=max_seqlen_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_k=max_seqlen_k, - ) - else: - attn_output = self.inner_cross_attn( - q, - kv, - causal=causal, - cu_seqlens=cu_seqlens_q, - max_seqlen=max_seqlen_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_k=max_seqlen_k, - ) - - return ( - pad_input(attn_output, indices_q, batch_size, max_seqlen_q) - if key_padding_mask is not None - else attn_output - ) - - if self.checkpointing: - return torch.utils.checkpoint.checkpoint( - self.inner_cross_attn, - q, - kv, - key_padding_mask=key_padding_mask, - causal=causal, - ) - output = self.inner_cross_attn(q, kv, key_padding_mask=key_padding_mask, causal=causal, causal_mask=causal_mask) - return output, kv - - def forward( - self, - x: torch.FloatTensor, - past_key_values: Optional[InferenceParams] = None, - attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - causal_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: - if attention_mask is not None: - attention_mask = attention_mask.bool() - else: - attention_mask = None - - attention_mask = None - kv = None - # MHA - if self.n_head == self.n_head_kv: - if past_key_values is None and False: - # If `past_key_values` are not supplied, we run self-attention - attn_output = self._forward_self_attn(x, attention_mask) - else: - # If `past_key_values` are supplied, it means that we might have cached values and - # could take advantage of cross-attention - attn_output, kv = self._forward_cross_attn(x, past_key_values, attention_mask, rotary_pos_emb, causal_mask) - # MQA / GQA - else: - # Regardless of `past_key_values` being supplied or not, it always use cross-attention - # because `q` and `kv` lengths might be different - attn_output = self._forward_cross_attn(x, past_key_values, attention_mask) - - output = rearrange(attn_output, "... h d -> ... (h d)") - output = self.out_proj(output) - - # return output if not self.return_residual else (output, x) - return output, kv - - -class ParallelBlock(nn.Module): - """Parallel block. - - This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen). - - """ - - def __init__( - self, - config: PretrainedConfig, - block_idx: Optional[int] = None, - ) -> None: - super().__init__() - - self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.resid_dropout = nn.Dropout(config.resid_pdrop) - self.block_idx = block_idx - - self.mixer = MHA(config, layer_idx=block_idx) - self.mlp = MLP(config) - - def forward( - self, - hidden_states: torch.FloatTensor, - past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, - attention_mask: Optional[torch.BoolTensor] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - causal_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.FloatTensor: - residual = hidden_states - hidden_states = self.ln(hidden_states) - - attn_outputs, kv = self.mixer( - hidden_states, - past_key_values=past_key_values, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - causal_mask=causal_mask, - ) - - if isinstance(attn_outputs, tuple): - attn_outputs = attn_outputs[0] - - # attn_outputs = self.resid_dropout(attn_outputs) - # feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) - feed_forward_hidden_states = self.mlp(hidden_states) - hidden_states = attn_outputs + feed_forward_hidden_states + residual - return hidden_states, kv - - -class CausalLMHead(nn.Module): - """Causal Language Modeling head. - - Reference: - Improving Language Understanding by Generative Pre-Training. - https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. - - """ - - def __init__(self, config: PretrainedConfig) -> None: - super().__init__() - - self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.linear = nn.Linear(config.n_embd, config.vocab_size) - - def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: - hidden_states = self.ln(hidden_states) - logits = self.linear(hidden_states).to(torch.float32) - return logits - - -class CausalLMLoss(nn.Module): - """Causal Language Modeling loss. - - Reference: - Improving Language Understanding by Generative Pre-Training. - https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf. - - """ - - def __init__(self, shift_labels: bool = True) -> None: - super().__init__() - - self.shift_labels = shift_labels - self.loss_fct = nn.CrossEntropyLoss() - - def forward(self, logits: torch.FloatTensor, labels: torch.LongTensor) -> torch.FloatTensor: - if self.shift_labels: - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - - loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) - - return loss - - -class PhiPreTrainedModel(PreTrainedModel): - """Phi pre-trained model.""" - - config_class = PhiConfig - base_model_prefix = "transformer" - supports_gradient_checkpointing = False - _no_split_modules = ["ParallelBlock"] - - def __init__(self, *inputs, **kwargs) -> None: - super().__init__(*inputs, **kwargs) - - def _init_weights(self, module: nn.Module) -> None: - if isinstance(module, (nn.Linear,)): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - if module.bias is not None: - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, - attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None, - **kwargs, - ) -> Dict[str, Any]: - if past_key_values is None or not (isinstance(past_key_values, InferenceParams)): - past_key_values = InferenceParams( - max_seqlen=self.config.n_positions, - max_batch_size=input_ids.shape[0], - seqlen_offset=0, - batch_size_offset=0, - key_value_memory_dict={}, - lengths_per_sample=None, - ) - else: - # Assume that `past_key_values` has cached all tokens up to the last token in `input_ids` - past_key_values.seqlen_offset = input_ids.shape[1] - 1 - input_ids = input_ids[:, -1].unsqueeze(-1) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "attention_mask": attention_mask, - } - - -class PhiModel(PhiPreTrainedModel): - """Phi model.""" - - _keys_to_ignore_on_load_missing = [""] - _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"] - - def __init__(self, config: PhiConfig) -> None: - super().__init__(config) - - self.embd = Embedding(config) - self.h = nn.ModuleList([ParallelBlock(config, block_idx=i) for i in range(config.n_layer)]) - self.gradient_checkpointing = False - self.post_init() - - def get_input_embeddings(self) -> nn.Embedding: - return self.embd.wte - - def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None: - self.embd.wte = new_embeddings - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, - attention_mask: Optional[torch.BoolTensor] = None, - ) -> torch.FloatTensor: - hidden_states = self.embd(input_ids) - - for layer in self.h: - hidden_states, _ = layer( - hidden_states, - past_key_values=past_key_values, - attention_mask=attention_mask, - ) - - return hidden_states - - -class PhiForCausalLM(PhiPreTrainedModel): - """Phi for Causal Language Modeling.""" - - _keys_to_ignore_on_load_missing = [""] - _keys_to_ignore_on_load_unexpected = [r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"] - - def __init__(self, config: PhiConfig) -> None: - super().__init__(config) - - self.transformer = PhiModel(config) - self.lm_head = CausalLMHead(config) - self.loss = CausalLMLoss() - - self.post_init() - - def get_output_embeddings(self) -> nn.Linear: - return self.lm_head.linear - - def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: - self.lm_head.linear = new_embeddings - - def forward( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None, - attention_mask: Optional[torch.BoolTensor] = None, - labels: Optional[torch.LongTensor] = None, - **kwargs, - ) -> CausalLMOutputWithPast: - hidden_states = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask) - lm_logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - loss = self.loss(lm_logits, labels) - - return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values) diff --git a/transformers/llm/export/llmexport.py b/transformers/llm/export/llmexport.py new file mode 100644 index 000000000..17862a632 --- /dev/null +++ b/transformers/llm/export/llmexport.py @@ -0,0 +1,1705 @@ +import os +import sys +import math +import copy +import json +import time +import base64 +import logging +import warnings +import argparse +import functools +from typing import Optional, Tuple + +from yaspin import yaspin + +import onnx +import torch +import numpy as np +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + +RESET = "\033[0m" +GREEN = "\033[32;1m" +YELLOW = "\033[33;4m" +EXPORT_LOG = '.export.log' + +# ignore warnning info +warnings.filterwarnings("ignore") +logging.basicConfig(level=logging.ERROR) +os.environ['TOKENIZERS_PARALLELISM'] = 'false' + +def spinner_run(text='Processing...'): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + with yaspin(text=text, color="cyan") as spinner: + start = time.time() + try: + result = func(*args, **kwargs) + except Exception as e: + spinner.fail("💥 Failed") + print(e) + exit(1) + end = time.time() + during = f'[{end-start:05.2f} s]'.replace('[0', '[ ') + padding = ' ' * (64 - len(spinner.text) - len(result)) + spinner.text = f'{spinner.text}{YELLOW}{result}{RESET}{padding}{GREEN}{during}{RESET}' + spinner.ok("✅ Done") + return result + return wrapper + return decorator + +class ModelMapper: + def __init__(self): + self.attrs = [] + self.mapper = dict() + self.regist_models() + + def get_map(self, config): + model_type = config.model_type + if model_type == 'chatglm': + if hasattr(config, 'vocab_size') and config.vocab_size == 130528: + model_type = 'chatglm' + else: + model_type = 'chatglm2' + if model_type in self.mapper: + return model_type, self.mapper[model_type] + return model_type, self.default_map + + def regist(self, model_type, model_map): + assert('config' in model_map and + 'decoder' in model_map and + 'attention' in model_map) + self.mapper[model_type] = model_map + + def regist_models(self): + self.defualt_map() + # regist models + self.regist_llama() + self.regist_qwen() + self.regist_glm() + self.regist_glm2() + self.regist_phi() + self.regist_gemma2() + + def regist_llama(self): + llama_map = self.default_map + self.regist('llama', llama_map) + self.regist('qwen2', llama_map) + self.regist('internlm', llama_map) + baichuan_map = copy.deepcopy(self.default_map) + baichuan_map[self.attention_key] = { + 'qkv_proj': 'W_pack', + 'o_proj': 'o_proj' + } + self.regist('baichuan', baichuan_map) + + def regist_qwen(self): + qwen_map = { + 'config': { + 'hidden_size': 'hidden_size', + 'num_attention_heads': 'num_attention_heads', + 'num_hidden_layers': 'num_hidden_layers', + 'rope_theta': 'rotary_emb_base', + }, + 'model': { + 'lm_': 'lm_head', + 'embed_': 'transformer.wte', + 'blocks_': 'transformer.h', + 'final_layernorm_': 'transformer.ln_f', + 'visual': 'transformer.visual' + }, + 'decoder': { + 'self_attn': 'attn', + 'mlp': 'mlp', + 'input_layernorm': 'ln_1', + 'post_attention_layernorm': 'ln_2' + }, + 'attention': { + 'qkv_proj': 'c_attn', + 'o_proj': 'c_proj' + } + } + self.regist('qwen', qwen_map) + + def regist_glm(self): + glm_map = { + 'config': { + 'hidden_size': 'hidden_size', + 'num_attention_heads': 'num_attention_heads', + 'num_hidden_layers': 'num_layers' + }, + 'model': { + 'lm_': 'lm_head', + 'embed_': 'transformer.word_embeddings', + 'blocks_': 'transformer.layers', + 'final_layernorm_': 'transformer.final_layernorm', + }, + 'decoder': { + 'self_attn': 'attention', + 'mlp': 'mlp', + 'input_layernorm': 'input_layernorm', + 'post_attention_layernorm': 'post_attention_layernorm' + }, + 'attention': { + 'qkv_proj': 'query_key_value', + 'o_proj': 'dense' + } + } + self.regist('chatglm', glm_map) + + def regist_glm2(self): + glm2_map = { + 'config': { + 'hidden_size': 'hidden_size', + 'num_attention_heads': 'num_attention_heads', + 'num_key_value_heads': 'multi_query_group_num', + 'num_hidden_layers': 'num_layers', + }, + 'model': { + 'lm_': 'transformer.output_layer', + 'embed_': 'transformer.embedding.word_embeddings', + 'blocks_': 'transformer.encoder.layers', + 'final_layernorm_': 'transformer.encoder.final_layernorm', + }, + 'decoder': { + 'self_attn': 'self_attention', + 'mlp': 'mlp', + 'input_layernorm': 'input_layernorm', + 'post_attention_layernorm': 'post_attention_layernorm' + }, + 'attention': { + 'qkv_proj': 'query_key_value', + 'o_proj': 'dense' + } + } + self.regist('chatglm2', glm2_map) + + def regist_phi(self): + phi_map = { + 'config': { + 'hidden_size': 'n_embd', + 'num_attention_heads': 'n_head', + 'num_hidden_layers': 'n_layer', + 'rotary_dim': 'rotary_dim' + }, + 'model': { + 'lm_': 'lm_head.linear', + 'embed_': 'transformer.embd.wte', + 'blocks_': 'transformer.h', + 'final_layernorm_': 'lm_head.ln', + }, + 'decoder': { + 'self_attn': 'mixer', + 'mlp': 'mlp', + 'input_layernorm': 'ln', + }, + 'attention': { + 'qkv_proj': 'Wqkv', + 'o_proj': 'out_proj' + } + } + self.regist('phi-msft', phi_map) + + def regist_gemma2(self): + gemma2_config = copy.deepcopy(self.default_config) + gemma2_config['head_dim'] = 'head_dim' + gemma2_decoder = copy.deepcopy(self.default_decoder) + gemma2_decoder['pre_feedforward_layernorm'] = 'pre_feedforward_layernorm' + gemma2_decoder['post_feedforward_layernorm'] = 'post_feedforward_layernorm' + gemma2_map = { + 'config': gemma2_config, + 'model': self.defualt_model, + 'decoder': gemma2_decoder, + 'attention': self.default_attention + } + self.regist('gemma2', gemma2_map) + + def defualt_map(self): + # default map is `LlamaForCausalLM` + self.config_key = 'config' + self.model_key = 'model' + self.decoder_key = 'decoder' + self.attention_key = 'attention' + self.default_config = { + 'hidden_size': 'hidden_size', + 'num_attention_heads': 'num_attention_heads', + 'num_hidden_layers': 'num_hidden_layers', + 'num_key_value_heads': 'num_key_value_heads', + 'rope_theta': 'rope_theta' + } + self.defualt_model = { + 'lm_': 'lm_head', + 'embed_': 'model.embed_tokens', + 'blocks_': 'model.layers', + 'final_layernorm_': 'model.norm', + } + self.default_decoder = { + 'self_attn': 'self_attn', + 'mlp': 'mlp', + 'input_layernorm': 'input_layernorm', + 'post_attention_layernorm': 'post_attention_layernorm' + } + self.default_attention = { + 'q_proj': 'q_proj', + 'k_proj': 'k_proj', + 'v_proj': 'v_proj', + 'o_proj': 'o_proj' + } + self.default_map = { + 'config': self.default_config, + 'model': self.defualt_model, + 'decoder': self.default_decoder, + 'attention': self.default_attention + } + + @staticmethod + def do_map(dst, src, map): + for dst_attr, src_attr in map.items(): + attributes = src_attr.split('.') + obj = src + for attr in attributes: + if hasattr(obj, attr): + obj = getattr(obj, attr) + else: + obj = None + break + setattr(dst, dst_attr, obj) + + +# Export class +class LlmExporterOp(torch.autograd.Function): + @staticmethod + def symbolic(g, input, in_features, out_features, has_bias, name): + args = [input] + # These become the operator attributes. + kwargs = { + "in_features_i": in_features, + "out_features_i": out_features, + "has_bias_i": has_bias, + "name_s": name + } + from torch.onnx.symbolic_helper import _get_tensor_sizes + out_sizes = _get_tensor_sizes(input)[:-1] + [out_features] + output_type = input.type().with_sizes(out_sizes) + return g.op("LlmExporter::FakeLinear", input, **kwargs).setType(output_type) + + @staticmethod + def forward(ctx, input, in_features, out_features, has_bias, name): + out_shape = list(input.shape)[:-1] + [out_features] + return input.new_zeros(out_shape) + +class FakeLinear(torch.nn.Module): + def __init__(self, in_features, out_features, has_bias, name): + super(FakeLinear, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.has_bias = has_bias + self.name = name + + def forward(self, x): + return LlmExporterOp.apply(x, self.in_features, self.out_features, self.has_bias, self.name) + +class OnnxRebuilder: + def __init__(self, onnx_path, weight_ops): + self.weight_ops = weight_ops + self.onnx_model = onnx.load(onnx_path) + self.dst_path = onnx_path + self.onnx_weight_path = f'{onnx_path}.data' + self.onnx_weight_offset = 0 + + def make_external(self, name, data, shape): + # write to external weight + length = self.onnx_weight.write(data.tobytes()) + location = os.path.basename(self.onnx_weight_path) + offset = self.onnx_weight_offset + self.onnx_weight_offset += length + tensor = onnx.TensorProto() + tensor.name = name + tensor.data_type = onnx.TensorProto.FLOAT + tensor.dims.extend(shape) + # external info + tensor.data_location = onnx.TensorProto.EXTERNAL + for k, v in { "location": location, "offset": offset, "length": length }.items(): + entry = tensor.external_data.add() + entry.key = k + entry.value = str(v) + self.onnx_model.graph.initializer.append(tensor) + + def build_weight(self, name, has_bias, ic, oc): + assert(name in self.weight_ops) + linear = self.weight_ops[name] + assert(linear.in_features == ic and + linear.out_features == oc and + (linear.bias is not None) == has_bias) + weight_name, bias_name = f'{name}_weight', f'{name}_bias' + weight = linear.weight.data.transpose(1, 0).flatten().numpy() + self.make_external(weight_name, weight, [ic, oc]) + if has_bias: + bias = linear.bias.data.flatten().numpy() + self.make_external(bias_name, bias, [oc]) + return weight_name, bias_name + + def rebuild(self): + from onnx import helper + new_nodes = [] + self.onnx_weight = open(self.onnx_weight_path, 'wb') + for node in self.onnx_model.graph.node: + if node.op_type == 'FakeLinear': + attributes = {a.name: a for a in node.attribute} + name = attributes.get('name').s.decode('utf-8') + has_bias = attributes.get('has_bias').i + ic = attributes.get('in_features').i + oc = attributes.get('out_features').i + weight, bias = self.build_weight(name, has_bias, ic, oc) + if has_bias: + # fakelinear -> matmul + add + middle_tensor = f'{name}_matmul' + new_nodes.append(helper.make_node('MatMul', [node.input[0], weight], [middle_tensor], name)) + new_nodes.append(helper.make_node('Add', [middle_tensor, bias], node.output, name)) + else: + # fakelinear -> matmul + new_nodes.append(helper.make_node('MatMul', [node.input[0], weight], node.output, name)) + else: + new_nodes.append(node) + self.onnx_weight.close() + del self.onnx_model.graph.node[:] + self.onnx_model.graph.node.extend(new_nodes) + onnx.save(self.onnx_model, self.dst_path) + return self.onnx_weight_path + +class MNNConveter: + def __init__(self, onnx_path, weight_ops, config): + self.weight_ops = weight_ops + self.quant_block = config.quant_block + self.quant_bit = config.quant_bit + self.lm_quant_bit = config.lm_quant_bit + self.mnn_weight_offset = 0 + self.onnx_model_path = onnx_path + self.mnn_model_path = onnx_path.replace('.onnx', '.mnn') + self.mnn_weight_path = f'{self.mnn_model_path}.weight' + if os.path.exists(config.mnnconvert): + self.mnnconvert = config.mnnconvert + else: + self.mnnconvert = None + + def convert(self, convert_args): + sfd = os.dup(1) + log_fp = open(EXPORT_LOG, "a") + log_fd = log_fp.fileno() + # mnnconvert ... > .convert_mnn.log + os.dup2(log_fd, 1) + try: + sys.argv = convert_args + sys.argc = len(convert_args) + if self.mnnconvert is None: + from MNN.tools import mnnconvert + mnnconvert.main() + else: + convert_args[0] = self.mnnconvert + cmd = ' '.join(convert_args) + message = os.popen(cmd).read() + print(message) + sys.argv = [] + finally: + os.dup2(sfd, 1) + os.close(log_fd) + + @spinner_run(f'convert onnx model to ') + def onnx2mnn(self, onnx_path, mnn_path, args = []): + convert_args = [ + '', + '-f', + 'ONNX', + '--modelFile', + str(onnx_path), + '--MNNModel', + str(mnn_path), + '--transformerFuse', + '--allowCustomOp' + ] + convert_args += args + self.convert(convert_args) + return mnn_path + + def mnn2json(self, mnn_path, json_path): + convert_args = [ + '', + '-f', + 'MNN', + '--modelFile', + str(mnn_path), + '--JsonFile', + str(json_path) + ] + self.convert(convert_args) + return json_path + + def json2mnn(self, json_path, mnn_path): + convert_args = [ + '', + '-f', + 'JSON', + '--modelFile', + str(json_path), + '--MNNModel', + str(mnn_path) + ] + self.convert(convert_args) + return mnn_path + + def export(self): + if self.weight_ops is None: + quant_args = [ + '--weightQuantBits', + str(self.quant_bit), + '--weightQuantBlock', + str(self.quant_block) + ] + self.onnx2mnn(self.onnx_model_path, self.mnn_model_path, quant_args) + else: + mnn_json = f'{self.mnn_model_path}.json' + self.onnx2mnn(self.onnx_model_path, self.mnn_model_path) + self.mnn2json(self.mnn_model_path, mnn_json) + self.rebuild(mnn_json) + self.json2mnn(mnn_json, self.mnn_model_path) + + @spinner_run(f'quant model weight to ') + def rebuild(self, json_path): + self.mnn_weight = open(self.mnn_weight_path, 'wb') + mnn_graph = json.load(open(json_path, 'rt')) + new_ops = [] + for op in mnn_graph['oplists']: + if op['type'] == 'Extra': + new_ops += self.rebuild_op(op, mnn_graph) + else: + new_ops.append(op) + mnn_graph['oplists'] = new_ops + with open(json_path, 'w', encoding='utf-8') as file: + json.dump(mnn_graph, file, ensure_ascii=False, indent=4) + return self.mnn_weight_path + + def quant(self, weight, quant_bit, quant_block): + weight = weight.numpy() + oc, ic = weight.shape + if quant_block == 0: + block_size = ic + else: + block_size = quant_block + block_num = ic // block_size + weight = weight.reshape(oc, block_num, block_size) + max_val = np.max(weight, axis=-1, keepdims=True) + min_val = np.min(weight, axis=-1, keepdims=True) + offset = 1 << (quant_bit - 1) + clip_max = offset - 1 + clip_min = -offset + scale = (max_val - min_val) / (clip_max - clip_min) + q_weight = np.round((weight - min_val) / scale) + clip_min + q_weight = (np.clip(q_weight.flatten(), clip_min, clip_max) + offset).astype(np.uint8) + q_weight = q_weight.reshape(-1, 2) + if quant_bit == 4: + q_weight = q_weight[:, 0] * 16 + q_weight[:, 1] + alpha = np.stack([min_val.flatten(), scale.flatten()], axis=-1).flatten() + return q_weight, alpha, clip_min + + def write_npy(self, data): + return self.mnn_weight.write(data.tobytes()) + + def write_header(self, ic, oc, quant_bit): + dim_num = self.mnn_weight.write(b'\x02') + shape_dtype = np.int16 + if oc > 65535 or ic > 65535: + shape_dtype = np.int32 + dim_length = self.write_npy(np.array([oc, ic]).astype(shape_dtype)) + offset = 1 << (quant_bit - 1) + weight_map = [i for i in range(-offset, offset)] + if len(weight_map) == 256: + weight_map.insert(0, 0) + else: + weight_map.insert(0, len(weight_map)) + map_length = self.write_npy(np.array(weight_map, dtype=np.int8)) + header_length = dim_num + dim_length + map_length + return header_length, shape_dtype == np.int32 + + def build_weight(self, linear, quant_bit, quant_block): + ic, oc = linear.in_features, linear.out_features + q_weight, alpha, q_min = self.quant(linear.weight.data, quant_bit, quant_block) + header_len, shape_int32 = self.write_header(ic, oc, quant_bit) + weight_len = self.write_npy(q_weight) + header_len + alpha_len = self.write_npy(alpha) + if linear.bias is not None: + bias = linear.bias.data.flatten().numpy() + bias_length = self.write_npy(bias) + else: + bias_length = 0 + # bias = np.zeros([oc], dtype=np.float32) + # bias_length = self.write_npy(bias) + external = [self.mnn_weight_offset, weight_len, alpha_len, bias_length, 0] + self.mnn_weight_offset += (weight_len + alpha_len + bias_length) + return external, q_min, shape_int32 + + def build_tensor(self, graph, tensor_name): + tensor_idx = [len(graph['tensorName'])] + graph['tensorName'].append(tensor_name) + return tensor_idx + + def rebuild_op(self, op, graph): + attrs = op['main']['attr'] + for attr in attrs: + if attr['key'] == 'name': + name = attr['s'] + elif attr['key'] == "in_features": + ic = attr["i"] + elif attr['key'] == "out_features": + oc = attr["i"] + elif attr['key'] == "has_bias": + has_bias = attr["i"] + linear = self.weight_ops[name] + assert(linear.in_features == ic and + linear.out_features == oc and + (linear.bias is not None) == has_bias) + + + quant_bit = self.lm_quant_bit if 'lm_head' in name else self.quant_bit + external, q_min, shape_int32 = self.build_weight(linear, quant_bit, self.quant_block) + + origin_input = op['inputIndexes'] + origin_output = op['outputIndexes'] + # build new tensor + pre_reshape_name = f'{name}/pre_reshape' + pre_convert_name = f'{name}/pre_convert' + conv_name = name + post_convert_name = f'{name}/post_convert' + post_reshape_name = f'{name}/post_reshape' + pre_reshape_output = self.build_tensor(graph, pre_reshape_name) + pre_convert_output = self.build_tensor(graph, pre_convert_name) + conv_output = self.build_tensor(graph, conv_name) + post_convert_output = self.build_tensor(graph, post_convert_name) + # [batch, seq, hidden_size_i] -[Linear] -> [batch, seq, hidden_size_o] + # [1, seq, hidden_size_i] ->[Reshape]-> [seq, hidden_size_i, 1, 1] + # -[Convert]-[Convolution]-[Convert]-> [Reshape] -> [1, seq, hidden_size_o] + pre_reshape = { + "name": pre_reshape_name, + "type": "Reshape", + "inputIndexes": origin_input, + "outputIndexes": pre_reshape_output, + "main_type": "Reshape", + "main": { + "dims": [-1, ic, 1, 1], + "dimType": "NCHW" + }, + "defaultDimentionFormat": "NHWC" + } + pre_convert = { + "name": pre_convert_name, + "inputIndexes": pre_reshape_output, + "outputIndexes": pre_convert_output, + "type": "ConvertTensor", + "main_type": "TensorConvertInfo", + "main": { + "source": "NCHW", + "dest": "NC4HW4" + }, + "defaultDimentionFormat": "NHWC" + } + conv_op = { + "name": conv_name, + "inputIndexes": pre_convert_output, + "outputIndexes": conv_output, + "type": "Convolution", + "main_type": "Convolution2D", + "main": { + 'common': { + 'dilateX': 1, 'dilateY': 1, 'strideX': 1, 'strideY': 1, + 'kernelX': 1, 'kernelY': 1, 'padX': 0, 'padY': 0, 'group': 1, + 'outputCount': oc, 'relu': False, 'padMode': 'CAFFE', + 'relu6': False, 'inputCount': ic, 'hasOutputShape': False + }, + "quanParameter": { + "quantScale": 1.0, "scaleIn": 0.0, "scaleOut": 0.0, + "useInt32": False, "has_scaleInt": False, "shapeInt32": shape_int32, + "type": 1, "aMax": 0, "aMin": q_min, "readType": oc * (ic // self.quant_block), "weightSize": 0 + }, + "external": external + }, + "defaultDimentionFormat": "NHWC" + } + post_convert = { + "name": post_convert_name, + "inputIndexes": conv_output, + "outputIndexes": post_convert_output, + "type": "ConvertTensor", + "main_type": "TensorConvertInfo", + "main": { + "source": "NC4HW4", + "dest": "NCHW" + }, + "defaultDimentionFormat": "NHWC" + } + post_reshape = { + "name": post_reshape_name, + "type": "Reshape", + "inputIndexes": post_convert_output, + "outputIndexes": origin_output, + "main_type": "Reshape", + "main": { + "dims": [1, -1, oc], + "dimType": "NCHW" + }, + "defaultDimentionFormat": "NHWC" + } + return [pre_reshape, pre_convert, conv_op, post_convert, post_reshape] + +# some wrapper class for export +class Embedding(torch.nn.Module): + def __init__(self, embed, config): + super().__init__() + self.hidden_size = config.hidden_size + self.embed = embed + if config.model_type == 'gemma2': + normalizer = torch.tensor(self.hidden_size**0.5) + self.embed.weight.data *= normalizer + + def forward(self, input_ids): + inputs_embeds = self.embed(input_ids).view(-1, 1, self.hidden_size) + return inputs_embeds + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +class Attention(torch.nn.Module): + def __init__(self, attn, config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.rotary = config.rotary + ModelMapper.do_map(self, attn, config.model_map['attention']) + if hasattr(self, 'qkv_proj') and self.qkv_proj is not None: + # split qkv linear to q, k, v + split_sizes = [self.hidden_size] * 3 + if self.qkv_proj.weight.shape[0] != self.hidden_size * 3: + # M/GQA + qkv_hidden_size = self.qkv_proj.weight.shape[0] + kv_hidden_size = (qkv_hidden_size - self.hidden_size) // 2 + split_sizes = [self.hidden_size, kv_hidden_size, kv_hidden_size] + self.q_proj = torch.nn.Linear(self.hidden_size, split_sizes[0]) + self.k_proj = torch.nn.Linear(self.hidden_size, split_sizes[1]) + self.v_proj = torch.nn.Linear(self.hidden_size, split_sizes[2]) + if config.model_type == 'chatglm': + # chatglm-6b + qkv_weight = self.qkv_proj.weight.data.view(self.num_heads, 3, self.head_dim, self.hidden_size) + self.q_proj.weight.data = qkv_weight[:, 0, :, :].reshape(self.hidden_size, self.hidden_size) + self.k_proj.weight.data = qkv_weight[:, 1, :, :].reshape(self.hidden_size, self.hidden_size) + self.v_proj.weight.data = qkv_weight[:, 2, :, :].reshape(self.hidden_size, self.hidden_size) + qkv_bias = self.qkv_proj.bias.data.view(self.num_heads, 3, self.head_dim) + self.q_proj.bias.data = qkv_bias[:, 0, :].reshape(self.hidden_size) + self.k_proj.bias.data = qkv_bias[:, 1, :].reshape(self.hidden_size) + self.v_proj.bias.data = qkv_bias[:, 2, :].reshape(self.hidden_size) + else: + # other + qw, kw, vw = torch.split(self.qkv_proj.weight, split_sizes) + self.q_proj.weight.data = qw + self.k_proj.weight.data = kw + self.v_proj.weight.data = vw + if self.qkv_proj.bias is not None: + qb, kb, vb = torch.split(self.qkv_proj.bias, split_sizes) + self.q_proj.bias.data = qb + self.k_proj.bias.data = kb + self.v_proj.bias.data = vb + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + rotary_pos_emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + kv_seq_len = key_states.shape[1] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[1] + + # rope + cos, sin = rotary_pos_emb[0], rotary_pos_emb[1] + query_states = self.rotary.apply_rotary_pos(query_states, cos, sin) + key_states = self.rotary.apply_rotary_pos(key_states, cos, sin) + # kv cache + if past_key_value is not None: + past_key, past_value = past_key_value[0], past_key_value[1] + key_states = torch.cat((past_key, key_states), dim=1) + value_states = torch.cat((past_value, value_states), dim=1) + + past_key_value = torch.stack((key_states, value_states)) + query_states = query_states.transpose(1, 2) + key_states = key_states.permute([0, 2, 3, 1]) + value_states = value_states.transpose(1, 2) + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + #------- attention ---------- + # query_states @ key_states + attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim) + # attention_mask + if attention_mask.dtype in (torch.bool, torch.int32): + # chatglm + attn_weights.masked_fill_(attention_mask, -10000.0) + else: + attn_weights = attn_weights + attention_mask + # upcast softmax to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + # attn_weights @ value_states + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, past_key_value + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +class Rotary(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.rope_theta = config.rope_theta + self.rotary_dim = config.head_dim + self.model_type = config.model_type + if hasattr(config, 'rotary_dim'): + self.rotary_dim = config.rotary_dim + if self.model_type == 'chatglm': + self.rotary_dim = config.head_dim // 2 + + def forward(self, position_ids): + theta = 1.0 / (self.rope_theta ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float32) / self.rotary_dim)) + position_ids = position_ids.float().reshape(-1, 1) + idx_theta = position_ids * theta + rotary_pos_emb = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)]) + if self.model_type != 'chatglm2': + rotary_pos_emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + rotary_pos_emb = rotary_pos_emb.unsqueeze(2).unsqueeze(1) + return rotary_pos_emb + + def apply_rotary_pos(self, x, cos, sin): + if self.model_type == 'chatglm': + return self.chatglm_rotary_pos(x, cos, sin) + if self.model_type == 'chatglm2': + return self.chatglm2_rotary_pos(x, cos, sin) + if self.model_type == 'phi-msft': + return self.phi_rotary_pos(x, cos, sin) + return self.llama_rotary_pos(x, cos, sin) + + def llama_rotary_pos(self, x, cos, sin): + x = (x * cos) + (rotate_half(x) * sin) + return x + + def phi_rotary_pos(self, x, cos, sin): + x, x_pass = x[..., :self.rotary_dim], x[..., self.rotary_dim:] + x = (x * cos) + (rotate_half(x) * sin) + return torch.cat((x, x_pass), dim=-1) + + def chatglm2_rotary_pos(self, x, cos, sin): + x, x_pass = x[..., :self.rotary_dim], x[..., self.rotary_dim:] + b, s, n, h = x.shape + xshaped = x.view(b, s, n, h//2, 2) + x = torch.concat( + [ + xshaped[..., 0] * cos - xshaped[..., 1] * sin, + xshaped[..., 1] * cos + xshaped[..., 0] * sin, + ], + -1, + ) + return torch.cat((x, x_pass), dim=-1) + + def chatglm_rotary_pos(self, x, cos, sin): + seq = x.shape[1] + x1, x2 = x[..., :self.rotary_dim], x[..., self.rotary_dim:] + cos1, sin1 = cos[:, :seq, ...], sin[:, :seq, ...] + cos2, sin2 = cos[:, seq:, ...], sin[:, seq:, ...] + x1 = (x1 * cos1) + (rotate_half(x1) * sin1) + x2 = (x2 * cos2) + (rotate_half(x2) * sin2) + return torch.cat((x1, x2), dim=-1) + +class Decoder(torch.nn.Module): + def __init__(self, decoder, config): + super().__init__() + ModelMapper.do_map(self, decoder, config.model_map['decoder']) + self.hidden_size = config.hidden_size + self.self_attn = Attention(self.self_attn, config) + # chatglm + self.alpha = (2 * config.num_hidden_layers) ** 0.5 if config.model_type == 'chatglm' else 1.0 + + def forward( + self, + hidden_states: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + hidden_states = hidden_states.view(1, -1, self.hidden_size) + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + norm_hidden_states = hidden_states + # Self Attention + hidden_states, present_key_value = self.self_attn( + hidden_states=hidden_states, + rotary_pos_emb=rotary_pos_emb, + attention_mask=attention_mask, + past_key_value=past_key_value, + ) + # Fully Connected + if not hasattr(self, 'post_attention_layernorm'): + # phi + feed_forward_hidden_states = self.mlp(norm_hidden_states) + hidden_states = hidden_states + feed_forward_hidden_states + residual + elif self.alpha != 1.0: + # chatglm-6b + hidden_states = norm_hidden_states * self.alpha + hidden_states + mlp_input = self.post_attention_layernorm(hidden_states) + mlp_output = self.mlp(mlp_input) + hidden_states = mlp_input * self.alpha + mlp_output + elif hasattr(self, 'pre_feedforward_layernorm'): + # gemma2 + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + else: + # general + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, present_key_value + +class Lm(torch.nn.Module): + def __init__(self, lm_, final_layernorm_, config): + super().__init__() + self.final_layernorm = final_layernorm_ + self.lm = lm_ + self.hidden_size = config.hidden_size + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.hidden_size)[-1].view(1, 1, self.hidden_size) + hidden_states = self.final_layernorm(hidden_states) + m_logits = self.lm(hidden_states) + return m_logits + +class LlmExporter(torch.nn.Module): + ''' + Base class for all llm model export. Inherits from [`torch.nn.Module`]. + ''' + + def __init__(self, args): + super().__init__() + self.init_from_args(args) + self.load_model(args.path) + + def init_from_args(self, args): + self.max_length = 1024 + self.stop_ids = [] + self.visual = None + self.dst_name = 'llm' + # load config from args + self.path = args.path + self.dst_path = args.dst_path + self.lora_path = args.lora_path + self.skip_slim = args.skip_slim + self.quant_bit = args.quant_bit + self.quant_block = args.quant_block + self.mnnconvert = args.mnnconvert + if args.lm_quant_bit is not None: + self.lm_quant_bit = args.lm_quant_bit + else: + self.lm_quant_bit = self.quant_bit + # init export dst dir + if not os.path.exists(self.dst_path): + os.makedirs(self.dst_path) + + def load_pretrained(self, model_path: str): + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + try: + self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).float().eval() + except: + self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).float().eval() + self.config = self.model.config + if self.lora_path is not None: + from peft import PeftModel + adapter = PeftModel.from_pretrained(self.model, model_id=self.lora_path) + self.model = adapter.merge_and_unload(progressbar=True) + + @spinner_run(f'load pretrained model ') + def load_model(self, model_path): + self.load_pretrained(model_path) + self.attention_mask_type = 'float' + # load tokenizer info + self.stop_ids.append(self.tokenizer.eos_token_id) + if hasattr(self.tokenizer, 'im_end_id'): + self.stop_ids.append(self.tokenizer.im_end_id) + eot_id = self.tokenizer.encode('<|eot_id|>') + if len(eot_id) == 1: + self.stop_ids.append(eot_id[0]) + if hasattr(self.model, 'generation_config'): + eos_token_id = self.model.generation_config.eos_token_id + from collections.abc import Iterable + if isinstance(eos_token_id, int): + self.stop_ids.append(eos_token_id) + elif isinstance(eos_token_id, Iterable): + for id in eos_token_id: + self.stop_ids.append(id) + self.stop_ids = [stop_id for stop_id in self.stop_ids if stop_id is not None] + self.stop_ids = list(set(self.stop_ids)) + model_mapper = ModelMapper() + + self.model_type, self.model_map = model_mapper.get_map(self.config) + # print(self.model) + # print(self.model_type, self.model_map) + # load config info + ModelMapper.do_map(self, self.config, self.model_map['config']) + if not hasattr(self, 'num_key_value_heads') or self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + if not hasattr(self, 'rope_theta') or self.rope_theta is None: + self.rope_theta = 10000.0 + if not hasattr(self, 'head_dim') or self.head_dim is None: + self.head_dim = self.hidden_size // self.num_attention_heads + # some export info + self.past_kv_shape = [self.num_hidden_layers, 2, 1, 0, self.num_key_value_heads, self.head_dim] + self.block_dynamic_axes = { + "inputs_embeds" : { 0: "seq_len" }, + "attention_mask" : { 2: "seq_len", 3: "seq_len" }, + "position_ids" : { 0: "seq_len" }, + "past_key_values" : { 1: "history_len" } + } + self.model_dynamic_axes = { + "input_ids" : { 0: "seq_len" }, + "attention_mask" : { 2: "seq_len", 3: "seq_len" }, + "position_ids" : { 0: "seq_len" }, + "past_key_values" : { 2: "history_len" } + } + self.llm_config = { + 'hidden_size' : self.hidden_size, + 'layer_nums' : self.num_hidden_layers, + 'attention_mask': self.attention_mask_type, + 'key_value_shape': self.past_kv_shape[1:], + "prompt_template": self.build_prompt('%s'), + 'is_visual': False + } + # load modules + ModelMapper.do_map(self, self.model, self.model_map['model']) + # rebuild modules + if self.embed_.weight is self.lm_.weight: + import copy + embed_copy = copy.deepcopy(self.embed_) + self.embed = Embedding(embed_copy, self) + else: + self.embed = Embedding(self.embed_, self) + # Rotary + self.rotary = Rotary(self) + self.blocks = [] + for block in self.blocks_.children(): + self.blocks.append(Decoder(block, self)) + self.lm = Lm(self.lm_, self.final_layernorm_, self) + # visual model + if self.visual is not None: + self.image_start_id = self.config.visual['image_start_id'] + self.image_size = self.config.visual['image_size'] + self.llm_config['is_visual'] = True + self.llm_config['img_size'] = self.image_size + self.llm_config['imgpad_len'] = 256 + self.llm_config['img_start'] = self.tokenizer.img_start_id + self.llm_config['img_end'] = self.tokenizer.img_end_id + self.llm_config['img_pad'] = self.tokenizer.img_pad_id + return model_path + + def get_attention_mask(self) -> torch.Tensor: + if self.model_type == 'chatglm': + return self.chatglm_attention_mask() + if self.token_len: + return torch.zeros([1, 1, 1, self.seq_len], dtype=torch.float32) + return (1 - torch.tril(torch.ones([1, 1, self.seq_len, self.seq_len]))) * torch.finfo(torch.float32).min + + def get_position_ids(self) -> torch.Tensor: + if self.model_type == 'chatglm': + return self.chatglm_position_ids() + if self.token_len: + return torch.tensor([[self.seq_len - 1]], dtype=torch.long) + return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0) + + def chatglm_attention_mask(self): + if self.token_len: + return torch.zeros([1]).bool().reshape([1, 1, 1, 1]) + attention_mask = torch.zeros([self.seq_len, self.seq_len], dtype=torch.bool) + for i in range(self.seq_len - 1): + attention_mask[i][-1] = True + attention_mask = attention_mask.reshape([1, 1, self.seq_len, self.seq_len]) + return attention_mask + + def chatglm_position_ids(self): + if self.token_len: + return torch.tensor([self.context_len, self.token_len + 1]).reshape([1, 2, 1]) + position_ids_0 = torch.arange(self.seq_len, dtype=torch.long) + position_ids_1 = torch.zeros(self.seq_len, dtype=torch.long) + position_ids_0[-1] = position_ids_0[-2] + position_ids_1[-1] = 1 + position_ids = torch.stack([position_ids_0, position_ids_1]).view(1, 2, -1) + return position_ids + + def visual_embed(self, input_ids): + if not torch.any(input_ids == self.image_start_id): + return self.embed(input_ids) + bos_pos = torch.where(input_ids == self.image_start_id) + eos_pos = torch.where(input_ids == self.image_start_id + 1) + img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1) + images = [] + for i, a, b in img_pos: + image = input_ids[i][a + 1 : b - 1].tolist() + image = image[ : image.index(self.image_start_id + 2)] + images.append(bytes(image).decode('utf-8')) + images = self.visual.encode(images) + hidden_states = self.embed(input_ids).view(1, -1, self.hidden_size) + for idx, (i, a, b) in enumerate(img_pos): + hidden_states[i][a + 1 : b] = images[idx] + return hidden_states.view(-1, 1, self.hidden_size) + + def embedding(self, input_ids): + if self.visual is not None and self.token_len == 0: + input_embeds = self.visual_embed(input_ids) + else: + input_embeds = self.embed(input_ids) + return input_embeds + + def forward(self, input_ids, attention_mask, position_ids, past_key_values): + hidden_states = input_ids # llm forward without embedding + presents = [] + rotary_pos_emb = self.rotary(position_ids) + for i in range(self.num_hidden_layers): + hidden_states, kv = self.blocks[i](hidden_states, rotary_pos_emb, attention_mask, past_key_values[i]) + presents.append(kv) + logits = self.lm(hidden_states).reshape(-1) + presents = torch.stack(presents) + self.seq_len += 1 + self.token_len += 1 + return logits, presents + + # some test functions + def build_prompt(self, query): + # just for test + if 'Qwen2' in self.path: + return f'<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n' + if 'Qwen' in self.path: + return f'\n<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n' + if 'Baichuan2' in self.path: + return f'{query}' + if 'internlm' in self.path: + return f'<|User|>:{query}\n<|Bot|>:' + if 'TinyLlama' in self.path: + return f'<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate\n<|user|>\n{query}\n<|assistant|>\n' + if 'Yi' in self.path: + return f'<|im_start|> user\n{query}<|im_end|>\n<|im_start|> assistant\n' + if 'deepseek' in self.path: + return f'<|begin_of_sentence|>User: {query}\n\nAssistant:' + if 'Llama-3.1' in self.path: + return f'<|start_header_id|>user<|end_header_id|>\n\n{query}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' + if 'Llama-3' in self.path: + return f'<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{query}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' + if 'Llama-2' in self.path: + return f'[INST]{query}[/INST]' + if 'chatglm2' in self.path: + return f'[Round 1]\n\n问:{query}\n\n答:' + if 'chatglm3' in self.path or 'glm-4' in self.path: + return f'<|user|>\n{query}\n<|assistant|>\n' + if 'chatglm' in self.path: + return f'{query}[gMASK]' + if 'phi-2' in self.path: + return f'Instruct: {query}\nOutput:' + if 'gemma-2' in self.path: + return f'user\n{query}\nmodel\n' + return query + + def str_to_ids(self, prompt): + input_ids = self.tokenizer(prompt, return_tensors="pt")['input_ids'] + return input_ids + + def id_to_str(self, token_id): + word = self.tokenizer._convert_id_to_token(int(token_id)) + word = self.tokenizer.convert_tokens_to_string([word]) + return word + + def response(self, query): + self.imitate_quant() + prompt = self.build_prompt(query) + input_ids = self.str_to_ids(prompt) + # print(f'prompt = {prompt}, ids = {input_ids}') + self.seq_len = input_ids.numel() + self.context_len = self.seq_len - 2 + self.token_len = 0 + past_key_values = [None for i in range(self.num_hidden_layers)] + token_id = input_ids + while self.token_len < self.max_length: + attention_mask = self.get_attention_mask() + position_ids = self.get_position_ids() + input_ids = self.embed(token_id) + logits, past_key_values = self.forward(input_ids, attention_mask, position_ids, past_key_values) + token_id = torch.argmax(logits) + if token_id in self.stop_ids: + print("", end='\n') + break + word = self.id_to_str(token_id) + print(word, end="", flush=True) + + def export_visual(self): + if self.visual is None: + return + input_images = torch.randn((1, 3, self.image_size, self.image_size)) + model = self.visual + onnx_model = f'{self.dst_path}/visual.onnx' + torch.onnx.export(model, (input_images), + onnx_model, + input_names=['input_images'], + output_names=['image_embeds'], + dynamic_axes={"input_images": { + 0: "size" + }}, + do_constant_folding=True, + opset_version=15) + return onnx_model + if not self.skip_slim: + slim(onnx_model, output_model=onnx_model) + + @spinner_run(f'export embedding to ') + def export_embed(self): + import ctypes + if hasattr(self, 'word_embeddings'): + # embedding model's embed + tensor_data = self.word_embeddings.weight.data.bfloat16() + else: + tensor_data = self.embed.embed.weight.data.bfloat16() + data_ptr = tensor_data.untyped_storage().data_ptr() + buffer = (ctypes.c_byte * (tensor_data.numel() * 2)).from_address(data_ptr) + embedding_file = f'{self.dst_path}/embeddings_bf16.bin' + with open(embedding_file, 'wb') as f: + f.write(buffer) + return embedding_file + + @spinner_run(f'export config to ') + def export_config(self, mnn_config = False): + config_json = f'{self.dst_path}/llm_config.json' + with open(config_json, 'w', encoding='utf-8') as f: + json.dump(self.llm_config, f, ensure_ascii=False, indent=4) + if not mnn_config: + return config_json + with open(f'{self.dst_path}/config.json', 'w', encoding='utf-8') as f: + config = { + "llm_model": f"{self.dst_name}.mnn", + "llm_weight": f"{self.dst_name}.mnn.weight", + "backend_type": "cpu", + "thread_num": 4, + "precision": "low", + "memory": "low" + } + json.dump(config, f, ensure_ascii=False, indent=4) + return config_json + + def quant(self, weight, quant_bit, quant_block): + weight = weight.numpy() + oc, ic = weight.shape + if quant_block == 0: + block_size = ic + else: + block_size = quant_block + block_num = ic // block_size + weight = weight.reshape(oc, block_num, block_size) + max_val = np.max(weight, axis=-1, keepdims=True) + min_val = np.min(weight, axis=-1, keepdims=True) + offset = 1 << (quant_bit - 1) + clip_max = offset - 1 + clip_min = -offset + scale = (max_val - min_val) / (clip_max - clip_min) + q_weight = np.round((weight - min_val) / scale) + clip_min + q_weight = (np.clip(q_weight.flatten(), clip_min, clip_max) + offset).astype(np.uint8) + q_weight = q_weight.reshape(-1, 2) + if quant_bit == 4: + q_weight = q_weight[:, 0] * 16 + q_weight[:, 1] + alpha = np.stack([min_val.flatten(), scale.flatten()], axis=-1).flatten() + return q_weight, alpha, clip_min + + def imitate_quant(self): + def quant_dequant(linear, quant_bit = self.quant_bit, quant_block = self.quant_block): + weight = linear.weight.data + oc, ic = weight.shape + if quant_block == 0: + block_size = ic + else: + block_size = quant_block + block_num = ic // block_size + weight = weight.reshape(oc, block_num, block_size) + max_val = torch.max(weight, axis=-1, keepdims=True).values + min_val = torch.min(weight, axis=-1, keepdims=True).values + offset = 1 << (quant_bit - 1) + clip_max = offset - 1 + clip_min = -offset + scale = (max_val - min_val) / (clip_max - clip_min) + q_weight = torch.round((weight - min_val) / scale) + clip_min + q_weight = torch.clip(q_weight, clip_min, clip_max) + dq_weight = (q_weight - clip_min) * scale + min_val + dq_weight = dq_weight.reshape(oc, ic).float() + linear.weight.data = dq_weight + return linear + with torch.no_grad(): + for i in range(self.num_hidden_layers): + for name, child in self.blocks[i].self_attn.named_children(): + if isinstance(child, torch.nn.Linear): + setattr(self.blocks[i].self_attn, name, quant_dequant(child)) + for name, child in self.blocks[i].mlp.named_children(): + if isinstance(child, torch.nn.Linear): + setattr(self.blocks[i].mlp, name, quant_dequant(child)) + self.lm.lm = quant_dequant(self.lm.lm) + + def unload_param(self): + self.unloaded_ops = {} + def build_faker(real, name): + faker = FakeLinear(real.in_features, real.out_features, real.bias is not None, name) + self.unloaded_ops[name] = real + return faker + # replace linear with fakelinear to save export memory and time + with torch.no_grad(): + for i in range(self.num_hidden_layers): + for name, child in self.blocks[i].self_attn.named_children(): + if isinstance(child, torch.nn.Linear): + setattr(self.blocks[i].self_attn, name, build_faker(child, f'/layers.{i}/self_attn/{name}/Linear')) + for name, child in self.blocks[i].mlp.named_children(): + if isinstance(child, torch.nn.Linear): + setattr(self.blocks[i].mlp, name, build_faker(child, f'/layers.{i}/mlp/{name}/Linear')) + self.lm.lm = build_faker(self.lm.lm, f'/lm/lm_head/Linear') + + @spinner_run(f'export model weight to ') + def onnx_load_param(self, onnx_path): + return OnnxRebuilder(onnx_path, self.unloaded_ops).rebuild() + + @spinner_run(f'slim the graph of ') + def onnx_slim(self, onnx_model): + import onnxslim + model = onnxslim.slim(onnx_model) + onnx.save(model, onnx_model) + return onnx_model + + @spinner_run(f'export onnx model to ') + def export_onnx(self): + # unload linear weight to save export memory + self.unload_param() + model = self + self.seq_len = 3 + self.token_len = 0 + input_ids = torch.arange(3, dtype=torch.long) + attention_mask = self.get_attention_mask() + position_ids = self.get_position_ids() + past_key_values = torch.zeros(self.past_kv_shape) + onnx_model = f'{self.dst_path}/{self.dst_name}.onnx' + input_ids = self.embedding(input_ids) + # export to onnx + torch.onnx.export( + model, (input_ids, attention_mask, position_ids, past_key_values), + onnx_model, + input_names=[ + 'input_ids', 'attention_mask', 'position_ids', 'past_key_values' + ], + output_names=['logits', 'presents'], + dynamic_axes=self.model_dynamic_axes, + do_constant_folding=True, + opset_version=15) + return onnx_model + + def export(self, export_type): + export_mnn = export_type == 'mnn' + # export tokenizer + self.export_tokenizer() + self.export_config(export_mnn) + self.export_embed() + if self.visual: + self.export_visual() + # export graph to llm.onnx + onnx_model = self.export_onnx() + if not self.skip_slim: + self.onnx_slim(onnx_model) + if export_mnn: + # convert onnx to mnn and quant weight + MNNConveter(onnx_model, self.unloaded_ops, self).export() + else: + # export weight to llm.onnx.data + self.onnx_load_param(onnx_model) + + @spinner_run(f'export tokenizer to ') + def export_tokenizer(self): + # load tokenizer file + tokenizer_model = os.path.join(self.path, 'tokenizer.model') + ice_text_model = os.path.join(self.path, 'ice_text.model') + try: + import sentencepiece as spm + if os.path.exists(tokenizer_model): + self.sp_model = spm.SentencePieceProcessor(tokenizer_model) + elif os.path.exists(ice_text_model): + self.sp_model = spm.SentencePieceProcessor(ice_text_model) + else: + self.sp_model = None + except: + self.sp_model = None + merge_file = os.path.join(self.path, 'merges.txt') + if os.path.exists(merge_file): + self.merge_txt = merge_file + else: + self.merge_txt = None + # TOKENIZER MAGIC NUMBER + MAGIC_NUMBER = 430 + # TOKENIZER TYPE + SENTENCEPIECE = 0; TIKTOIKEN = 1; BERT = 2; HUGGINGFACE = 3 + def write_line(fp, *args): + for arg in args: + for token in arg: + fp.write(str(token) + ' ') + fp.write('\n') + def write_header(fp, type, speicals, prefix = []): + fp.write(f'{MAGIC_NUMBER} {type}\n') + fp.write(f'{len(speicals)} {len(self.stop_ids)} {len(prefix)}\n') + write_line(fp, speicals, self.stop_ids, prefix) + + file_path = os.path.join(self.dst_path, "tokenizer.txt") + special_list = list(self.tokenizer.added_tokens_decoder.keys()) + if hasattr(self.tokenizer, 'special_tokens'): + for k, v in self.tokenizer.special_tokens.items(): + special_list.append(v) + if hasattr(self.tokenizer, 'gmask_token_id'): + special_list.append(self.tokenizer.gmask_token_id) + vocab_list = [] + prefix_list = [] + if hasattr(self.tokenizer, 'get_prefix_tokens'): + prefix_list = self.tokenizer.get_prefix_tokens() + if self.sp_model is not None: + # senetencepiece + NORMAL = 1; UNKNOWN = 2; CONTROL = 3 + USER_DEFINED = 4; UNUSED = 5; BYTE = 6 + for i in range(self.sp_model.GetPieceSize()): + token = self.sp_model.IdToPiece(i) + score = self.sp_model.GetScore(i) + token_type = NORMAL + if self.sp_model.IsUnknown(i): + token_type = UNKNOWN + elif self.sp_model.IsControl(i): + token_type = CONTROL + elif self.sp_model.IsUnused(i): + token_type = UNUSED + elif self.sp_model.IsByte(i): + token_type = BYTE + if self.path == 'Chatglm_6b': + if '' in token: token = '\n' + if '<|tab|>' in token: token = '\t' + if '<|blank_' in token: token = ' ' * int(token[8:token.find('|>')]) + if '▁' in token: token = token.replace('▁', ' ') + token_encode = base64.b64encode(token.encode("utf-8")).decode("utf8") + vocab_list.append(f'{token_encode} {score} {token_type}\n') + with open(file_path, "w", encoding="utf8") as fp: + write_header(fp, SENTENCEPIECE, special_list, prefix_list) + fp.write(f'{len(vocab_list)}\n') + for vocab in vocab_list: + fp.write(vocab) + elif hasattr(self.tokenizer, 'mergeable_ranks'): + # tikton + vocab_list = [] + for k, v in self.tokenizer.mergeable_ranks.items(): + line = base64.b64encode(k).decode("utf8") + "\n" + vocab_list.append(line) + if hasattr(self.tokenizer, 'special_tokens'): + for k, v in self.tokenizer.special_tokens.items(): + line = base64.b64encode(k.encode("utf-8")).decode("utf8") + "\n" + vocab_list.append(line) + if hasattr(self.tokenizer, 'added_tokens_decoder'): + for k, v in self.tokenizer.added_tokens_decoder.items(): + line = base64.b64encode(v.__str__().encode("utf-8")).decode("utf8") + "\n" + vocab_list.append(line) + with open(file_path, "w", encoding="utf8") as fp: + write_header(fp, TIKTOIKEN, special_list, prefix_list) + fp.write(f'{len(vocab_list)}\n') + for vocab in vocab_list: + fp.write(vocab) + elif self.merge_txt is not None: + # huggingface tokenizer + merge_list = [] + vocab = self.tokenizer.get_vocab() + special_list = list(self.tokenizer.added_tokens_decoder.keys()) + vocab_list = ['' for i in range(len(vocab))] + # load vocab + for k, v in vocab.items(): + vocab_list[int(v)] = k + # load merge + with open(self.merge_txt, 'rt') as merge: + for line in merge.readlines(): + merge_list.append(line) + # write to tokenizer.txt + with open(file_path, "w", encoding="utf8") as fp: + write_header(fp, HUGGINGFACE, special_list) + fp.write(f'{len(vocab_list)} {len(merge_list)}\n') + for v in vocab_list: + fp.write(v + '\n') + for m in merge_list: + fp.write(m) + else: + # tiktoken or bert + if 'bert' in type(self.tokenizer).__name__.lower(): + tokenizer_type = BERT + else: + tokenizer_type = TIKTOIKEN + # bert tokenizer + def unicode_to_byte(u: int): + if u >= 256 and u <= 288: + return u - 256 + if u >= 289 and u <= 322: + return u - 162 + if u == 323: + return 173 + if u == 65372: # | + return 124 + if u == 9601: # _ + return 95 + return u + vocab = self.tokenizer.get_vocab() + vocab_list = ['' for i in range(len(vocab))] + for k, v in vocab.items(): + try: + vocab_list[int(v)] = bytes([unicode_to_byte(ord(c)) for c in k]).decode('utf-8', errors='ignore') + except: + vocab_list[int(v)] = k + special_list = list(self.tokenizer.added_tokens_decoder.keys()) + with open(file_path, "w", encoding="utf8") as fp: + write_header(fp, tokenizer_type, special_list) + fp.write(f'{len(vocab_list)}\n') + for v in vocab_list: + line = base64.b64encode(v.encode('utf-8')).decode("utf8") + "\n" + fp.write(line) + return file_path + + +class EmbeddingExporter(LlmExporter): + def __init__(self, args): + super().__init__(args) + self.dst_name = 'embedding' + + def word_embed(self, input_ids): + return self.word_embeddings(input_ids.view(1, -1)) + + def bge_forward(self, inputs_embeds, position_ids, attention_mask): + # bert absolute position + inputs_embeds = inputs_embeds.reshape(1, -1, self.hidden_size) + position_embeddings = self.position_embeddings(position_ids) + embeddings = inputs_embeds + position_embeddings + self.token_type_embeddings + hidden_states = self.embedding_layernorm(embeddings) + for i in range(self.num_hidden_layers): + hidden_states = self.blocks[i](hidden_states, attention_mask)[0] + sentence_embeddings = hidden_states[:, 0] + sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) + return sentence_embeddings + + def gte_forward(self, inputs_embeds, position_ids, attention_mask): + # rope position + inputs_embeds = inputs_embeds.reshape(1, -1, self.hidden_size) + freqs = position_ids.float().reshape(-1, 1) * self.inv_freq + emb = torch.cat((freqs, freqs), dim=-1) + rope_embeds = torch.stack([emb.cos(), emb.sin()]).unsqueeze(-2).unsqueeze(1) + attention_bias = 1 - attention_mask.float() + hidden_states = self.embedding_layernorm(inputs_embeds + self.token_type_embeddings) + for i in range(self.num_hidden_layers): + hidden_states = self.blocks[i](hidden_states, attention_bias, rope_embeds)[0] + sentence_embeddings = hidden_states[:, 0] + sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) + return sentence_embeddings + + def forward(self, inputs_embeds, position_ids, attention_mask): + if self.model_type == 'bert': + return self.bge_forward(inputs_embeds, position_ids, attention_mask) + if self.model_type == 'new': + return self.gte_forward(inputs_embeds, position_ids, attention_mask) + raise RuntimeError(f'Not support embedding model: {self.model_type}!') + + def response(self, query): + self.eval() + input_ids = self.tokenizer(query)['input_ids'] + self.seq_len = len(input_ids) + input_ids = torch.tensor(input_ids) + position_ids = self.get_position_ids() + attention_mask = self.get_attention_mask() + inputs_embeds = self.word_embed(input_ids) + res = self.forward(inputs_embeds, position_ids, attention_mask) + print(res) + return res + + @spinner_run(f'load pretrained model ') + def load_model(self, model_path): + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).float().eval() + self.config = self.model.config + transformer = self.model.encoder + self.model_type = self.config.model_type + self.lm_ = self.model.pooler + self.embed_ = self.model.embeddings + self.word_embeddings = self.embed_.word_embeddings + self.token_type_embeddings = self.embed_.token_type_embeddings.weight.data[0] + self.embedding_layernorm = self.embed_.LayerNorm + if hasattr(self.embed_, 'position_embeddings'): + self.position_embeddings = self.embed_.position_embeddings + self.hidden_size = self.word_embeddings.weight.shape[-1] + self.blocks = transformer.layer + if self.model_type == 'new': + self.inv_freq = self.embed_.rotary_emb.inv_freq + # some wrapper + self.stop_ids = [] + self.num_hidden_layers = len(self.blocks) + self.embed = self.embed_ + self.lm = self.lm_ + # some config for export + self.model_dynamic_axes = { + "input_ids" : { 1: "seq_len" }, + "position_ids" : { 1: "seq_len" }, + "attention_mask" : { 3: "seq_len" } + } + self.attention_mask_type = 'int' + self.llm_config = { + 'hidden_size' : self.hidden_size, + 'layer_nums' : self.num_hidden_layers, + 'attention_mask': self.attention_mask_type, + 'key_value_shape': [], + "prompt_template": self.build_prompt('%s'), + 'is_visual': False + } + return model_path + + @spinner_run(f'export onnx model to ') + def export_onnx(self): + model = self.eval() + self.seq_len = 3 + input_ids = torch.arange(3, dtype=torch.long) + position_ids = self.get_position_ids() + attention_mask = self.get_attention_mask() + inputs_embeds = self.word_embed(input_ids) + onnx_model = f'{self.dst_path}/{self.dst_name}.onnx' + torch.onnx.export( + model, (inputs_embeds, position_ids, attention_mask), + onnx_model, + input_names=[ + 'input_ids', + 'position_ids', + 'attention_mask' + ], + output_names=['sentence_embeddings'], + dynamic_axes=self.model_dynamic_axes, + do_constant_folding=True, + opset_version=15) + return onnx_model + + def export(self, export_type): + export_mnn = 'mnn' in export_type + self.export_tokenizer() + self.export_config(export_mnn) + self.export_embed() + onnx_model = self.export_onnx() + if not self.skip_slim: + self.onnx_slim(onnx_model) + if export_mnn: + MNNConveter(onnx_model, None, self).export() + + def build_prompt(self, query): + if self.model_type == 'bert': + return f'[CLS]{query}[SEP]' + if self.model_type == 'new': + return f' {query}' + + def get_position_ids(self) -> torch.Tensor: + return torch.arange(self.seq_len, dtype=torch.long).unsqueeze(0) + + def get_attention_mask(self) -> torch.Tensor: + return torch.ones([1, 1, 1, self.seq_len], dtype=torch.long) + +def export(path, + type = None, + lora_path = None, + dst_path = './model', + export = 'onnx', + skip_slim = False, + quant_bit = 4, + quant_block = 128, + lm_quant_bit = None): + args = argparse.Namespace() + for k, v in { + 'path': path, + 'type': type, + 'lora_path': lora_path, + 'dst_path': dst_path, + 'export': export, + 'skip_slim': skip_slim, + 'quant_bit': quant_bit, + 'quant_block': quant_block, + 'lm_quant_bit': lm_quant_bit + }.items(): + setattr(args, k, v) + if 'bge' in path: + llm_exporter = EmbeddingExporter(args) + else: + llm_exporter = LlmExporter(args) + # export + llm_exporter.export(export) + +def main(): + parser = argparse.ArgumentParser(description='llm_exporter', formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('--path', type=str, required=True, + help='path(`str` or `os.PathLike`):\nCan be either:' + '\n\t- A string, the *model id* of a pretrained model like `THUDM/chatglm-6b`. [TODO]' + '\n\t- A path to a *directory* clone from repo like `../chatglm-6b`.') + parser.add_argument('--type', type=str, default=None, + help='type(`str`, *optional*):' + '\n\tThe pretrain llm model type.' + ) + parser.add_argument('--lora_path', type=str, default=None, help='lora path, defaut is `None` mean not apply lora.') + parser.add_argument('--dst_path', type=str, default='./model', help='export onnx/mnn model to path, defaut is `./model`.') + parser.add_argument('--test', type=str, help='test model inference with query `TEST`.') + parser.add_argument('--export', type=str, default=None, help='export model to an onnx/mnn model.') + parser.add_argument('--skip_slim', action='store_true', help='Whether or not to skip onnx-slim.') + parser.add_argument('--quant_bit', type=int, default=4, help='mnn quant bit, 4 or 8, default is 4.') + parser.add_argument('--quant_block', type=int, default=128, help='mnn quant block, default is 0 mean channle-wise.') + parser.add_argument('--lm_quant_bit', type=int, default=None, help='mnn lm_head quant bit, 4 or 8, default is `quant_bit`.') + parser.add_argument('--mnnconvert', type=str, default='../../../build/MNNConvert', help='local mnnconvert path, if invalid, using pymnn.') + + args = parser.parse_args() + + model_path = args.path + model_type = args.type + + if 'gte' in model_path or 'bge' in model_path: + llm_exporter = EmbeddingExporter(args) + else: + llm_exporter = LlmExporter(args) + + # some actions + if args.test is not None: + llm_exporter.response(args.test) + + if args.export is not None: + llm_exporter.export(args.export) + +if __name__ == '__main__': + main() \ No newline at end of file