Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple OpenAI clients per Pipeline #563

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

### Features

### Pipelines can now have LLMBlocks with different OpenAI clients

For advanced use-cases, PipelineContext now accepts a `clients` dictionary of string to OpenAI client mappings. The special string of "default" sets the OpenAI client used for LLMBlocks by default, but individual LLMBlocks can override the client used by the `client` parameter in their yaml config.

Backwards-compatibility is maintained for Pipelines that only need a single client, where setting the `client` property on PipelineContext objects just sets the default client in the `clients` dictionary automatically.

### LLMBlocks can now specify `model_family` or `model_id` in their config

Each `LLMBlock` in a `Pipeline` can now specify `model_family` or `model_id` in their yaml configuration to set the values to use for these blocks, as opposed to setting this for the entire `Pipeline` in the `PipelineContext` object. This is useful for the cases where multiple `LLMBlocks` exist in the same `Pipeline` where each one uses a different model.
Expand Down
5 changes: 5 additions & 0 deletions docs/examples/multiple_llm_clients/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Multiple LLM clients in a single Pipeline

For advanced use-cases, PipelineContext accepts a `clients` dictionary of string to OpenAI client mappings. The special string of "default" sets the OpenAI client used for LLMBlocks by default, but individual LLMBlocks can override the client used by the `client` parameter in their yaml config.

See `pipeline.yaml` in this directory for an example of a Pipeline that uses different clients per `LLMBlock`.
16 changes: 16 additions & 0 deletions docs/examples/multiple_llm_clients/llm_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
system: You are a helpful AI assistant.

introduction: |
Repeat the document below back to me verbatim.

principles: |
Do not change anything.

examples: ""

generation: |
Document:
{{document}}

start_tags: [""]
end_tags: [""]
44 changes: 44 additions & 0 deletions docs/examples/multiple_llm_clients/pipeline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
version: "1.0"
blocks:
# This uses the default client, since we don't specify one
- name: default_client
type: LLMBlock
config:
model_family: mixtral
model_id: Mixtral-8x7B-Instruct-v0.1
config_path: llm_config.yaml
output_cols:
- column_one

# We can also explicitly specify the default client
- name: also_default_client
type: LLMBlock
config:
client: default
model_family: mixtral
model_id: Mixtral-8x7B-Instruct-v0.1
config_path: llm_config.yaml
output_cols:
- column_two

# This uses the "server_a" client explicitly
- name: server_a_client
type: LLMBlock
config:
client: server_a
model_family: granite
model_id: granite-7b-lab
config_path: llm_config.yaml
output_cols:
- column_three

# This uses the "server_b" client explicitly
- name: server_b_client
type: LLMBlock
config:
client: server_b
model_family: granite
model_id: granite-7b-lab
config_path: llm_config.yaml
output_cols:
- column_four
23 changes: 17 additions & 6 deletions src/instructlab/sdg/blocks/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ def template_from_struct_and_config(struct, config):
return PromptRegistry.template_from_string(struct.format(**filtered_config))


def _resolve_client(client_name, context, block):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we include type-hints here?

client = context.clients.get(client_name, None)
if not client:
raise BlockConfigParserError(
f"{type(block).__name__} {block.block_name} requests a client named {client_name} but no client of that name was found in the PipelineContext clients"
)
return client


