|
7 | 7 |
|
8 | 8 | import onnx |
9 | 9 |
|
| 10 | +from olive.hardware import Device |
10 | 11 | from olive.hardware.accelerator import AcceleratorSpec |
| 12 | +from olive.hardware.constants import ExecutionProvider |
11 | 13 | from olive.model import CompositeModelHandler, ONNXModelHandler |
12 | 14 | from olive.passes import Pass |
13 | 15 | from olive.passes.onnx.common import ( |
14 | 16 | add_version_metadata_to_model_proto, |
15 | 17 | fix_dim_params, |
16 | 18 | process_llm_pipeline, |
17 | 19 | resave_model, |
| 20 | + update_llm_pipeline_genai_config_gpu, |
18 | 21 | ) |
19 | 22 | from olive.passes.onnx.onnx_dag import OnnxDAG |
20 | 23 | from olive.passes.pass_config import BasePassConfig, PassConfigParam |
@@ -61,9 +64,18 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon |
61 | 64 | ), |
62 | 65 | } |
63 | 66 |
|
64 | | - def _run_for_config( |
65 | | - self, model: CompositeModelHandler, config: type[BasePassConfig], output_model_path: str |
66 | | - ) -> CompositeModelHandler: |
| 67 | + def _run_for_config(self, model, config: type[BasePassConfig], output_model_path: str): |
| 68 | + if ( |
| 69 | + self.accelerator_spec.execution_provider == ExecutionProvider.QNNExecutionProvider |
| 70 | + and self.accelerator_spec.accelerator_type == Device.GPU |
| 71 | + ): |
| 72 | + assert isinstance(model, ONNXModelHandler), "StaticLLM (qnn-gpu) requires a single ONNXModelHandler." |
| 73 | + return self._run_qnn_gpu(model, config, output_model_path) |
| 74 | + |
| 75 | + else: |
| 76 | + return self._run_generic(model, config, output_model_path) |
| 77 | + |
| 78 | + def _run_generic(self, model: CompositeModelHandler, config: type[BasePassConfig], output_model_path: str): |
67 | 79 | assert isinstance(model, CompositeModelHandler), "StaticLLM pass only supports CompositeModelHandler" |
68 | 80 | model_components = list(model.model_components) |
69 | 81 | assert all(isinstance(m, ONNXModelHandler) for m in model_components), "All components must be ONNXModelHandler" |
@@ -169,6 +181,60 @@ def process_context_iterator(component_models, llm_pipeline, output_dir): |
169 | 181 | group_session_options=config.group_session_options, |
170 | 182 | ) |
171 | 183 |
|
| 184 | + def _run_qnn_gpu(self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: Path): |
| 185 | + output_model_dir = Path(output_model_path).with_suffix("") |
| 186 | + model_path = Path(model.model_path) |
| 187 | + |
| 188 | + # --- Step 1: Load model (handle both single and external data) --- |
| 189 | + try: |
| 190 | + model_proto = onnx.load(model_path, load_external_data=True) |
| 191 | + except Exception as e: |
| 192 | + raise RuntimeError(f"Failed to load ONNX model: {e}") from e |
| 193 | + |
| 194 | + # --- Step 2: Fix symbolic dimensions --- |
| 195 | + batch_size, sequence_length = OnnxDAG(model_proto).get_io_shape("input_ids") |
| 196 | + if not (isinstance(batch_size, str) and isinstance(sequence_length, str)): |
| 197 | + raise ValueError("Input dimensions must be symbolic before static shape fixing.") |
| 198 | + |
| 199 | + param_mapping = {batch_size: config.batch_size, sequence_length: config.context_length} |
| 200 | + self.fix_shape(model_proto, param_mapping) |
| 201 | + |
| 202 | + # --- Step 3: Save model as external-data format --- |
| 203 | + output_model_file = Path(output_model_dir) / "model.onnx" |
| 204 | + external_data_file = Path(output_model_dir) / "model.onnx.data" |
| 205 | + |
| 206 | + onnx.save( |
| 207 | + model_proto, |
| 208 | + str(output_model_file), |
| 209 | + save_as_external_data=True, |
| 210 | + all_tensors_to_one_file=True, |
| 211 | + location=external_data_file.name, |
| 212 | + convert_attribute=False, |
| 213 | + ) |
| 214 | + |
| 215 | + decoder_config_extra = { |
| 216 | + "inputs": { |
| 217 | + "past_sequence_length": "past_seq_len", |
| 218 | + "total_sequence_length": "total_seq_len", |
| 219 | + }, |
| 220 | + "sliding_window": { |
| 221 | + "window_size": config.context_length, |
| 222 | + "pad_value": 0, |
| 223 | + "alignment": "left", |
| 224 | + "slide_key_value_cache": False, |
| 225 | + }, |
| 226 | + } |
| 227 | + |
| 228 | + input_model_path = model.model_path |
| 229 | + model_static = ONNXModelHandler(model_path=output_model_dir, onnx_file_name=output_model_file.name) |
| 230 | + |
| 231 | + return update_llm_pipeline_genai_config_gpu( |
| 232 | + model_static, |
| 233 | + output_model_dir, |
| 234 | + input_model_path, |
| 235 | + decoder_config_extra, |
| 236 | + ) |
| 237 | + |
172 | 238 | @staticmethod |
173 | 239 | def fix_shape(model_proto: onnx.ModelProto, param_mapping: dict[str, int]): |
174 | 240 | """Fix the shape of the model based on the param mapping. |
|
0 commit comments