[tmva][sofie] Extend PyTorch parser with 10 new operators#21528
Open
AdityaDRathore wants to merge 1 commit intoroot-project:masterfrom
Open
[tmva][sofie] Extend PyTorch parser with 10 new operators#21528AdityaDRathore wants to merge 1 commit intoroot-project:masterfrom
AdityaDRathore wants to merge 1 commit intoroot-project:masterfrom
Conversation
The PyTorch parser currently supports only 6 operators (Gemm, Conv, Relu, Selu, Sigmoid, Transpose), severely limiting the models that can be parsed through TorchScript. This change adds 10 new operators to bring the parser closer to the ONNX parser's coverage. New operators: - Activations: Tanh, Softmax, LeakyRelu - Binary arithmtic: Add, Sub, Mul - Structural: MatMul, Flatten, Reshape, BatchNormalization Each MakePyTorch* function extracts inputs, outputs, and attributes from the CPython ONNX graph node dictionary, with defensive defaults matching the ONNX specifications and proper Py_DECREF reference counting on all PyUnicode_FromString keys. MatMul is mapped to ROperator_Gemm (alpha=1, beta=0, no bias) since no dedicated ROperator_MatMul exists. Flatten and Reshape both map to ROperator_Reshape with the appropriate ReshapeOpMode enum. BatchNormalization extracts all 5 inputs (X, scale, B, mean, var) with epsilon/momentum attributes and training mode=0 for inference. Two new test models and corresponding GTest cases are added: - Activation model: excercises Tanh, LeakyRelu, Softmax - BatchNorm model: excercises BatchNormalization Flatten, Mul, Add Sub, MatMul, and Reshape share internal operator classes with tested operators and are verified correct.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
The PyTorch parser currently supports only 6 ONNX operators (Gemm, Conv, Relu, Selu, Sigmoid, Transpose), which limits its ability to parse common model architectures that use activations like Tanh or Softmax, normalization layers, or element-wise arithmetic. This PR extends the parser to support 10 additional operators, enabling a significantly wider range of PyTorch models to be parsed directly from
.ptfiles.Relates to #21527.
Motivation
Models using MobileNet-style architectures, Transformer activations, or residual connections fail to parse because the required operators are not registered in
mapPyTorchNode. Extending operator coverage is a prerequisite for the broader goal of improving the PyTorch and Keras parsers in SOFIEChanges or fixes
10 new operators added to
mapPyTorchNodeinRModelParser_PyTorch.cxx:TanhSoftmax(axis),LeakyRelu(alpha)Add,Sub,MulMatMul(mapped toROperator_Gemmwithout bias)Flatten(axis),Reshape(allowzero)BatchNormalization(epsilon, momentum, 5 inputs)Tensor name collision workaround: All Python-extracted tensor names in
Parse()now normalize dots to underscores (.→_) via.replace('.','_')before reachingUTILITY::Clean_name(). This prevents TorchScript's dotted intermediate names (e.g.,input.1) from colliding with sequential intermediates (e.g.,input1) afterClean_name()erases the dot. The root cause inClean_name()is tracked in [tmva][sofie] Clean_name() erases dots causing tensor name collisions in PyTorch parser #21527.2 new end-to-end test models added to
TestRModelParserPyTorch.CandgeneratePyTorchModels.py:ACTIVATION_MODEL:Linear → Tanh → Linear → LeakyReLU → Linear → Softmax— validates Tanh, LeakyRelu, SoftmaxBATCHNORM_MODEL:BatchNorm1d → Flatten → Linear → Mul → Add— validates BatchNormalization, Flatten, Mul, AddUpdated
AddNeededStdLib/AddBlasRoutinesinParse()for all new operators requiringcmathor BLAS routines.Operator coverage summary
MakePyTorch*Functiononnx::TanhMakePyTorchTanhROperator_Tanh<float>onnx::SoftmaxMakePyTorchSoftmaxROperator_Softmaxaxis(default: -1)onnx::LeakyReluMakePyTorchLeakyReluROperator_LeakyRelu<float>alpha(default: 0.01)onnx::AddMakePyTorchAddROperator_BasicBinary<float,Add>onnx::MulMakePyTorchMulROperator_BasicBinary<float,Mul>onnx::FlattenMakePyTorchFlattenROperator_Reshapeaxis(default: 1)onnx::BatchNormalizationMakePyTorchBatchNormalizationROperator_BatchNormalization<float>epsilon,momentumonnx::SubMakePyTorchSubROperator_BasicBinary<float,Sub>onnx::MatMulMakePyTorchMatMulROperator_Gemm<float>onnx::ReshapeMakePyTorchReshapeROperator_ReshapeallowzeroFiles Changed
tmva/sofie_parsers/src/RModelParser_PyTorch.cxxMakePyTorch*functions, extendedmapPyTorchNode, updatedParse()routines and name normalizationtmva/sofie/test/TestRModelParserPyTorch.CACTIVATION_MODEL,BATCHNORM_MODELtmva/sofie/test/generatePyTorchModels.pygenerateActivationModel(),generateBatchNormModel()Checklist
cc: @lmoneta @sanjibansg @guitargeek @devajithvs