Skip to content

Commit f01a10f

Browse files
author
Ali Alshaarawy
committed
test propagation using simple buffer casting transform test
1 parent c09c553 commit f01a10f

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

thunder/tests/test_transforms.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -647,21 +647,22 @@ def test_disable_params_and_buffer_check():
647647

648648
assert len(check_bsyms) == 1 # We only have the check for input.
649649

650-
def test_buffer_quantization():
650+
651+
def test_buffer_dtype_casting():
651652
import torch.nn as nn
652653
import itertools
653654

654655
from typing import Any, Optional, Tuple, Union, List
655656

656657
class CastBuffers(thunder.core.transform_common.Transform):
657658
def __init__(self):
658-
self.quant_states = {}
659+
self.cast_states = {}
659660

660661
def transform_module(self, model: thunder.ThunderModule):
661662
self.thunder_module = model
662663
for n, b in model._model.named_buffers():
663664
qb = b.to(torch.bfloat16)
664-
self.quant_states[n] = {
665+
self.cast_states[n] = {
665666
"dtype": b.dtype,
666667
"shape": tuple(b.shape),
667668
"qb.dtype": qb.dtype,
@@ -679,13 +680,13 @@ def transform_traces_pre_prologue(
679680

680681
prologue_proxy_map = {
681682
get_param_bsym.output.name: dict(
682-
shape=self.quant_states[model_weight_name]["qb.shape"],
683+
shape=self.cast_states[model_weight_name]["qb.shape"],
683684
dtype=thunder.dtypes.to_dtype(
684-
self.quant_states[model_weight_name]["qb.dtype"]
685+
self.cast_states[model_weight_name]["qb.dtype"]
685686
),
686687
)
687688
for model_weight_name, (check_bsym, get_param_bsym) in checks.items()
688-
if model_weight_name in self.quant_states
689+
if model_weight_name in self.cast_states
689690
}
690691

691692
# here we switch the prologue_trace to a copy with new metadata
@@ -696,7 +697,7 @@ def transform_traces_pre_prologue(
696697
)
697698

698699
checks = thunder.transforms.utils.get_checks(prologue_trace)
699-
for n, qs in self.quant_states.items():
700+
for n, qs in self.cast_states.items():
700701
check, get_param = checks[n]
701702
# check has args: tensor, shape, device, dtype, requires_grad
702703
proxy, _, device, _, requires_grad = check.args
@@ -823,4 +824,4 @@ def check_dtypes(bsym):
823824
for tr in thunder.last_traces(cast_jit):
824825
if str(tr.get_provenance()) == "# Constructed by Dtype Convert":
825826
for bsym in tr.bound_symbols:
826-
check_dtypes(bsym)
827+
check_dtypes(bsym)

0 commit comments

Comments
 (0)