@@ -646,3 +646,182 @@ def test_disable_params_and_buffer_check():
646646 )
647647
648648 assert len (check_bsyms ) == 1 # We only have the check for input.
649+
650+ def test_buffer_quantization ():
651+ import torch .nn as nn
652+ import itertools
653+
654+ from typing import Any , Optional , Tuple , Union , List
655+
656+ class QuantizeBuffers (thunder .core .transform_common .Transform ):
657+ def __init__ (self ):
658+ self .quant_states = {}
659+ self .quantized_submodule_names = set ()
660+
661+ def transform_module (self , model : thunder .ThunderModule ):
662+ self .thunder_module = model
663+ for n , b in model ._model .named_buffers ():
664+ qb = b .to (torch .bfloat16 )
665+ self .quant_states [n ] = {
666+ "dtype" : b .dtype ,
667+ "shape" : tuple (b .shape ),
668+ "qb.dtype" : qb .dtype ,
669+ "qb.shape" : tuple (qb .shape ),
670+ }
671+ model ._overrides_buffers [n ] = qb
672+
673+ def transform_traces_pre_prologue (
674+ self , prologue_trace , computation_trace , epilogue_trace , ** kwargs
675+ ):
676+ tm = self .thunder_module
677+ from thunder .core .trace import tracectx
678+
679+ checks = thunder .transforms .utils .get_checks (prologue_trace )
680+
681+ prologue_proxy_map = {
682+ get_param_bsym .output .name : dict (
683+ shape = self .quant_states [model_weight_name ]["qb.shape" ],
684+ dtype = thunder .dtypes .to_dtype (
685+ self .quant_states [model_weight_name ]["qb.dtype" ]
686+ ),
687+ )
688+ for model_weight_name , (check_bsym , get_param_bsym ) in checks .items ()
689+ if model_weight_name in self .quant_states
690+ }
691+
692+ # here we switch the prologue_trace to a copy with new metadata
693+ prologue_trace = (
694+ thunder .transforms .quantization .trace_with_replaced_proxy_metadata (
695+ prologue_trace , prologue_proxy_map
696+ )
697+ )
698+
699+ checks = thunder .transforms .utils .get_checks (prologue_trace )
700+ for n , qs in self .quant_states .items ():
701+ check , get_param = checks [n ]
702+ # check has args: tensor, shape, device, dtype, requires_grad
703+ proxy , _ , device , _ , requires_grad = check .args
704+ check .args = (
705+ proxy ,
706+ qs ["qb.shape" ],
707+ device ,
708+ qs ["qb.dtype" ],
709+ False ,
710+ )
711+
712+ computation_proxy_map = {
713+ csym .name : dict (
714+ shape = psym .shape ,
715+ dtype = psym .dtype ,
716+ )
717+ for psym , csym in zip (
718+ prologue_trace .bound_symbols [- 1 ].args [0 ][0 ], computation_trace .args
719+ )
720+ if psym .shape != csym .shape or psym .dtype != csym .dtype
721+ }
722+
723+ new_computation_trace = (
724+ thunder .transforms .quantization .trace_with_replaced_proxy_metadata (
725+ computation_trace , computation_proxy_map
726+ )
727+ )
728+
729+ producers , consumers = thunder .core .utils .producers_and_consumers (
730+ new_computation_trace
731+ )
732+
733+ bound_symbols = new_computation_trace .bound_symbols
734+ new_computation_trace .bound_symbols = []
735+
736+ new_computation_trace ._siginfo .args = [
737+ (a .name , None ) for a in new_computation_trace .args
738+ ]
739+
740+ computation_proxy_map = {}
741+ new_bound_symbols = []
742+ for bsym in bound_symbols :
743+ if (
744+ bsym .sym == thunder .torch .to
745+ and producers [bsym .args [0 ]].sym == thunder .core .prims .unpack_trivial
746+ ):
747+ inp = bsym .args [0 ]
748+ args = (inp , inp .dtype , * bsym .args [2 :])
749+ computation_proxy_map [bsym .output .name ] = dict (
750+ shape = inp .shape , dtype = inp .dtype
751+ )
752+ assert (
753+ len (bsym .subsymbols ) == 1
754+ and bsym .subsymbols [0 ].sym
755+ == thunder .core .prims .convert_element_type
756+ )
757+ subsymbols = [bsym .subsymbols [0 ].from_bsym (args = (inp , inp .dtype ))]
758+ new_bound_symbols .append (
759+ bsym .from_bsym (args = args , subsymbols = subsymbols )
760+ )
761+ else :
762+ new_bound_symbols .append (bsym .from_bsym ())
763+
764+ new_computation_trace .bound_symbols = new_bound_symbols
765+
766+ new_computation_trace = (
767+ thunder .transforms .quantization .trace_with_replaced_proxy_metadata (
768+ new_computation_trace , computation_proxy_map
769+ )
770+ )
771+
772+ new_computation_trace .set_provenance (
773+ thunder .core .trace .TraceProvenance ("Dtype Convert" )
774+ )
775+ return prologue_trace , new_computation_trace , epilogue_trace
776+
777+ class cast (nn .Module ):
778+ def __init__ (
779+ self ,
780+ k_shape : Tuple [int , int , int , int ],
781+ v_shape : Tuple [int , int , int , int ],
782+ device : Optional [torch .device ] = None ,
783+ dtype : Optional [torch .dtype ] = None ,
784+ ) -> None :
785+ super ().__init__ ()
786+ self .register_buffer ("k" , torch .zeros (k_shape , device = device , dtype = dtype ), persistent = False )
787+ self .register_buffer ("v" , torch .zeros (v_shape , device = device , dtype = dtype ), persistent = False )
788+
789+ def forward (self , k : torch .Tensor , v : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
790+ # move the buffer to the activation dtype for when AMP is used
791+ self .k = self .k .to (k .dtype )
792+ self .v = self .v .to (v .dtype )
793+ # update the cache
794+ return self .k , self .v
795+
796+ # BUG: issue: 1637
797+ class ParentModule (nn .Module ):
798+ def __init__ (self , k_shape : Tuple [int , int , int , int ], v_shape : Tuple [int , int , int , int ], device : Optional [torch .device ] = None , dtype : Optional [torch .dtype ] = None ):
799+ super ().__init__ ()
800+ self .cast_module = cast (k_shape , v_shape , device = device , dtype = dtype )
801+
802+ def forward (self , k : torch .Tensor , v : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
803+ return self .cast_module (k , v )
804+
805+ with torch .device ("cpu" ):
806+ k_shape = (2 , 3 , 4 , 5 )
807+ v_shape = (2 , 3 , 4 , 5 )
808+ device = torch .device ("cpu" )
809+ dtype = torch .float32
810+ model = (ParentModule (k_shape , v_shape , device = device , dtype = dtype ).eval ().requires_grad_ (False ))
811+
812+ k = torch .randn (2 , 3 , 4 , 5 , device = device , dtype = torch .half )
813+ v = torch .randn (2 , 3 , 4 , 5 , device = device , dtype = torch .half )
814+ cast_jit = thunder .jit (model , transforms = [QuantizeBuffers (),])
815+ output_k , output_v = cast_jit (k , v )
816+
817+ def check_dtypes (bsym ):
818+ for a in itertools .chain (bsym .flat_args , bsym .flat_outs ):
819+ if isinstance (a , thunder .TensorProxy ):
820+ assert a .dtype == thunder .dtypes .bfloat16
821+ for sbsym in bsym .subsymbols :
822+ check_dtypes (sbsym )
823+
824+ for tr in thunder .last_traces (cast_jit ):
825+ if str (tr .get_provenance ()) == "# Constructed by Dtype Convert" :
826+ for bsym in tr .bound_symbols :
827+ check_dtypes (bsym )
0 commit comments