def _resolve_model_id(model_id, ctx_model_id, block):
# If a model id was passed in the PipelineContext, use that
if ctx_model_id:
Expand Down Expand Up @@ -105,6 +114,7 @@ def __init__(
model_id=None,
model_family=None,
model_prompt=None,
client="default",
gen_kwargs={},
parser_kwargs={},
batch_kwargs={},
Expand All @@ -117,6 +127,7 @@ def __init__(
self.prompt_template = template_from_struct_and_config(
self.prompt_struct, self.block_config
)
self.client = _resolve_client(client, self.ctx, self)
self.model_id = _resolve_model_id(model_id, self.ctx.model_id, self)
self.model_family = models.get_model_family(
_resolve_model_family(model_family, self.ctx.model_family),
Expand Down Expand Up @@ -146,7 +157,7 @@ def __init__(
# Whether the LLM server supports a list of input prompts
# and supports the n parameter to generate n outputs per input
self.server_supports_batched = server_supports_batched(
self.ctx.client, self.model_id
self.client, self.model_id
)

def _parse(self, generated_string) -> dict:
Expand Down Expand Up @@ -236,9 +247,7 @@ def _generate(self, samples) -> list:
logger.debug(f"STARTING GENERATION FOR LLMBlock USING PROMPTS: {prompts}")
logger.debug(f"Generation arguments: {self.gen_kwargs}")
if self.server_supports_batched:
response = self.ctx.client.completions.create(
prompt=prompts, **self.gen_kwargs
)
response = self.client.completions.create(prompt=prompts, **self.gen_kwargs)
return [choice.text.strip() for choice in response.choices]

results = []
Expand All @@ -248,7 +257,7 @@ def _generate(self, samples) -> list:
for prompt in prompts:
logger.debug(f"CREATING COMPLETION FOR PROMPT: {prompt}")
for _ in range(self.gen_kwargs.get("n", 1)):
response = self.ctx.client.completions.create(
response = self.client.completions.create(
prompt=prompt, **self.gen_kwargs
)
results.append(response.choices[0].text.strip())
Expand Down Expand Up @@ -514,9 +523,11 @@ def __init__(
input_col,
output_col,
model_id=None,
client="default",
gen_kwargs={},
) -> None:
super().__init__(ctx, pipe, block_name)
self.client = _resolve_client(client, self.ctx, self)
self.model_id = _resolve_model_id(model_id, self.ctx.model_id, self)
self.input_col = input_col
self.output_col = output_col
Expand Down Expand Up @@ -553,7 +564,7 @@ def _generate(self, samples) -> list:
n = self.gen_kwargs.get("n", 1)
for message in messages:
logger.debug(f"CREATING CHAT COMPLETION FOR MESSAGE: {message}")
responses = self.ctx.client.chat.completions.create(
responses = self.client.chat.completions.create(
messages=message, **self.gen_kwargs
)
if n > 1:
Expand Down
35 changes: 34 additions & 1 deletion src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
# on individual datasets
DEFAULT_DATASET_NUM_PROCS = 8

client: OpenAI
# The key of our default client
DEFAULT_CLIENT_KEY = "default"

client: Optional[OpenAI] = None
model_family: Optional[str] = None
model_id: Optional[str] = None
num_instructions_to_generate: Optional[int] = None
Expand All @@ -70,6 +73,9 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
max_num_tokens: Optional[int] = llmblock.DEFAULT_MAX_NUM_TOKENS
batch_size: int = DEFAULT_BATCH_SIZE
batch_num_workers: Optional[int] = None
clients: Optional[Dict[str, OpenAI]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great changes and excellent thoughts for supporting backwards compatibility. One small nit:

Here, it seems that, there's no mandate in PipelineContext that there should be a default key present in the dictionary passed to clients property. But subsequently, when a Block is initialized with the Pipeline created from this PipelineContext, it could potentially throw an error if the block config does not explicitly specify a client key (it will expect default to exist) (?), if my understanding is correct.

So maybe should we have a check either in PipelineContext, or in LLMBlock to explicitly mention default key, value for robustness?


_clients = None

@property
def batching_enabled(self) -> bool:
Expand All @@ -78,6 +84,33 @@ def batching_enabled(self) -> bool:
"""
return self.batch_size > 0 and self.batch_num_workers != 1

@property # type: ignore
def client(self):
return self.clients.get(self.DEFAULT_CLIENT_KEY, None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few questions:

  1. Should we also handle the case where self.client = None but self.clients contains a single client not set to the default client? It seems like in this case, we could perhaps make the assumption that this is the default client. Otherwise, if there are multiple clients then it would become ambiguous.
  2. Why do we specify None as the explicit default here? Won't this be the default of the .get method?


@client.setter
def client(self, value):
if isinstance(value, property):
# No default value
value = None
self.clients[self.DEFAULT_CLIENT_KEY] = value

@property # type: ignore
def clients(self):
if self._clients is None:
self._clients = {}
return self._clients

@clients.setter
def clients(self, value):
if isinstance(value, property):
# Empty hash default value
value = {}
if value:
# Only set _clients if passed in a value, so we don't
# override it with the default of None from the @dataclass
self._clients = value

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like this could be added to make sure 'default' key is explicitly mentioned, so the blocks know how to behave during fallback path (?). Just a suggestion, please feel free to disregard or consider alternate implementations for the same.

Suggested change
def __post_init__(self):
if self.clients is not None and "default" not in self.clients:
raise ValueError("PipelineContext requires a 'default' client in the clients dictionary")


# This is part of the public API.
class PipelineBlockError(Exception):
Expand Down
6 changes: 6 additions & 0 deletions src/instructlab/sdg/pipelines/schema/v1.json
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@
"model_prompt": {
"type": "string"
},
"client": {
"type": "string"
},
"parser_kwargs": {
"type": "object",
"properties": {
Expand Down Expand Up @@ -192,6 +195,9 @@
"model_prompt": {
"type": "string"
},
"client": {
"type": "string"
},
"selector_column_name": {
"type": "string"
},
Expand Down
28 changes: 28 additions & 0 deletions tests/functional/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

# Standard
from pathlib import Path
from unittest.mock import MagicMock, patch
import shlex
import shutil
import subprocess
Expand All @@ -11,6 +12,7 @@
from docling.document_converter import DocumentConverter

# First Party
from instructlab.sdg.pipeline import Pipeline, PipelineContext, _lookup_block_type
from instructlab.sdg.utils.json import jlload


Expand Down Expand Up @@ -74,3 +76,29 @@ def test_example_iterblock(tmp_path: Path, examples_path: Path):
output = jlload(output_jsonl)
assert len(output) == 5
assert output[4]["baz"] == "bar"


def test_example_multiple_llm_clients(examples_path: Path):
pipeline_path = examples_path.joinpath("multiple_llm_clients", "pipeline.yaml")
default_client = MagicMock()
server_a_client = MagicMock()
server_b_client = MagicMock()
context = PipelineContext(
clients={
"default": default_client,
"server_a": server_a_client,
"server_b": server_b_client,
}
)
pipeline = Pipeline.from_file(context, pipeline_path)
blocks = []
for block_prop in pipeline.chained_blocks:
block_name = block_prop["name"]
block_type = _lookup_block_type(block_prop["type"])
block_config = block_prop["config"]
block = block_type(pipeline.ctx, pipeline, block_name, **block_config)
blocks.append(block)
assert blocks[0].client == default_client
assert blocks[1].client == default_client
assert blocks[2].client == server_a_client
assert blocks[3].client == server_b_client
4 changes: 2 additions & 2 deletions tests/test_default_pipeline_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Standard
from importlib import resources
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import unittest

# Third Party
Expand Down Expand Up @@ -53,7 +53,7 @@ def setUp(self):

def test_pipeline_from_config(self):
ctx = PipelineContext(
client=None,
client=MagicMock(),
model_family="mixtral",
model_id="model",
num_instructions_to_generate=1,
Expand Down
Loading
Loading