Skip to content

Commit c09c553

Browse files
author
Ali Alshaarawy
committed
test propagation using simple buffer casting transform test
1 parent b85f1b0 commit c09c553

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

thunder/tests/test_transforms.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -653,10 +653,9 @@ def test_buffer_quantization():
653653

654654
from typing import Any, Optional, Tuple, Union, List
655655

656-
class QuantizeBuffers(thunder.core.transform_common.Transform):
656+
class CastBuffers(thunder.core.transform_common.Transform):
657657
def __init__(self):
658658
self.quant_states = {}
659-
self.quantized_submodule_names = set()
660659

661660
def transform_module(self, model: thunder.ThunderModule):
662661
self.thunder_module = model
@@ -811,7 +810,7 @@ def forward(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch
811810

812811
k = torch.randn(2, 3, 4, 5, device=device, dtype=torch.half)
813812
v = torch.randn(2, 3, 4, 5, device=device, dtype=torch.half)
814-
cast_jit = thunder.jit(model, transforms=[QuantizeBuffers(),])
813+
cast_jit = thunder.jit(model, transforms=[CastBuffers(),])
815814
output_k, output_v = cast_jit(k, v)
816815

817816
def check_dtypes(bsym):

0 commit comments

Comments
 (0)