Skip to content

Commit 435c399

Browse files
authored
Fix (brevitas_examples/diffusion): workaround for svdquant with SDXL (Xilinx#1256)
1 parent 16a27c1 commit 435c399

File tree

1 file changed

+5
-0
lines changed
  • src/brevitas_examples/stable_diffusion

1 file changed

+5
-0
lines changed

src/brevitas_examples/stable_diffusion/main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from brevitas_examples.common.generative.quantize import generate_quantizers
5555
from brevitas_examples.common.parse_utils import add_bool_arg
5656
from brevitas_examples.common.parse_utils import quant_format_validator
57+
from brevitas_examples.common.svd_quant import ErrorCorrectedModule
5758
from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager
5859
from brevitas_examples.llm.llm_quant.svd_quant import apply_svd_quant
5960
from brevitas_examples.stable_diffusion.mlperf_evaluation.accuracy import compute_mlperf_fid
@@ -606,6 +607,10 @@ def sdpa_zp_stats_type():
606607
rank=args.svd_quant_rank,
607608
iters=args.svd_quant_iters,
608609
dtype=torch.float32)
610+
# Workaround to expose `in_features` attribute from the ErrorCorrectedModule Wrapper
611+
for m in denoising_network.modules():
612+
if isinstance(m, ErrorCorrectedModule) and hasattr(m.layer, 'in_features'):
613+
m.in_features = m.layer.in_features
609614
print("SVDQuant applied.")
610615

611616
if args.compile_ptq:

0 commit comments

Comments
 (0)