Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Langchain SHARK Compilation support for all paths #1691

Merged
merged 1 commit into from
Jul 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 196 additions & 36 deletions apps/language_models/langchain/h2oai_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
from apps.stable_diffusion.src.utils.utils import _compile_module
from io import BytesIO
import torch_mlir

from transformers import TextGenerationPipeline
from transformers.pipelines.text_generation import ReturnType
Expand All @@ -20,8 +22,38 @@
from pathlib import Path
from shark.shark_inference import SharkInference
from shark.shark_downloader import download_public_file
from shark.shark_importer import import_with_fx
from apps.stable_diffusion.src import args

# Brevitas
from typing import List, Tuple
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl

def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
return [lhs[0], rhs[0]]
else:
raise ValueError("Input shapes not supported.")


def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype


def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return


brevitas_matmul_rhs_group_quant_library = [
brevitas〇matmul_rhs_group_quant〡shape,
brevitas〇matmul_rhs_group_quant〡dtype,
brevitas〇matmul_rhs_group_quant〡has_value_semantics]

global_device = "cuda"
global_precision = "fp16"

Expand All @@ -31,6 +63,67 @@
tensor_device = "cpu" if args.device == "cpu" else "cuda"


class H2OGPTModel(torch.nn.Module):
def __init__(self, device, precision):
super().__init__()
torch_dtype = (
torch.float32
if precision == "fp32" or device == "cpu"
else torch.float16
)
device_map = {"": "cpu"} if device == "cpu" else {"": 0}
model_kwargs = {
"local_files_only": False,
"torch_dtype": torch_dtype,
"resume_download": True,
"use_auth_token": False,
"trust_remote_code": True,
"offload_folder": "offline_folder",
"device_map": device_map,
}
config = AutoConfig.from_pretrained(
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
use_auth_token=False,
trust_remote_code=True,
offload_folder="offline_folder",
)
self.model = AutoModelForCausalLM.from_pretrained(
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
config=config,
**model_kwargs,
)
if precision in ["int4", "int8"]:
print("Applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
self.model.transformer.h,
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=128,
quantize_weight_zero_point=False,
)
print("Weight quantization applied.")

def forward(self, input_ids, attention_mask):
input_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": None,
"use_cache": True,
}
output = self.model(
**input_dict,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
return output.logits[:, -1, :]


class H2OGPTSHARKModel(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -42,47 +135,48 @@ def __init__(self):
mlir_path = Path(model_name + "_" + args.precision + ".mlir")
shark_module = None

need_to_compile = False
if not vmfb_path.exists():
if args.device in ["cuda", "cpu"] and args.precision in [
"fp16",
"fp32",
]:
# Downloading VMFB from shark_tank
print("Downloading vmfb from shark tank.")
need_to_compile = True
# Downloading VMFB from shark_tank
print("Trying to download pre-compiled vmfb from shark tank.")
download_public_file(
"gs://shark_tank/langchain/" + str(vmfb_path),
vmfb_path.absolute(),
single_file=True,
)
if vmfb_path.exists():
print(
"Pre-compiled vmfb downloaded from shark tank successfully."
)
need_to_compile = False

if need_to_compile:
if not mlir_path.exists():
print("Trying to download pre-generated mlir from shark tank.")
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/langchain/" + str(vmfb_path),
vmfb_path.absolute(),
"gs://shark_tank/langchain/" + str(mlir_path),
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
# Downloading MLIR from shark_tank
download_public_file(
"gs://shark_tank/langchain/" + str(mlir_path),
mlir_path.absolute(),
single_file=True,
)
if mlir_path.exists():
with open(mlir_path, "rb") as f:
bytecode = f.read()
else:
raise ValueError(
f"MLIR not found at {mlir_path.absolute()}"
" after downloading! Please check path and try again"
)
shark_module = SharkInference(
mlir_module=bytecode,
device=args.device,
mlir_dialect="linalg",
)
print(f"[DEBUG] generating vmfb.")
shark_module = _compile_module(
shark_module, extended_model_name, []
)
print("Saved newly generated vmfb.")
# Generating the mlir
bytecode = self.get_bytecode(tensor_device, args.precision)

shark_module = SharkInference(
mlir_module=bytecode,
device=args.device,
mlir_dialect="linalg",
)
print(f"[DEBUG] generating vmfb.")
shark_module = _compile_module(
shark_module, extended_model_name, []
)
print("Saved newly generated vmfb.")

if shark_module is None:
if vmfb_path.exists():
Expand All @@ -97,6 +191,72 @@ def __init__(self):

self.model = shark_module

def get_bytecode(self, device, precision):
h2ogpt_model = H2OGPTModel(device, precision)

compilation_input_ids = torch.randint(
low=1, high=10000, size=(1, 400)
).to(device=device)
compilation_attention_mask = torch.ones(1, 400, dtype=torch.int64).to(
device=device
)

h2ogptCompileInput = (
compilation_input_ids,
compilation_attention_mask,
)

print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
h2ogpt_model,
h2ogptCompileInput,
is_f16=False,
precision=precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del h2ogpt_model
del self.src_model

print(f"[DEBUG] generating torch mlir")
if precision in ["int4", "int8"]:
from torch_mlir.compiler_utils import (
run_pipeline_with_repro_report,
)

module = torch_mlir.compile(
ts_graph,
[*h2ogptCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
module = torch_mlir.compile(
ts_graph,
[*h2ogptCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph

print(f"[DEBUG] converting to bytecode")
bytecode_stream = BytesIO()
module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
del module

return bytecode

def forward(self, input_ids, attention_mask):
result = torch.from_numpy(
self.model(
Expand Down
Loading