Skip to content

Commit

Permalink
Fix (example/sdxl): Allow export of FP8 linear/conv layers.
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Sep 10, 2024
1 parent 60a94d8 commit 223700a
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/brevitas_examples/stable_diffusion/sd_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ def handle_quant_param(layer, layer_dict):
layer_dict['output_scale_shape'] = output_scale.shape
layer_dict['input_scale'] = input_scale.numpy().tolist()
layer_dict['input_scale_shape'] = input_scale.shape
layer_dict['input_zp'] = input_zp.numpy().tolist()
layer_dict['input_zp'] = input_zp.to(torch.float32).cpu().numpy().tolist()
layer_dict['input_zp_shape'] = input_zp.shape
layer_dict['input_zp_dtype'] = str(torch.int8)
layer_dict['input_zp_dtype'] = str(input_zp.dtype)
layer_dict['weight_scale'] = weight_scale.cpu().numpy().tolist()
nelems = layer.weight.shape[0]
weight_scale_shape = [nelems] + [1] * (layer.weight.data.ndim - 1)
layer_dict['weight_scale_shape'] = weight_scale_shape
if torch.sum(weight_zp) != 0.:
if torch.sum(weight_zp.to(torch.float32)) != 0.:
weight_zp = weight_zp - 128. # apply offset to have signed z
layer_dict['weight_zp'] = weight_zp.cpu().numpy().tolist()
layer_dict['weight_zp'] = weight_zp.to(torch.float32).cpu().numpy().tolist()
layer_dict['weight_zp_shape'] = weight_scale_shape
layer_dict['weight_zp_dtype'] = str(torch.int8)
layer_dict['weight_zp_dtype'] = str(weight_zp.dtype)
return layer_dict


Expand All @@ -63,14 +63,16 @@ def export_quant_params(pipe, output_dir, export_vae=False):
vae_output_path = os.path.join(output_dir, 'vae.safetensors')
print(f"Saving vae to {vae_output_path} ...")
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
export_manager = StdQCDQONNXManager
export_manager.change_weight_export(export_weight_q_node=True) # We're exporting FP weights + quantization parameters
quant_params = dict()
state_dict = pipe.unet.state_dict()
state_dict = {k: v for (k, v) in state_dict.items() if 'tensor_quant' not in k}
state_dict = {k: v for (k, v) in state_dict.items() if not k.endswith('.scale.weight')}
state_dict = {k.replace('.layer.', '.'): v for (k, v) in state_dict.items()}

handled_quant_layers = set()
with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, StdQCDQONNXManager):
with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager):
for name, module in pipe.unet.named_modules():
if isinstance(module, EqualizedModule):
if id(module.layer) in handled_quant_layers:
Expand Down

0 comments on commit 223700a

Please sign in to comment.