Skip to content

Commit f2a38f5

Browse files
narendasanperi044
andauthored
fix(aten::instance_norm): Handle optional inputs in instance norm con… (#3367)
Co-authored-by: Dheeraj Peri <[email protected]>
1 parent 1359e74 commit f2a38f5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+582
-656
lines changed

Diff for: .github/scripts/generate_binary_build_matrix.py

-1
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,6 @@ def generate_wheels_matrix(
469469
ret: List[Dict[str, Any]] = []
470470
for python_version in python_versions:
471471
for arch_version in arches:
472-
473472
# TODO: Enable Python 3.13 support for ROCM
474473
if arch_version in ROCM_ARCHES and python_version == "3.13":
475474
continue

Diff for: core/conversion/converters/impl/batch_norm.cpp

+11-4
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,14 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
134134

135135
auto eps = static_cast<float>(args[7].unwrapToDouble(1e-5f));
136136

137-
auto scales = args[1].unwrapToTensor(at::ones(shape[1], options)).cpu().contiguous();
138-
auto bias = args[2].unwrapToTensor(at::zeros(shape[1], options)).cpu().contiguous();
139-
137+
auto scales = at::ones(shape[1], options);
138+
if (!args[1].IValue()->isNone()) {
139+
scales = args[1].unwrapToTensor(at::ones(shape[1], options)).cpu().contiguous();
140+
}
141+
auto bias = at::zeros(shape[1], options);
142+
if (!args[2].IValue()->isNone()) {
143+
bias = args[2].unwrapToTensor(at::zeros(shape[1], options)).cpu().contiguous();
144+
}
140145
// track_running_stats=True
141146
if (!args[3].IValue()->isNone() || !args[4].IValue()->isNone()) {
142147
auto running_mean = args[3].unwrapToTensor();
@@ -154,6 +159,8 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
154159
return true;
155160
}
156161

162+
// Not sure this actually does something since the cudnn_enabled is from the PyTorch context.
163+
// We need cuDNN either way to run this converter
157164
auto cudnn_enabled = static_cast<bool>(args[8].unwrapToBool(false));
158165
if (!cudnn_enabled) {
159166
LOG_DEBUG(
@@ -162,7 +169,7 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
162169
so for some functionalities, users need to install correct \
163170
cuDNN version by themselves. Please see our support matrix \
164171
here: https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html.");
165-
return false;
172+
// return false;
166173
}
167174

168175
const int relu = 0;

Diff for: core/util/prelude.h

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
// A collection of headers from util that will typically get included in most
44
// files
5+
#include <cstdint>
56
#include "core/util/Exception.h"
67
#include "core/util/build_info.h"
78
#include "core/util/jit_util.h"

Diff for: docs/_downloads/c0341280f3b022df00c4241c42d9ee8b/custom_kernel_plugins.py

-4
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
316316

317317
import cupy as cp # Needed to work around API gaps in PyTorch to build torch.Tensors around preallocated CUDA memory
318318
import numpy as np
319-
320319
import tensorrt as trt
321320

322321

@@ -348,7 +347,6 @@ def get_output_dimensions(
348347
inputs: List[trt.DimsExprs],
349348
exprBuilder: trt.IExprBuilder,
350349
) -> trt.DimsExprs:
351-
352350
output_dims = trt.DimsExprs(inputs[0])
353351

354352
for i in range(np.size(self.pads) // 2):
@@ -404,7 +402,6 @@ def enqueue(
404402
workspace: int,
405403
stream: int,
406404
) -> None:
407-
408405
# Host code is slightly different as this will be run as part of the TRT execution
409406
in_dtype = torchtrt.dtype.try_from(input_desc[0].type).to(np.dtype)
410407

@@ -528,7 +525,6 @@ def circular_padding_converter(
528525
kwargs: Dict[str, Argument],
529526
name: str,
530527
):
531-
532528
# How to retrieve a plugin if it is defined elsewhere (e.g. linked library)
533529
plugin_registry = trt.get_plugin_registry()
534530
plugin_creator = plugin_registry.get_plugin_creator(

Diff for: examples/dynamo/custom_kernel_plugins.py

-4
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
316316

317317
import cupy as cp # Needed to work around API gaps in PyTorch to build torch.Tensors around preallocated CUDA memory
318318
import numpy as np
319-
320319
import tensorrt as trt
321320

322321

@@ -348,7 +347,6 @@ def get_output_dimensions(
348347
inputs: List[trt.DimsExprs],
349348
exprBuilder: trt.IExprBuilder,
350349
) -> trt.DimsExprs:
351-
352350
output_dims = trt.DimsExprs(inputs[0])
353351

354352
for i in range(np.size(self.pads) // 2):
@@ -404,7 +402,6 @@ def enqueue(
404402
workspace: int,
405403
stream: int,
406404
) -> None:
407-
408405
# Host code is slightly different as this will be run as part of the TRT execution
409406
in_dtype = torchtrt.dtype.try_from(input_desc[0].type).to(np.dtype)
410407

@@ -528,7 +525,6 @@ def circular_padding_converter(
528525
kwargs: Dict[str, Argument],
529526
name: str,
530527
):
531-
532528
# How to retrieve a plugin if it is defined elsewhere (e.g. linked library)
533529
plugin_registry = trt.get_plugin_registry()
534530
plugin_creator = plugin_registry.get_plugin_creator(

Diff for: notebooks/CitriNet-example.ipynb

+4-10
Original file line numberDiff line numberDiff line change
@@ -384,12 +384,11 @@
384384
"metadata": {},
385385
"outputs": [],
386386
"source": [
387-
"import nemo\n",
388387
"import torch\n",
389388
"\n",
390389
"import nemo.collections.asr as nemo_asr\n",
391390
"from nemo.core import typecheck\n",
392-
"typecheck.set_typecheck_enabled(False) "
391+
"typecheck.set_typecheck_enabled(False)"
393392
]
394393
},
395394
{
@@ -572,11 +571,8 @@
572571
"from __future__ import absolute_import\n",
573572
"from __future__ import division\n",
574573
"\n",
575-
"import argparse\n",
576574
"import timeit\n",
577575
"import numpy as np\n",
578-
"import torch\n",
579-
"import torch_tensorrt as trtorch\n",
580576
"import torch.backends.cudnn as cudnn\n",
581577
"\n",
582578
"def benchmark(model, input_tensor, num_loops, model_name, batch_size):\n",
@@ -632,7 +628,7 @@
632628
" else:\n",
633629
" model_name = f\"{variant}.ts\"\n",
634630
"\n",
635-
" print(f\"Loading model: {model_name}\") \n",
631+
" print(f\"Loading model: {model_name}\")\n",
636632
" # Load traced model to CPU first\n",
637633
" model = torch.jit.load(model_name).cuda()\n",
638634
" cudnn.benchmark = True\n",
@@ -727,9 +723,7 @@
727723
],
728724
"source": [
729725
"import torch\n",
730-
"import torch.nn as nn\n",
731726
"import torch_tensorrt as torchtrt\n",
732-
"import argparse\n",
733727
"\n",
734728
"variant = \"stt_en_citrinet_256\"\n",
735729
"precisions = [torch.float, torch.half]\n",
@@ -827,7 +821,7 @@
827821
" else:\n",
828822
" model_name = f\"{variant}.ts\"\n",
829823
"\n",
830-
" print(f\"Loading model: {model_name}\") \n",
824+
" print(f\"Loading model: {model_name}\")\n",
831825
" # Load traced model to CPU first\n",
832826
" model = torch.jit.load(model_name).cuda()\n",
833827
" cudnn.benchmark = True\n",
@@ -906,7 +900,7 @@
906900
" else:\n",
907901
" model_name = f\"{variant}.ts\"\n",
908902
"\n",
909-
" print(f\"Loading model: {model_name}\") \n",
903+
" print(f\"Loading model: {model_name}\")\n",
910904
" # Load traced model to CPU first\n",
911905
" model = torch.jit.load(model_name).cuda()\n",
912906
" cudnn.benchmark = True\n",

Diff for: notebooks/EfficientNet-example.ipynb

+9-9
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@
167167
"import torch.backends.cudnn as cudnn\n",
168168
"from timm.data import resolve_data_config\n",
169169
"from timm.data.transforms_factory import create_transform\n",
170-
"import json \n",
170+
"import json\n",
171171
"\n",
172172
"efficientnet_b0_model = timm.create_model('efficientnet_b0',pretrained=True)\n",
173173
"model = efficientnet_b0_model.eval().to(\"cuda\")"
@@ -305,13 +305,13 @@
305305
" transforms.ToTensor(),\n",
306306
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
307307
" ])\n",
308-
" input_tensor = preprocess(img) \n",
308+
" input_tensor = preprocess(img)\n",
309309
" plt.subplot(2,2,i+1)\n",
310310
" plt.imshow(img)\n",
311311
" plt.axis('off')\n",
312312
"\n",
313313
"# loading labels\n",
314-
"with open(\"./data/imagenet_class_index.json\") as json_file: \n",
314+
"with open(\"./data/imagenet_class_index.json\") as json_file:\n",
315315
" d = json.load(json_file)"
316316
]
317317
},
@@ -341,7 +341,7 @@
341341
" preprocess = efficientnet_preprocess()\n",
342342
" input_tensor = preprocess(img)\n",
343343
" input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model\n",
344-
" \n",
344+
"\n",
345345
" # move the input and model to GPU for speed if available\n",
346346
" if torch.cuda.is_available():\n",
347347
" input_batch = input_batch.to('cuda')\n",
@@ -351,7 +351,7 @@
351351
" output = model(input_batch)\n",
352352
" # Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes\n",
353353
" sm_output = torch.nn.functional.softmax(output[0], dim=0)\n",
354-
" \n",
354+
"\n",
355355
" ind = torch.argmax(sm_output)\n",
356356
" return d[str(ind.item())], sm_output[ind] #([predicted class, description], probability)\n",
357357
"\n",
@@ -360,7 +360,7 @@
360360
" input_data = input_data.to(\"cuda\")\n",
361361
" if dtype=='fp16':\n",
362362
" input_data = input_data.half()\n",
363-
" \n",
363+
"\n",
364364
" print(\"Warm up ...\")\n",
365365
" with torch.no_grad():\n",
366366
" for _ in range(nwarmup):\n",
@@ -430,13 +430,13 @@
430430
"for i in range(4):\n",
431431
" img_path = './data/img%d.JPG'%i\n",
432432
" img = Image.open(img_path)\n",
433-
" \n",
433+
"\n",
434434
" pred, prob = predict(img_path, efficientnet_b0_model)\n",
435435
" print('{} - Predicted: {}, Probablility: {}'.format(img_path, pred, prob))\n",
436436
"\n",
437437
" plt.subplot(2,2,i+1)\n",
438-
" plt.imshow(img);\n",
439-
" plt.axis('off');\n",
438+
" plt.imshow(img)\n",
439+
" plt.axis('off')\n",
440440
" plt.title(pred[1])"
441441
]
442442
},

Diff for: notebooks/Hugging-Face-BERT.ipynb

+6-6
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,9 @@
233233
"metadata": {},
234234
"outputs": [],
235235
"source": [
236-
"masked_sentences = ['Paris is the [MASK] of France.', \n",
237-
" 'The primary [MASK] of the United States is English.', \n",
238-
" 'A baseball game consists of at least nine [MASK].', \n",
236+
"masked_sentences = ['Paris is the [MASK] of France.',\n",
237+
" 'The primary [MASK] of the United States is English.',\n",
238+
" 'A baseball game consists of at least nine [MASK].',\n",
239239
" 'Topology is a branch of [MASK] concerned with the properties of geometric objects that remain unchanged under continuous transformations.']\n",
240240
"pos_masks = [4, 3, 9, 6]"
241241
]
@@ -357,7 +357,7 @@
357357
"metadata": {},
358358
"outputs": [],
359359
"source": [
360-
"trt_model = torch_tensorrt.compile(traced_mlm_model, \n",
360+
"trt_model = torch_tensorrt.compile(traced_mlm_model,\n",
361361
" inputs= [torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32), # input_ids\n",
362362
" torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32), # token_type_ids\n",
363363
" torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32)], # attention_mask\n",
@@ -396,7 +396,7 @@
396396
"enc_inputs = enc(masked_sentences, return_tensors='pt', padding='max_length', max_length=128)\n",
397397
"enc_inputs = {k: v.type(torch.int32).cuda() for k, v in enc_inputs.items()}\n",
398398
"output_trt = trt_model(enc_inputs['input_ids'], enc_inputs['token_type_ids'], enc_inputs['attention_mask'])\n",
399-
"most_likely_token_ids_trt = [torch.argmax(output_trt[i, pos, :]) for i, pos in enumerate(pos_masks)] \n",
399+
"most_likely_token_ids_trt = [torch.argmax(output_trt[i, pos, :]) for i, pos in enumerate(pos_masks)]\n",
400400
"unmasked_tokens_trt = enc.decode(most_likely_token_ids_trt).split(' ')\n",
401401
"unmasked_sentences_trt = [masked_sentences[i].replace('[MASK]', token) for i, token in enumerate(unmasked_tokens_trt)]\n",
402402
"for sentence in unmasked_sentences_trt:\n",
@@ -418,7 +418,7 @@
418418
"metadata": {},
419419
"outputs": [],
420420
"source": [
421-
"trt_model_fp16 = torch_tensorrt.compile(traced_mlm_model, \n",
421+
"trt_model_fp16 = torch_tensorrt.compile(traced_mlm_model,\n",
422422
" inputs= [torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32), # input_ids\n",
423423
" torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32), # token_type_ids\n",
424424
" torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32)], # attention_mask\n",

Diff for: notebooks/Resnet50-CPP.ipynb

-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070
"outputs": [],
7171
"source": [
7272
"import torch\n",
73-
"import torchvision\n",
7473
"\n",
7574
"torch.hub._validate_not_a_forked_repo=lambda a,b,c: True\n",
7675
"\n",

Diff for: notebooks/Resnet50-example.ipynb

+10-11
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,6 @@
428428
],
429429
"source": [
430430
"import torch\n",
431-
"import torchvision\n",
432431
"\n",
433432
"torch.hub._validate_not_a_forked_repo=lambda a,b,c: True\n",
434433
"\n",
@@ -558,7 +557,7 @@
558557
"from PIL import Image\n",
559558
"from torchvision import transforms\n",
560559
"import matplotlib.pyplot as plt\n",
561-
"import json \n",
560+
"import json\n",
562561
"\n",
563562
"fig, axes = plt.subplots(nrows=2, ncols=2)\n",
564563
"\n",
@@ -571,13 +570,13 @@
571570
" transforms.ToTensor(),\n",
572571
" transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
573572
" ])\n",
574-
" input_tensor = preprocess(img) \n",
573+
" input_tensor = preprocess(img)\n",
575574
" plt.subplot(2,2,i+1)\n",
576575
" plt.imshow(img)\n",
577576
" plt.axis('off')\n",
578577
"\n",
579-
"# loading labels \n",
580-
"with open(\"./data/imagenet_class_index.json\") as json_file: \n",
578+
"# loading labels\n",
579+
"with open(\"./data/imagenet_class_index.json\") as json_file:\n",
581580
" d = json.load(json_file)"
582581
]
583582
},
@@ -614,7 +613,7 @@
614613
" preprocess = rn50_preprocess()\n",
615614
" input_tensor = preprocess(img)\n",
616615
" input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model\n",
617-
" \n",
616+
"\n",
618617
" # move the input and model to GPU for speed if available\n",
619618
" if torch.cuda.is_available():\n",
620619
" input_batch = input_batch.to('cuda')\n",
@@ -624,7 +623,7 @@
624623
" output = model(input_batch)\n",
625624
" # Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes\n",
626625
" sm_output = torch.nn.functional.softmax(output[0], dim=0)\n",
627-
" \n",
626+
"\n",
628627
" ind = torch.argmax(sm_output)\n",
629628
" return d[str(ind.item())], sm_output[ind] #([predicted class, description], probability)\n",
630629
"\n",
@@ -633,7 +632,7 @@
633632
" input_data = input_data.to(\"cuda\")\n",
634633
" if dtype=='fp16':\n",
635634
" input_data = input_data.half()\n",
636-
" \n",
635+
"\n",
637636
" print(\"Warm up ...\")\n",
638637
" with torch.no_grad():\n",
639638
" for _ in range(nwarmup):\n",
@@ -695,13 +694,13 @@
695694
"for i in range(4):\n",
696695
" img_path = './data/img%d.JPG'%i\n",
697696
" img = Image.open(img_path)\n",
698-
" \n",
697+
"\n",
699698
" pred, prob = predict(img_path, resnet50_model)\n",
700699
" print('{} - Predicted: {}, Probablility: {}'.format(img_path, pred, prob))\n",
701700
"\n",
702701
" plt.subplot(2,2,i+1)\n",
703-
" plt.imshow(img);\n",
704-
" plt.axis('off');\n",
702+
" plt.imshow(img)\n",
703+
" plt.axis('off')\n",
705704
" plt.title(pred[1])"
706705
]
707706
},

0 commit comments

Comments
 (0)