From 3228d4d61e61311f7da5c98cdff189e0e8a99a25 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 25 Jul 2023 12:11:45 +0000 Subject: [PATCH] Add Langchain SHARK Compilation support for all paths --- .../langchain/h2oai_pipeline.py | 232 +++++++++++++++--- 1 file changed, 196 insertions(+), 36 deletions(-) diff --git a/apps/language_models/langchain/h2oai_pipeline.py b/apps/language_models/langchain/h2oai_pipeline.py index 51920158ea..8f09cb486f 100644 --- a/apps/language_models/langchain/h2oai_pipeline.py +++ b/apps/language_models/langchain/h2oai_pipeline.py @@ -1,5 +1,7 @@ import os from apps.stable_diffusion.src.utils.utils import _compile_module +from io import BytesIO +import torch_mlir from transformers import TextGenerationPipeline from transformers.pipelines.text_generation import ReturnType @@ -20,8 +22,38 @@ from pathlib import Path from shark.shark_inference import SharkInference from shark.shark_downloader import download_public_file +from shark.shark_importer import import_with_fx from apps.stable_diffusion.src import args +# Brevitas +from typing import List, Tuple +from brevitas_examples.llm.llm_quant.quantize import quantize_model +from brevitas_examples.llm.llm_quant.run_utils import get_model_impl + +def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: + if len(lhs) == 3 and len(rhs) == 2: + return [lhs[0], lhs[1], rhs[0]] + elif len(lhs) == 2 and len(rhs) == 2: + return [lhs[0], rhs[0]] + else: + raise ValueError("Input shapes not supported.") + + +def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int: + # output dtype is the dtype of the lhs float input + lhs_rank, lhs_dtype = lhs_rank_dtype + return lhs_dtype + + +def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: + return + + +brevitas_matmul_rhs_group_quant_library = [ + brevitas〇matmul_rhs_group_quant〡shape, + brevitas〇matmul_rhs_group_quant〡dtype, + brevitas〇matmul_rhs_group_quant〡has_value_semantics] + global_device = "cuda" global_precision = "fp16" @@ -31,6 +63,67 @@ tensor_device = "cpu" if args.device == "cpu" else "cuda" +class H2OGPTModel(torch.nn.Module): + def __init__(self, device, precision): + super().__init__() + torch_dtype = ( + torch.float32 + if precision == "fp32" or device == "cpu" + else torch.float16 + ) + device_map = {"": "cpu"} if device == "cpu" else {"": 0} + model_kwargs = { + "local_files_only": False, + "torch_dtype": torch_dtype, + "resume_download": True, + "use_auth_token": False, + "trust_remote_code": True, + "offload_folder": "offline_folder", + "device_map": device_map, + } + config = AutoConfig.from_pretrained( + "h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3", + use_auth_token=False, + trust_remote_code=True, + offload_folder="offline_folder", + ) + self.model = AutoModelForCausalLM.from_pretrained( + "h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3", + config=config, + **model_kwargs, + ) + if precision in ["int4", "int8"]: + print("Applying weight quantization..") + weight_bit_width = 4 if precision == "int4" else 8 + quantize_model( + self.model.transformer.h, + dtype=torch.float32, + weight_bit_width=weight_bit_width, + weight_param_method="stats", + weight_scale_precision="float", + weight_quant_type="asym", + weight_quant_granularity="per_group", + weight_group_size=128, + quantize_weight_zero_point=False, + ) + print("Weight quantization applied.") + + def forward(self, input_ids, attention_mask): + input_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": None, + "use_cache": True, + } + output = self.model( + **input_dict, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + return output.logits[:, -1, :] + + class H2OGPTSHARKModel(torch.nn.Module): def __init__(self): super().__init__() @@ -42,47 +135,48 @@ def __init__(self): mlir_path = Path(model_name + "_" + args.precision + ".mlir") shark_module = None + need_to_compile = False if not vmfb_path.exists(): - if args.device in ["cuda", "cpu"] and args.precision in [ - "fp16", - "fp32", - ]: - # Downloading VMFB from shark_tank - print("Downloading vmfb from shark tank.") + need_to_compile = True + # Downloading VMFB from shark_tank + print("Trying to download pre-compiled vmfb from shark tank.") + download_public_file( + "gs://shark_tank/langchain/" + str(vmfb_path), + vmfb_path.absolute(), + single_file=True, + ) + if vmfb_path.exists(): + print( + "Pre-compiled vmfb downloaded from shark tank successfully." + ) + need_to_compile = False + + if need_to_compile: + if not mlir_path.exists(): + print("Trying to download pre-generated mlir from shark tank.") + # Downloading MLIR from shark_tank download_public_file( - "gs://shark_tank/langchain/" + str(vmfb_path), - vmfb_path.absolute(), + "gs://shark_tank/langchain/" + str(mlir_path), + mlir_path.absolute(), single_file=True, ) + if mlir_path.exists(): + with open(mlir_path, "rb") as f: + bytecode = f.read() else: - if mlir_path.exists(): - with open(mlir_path, "rb") as f: - bytecode = f.read() - else: - # Downloading MLIR from shark_tank - download_public_file( - "gs://shark_tank/langchain/" + str(mlir_path), - mlir_path.absolute(), - single_file=True, - ) - if mlir_path.exists(): - with open(mlir_path, "rb") as f: - bytecode = f.read() - else: - raise ValueError( - f"MLIR not found at {mlir_path.absolute()}" - " after downloading! Please check path and try again" - ) - shark_module = SharkInference( - mlir_module=bytecode, - device=args.device, - mlir_dialect="linalg", - ) - print(f"[DEBUG] generating vmfb.") - shark_module = _compile_module( - shark_module, extended_model_name, [] - ) - print("Saved newly generated vmfb.") + # Generating the mlir + bytecode = self.get_bytecode(tensor_device, args.precision) + + shark_module = SharkInference( + mlir_module=bytecode, + device=args.device, + mlir_dialect="linalg", + ) + print(f"[DEBUG] generating vmfb.") + shark_module = _compile_module( + shark_module, extended_model_name, [] + ) + print("Saved newly generated vmfb.") if shark_module is None: if vmfb_path.exists(): @@ -97,6 +191,72 @@ def __init__(self): self.model = shark_module + def get_bytecode(self, device, precision): + h2ogpt_model = H2OGPTModel(device, precision) + + compilation_input_ids = torch.randint( + low=1, high=10000, size=(1, 400) + ).to(device=device) + compilation_attention_mask = torch.ones(1, 400, dtype=torch.int64).to( + device=device + ) + + h2ogptCompileInput = ( + compilation_input_ids, + compilation_attention_mask, + ) + + print(f"[DEBUG] generating torchscript graph") + ts_graph = import_with_fx( + h2ogpt_model, + h2ogptCompileInput, + is_f16=False, + precision=precision, + f16_input_mask=[False, False], + mlir_type="torchscript", + ) + del h2ogpt_model + del self.src_model + + print(f"[DEBUG] generating torch mlir") + if precision in ["int4", "int8"]: + from torch_mlir.compiler_utils import ( + run_pipeline_with_repro_report, + ) + + module = torch_mlir.compile( + ts_graph, + [*h2ogptCompileInput], + output_type=torch_mlir.OutputType.TORCH, + backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False, + ) + print(f"[DEBUG] converting torch to linalg") + run_pipeline_with_repro_report( + module, + "builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)", + description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", + ) + else: + module = torch_mlir.compile( + ts_graph, + [*h2ogptCompileInput], + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + del ts_graph + + print(f"[DEBUG] converting to bytecode") + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + del module + + return bytecode + def forward(self, input_ids, attention_mask): result = torch.from_numpy( self.model(