Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 53 additions & 19 deletions olive/passes/onnx/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
import re
from collections.abc import Iterable
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Optional, Union
Expand Down Expand Up @@ -776,40 +777,46 @@ def update_llm_pipeline_genai_config(


def update_llm_pipeline_genai_config_gpu(
model: ONNXModelHandler,
model: Union[ONNXModelHandler, CompositeModelHandler],
output_model_dir: Union[str, Path],
input_model_path: Union[str, Path],
decoder_config_extra: Optional[dict[str, Any]] = None,
) -> ONNXModelHandler:
composite_components: Optional[Iterable[tuple[str, ONNXModelHandler]]] = None,
) -> Union[ONNXModelHandler, CompositeModelHandler]:
"""Update the LLM pipeline in the model's genai_config.json file.

:param model: The model to update.
:param model: The model (single or composite) to update.
:param output_model_dir: Directory where the updated genai_config.json should be written.
:param decoder_config_extra: Extra configuration for the decoder.
:param composite_components: Optional iterable of (component_name, ONNXModelHandler)
used to build a multi-component pipeline.
:return: The same `model` object (with its directory now having updated genai_config.json).
"""
output_model_dir = Path(output_model_dir)

# update genai_config if it exists
additional_files = model.model_attributes["additional_files"]
genai_config_path = None
genai_config_path = Path(input_model_path).parent / "genai_config.json"
for file_path in additional_files:
if Path(file_path).name == "genai_config.json":
genai_config_path = file_path
break

if genai_config_path.exists():
genai_config_path = str(genai_config_path.resolve())
else:
if not genai_config_path:
return model

with open(genai_config_path) as f:
genai_config = json.load(f)

# update model_type
genai_config["model"]["type"] = "decoder-pipeline"

# Update the provider_options list
provider_option = {"qnn": {"backend_type": "gpu"}}
genai_config["model"]["decoder"]["session_options"]["provider_options"] = [provider_option]
decoder = genai_config["model"].setdefault("decoder", {})
session_opts = decoder.setdefault("session_options", {})
session_opts["provider_options"] = [provider_option]

# update decoder config
decoder_config = genai_config["model"]["decoder"]
decoder_config.get("sliding_window", {}).pop("slide_inputs", None)

