diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index 89d846a79..99b77a275 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -40,18 +40,18 @@ def handle_quant_param(layer, layer_dict): layer_dict['output_scale_shape'] = output_scale.shape layer_dict['input_scale'] = input_scale.numpy().tolist() layer_dict['input_scale_shape'] = input_scale.shape - layer_dict['input_zp'] = input_zp.numpy().tolist() + layer_dict['input_zp'] = input_zp.to(torch.float32).cpu().numpy().tolist() layer_dict['input_zp_shape'] = input_zp.shape - layer_dict['input_zp_dtype'] = str(torch.int8) + layer_dict['input_zp_dtype'] = str(input_zp.dtype) layer_dict['weight_scale'] = weight_scale.cpu().numpy().tolist() nelems = layer.weight.shape[0] weight_scale_shape = [nelems] + [1] * (layer.weight.data.ndim - 1) layer_dict['weight_scale_shape'] = weight_scale_shape - if torch.sum(weight_zp) != 0.: + if torch.sum(weight_zp.to(torch.float32)) != 0.: weight_zp = weight_zp - 128. # apply offset to have signed z - layer_dict['weight_zp'] = weight_zp.cpu().numpy().tolist() + layer_dict['weight_zp'] = weight_zp.to(torch.float32).cpu().numpy().tolist() layer_dict['weight_zp_shape'] = weight_scale_shape - layer_dict['weight_zp_dtype'] = str(torch.int8) + layer_dict['weight_zp_dtype'] = str(weight_zp.dtype) return layer_dict @@ -63,6 +63,8 @@ def export_quant_params(pipe, output_dir, export_vae=False): vae_output_path = os.path.join(output_dir, 'vae.safetensors') print(f"Saving vae to {vae_output_path} ...") from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager + export_manager = StdQCDQONNXManager + export_manager.change_weight_export(export_weight_q_node=True) # We're exporting FP weights + quantization parameters quant_params = dict() state_dict = pipe.unet.state_dict() state_dict = {k: v for (k, v) in state_dict.items() if 'tensor_quant' not in k} @@ -70,7 +72,7 @@ def export_quant_params(pipe, output_dir, export_vae=False): state_dict = {k.replace('.layer.', '.'): v for (k, v) in state_dict.items()} handled_quant_layers = set() - with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, StdQCDQONNXManager): + with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager): for name, module in pipe.unet.named_modules(): if isinstance(module, EqualizedModule): if id(module.layer) in handled_quant_layers: