-
Notifications
You must be signed in to change notification settings - Fork 92
/
Copy pathtorchex.py
2220 lines (1782 loc) · 89.3 KB
/
torchex.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import operator
import importlib
from dataclasses import replace
from contextlib import ContextDecorator
from functools import wraps, partial
from inspect import signature
from itertools import groupby
from numbers import Number
from typing import TYPE_CHECKING
from collections.abc import Callable
from collections.abc import Hashable, Sequence
from collections.abc import Sequence
from types import ModuleType
from enum import Enum, auto
import torch
import math
from looseversion import LooseVersion
import thunder.core.dtypes as dtypes
from thunder.core.dtypes import to_torch_dtype, to_dtype
import thunder.core.devices as devices
from thunder.core.devices import to_torch_device, to_device
import thunder.core.prims as prims
from thunder.core.trace import TraceCtx, set_tracectx, reset_tracectx, from_trace
from thunder.core.proxies import NumberProxy, TensorProxy, FutureTensorProxy, variableify, pytype
from thunder.core.pytree import tree_flatten, tree_unflatten
from thunder.core.symbol import Symbol, BoundSymbol
from thunder.distributed.prims import DistributedReduceOps
import thunder.distributed.prims as dist_prims
import thunder.core.utils as utils
import thunder.torch as ltorch
from thunder.extend import OperatorExecutor, register_executor, add_always_executor
from thunder.core.transforms import (
get_grad,
put_grad,
)
if TYPE_CHECKING:
from thunder.common import CompileData
ex = OperatorExecutor("torch", version=torch.__version__)
register_executor(ex)
add_always_executor(ex)
# Common annotations
TensorLike = TensorProxy
FutureTensorLike = FutureTensorProxy
DeviceLike = str | devices.Device | torch.device
dtypeLike = dtypes.dtype | torch.dtype
#
# Helper functions
#
def _always_executable(*args, **kwargs) -> bool:
return True
def _register_torch_operation(name: str, *, like: None | Symbol = None, module: type | ModuleType = torch) -> Symbol:
like: Symbol = like if like is not None else getattr(ltorch, name)
return ex.register_operator(name, like=like, module=module)
def _register_implementation(
id_or_symbol: Hashable | Symbol,
op: None | Symbol = None,
*,
checker: Callable,
execution_transform: Callable = None,
):
ex.register_implementation(id_or_symbol, op, checker=checker, execution_transform=execution_transform)
#
# Data movement operations
#
to = _register_torch_operation("to", module=torch.Tensor)
def _convert_element_type_prim_checker(a: Number | TensorProxy, dtype: dtypes.dtype) -> bool:
return isinstance(a, TensorProxy)
# NOTE The convert element type primitive is (currently) modeled as always creating a copy
def _convert_element_type_transform(
a: TensorLike,
/,
dtype: dtypes.dtype,
) -> TensorLike:
torch_dtype: torch.dtype = to_torch_dtype(dtype)
return to(a, torch_dtype, copy=True)
def _to_transform(
a: TensorLike,
tensor_dtype_or_device: None | TensorLike | dtypeLike | DeviceLike = None,
optional_positional_dtype: None | dtypeLike = None,
/,
*,
device: None | DeviceLike = None,
dtype: None | dtypeLike = None,
copy: bool = False,
memory_format: None | torch.memory_format = None,
) -> TensorLike:
device: None | devices.Device
dtype: None | dtypes.dtype
device, dtype = ltorch._parse_to_device_and_dtype(
tensor_dtype_or_device, optional_positional_dtype, device=device, dtype=dtype
)
torch_device: None | torch.device = to_torch_device(device)
torch_dtype: None | torch.dtype = to_torch_dtype(dtype)
kwargs = {"copy": copy}
if torch_device is not None:
kwargs["device"] = torch_device
if torch_dtype is not None:
kwargs["dtype"] = torch_dtype
if memory_format is not None:
kwargs["memory_format"] = memory_format
return to(a, **kwargs)
def _device_put_transform(a: TensorProxy, device: devices.Device) -> TensorProxy:
torch_device: str = device.device_str()
return to(a, torch_device)
_register_implementation(prims.device_put, checker=_always_executable, execution_transform=_device_put_transform)
_register_implementation(
prims.convert_element_type,
checker=_convert_element_type_prim_checker,
execution_transform=_convert_element_type_transform,
)
_register_implementation(ltorch.to, checker=_always_executable, execution_transform=_to_transform)
#
# Disable torch.autocast operations
#
def no_autocast(fn):
"""
A decorator that disables torch.autocast for the duration of the decorated
function.
In Thunder this is useful when you want to ensure that the generated
function is not run with PyTorch's autocast enabled to execute exactly as
generated.
Args:
fn: The function to decorate.
Returns:
The decorated function.
"""
# This decorator intentionally does not use the torch.autocast decorator
# because it is much slower than the implementation here. This is because
# the torch.autocast decorator has a lot more overhead to support various
# features that are not needed in Thunder.
from torch import set_autocast_enabled
prev_cpu = torch.is_autocast_cpu_enabled()
prev = torch.is_autocast_enabled()
@wraps(fn)
def no_autocast_fn(*args, **kwargs):
try:
set_autocast_enabled("cpu", False)
set_autocast_enabled("cuda", False)
return fn(*args, **kwargs)
finally:
set_autocast_enabled("cpu", prev_cpu)
set_autocast_enabled("cuda", prev)
return no_autocast_fn
#
# Tensor creation operations
#
arange = _register_torch_operation("arange")
full = _register_torch_operation("full")
full_like = _register_torch_operation("full_like")
ones = _register_torch_operation("ones")
ones_like = _register_torch_operation("ones_like")
tensor_from_sequence = _register_torch_operation("tensor")
zeros = _register_torch_operation("zeros")
zeros_like = _register_torch_operation("zeros_like")
randn = _register_torch_operation("randn")
empty = _register_torch_operation("empty")
einsum = _register_torch_operation("einsum")
clone = _register_torch_operation("clone")
def _uniform_philox_like(
shape: Sequence[int],
*,
stride: None = None,
device: DeviceLike,
dtype: dtypeLike,
seed: TensorProxy,
offset: TensorProxy,
) -> tuple[TensorLike, TensorLike]:
random_values = ltorch.uniform_philox(shape, 0.0, 1.0, device=device, dtype=dtype, seed=seed, offset=offset)
offset: TensorProxy = TensorProxy(shape=(), device=devices.cpu, dtype=dtypes.int64)
return random_values, offset
uniform_philox = _register_torch_operation("ops.rngprims.philox_rand", like=_uniform_philox_like)
# NOTE We define a custom PyTorch uniform operation, because PyTorch has no out-of-place uniform
def _uniform(
shape: Sequence[int], minval: Number, maxval: Number, *, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
t: torch.Tensor = torch.empty(shape, device=device, dtype=dtype)
t.uniform_(minval, maxval)
return t
def _uniform_meta(
shape: Sequence[int], minval: Number, maxval: Number, *, device: torch.device, dtype: torch.dtype
) -> TensorProxy:
thunder_device = to_device(device)
thunder_dtype = to_dtype(dtype)
return TensorProxy(shape=shape, device=thunder_device, dtype=thunder_dtype, requires_grad=False)
uniform = ex.register_operator("uniform", meta=_uniform_meta, fn=_uniform)
def _arange_transform(
start: Number,
end: None | Number = None,
step: Number = 1,
*,
device: None | DeviceLike = None,
dtype: None | dtypeLike = None,
):
torch_device: None | torch.device = to_torch_device(device)
torch_dtype: None | torch.dtype = to_torch_dtype(dtype)
if end is None:
end = start
start = 0
return arange(start=start, step=step, end=end, device=torch_device, dtype=torch_dtype)
# TODO Remove or restore exogenous_like
# def _exogenous_like_helper(likes: Sequence[torch.Tensor], /) -> tuple[torch.Tensor, ...]:
# return tuple([torch.zeros_like(x) for x in likes])
# def exogenous_like(bsym: BoundSymbol, likes: Sequence[TensorProxy], /) -> BoundSymbol:
# sym = Symbol(name="exogenous_like", meta=None)
# ctx: dict[str, Any] = {"exogenous_like": _exogenous_like_helper}
# return sym.bind(likes, output=bsym.output, _call_ctx=ctx)
def _full_transform(
shape: Sequence[int], fill_value: Number, *, device: None | devices.Device, dtype: None | dtypes.dtype
) -> TensorProxy:
torch_device: None | torch.device = to_torch_device(device)
torch_dtype: None | torch.dtype = to_torch_dtype(dtype)
return full(shape, fill_value, device=torch_device, dtype=torch_dtype)
def _full_like_transform(
a: TensorLike, /, fill_value: Number, *, device: None | DeviceLike = None, dtype: None | dtypeLike = None
) -> TensorLike:
torch_device: None | torch.device = to_torch_device(device)
torch_dtype: None | torch.dtype = to_torch_dtype(dtype)
return full_like(a, fill_value=fill_value, device=torch_device, dtype=torch_dtype)
def _ones_transform(*shape: int, device: None | DeviceLike = None, dtype: None | dtypeLike = None) -> TensorLike:
torch_device: None | torch.device = to_torch_device(device)
torch_dtype: None | torch.dtype = to_torch_dtype(dtype)
return ones(*shape, device=torch_device, dtype=torch_dtype)
def _ones_like_transform(
a: TensorLike, /, *, device: None | DeviceLike = None, dtype: None | dtypeLike = None
) -> TensorLike:
torch_device: None | torch.device = to_torch_device(device)
torch_dtype: None | torch.dtype = to_torch_dtype(dtype)
return ones_like(a, device=torch_device, dtype=torch_dtype)
def _iota_transform(
length: Number, *, start: Number, step: Number, device: devices.Device, dtype: dtypes.dtype
) -> TensorLike:
torch_device: torch.device = to_torch_device(device)
torch_dtype: torch.dtype = to_torch_dtype(dtype)
end: Number = start + length * step
return arange(start=start, step=step, end=end, device=torch_device, dtype=torch_dtype)
def _uniform_transform(
shape: Sequence[int],
minval: Number = 0.0,
maxval: Number = 1.0,
*,
device: DeviceLike,
dtype: dtypeLike,
) -> TensorLike:
torch_device: torch.device = to_torch_device(device)
torch_dtype: torch.dtype = to_torch_dtype(dtype)
return uniform(shape, minval, maxval, device=torch_device, dtype=torch_dtype)
# NOTE minval == 0. and maxval == 1. due to the checker
def _uniform_philox_prim_transform(
shape: Sequence[int],
minval: float,
maxval: float,
*,
device: devices.Device,
dtype: dtypes.dtype,
seed: int | TensorProxy,
offset: int | TensorProxy,
) -> TensorLike:
torch_device = to_torch_device(device)
torch_dtype = to_torch_dtype(dtype)
seed_tensor: TensorLike = ltorch.tensor(seed) if isinstance(seed, int) else seed
offset_tensor: TensorLike = ltorch.tensor(offset) if isinstance(offset, int) else offset
random_values, offset = uniform_philox(
shape, stride=None, seed=seed_tensor, offset=offset_tensor, device=torch_device, dtype=torch_dtype
)
return random_values
# TODO Consider restricting to seed and offset being tensors, too?
def _uniform_philox_prim_checker(
shape: Sequence[int],
minval: float,
maxval: float,
*,
device: devices.Device,
dtype: dtypes.dtype,
seed: int | TensorProxy,
offset: int | TensorProxy,
) -> bool:
if minval != 0 or maxval != 1:
return False
if offset % 4 != 0:
return False
if device.devicetype != devices.DeviceType.CUDA:
return False
return True
# NOTE minval == 0. and maxval == 1. due to the checker
def _uniform_philox_transform(
shape: Sequence[int],
minval: Number = 0.0,
maxval: Number = 1.0,
*,
device: DeviceLike,
dtype: dtypeLike,
seed: int | TensorProxy,
offset: int | TensorProxy,
) -> TensorLike:
torch_device = to_torch_device(device)
torch_dtype = to_torch_dtype(dtype)
seed_tensor: TensorLike = ltorch.tensor(seed) if isinstance(seed, int) else seed
offset_tensor: TensorLike = ltorch.tensor(offset) if isinstance(offset, int) else offset
random_values, offset = uniform_philox(
shape, stride=None, seed=seed_tensor, offset=offset_tensor, device=torch_device, dtype=torch_dtype
)
return random_values
# TODO -- How can we validate that the tensor has the appropriate offset?
def _uniform_philox_checker(
shape: Sequence[int],
minval: Number = 0.0,
maxval: Number = 1.0,
*,
device: DeviceLike,
dtype: dtypeLike,
seed: int | TensorProxy,
offset: int | TensorProxy,
) -> bool:
if minval != 0 or maxval != 1:
return False
if isinstance(offset, TensorProxy) or offset % 4 != 0:
return False
if to_device(device).devicetype != devices.DeviceType.CUDA:
return False
return True
def _zeros_transform(*shape: int, device: None | DeviceLike = None, dtype: None | dtypeLike = None) -> TensorLike:
torch_device: None | torch.device = to_torch_device(device)
torch_dtype: None | torch.dtype = to_torch_dtype(dtype)
return zeros(*shape, device=torch_device, dtype=torch_dtype)
def _zeros_like_transform(
a: TensorLike, /, *, device: None | DeviceLike = None, dtype: None | dtypeLike = None
) -> TensorLike:
torch_device: None | torch.device = to_torch_device(device)
torch_dtype: None | torch.dtype = to_torch_dtype(dtype)
return zeros_like(a, device=torch_device, dtype=torch_dtype)
def _randn_prims_transform(
shape: tuple[int, ...],
*,
device: devices.Device,
dtype: dtypes.dtype,
) -> TensorLike:
torch_device: torch.device = to_torch_device(device)
torch_dtype: torch.dtype = to_torch_dtype(dtype)
return randn(shape, device=torch_device, dtype=torch_dtype)
def _empty_prims_transform(
shape: tuple[int, ...],
*,
device: devices.Device,
dtype: dtypes.dtype,
) -> TensorLike:
torch_device: torch.device = to_torch_device(device)
torch_dtype: torch.dtype = to_torch_dtype(dtype)
return empty(shape, device=torch_device, dtype=torch_dtype)
def _clone_prims_transform(a: TensorLike, **kwargs) -> TensorLike:
return clone(a)
def _tensor_from_sequence_prims_transform(
seq_or_number, *, device: devices.Device, dtype: None | dtypes.dtype
) -> TensorLike:
torch_device: torch.device = to_torch_device(device)
torch_dtype: torch.dtype = to_torch_dtype(dtype)
return tensor_from_sequence(seq_or_number, device=torch_device, dtype=torch_dtype)
def _get_and_update_rng_state_impl(seed, offset, device):
state = torch.cuda.get_rng_state(device)
seed, offset = torch.chunk(state, 2)
# We follow the nvFuser way here. The offset used by nvfuser = pytorch_offset // 4
# See Note [Divide offset by 4] https://github.com/NVIDIA/Fuser/blob/729f36c/csrc/rng.cpp#L54
seed = seed.view(torch.int64).item()
offset = offset.view(torch.int64).item() // 4
# We follow the nvFuser way here. pytorch_new_offset = (nvfuser_offset + 1) * 4
# See Note [Divide offset by 4] https://github.com/NVIDIA/Fuser/blob/729f36c/csrc/rng.cpp#L54
new_offset = (offset + 1) * 4
seed_portion = torch.tensor([seed], device="cpu").view(torch.uint8)
offset_portion = torch.tensor([new_offset], device="cpu").view(torch.uint8)
new_state = torch.cat([seed_portion, offset_portion])
torch.cuda.set_rng_state(new_state, device)
return seed, offset
get_and_update_rng_state_impl = ex.register_operator(
"get_and_update_rng_state_impl",
meta=prims.get_and_update_rng_state.meta,
fn=_get_and_update_rng_state_impl,
)
_register_implementation(prims.full, checker=_always_executable, execution_transform=_full_transform)
_register_implementation(prims.iota, checker=_always_executable, execution_transform=_iota_transform)
_register_implementation(prims.uniform, checker=_always_executable, execution_transform=_uniform_transform)
_register_implementation(
prims.uniform_philox, checker=_uniform_philox_prim_checker, execution_transform=_uniform_philox_prim_transform
)
_register_implementation(prims.get_and_update_rng_state, get_and_update_rng_state_impl, checker=_always_executable)
_register_implementation(prims.randn, checker=_always_executable, execution_transform=_randn_prims_transform)
_register_implementation(prims.empty, checker=_always_executable, execution_transform=_empty_prims_transform)
_register_implementation(prims.clone, checker=_always_executable, execution_transform=_clone_prims_transform)
_register_implementation(
prims.tensor_from_sequence, checker=_always_executable, execution_transform=_tensor_from_sequence_prims_transform
)
_register_implementation(ltorch.arange, checker=_always_executable, execution_transform=_arange_transform)
_register_implementation(ltorch.full, checker=_always_executable, execution_transform=_full_transform)
_register_implementation(ltorch.full_like, checker=_always_executable, execution_transform=_full_like_transform)
_register_implementation(ltorch.ones, checker=_always_executable, execution_transform=_ones_transform)
_register_implementation(ltorch.ones_like, checker=_always_executable, execution_transform=_ones_like_transform)
_register_implementation(ltorch.uniform, checker=_always_executable, execution_transform=_uniform_transform)
_register_implementation(
ltorch.uniform_philox, checker=_uniform_philox_checker, execution_transform=_uniform_philox_transform
)
_register_implementation(ltorch.zeros, checker=_always_executable, execution_transform=_zeros_transform)
_register_implementation(ltorch.zeros_like, checker=_always_executable, execution_transform=_zeros_like_transform)
#
# Reshaping and permuting operations
#
cat = _register_torch_operation("cat")
chunk = _register_torch_operation("chunk")
diagonal = _register_torch_operation("diagonal")
expand = _register_torch_operation("expand", module=torch.Tensor)
flatten = _register_torch_operation("flatten")
flip = _register_torch_operation("flip")
getitem = _register_torch_operation("__getitem__", like=ltorch.getitem, module=torch.Tensor)
movedim = _register_torch_operation("movedim")
permute = _register_torch_operation("permute")
repeat = _register_torch_operation("repeat", module=torch.Tensor)
reshape = _register_torch_operation("reshape")
select = _register_torch_operation("select")
split = _register_torch_operation("split")
stack = _register_torch_operation("stack")
squeeze = _register_torch_operation("squeeze")
tensor_split = _register_torch_operation("tensor_split")
transpose = _register_torch_operation("transpose")
unbind = _register_torch_operation("unbind")
unfold = _register_torch_operation("unfold", module=torch.Tensor)
unsqueeze = _register_torch_operation("unsqueeze")
view = _register_torch_operation("view", module=torch.Tensor)
view_as = _register_torch_operation("view_as", module=torch.Tensor)
all_tensor = _register_torch_operation("all", like=ltorch.all_tensor)
any_tensor = _register_torch_operation("any", like=ltorch.any_tensor)
def _broadcast_in_dim_prim_transform(
a: TensorProxy, /, shape: Sequence[int], broadcast_dimensions: Sequence[int]
) -> TensorProxy:
s = list(shape)
for broadcast_dim in broadcast_dimensions:
s[broadcast_dim] = -1
v = a
for idx, x in enumerate(s):
if x != -1:
v = unsqueeze(v, idx)
return expand(v, shape)
def _flip_transform(a: TensorLike, /, *dims: int) -> TensorLike:
dims = utils.extract_shape_from_varargs(dims)
return flip(a, dims)
def _permute_transform(a: TensorLike, /, *dims: int) -> TensorLike:
dims = utils.extract_shape_from_varargs(dims)
return permute(a, dims)
# NOTE The transpose prim is analogous to PyTorch's permute operation, and the argument names do not match
def _transpose_prim_transform(a: TensorProxy, /, permutation: Sequence[int]) -> TensorLike:
return permute(a, permutation)
def _reshape_transform(a: TensorLike, /, *dims: int) -> TensorLike:
dims = utils.extract_shape_from_varargs(dims)
return reshape(a, dims)
# TODO When getitem is fully supported this can be changed to be an execution transform instead of a direct impl
def _slice_prim_impl(
a: torch.Tensor, start_indices: Sequence[int], end_indices: Sequence[int], strides: None | Sequence[int] = None
) -> torch.Tensor:
_strides = strides if strides is not None else [1] * len(start_indices)
slices: list = []
for start, stop, step in zip(start_indices, end_indices, _strides):
slices.append(slice(start, stop, step))
return operator.getitem(a, slices)
# NOTE PyTorch has a bug where it doesn't interpret calls like squeeze(a, None) correctly
def _squeeze_transform(a: TensorLike, /, dim: None | int | Sequence[int] = None) -> TensorLike:
if dim is None:
return squeeze(a)
return squeeze(a, dim)
_register_implementation(
prims.broadcast_in_dim, checker=_always_executable, execution_transform=_broadcast_in_dim_prim_transform
)
_register_implementation(prims.cat, cat, checker=_always_executable)
_register_implementation(prims.flip, flip, checker=_always_executable)
# NOTE - `ltorch.reshape` short circuits when new shape is same as original shape and returns the input proxy as output.
# `prims.reshape` doesn't do that and returns a new proxy. So we add `torch_prims_reshape_impl` which is consistent
# with `prims.reshape` semantics otherwise this can lead incorrectness.
torch_prims_reshape_impl = ex.register_operator("torch_prims_reshape_impl", meta=prims.reshape.meta, fn=torch.reshape)
_register_implementation(prims.reshape, torch_prims_reshape_impl, checker=_always_executable)
slice_prim_impl = ex.register_operator("torch_slice_prim_impl", meta=prims.slice_prim.meta, fn=_slice_prim_impl)
_register_implementation(prims.slice_prim, slice_prim_impl, checker=_always_executable)
_register_implementation(prims.squeeze, checker=_always_executable, execution_transform=_squeeze_transform)
_register_implementation(prims.transpose, checker=_always_executable, execution_transform=_transpose_prim_transform)
_register_implementation(prims.unfold, unfold, checker=_always_executable)
_register_implementation(prims.view, view, checker=_always_executable)
_register_implementation(ltorch.cat, cat, checker=_always_executable)
_register_implementation(ltorch.chunk, chunk, checker=_always_executable)
_register_implementation(ltorch.diagonal, diagonal, checker=_always_executable)
_register_implementation(ltorch.expand, expand, checker=_always_executable)
_register_implementation(ltorch.flatten, flatten, checker=_always_executable)
_register_implementation(ltorch.flip, checker=_always_executable, execution_transform=_flip_transform)
_register_implementation(ltorch.getitem, getitem, checker=_always_executable)
_register_implementation(ltorch.movedim, movedim, checker=_always_executable)
_register_implementation(ltorch.permute, checker=_always_executable, execution_transform=_permute_transform)
_register_implementation(ltorch.repeat, repeat, checker=_always_executable)
_register_implementation(ltorch.reshape, checker=_always_executable, execution_transform=_reshape_transform)
_register_implementation(ltorch.select, select, checker=_always_executable)
_register_implementation(ltorch.split, split, checker=_always_executable)
_register_implementation(ltorch.stack, stack, checker=_always_executable)
_register_implementation(ltorch.squeeze, checker=_always_executable, execution_transform=_squeeze_transform)
_register_implementation(ltorch.tensor_split, tensor_split, checker=_always_executable)
_register_implementation(ltorch.transpose, transpose, checker=_always_executable)
_register_implementation(ltorch.unbind, unbind, checker=_always_executable)
_register_implementation(ltorch.unfold, unfold, checker=_always_executable)
_register_implementation(ltorch.unsqueeze, unsqueeze, checker=_always_executable)
_register_implementation(ltorch.view, view, checker=_always_executable)
_register_implementation(ltorch.view_as, view_as, checker=_always_executable)
_register_implementation(ltorch.all_tensor, all_tensor, checker=_always_executable)
_register_implementation(ltorch.any_tensor, any_tensor, checker=_always_executable)
#
# Memory format operations
#
contiguous = _register_torch_operation("contiguous", module=torch.Tensor)
# TODO Detect if the tensor is already contiguous as requested, and if so just return it
# TODO Review how strides are set if the tensor contains no elements
def _stride_order_prim_impl(a: torch.Tensor, order: Sequence[int]) -> torch.Tensor:
# Canonicalizes permutation as a tuple so it can be compared to the channels_last special cases below
order = tuple(order)
# Special cases channels_last and channels_last_3d cases
if order == (3, 0, 2, 1):
return a.contiguous(memory_format=torch.channels_last)
elif order == (4, 0, 3, 2, 1):
return a.contiguous(memory_format=torch.channels_last_3d)
# Creates a tensor with the appropriate shape and strides, then copies the input
# tensor into it
ordered_dims = sorted(zip(a.shape, order), key=lambda x: x[1])
ordered_strides = [1]
accum = ordered_dims[0][0]
for dim_length, _ in ordered_dims[1:]:
ordered_strides.append(accum)
accum *= dim_length
strides = tuple(ordered_strides[x] for x in order)
return torch.empty_strided(a.shape, strides, device=a.device, dtype=a.dtype).copy_(a)
stride_order_prim_impl = ex.register_operator(
"torch_stride_order_prim_impl", meta=prims.stride_order.meta, fn=_stride_order_prim_impl
)
_register_implementation(prims.stride_order, stride_order_prim_impl, checker=_always_executable)
_register_implementation(ltorch.contiguous, contiguous, checker=_always_executable)
#
# Elementwise unary operations
#
# NOTE torch_abs to avoid a conflict with Python's builtin abs()
torch_abs = _register_torch_operation("abs")
acos = _register_torch_operation("acos")
acosh = _register_torch_operation("acosh")
asin = _register_torch_operation("asin")
asinh = _register_torch_operation("asinh")
atan = _register_torch_operation("atan")
atanh = _register_torch_operation("atanh")
bitwise_not = _register_torch_operation("bitwise_not")
ceil = _register_torch_operation("ceil")
cos = _register_torch_operation("cos")
cosh = _register_torch_operation("cosh")
digamma = _register_torch_operation("digamma")
erf = _register_torch_operation("erf")
erfc = _register_torch_operation("erfc")
erfinv = _register_torch_operation("erfinv")
exp = _register_torch_operation("exp")
exp2 = _register_torch_operation("exp2")
expm1 = _register_torch_operation("expm1")
floor = _register_torch_operation("floor")
isfinite = _register_torch_operation("isfinite")
lgamma = _register_torch_operation("lgamma")
log = _register_torch_operation("log")
log10 = _register_torch_operation("log10")
log1p = _register_torch_operation("log1p")
log2 = _register_torch_operation("log2")
# TODO Update ndtri to be like thudner.torch...ndtri when it's available
ndtri = _register_torch_operation("ndtri", like=prims.ndtri, module=torch.special)
neg = _register_torch_operation("neg")
reciprocal = _register_torch_operation("reciprocal")
# # NOTE torch_round to avoid a name conflict with the builtin round
torch_round = _register_torch_operation("round")
rsqrt = _register_torch_operation("rsqrt")
# # NOTE That PyTorch's "sgn" corresponds with the "sign" primitive
sgn = _register_torch_operation("sgn", like=ltorch.sign)
# # NOTE torch.sign isn't bound here because thunder always uses sgn
# sign = _register_torch_operation("sign")
signbit = _register_torch_operation("signbit")
sin = _register_torch_operation("sin")
sinh = _register_torch_operation("sinh")
sqrt = _register_torch_operation("sqrt")
tan = _register_torch_operation("tan")
tanh = _register_torch_operation("tanh")
trunc = _register_torch_operation("trunc")
real = _register_torch_operation("real")
def _elementwise_unary_checker(a: Number | TensorLike) -> bool:
return isinstance(a, TensorLike)
# NOTE PyTorch doesn't have an erfcinv implementation
def _erfcinv_impl(a: torch.Tensor) -> torch.Tensor:
return torch.erfinv(1 - a)
_register_elementwise_unary_implementation = partial(_register_implementation, checker=_elementwise_unary_checker)
_register_elementwise_unary_implementation(prims.abs, torch_abs)
_register_elementwise_unary_implementation(prims.acos, acos)
_register_elementwise_unary_implementation(prims.acosh, acosh)
_register_elementwise_unary_implementation(prims.asin, asin)
_register_elementwise_unary_implementation(prims.asinh, asinh)
_register_elementwise_unary_implementation(prims.atan, atan)
_register_elementwise_unary_implementation(prims.atanh, atanh)
_register_elementwise_unary_implementation(prims.bitwise_not, bitwise_not)
_register_elementwise_unary_implementation(prims.ceil, ceil)
_register_elementwise_unary_implementation(prims.cos, cos)
_register_elementwise_unary_implementation(prims.cosh, cosh)
_register_elementwise_unary_implementation(prims.digamma, digamma)
_register_elementwise_unary_implementation(prims.erf, erf)
_register_elementwise_unary_implementation(prims.erfc, erfc)
erfcinv = ex.register_operator("torch_erfcinv_impl", meta=prims.erfcinv, fn=_erfcinv_impl)
_register_elementwise_unary_implementation(prims.erfcinv, erfcinv)
_register_elementwise_unary_implementation(prims.erfinv, erfinv)
_register_elementwise_unary_implementation(prims.exp, exp)
_register_elementwise_unary_implementation(prims.exp2, exp2)
_register_elementwise_unary_implementation(prims.expm1, expm1)
_register_elementwise_unary_implementation(prims.floor, floor)
_register_elementwise_unary_implementation(prims.isfinite, isfinite)
_register_elementwise_unary_implementation(prims.lgamma, lgamma)
_register_elementwise_unary_implementation(prims.log, log)
_register_elementwise_unary_implementation(prims.log10, log10)
_register_elementwise_unary_implementation(prims.log1p, log1p)
_register_elementwise_unary_implementation(prims.log2, log2)
_register_elementwise_unary_implementation(prims.ndtri, ndtri)
_register_elementwise_unary_implementation(prims.neg, neg)
_register_elementwise_unary_implementation(prims.reciprocal, reciprocal)
_register_elementwise_unary_implementation(prims.round, torch_round)
_register_elementwise_unary_implementation(prims.rsqrt, rsqrt)
_register_elementwise_unary_implementation(prims.sign, sgn)
_register_elementwise_unary_implementation(prims.signbit, signbit)
_register_elementwise_unary_implementation(prims.sin, sin)
_register_elementwise_unary_implementation(prims.sinh, sinh)
_register_elementwise_unary_implementation(prims.sqrt, sqrt)
_register_elementwise_unary_implementation(prims.tan, tan)
_register_elementwise_unary_implementation(prims.tanh, tanh)
_register_elementwise_unary_implementation(prims.trunc, trunc)
_register_elementwise_unary_implementation(prims.real, real)
_register_elementwise_unary_implementation(ltorch.abs, torch_abs)
_register_elementwise_unary_implementation(ltorch.acos, acos)
_register_elementwise_unary_implementation(ltorch.acosh, acosh)
_register_elementwise_unary_implementation(ltorch.asin, asin)
_register_elementwise_unary_implementation(ltorch.asinh, asinh)
_register_elementwise_unary_implementation(ltorch.atan, atan)
_register_elementwise_unary_implementation(ltorch.atanh, atanh)
_register_elementwise_unary_implementation(ltorch.bitwise_not, bitwise_not)
_register_elementwise_unary_implementation(ltorch.ceil, ceil)
_register_elementwise_unary_implementation(ltorch.cos, cos)
_register_elementwise_unary_implementation(ltorch.cosh, cosh)
_register_elementwise_unary_implementation(ltorch.digamma, digamma)
_register_elementwise_unary_implementation(ltorch.erf, erf)
_register_elementwise_unary_implementation(ltorch.erfc, erfc)
_register_elementwise_unary_implementation(ltorch.erfinv, erfinv)
_register_elementwise_unary_implementation(ltorch.exp, exp)
_register_elementwise_unary_implementation(ltorch.exp2, exp2)
_register_elementwise_unary_implementation(ltorch.expm1, expm1)
_register_elementwise_unary_implementation(ltorch.floor, floor)
_register_elementwise_unary_implementation(ltorch.isfinite, isfinite)
_register_elementwise_unary_implementation(ltorch.lgamma, lgamma)
_register_elementwise_unary_implementation(ltorch.log, log)
_register_elementwise_unary_implementation(ltorch.log10, log10)
_register_elementwise_unary_implementation(ltorch.log1p, log1p)
_register_elementwise_unary_implementation(ltorch.log2, log2)
# TODO Update ndtri when it's added back to thunder.torch...
# _register_elementwise_unary_implementation(ltorch.ndtri, ndtri)
_register_elementwise_unary_implementation(ltorch.neg, neg)
_register_elementwise_unary_implementation(ltorch.reciprocal, reciprocal)
_register_elementwise_unary_implementation(ltorch.round, torch_round)
_register_elementwise_unary_implementation(ltorch.rsqrt, rsqrt)
_register_elementwise_unary_implementation(ltorch.sign, sgn)
_register_elementwise_unary_implementation(ltorch.signbit, signbit)
_register_elementwise_unary_implementation(ltorch.sin, sin)
_register_elementwise_unary_implementation(ltorch.sin, sin)
_register_elementwise_unary_implementation(ltorch.sqrt, sqrt)
_register_elementwise_unary_implementation(ltorch.tan, tan)
_register_elementwise_unary_implementation(ltorch.tanh, tanh)
_register_elementwise_unary_implementation(ltorch.trunc, trunc)
_register_elementwise_unary_implementation(ltorch.real, real)
# nn.functional elementwise unary
celu = _register_torch_operation("celu", module=torch.nn.functional)
elu = _register_torch_operation("elu", module=torch.nn.functional)
gelu = _register_torch_operation("gelu", module=torch.nn.functional)
hardshrink = _register_torch_operation("hardshrink", module=torch.nn.functional)
hardswish = _register_torch_operation("hardswish", module=torch.nn.functional)
leaky_relu = _register_torch_operation("leaky_relu", module=torch.nn.functional)
logsigmoid = _register_torch_operation("logsigmoid", module=torch.nn.functional)
log_sigmoid_backward = _register_torch_operation(
"torch.ops.aten.log_sigmoid_backward", like=ltorch.log_sigmoid_backward
)
relu = _register_torch_operation("relu", module=torch.nn.functional)
relu6 = _register_torch_operation("relu6", module=torch.nn.functional)
selu = _register_torch_operation("selu", module=torch.nn.functional)
silu = _register_torch_operation("silu", module=torch.nn.functional)
def _elementwise_unary_with_inplace_checker(a: TensorProxy, /, inplace: bool = False) -> bool:
return isinstance(a, TensorProxy) and not inplace
_register_elementwise_unary_implementation(ltorch.elu, elu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.celu, celu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.gelu, gelu, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.hardshrink, hardshrink, checker=_always_executable)
_register_elementwise_unary_implementation(ltorch.hardswish, hardswish, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.leaky_relu, leaky_relu, checker=_always_executable)
_register_elementwise_unary_implementation(
ltorch.log_sigmoid_backward, log_sigmoid_backward, checker=_always_executable
)
_register_elementwise_unary_implementation(ltorch.logsigmoid, logsigmoid)
_register_elementwise_unary_implementation(ltorch.relu, relu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.relu6, relu6, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.selu, selu, checker=_elementwise_unary_with_inplace_checker)
_register_elementwise_unary_implementation(ltorch.silu, silu, checker=_always_executable)
#
# Elementwise binary operations
#
# TODO Review type promotion differences
add = _register_torch_operation("add")
atan2 = _register_torch_operation("atan2")
bitwise_and = _register_torch_operation("bitwise_and")
bitwise_or = _register_torch_operation("bitwise_or")
bitwise_xor = _register_torch_operation("bitwise_xor")
copysign = _register_torch_operation("copysign")
eq = _register_torch_operation("eq")
floor_divide = _register_torch_operation("floor_divide")
fmod = _register_torch_operation("fmod")
ge = _register_torch_operation("ge")
gt = _register_torch_operation("gt")
logical_and = _register_torch_operation("logical_and")
le = _register_torch_operation("le")
lt = _register_torch_operation("lt")
maximum = _register_torch_operation("maximum")
minimum = _register_torch_operation("minimum")
mul = _register_torch_operation("mul")
ne = _register_torch_operation("ne")
nextafter = _register_torch_operation("nextafter")
polygamma = _register_torch_operation("polygamma", module=torch.special)
pow = _register_torch_operation("pow")
remainder = _register_torch_operation("remainder")
sub = _register_torch_operation("sub")
true_divide = _register_torch_operation("true_divide")
zeta = _register_torch_operation("zeta", module=torch.special)
div = _register_torch_operation("div")
# NOTE PyTorch elementwise operations require at least one input to be a tensor
def _elementwise_binary_checker(a: Number | TensorProxy, b: Number | TensorProxy) -> bool:
return isinstance(a, TensorLike) or isinstance(b, TensorLike)
def _div_checker(
a: Number | TensorProxy,
b: Number | TensorProxy,
*,
rounding_mode: None | str = None,
out: None | TensorProxy = None,
) -> TensorProxy:
return _elementwise_binary_checker(a, b) and (rounding_mode is None or isinstance(rounding_mode, str))
# NOTE add and sub have special check and factory functions to support alpha
def _add_sub_checker(
a: Number | TensorProxy, b: Number | TensorProxy, *, alpha: None | Number | TensorProxy = None
) -> bool:
return _elementwise_binary_checker(a, b) and (alpha is None or isinstance(alpha, Number))
# NOTE add and sub have a custom execution transform because the torch operations don't support alpha=None
def _add_transform(
a: Number | TensorProxy, b: Number | TensorProxy, *, alpha: None | Number | TensorProxy = None
) -> TensorProxy:
if alpha is None:
return add(a, b)
return add(a, b, alpha=alpha)
# Maps exact inputs to truncation division
def _div_prim_impl(a: Number | torch.Tensor, b: Number | torch.Tensor) -> torch.Tensor:
if dtypes.is_exact_dtype(to_dtype(a.dtype)) and dtypes.is_exact_dtype(to_dtype(a.dtype)):
return torch.div(a, b, rounding_mode="trunc")
return torch.true_divide(a, b)
# NOTE add and sub have a custom execution transform because the torch operations don't support alpha=None
def _sub_transform(a: Number | TensorProxy, b: Number | TensorProxy, *, alpha: None | Number = None) -> TensorProxy:
if alpha is None:
return sub(a, b)
return sub(a, b, alpha=alpha)
def _div_transform(
a: Number | TensorProxy,
b: Number | TensorProxy,
/,
*,
rounding_mode: None | str = None,
out: None | TensorProxy = None,
) -> TensorProxy:
if rounding_mode is None:
return div(a, b)
return div(a, b, rounding_mode=rounding_mode)
_register_elementwise_binary_implementation = partial(_register_implementation, checker=_elementwise_binary_checker)
_register_elementwise_binary_implementation(prims.add, add)
_register_elementwise_binary_implementation(prims.atan2, atan2)
_register_elementwise_binary_implementation(prims.bitwise_and, bitwise_and)
_register_elementwise_binary_implementation(prims.bitwise_or, bitwise_or)
_register_elementwise_binary_implementation(prims.bitwise_xor, bitwise_xor)
div_prim_impl = ex.register_operator("torch_div_prim_impl", meta=prims.div.meta, fn=_div_prim_impl)
_register_elementwise_binary_implementation(prims.div, div_prim_impl)
_register_elementwise_binary_implementation(prims.eq, eq)
_register_elementwise_binary_implementation(prims.fmod, fmod)
_register_elementwise_binary_implementation(prims.ge, ge)
_register_elementwise_binary_implementation(prims.gt, gt)
_register_elementwise_binary_implementation(prims.le, le)
_register_elementwise_binary_implementation(prims.lt, lt)
_register_elementwise_binary_implementation(prims.maximum, maximum)
_register_elementwise_binary_implementation(prims.minimum, minimum)
_register_elementwise_binary_implementation(prims.mul, mul)
_register_elementwise_binary_implementation(prims.ne, ne)
_register_elementwise_binary_implementation(prims.nextafter, nextafter)
_register_elementwise_binary_implementation(prims.pow, pow)
_register_elementwise_binary_implementation(prims.remainder, remainder)
_register_elementwise_binary_implementation(prims.sub, sub)
_register_elementwise_binary_implementation(prims.zeta, zeta)
_register_elementwise_binary_implementation(ltorch.add, checker=_add_sub_checker, execution_transform=_add_transform)
_register_elementwise_binary_implementation(ltorch.atan2, atan2)
_register_elementwise_binary_implementation(ltorch.bitwise_and, bitwise_and)
_register_elementwise_binary_implementation(ltorch.bitwise_or, bitwise_or)
_register_elementwise_binary_implementation(ltorch.bitwise_xor, bitwise_xor)
_register_elementwise_binary_implementation(ltorch.copysign, copysign)
_register_elementwise_binary_implementation(ltorch.eq, eq)
_register_elementwise_binary_implementation(ltorch.floor_divide, floor_divide)
_register_elementwise_binary_implementation(ltorch.fmod, fmod)
_register_elementwise_binary_implementation(ltorch.ge, ge)
_register_elementwise_binary_implementation(ltorch.gt, gt)