for key, value in (decoder_config_extra or {}).items():
exisiting_value = decoder_config.get(key)
if isinstance(exisiting_value, dict):
Expand All @@ -819,20 +826,47 @@ def update_llm_pipeline_genai_config_gpu(
else:
decoder_config[key] = value

pipeline_config = {}
component_io_config = model.io_config
pipeline_config["model_onnx"] = {
"filename": Path(model.model_path).name,
"inputs": component_io_config["input_names"],
"outputs": component_io_config["output_names"],
}
# --- Build pipeline_config ---
pipeline_config: dict[str, Any] = {}

if composite_components is None:
if not isinstance(model, ONNXModelHandler):
handlers = list(model.get_model_components())
if not handlers:
return model
_, single_handler = handlers[0]
else:
single_handler = model

component_io_config = single_handler.io_config
pipeline_config["model_onnx"] = {
"filename": Path(single_handler.model_path).name,
"inputs": component_io_config["input_names"],
"outputs": component_io_config["output_names"],
}

else:
# Composite case: one entry per component
for comp_name, comp_handler in composite_components:
component_io_config = comp_handler.io_config
pipeline_config[comp_name] = {
"filename": Path(comp_handler.model_path).name,
"inputs": component_io_config["input_names"],
"outputs": component_io_config["output_names"],
}
if comp_name == "ctx_1":
pipeline_config[comp_name]["run_on_prompt"] = False
else:
pipeline_config[comp_name]["run_on_token_gen"] = False

decoder_config["pipeline"] = [pipeline_config]

# save the updated genai_config
new_genai_config_path = output_model_dir / "genai_config.json"
with new_genai_config_path.open("w") as f:
json.dump(genai_config, f, indent=4)
additional_files.remove(genai_config_path)
additional_files.append(str(new_genai_config_path))

return model

Expand Down
2 changes: 2 additions & 0 deletions olive/passes/onnx/context_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def _generate_context_binary(
if execution_provider == ExecutionProvider.QNNExecutionProvider:
if str(device).lower() == "gpu":
provider_options["backend_path"] = "libQnnGpu.so" if platform.system() == "Linux" else "QnnGpu.dll"
if share_ep_contexts:
provider_options["enable_gpu_weight_sharing"] = "1"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is enable_gpu_weight_sharing a supported provider option? The corresponding PR to onnxruntime-qnn (onnxruntime/onnxruntime-qnn#67) does not add it as one.

update_llm_pipeline_genai_config_gpu_ctxbin(model_path)
else:
if version.parse(OrtVersion).release < version.parse("1.22.0").release:
Expand Down
138 changes: 114 additions & 24 deletions olive/passes/onnx/static_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from copy import deepcopy
from pathlib import Path

import onnx
Expand Down Expand Up @@ -56,6 +57,14 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
default_value=64,
description="Input length of the context model.",
),
"context_lengths": PassConfigParam(
type_=list[int],
default_value=None,
description=(
"List of context lengths to generate static models QNN_GPU."
"If None or empty, falls back to single 'context_length'."
),
),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vjatoth-qti, would the intended use of this in a recipe look like this?

        "st": {
            "type": "StaticLLM",
            "batch_size": 1,
            "context_lengths": [1, 64]
        }

Do we ever expect len(context_lengths) > 2? If not, it seems like we could follow the NPU strategy for StaticLLM, where "context_length": x always implies a hybrid AR1 + ARx model, and we wouldn't need a new "context_lengths" key.

"group_session_options": PassConfigParam(
type_=dict,
description=(
Expand Down Expand Up @@ -182,57 +191,138 @@ def process_context_iterator(component_models, llm_pipeline, output_dir):
)

def _run_qnn_gpu(self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: Path):
"""QNN_GPU path: generate one or more static ONNX models for different context lengths.

- If config.context_lengths is None/empty: use config.context_length (single model).
- If config.context_lengths has 1 value: use that context length (single model).
- If config.context_lengths has >1 values: generate multiple models and return CompositeModelHandler.
"""
output_model_dir = Path(output_model_path).with_suffix("")
model_path = Path(model.model_path)

# --- Step 1: Load model (handle both single and external data) ---
try:
model_proto = onnx.load(model_path, load_external_data=True)
base_model_proto = onnx.load(model_path, load_external_data=True)
except Exception as e:
raise RuntimeError(f"Failed to load ONNX model: {e}") from e

# --- Step 2: Fix symbolic dimensions ---
batch_size, sequence_length = OnnxDAG(model_proto).get_io_shape("input_ids")
# --- Step 2: Get symbolic batch and sequence dims once ---
batch_size, sequence_length = OnnxDAG(base_model_proto).get_io_shape("input_ids")
if not (isinstance(batch_size, str) and isinstance(sequence_length, str)):
raise ValueError("Input dimensions must be symbolic before static shape fixing.")

param_mapping = {batch_size: config.batch_size, sequence_length: config.context_length}
self.fix_shape(model_proto, param_mapping)
# --- Determine which context lengths to use ---
cfg_ctx_lengths = getattr(config, "context_lengths", None) or []
ctx_lengths_list = [int(x) for x in cfg_ctx_lengths if x is not None]

if not ctx_lengths_list:
# Fall back to single context_length in config
ctx_lengths_list = [int(config.context_length)]

# If only one context length, we still treat it uniformly but return a single handler.
multiple = len(ctx_lengths_list) > 1

generated_handlers: dict[int, ONNXModelHandler] = {}
generated_names: dict[int, str] = {}

for ctx_len in ctx_lengths_list:
# --- Clone base model proto for this variant ---
model_proto = onnx.ModelProto()
model_proto.CopyFrom(base_model_proto)

# --- Step 3: Fix symbolic dimensions for this context length ---
param_mapping = {batch_size: config.batch_size, sequence_length: ctx_len}
self.fix_shape(model_proto, param_mapping)

add_version_metadata_to_model_proto(model_proto)

# --- Step 4: Save as external-data ONNX ---
onnx_file_name = f"model_ctx{ctx_len}.onnx"
output_model_file = Path(output_model_dir) / onnx_file_name
external_data_file = Path(output_model_dir) / f"{onnx_file_name}.data"

# --- Step 3: Save model as external-data format ---
output_model_file = Path(output_model_dir) / "model.onnx"
external_data_file = Path(output_model_dir) / "model.onnx.data"
output_model_dir.mkdir(parents=True, exist_ok=True)
onnx.save(
model_proto,
str(output_model_file),
save_as_external_data=True,
all_tensors_to_one_file=True,
location=external_data_file.name,
convert_attribute=False,
)

# Build handler for this static model
new_model_attributes = deepcopy(model.model_attributes) or {}
handler = ONNXModelHandler(
model_path=output_model_dir,
onnx_file_name=output_model_file.name,
model_attributes=new_model_attributes,
)

# Store handler + a logical component name (e.g., ctx_128)
generated_handlers[ctx_len] = handler
generated_names[ctx_len] = f"ctx_{ctx_len}"

# --- Step 5: Update genai_config.json ---
# For single model: pipeline with one component.
# For multiple models: pipeline with multiple components (composite).
if not multiple:
# Single context length
ctx_len = ctx_lengths_list[0]
handler = generated_handlers[ctx_len]

decoder_config_extra = {
"inputs": {
"past_sequence_length": "past_seq_len",
"total_sequence_length": "total_seq_len",
},
"sliding_window": {
"window_size": ctx_len,
"pad_value": 0,
"alignment": "left",
"slide_key_value_cache": False,
},
}

return update_llm_pipeline_genai_config_gpu(
model=handler,
output_model_dir=output_model_dir,
decoder_config_extra=decoder_config_extra,
composite_components=None,
)

# Multiple context lengths -> wrap in CompositeModelHandler and create composite pipeline
components = []
component_names = []
for ctx_len, handler in sorted(generated_handlers.items(), key=lambda kv: kv[0]):
components.append(handler)
component_names.append(generated_names[ctx_len])

new_model_attributes = deepcopy(model.model_attributes) or {}

onnx.save(
model_proto,
str(output_model_file),
save_as_external_data=True,
all_tensors_to_one_file=True,
location=external_data_file.name,
convert_attribute=False,
composite = CompositeModelHandler(
model_components=components, model_component_names=component_names, model_attributes=new_model_attributes
)

decoder_config_extra = {
# Build per-component sliding_window config keyed by name
composite_decoder_extra = {
"inputs": {
"past_sequence_length": "past_seq_len",
"total_sequence_length": "total_seq_len",
},
"sliding_window": {
"window_size": config.context_length,
"window_size": max(ctx_lengths_list),
"pad_value": 0,
"alignment": "left",
"slide_key_value_cache": False,
},
}

input_model_path = model.model_path
model_static = ONNXModelHandler(model_path=output_model_dir, onnx_file_name=output_model_file.name)

return update_llm_pipeline_genai_config_gpu(
model_static,
output_model_dir,
input_model_path,
decoder_config_extra,
model=composite,
output_model_dir=output_model_dir,
decoder_config_extra=composite_decoder_extra,
composite_components=list(zip(component_names, components)),
)

@staticmethod
Expand Down
Loading