Skip to content

Commit b85f1b0

Browse files
author
Ali Alshaarawy
committed
test propagation using simple buffer casting transform test
1 parent f41f6cd commit b85f1b0

File tree

1 file changed

+179
-0
lines changed

1 file changed

+179
-0
lines changed

thunder/tests/test_transforms.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,3 +646,182 @@ def test_disable_params_and_buffer_check():
646646
)
647647

648648
assert len(check_bsyms) == 1 # We only have the check for input.
649+
650+
def test_buffer_quantization():
651+
import torch.nn as nn
652+
import itertools
653+
654+
from typing import Any, Optional, Tuple, Union, List
655+
656+
class QuantizeBuffers(thunder.core.transform_common.Transform):
657+
def __init__(self):
658+
self.quant_states = {}
659+
self.quantized_submodule_names = set()
660+
661+
def transform_module(self, model: thunder.ThunderModule):
662+
self.thunder_module = model
663+
for n, b in model._model.named_buffers():
664+
qb = b.to(torch.bfloat16)
665+
self.quant_states[n] = {
666+
"dtype": b.dtype,
667+
"shape": tuple(b.shape),
668+
"qb.dtype": qb.dtype,
669+
"qb.shape": tuple(qb.shape),
670+
}
671+
model._overrides_buffers[n] = qb
672+
673+
def transform_traces_pre_prologue(
674+
self, prologue_trace, computation_trace, epilogue_trace, **kwargs
675+
):
676+
tm = self.thunder_module
677+
from thunder.core.trace import tracectx
678+
679+
checks = thunder.transforms.utils.get_checks(prologue_trace)
680+
681+
prologue_proxy_map = {
682+
get_param_bsym.output.name: dict(
683+
shape=self.quant_states[model_weight_name]["qb.shape"],
684+
dtype=thunder.dtypes.to_dtype(
685+
self.quant_states[model_weight_name]["qb.dtype"]
686+
),
687+
)
688+
for model_weight_name, (check_bsym, get_param_bsym) in checks.items()
689+
if model_weight_name in self.quant_states
690+
}
691+
692+
# here we switch the prologue_trace to a copy with new metadata
693+
prologue_trace = (
694+
thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
695+
prologue_trace, prologue_proxy_map
696+
)
697+
)
698+
699+
checks = thunder.transforms.utils.get_checks(prologue_trace)
700+
for n, qs in self.quant_states.items():
701+
check, get_param = checks[n]
702+
# check has args: tensor, shape, device, dtype, requires_grad
703+
proxy, _, device, _, requires_grad = check.args
704+
check.args = (
705+
proxy,
706+
qs["qb.shape"],
707+
device,
708+
qs["qb.dtype"],
709+
False,
710+
)
711+
712+
computation_proxy_map = {
713+
csym.name: dict(
714+
shape=psym.shape,
715+
dtype=psym.dtype,
716+
)
717+
for psym, csym in zip(
718+
prologue_trace.bound_symbols[-1].args[0][0], computation_trace.args
719+
)
720+
if psym.shape != csym.shape or psym.dtype != csym.dtype
721+
}
722+
723+
new_computation_trace = (
724+
thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
725+
computation_trace, computation_proxy_map
726+
)
727+
)
728+
729+
producers, consumers = thunder.core.utils.producers_and_consumers(
730+
new_computation_trace
731+
)
732+
733+
bound_symbols = new_computation_trace.bound_symbols
734+
new_computation_trace.bound_symbols = []
735+
736+
new_computation_trace._siginfo.args = [
737+
(a.name, None) for a in new_computation_trace.args
738+
]
739+
740+
computation_proxy_map = {}
741+
new_bound_symbols = []
742+
for bsym in bound_symbols:
743+
if (
744+
bsym.sym == thunder.torch.to
745+
and producers[bsym.args[0]].sym == thunder.core.prims.unpack_trivial
746+
):
747+
inp = bsym.args[0]
748+
args = (inp, inp.dtype, *bsym.args[2:])
749+
computation_proxy_map[bsym.output.name] = dict(
750+
shape=inp.shape, dtype=inp.dtype
751+
)
752+
assert (
753+
len(bsym.subsymbols) == 1
754+
and bsym.subsymbols[0].sym
755+
== thunder.core.prims.convert_element_type
756+
)
757+
subsymbols = [bsym.subsymbols[0].from_bsym(args=(inp, inp.dtype))]
758+
new_bound_symbols.append(
759+
bsym.from_bsym(args=args, subsymbols=subsymbols)
760+
)
761+
else:
762+
new_bound_symbols.append(bsym.from_bsym())
763+
764+
new_computation_trace.bound_symbols = new_bound_symbols
765+
766+
new_computation_trace = (
767+
thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
768+
new_computation_trace, computation_proxy_map
769+
)
770+
)
771+
772+
new_computation_trace.set_provenance(
773+
thunder.core.trace.TraceProvenance("Dtype Convert")
774+
)
775+
return prologue_trace, new_computation_trace, epilogue_trace
776+
777+
class cast(nn.Module):
778+
def __init__(
779+
self,
780+
k_shape: Tuple[int, int, int, int],
781+
v_shape: Tuple[int, int, int, int],
782+
device: Optional[torch.device] = None,
783+
dtype: Optional[torch.dtype] = None,
784+
) -> None:
785+
super().__init__()
786+
self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)
787+
self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)
788+
789+
def forward(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
790+
# move the buffer to the activation dtype for when AMP is used
791+
self.k = self.k.to(k.dtype)
792+
self.v = self.v.to(v.dtype)
793+
# update the cache
794+
return self.k, self.v
795+
796+
# BUG: issue: 1637
797+
class ParentModule(nn.Module):
798+
def __init__(self, k_shape: Tuple[int, int, int, int], v_shape: Tuple[int, int, int, int], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
799+
super().__init__()
800+
self.cast_module = cast(k_shape, v_shape, device=device, dtype=dtype)
801+
802+
def forward(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
803+
return self.cast_module(k, v)
804+
805+
with torch.device("cpu"):
806+
k_shape = (2, 3, 4, 5)
807+
v_shape = (2, 3, 4, 5)
808+
device = torch.device("cpu")
809+
dtype = torch.float32
810+
model = (ParentModule(k_shape, v_shape, device=device, dtype=dtype).eval().requires_grad_(False))
811+
812+
k = torch.randn(2, 3, 4, 5, device=device, dtype=torch.half)
813+
v = torch.randn(2, 3, 4, 5, device=device, dtype=torch.half)
814+
cast_jit = thunder.jit(model, transforms=[QuantizeBuffers(),])
815+
output_k, output_v = cast_jit(k, v)
816+
817+
def check_dtypes(bsym):
818+
for a in itertools.chain(bsym.flat_args, bsym.flat_outs):
819+
if isinstance(a, thunder.TensorProxy):
820+
assert a.dtype == thunder.dtypes.bfloat16
821+
for sbsym in bsym.subsymbols:
822+
check_dtypes(sbsym)
823+
824+
for tr in thunder.last_traces(cast_jit):
825+
if str(tr.get_provenance()) == "# Constructed by Dtype Convert":
826+
for bsym in tr.bound_symbols:
827+
check_dtypes(bsym)

0 commit comments

Comments
 (0)