@@ -647,21 +647,22 @@ def test_disable_params_and_buffer_check():
647
647
648
648
assert len (check_bsyms ) == 1 # We only have the check for input.
649
649
650
- def test_buffer_quantization ():
650
+
651
+ def test_buffer_dtype_casting ():
651
652
import torch .nn as nn
652
653
import itertools
653
654
654
655
from typing import Any , Optional , Tuple , Union , List
655
656
656
657
class CastBuffers (thunder .core .transform_common .Transform ):
657
658
def __init__ (self ):
658
- self .quant_states = {}
659
+ self .cast_states = {}
659
660
660
661
def transform_module (self , model : thunder .ThunderModule ):
661
662
self .thunder_module = model
662
663
for n , b in model ._model .named_buffers ():
663
664
qb = b .to (torch .bfloat16 )
664
- self .quant_states [n ] = {
665
+ self .cast_states [n ] = {
665
666
"dtype" : b .dtype ,
666
667
"shape" : tuple (b .shape ),
667
668
"qb.dtype" : qb .dtype ,
@@ -679,13 +680,13 @@ def transform_traces_pre_prologue(
679
680
680
681
prologue_proxy_map = {
681
682
get_param_bsym .output .name : dict (
682
- shape = self .quant_states [model_weight_name ]["qb.shape" ],
683
+ shape = self .cast_states [model_weight_name ]["qb.shape" ],
683
684
dtype = thunder .dtypes .to_dtype (
684
- self .quant_states [model_weight_name ]["qb.dtype" ]
685
+ self .cast_states [model_weight_name ]["qb.dtype" ]
685
686
),
686
687
)
687
688
for model_weight_name , (check_bsym , get_param_bsym ) in checks .items ()
688
- if model_weight_name in self .quant_states
689
+ if model_weight_name in self .cast_states
689
690
}
690
691
691
692
# here we switch the prologue_trace to a copy with new metadata
@@ -696,7 +697,7 @@ def transform_traces_pre_prologue(
696
697
)
697
698
698
699
checks = thunder .transforms .utils .get_checks (prologue_trace )
699
- for n , qs in self .quant_states .items ():
700
+ for n , qs in self .cast_states .items ():
700
701
check , get_param = checks [n ]
701
702
# check has args: tensor, shape, device, dtype, requires_grad
702
703
proxy , _ , device , _ , requires_grad = check .args
@@ -823,4 +824,4 @@ def check_dtypes(bsym):
823
824
for tr in thunder .last_traces (cast_jit ):
824
825
if str (tr .get_provenance ()) == "# Constructed by Dtype Convert" :
825
826
for bsym in tr .bound_symbols :
826
- check_dtypes (bsym )
827
+ check_dtypes (bsym )
0 commit comments