Skip to content

Commit 4b0c0fb

Browse files
committed
Allow multiple OpenAI clients per Pipeline
This change allows a user to construct a PipelineContext with multiple OpenAI clients, such as: ```python PipelineContext( clients={ "default": OpenAI(base_url="https://foo.local"), "server_a": OpenAI(base_url="https://server_a.local"), "server_b": OpenAI(base_url="https://server_b.local"), } ) ``` And then, within the pipeline yaml, choose which client to apply to which LLMBlock via a new `client` key, such as: ```yaml version: "1.0" blocks: - name: server_a_client type: LLMBlock config: client: server_a ... - name: server_b_client type: LLMBlock config: client: server_b ... ``` See `docs/examples/multiple_llm_clients` for more details and a full example. Resolves #521 Signed-off-by: Ben Browning <[email protected]>
1 parent 3890d99 commit 4b0c0fb

File tree

10 files changed

+327
-56
lines changed

10 files changed

+327
-56
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
### Features
44

5+
### Pipelines can now have LLMBlocks with different OpenAI clients
6+
7+
For advanced use-cases, PipelineContext now accepts a `clients` dictionary of string to OpenAI client mappings. The special string of "default" sets the OpenAI client used for LLMBlocks by default, but individual LLMBlocks can override the client used by the `client` parameter in their yaml config.
8+
9+
Backwards-compatibility is maintained for Pipelines that only need a single client, where setting the `client` property on PipelineContext objects just sets the default client in the `clients` dictionary automatically.
10+
511
### LLMBlocks can now specify `model_family` or `model_id` in their config
612

713
Each `LLMBlock` in a `Pipeline` can now specify `model_family` or `model_id` in their yaml configuration to set the values to use for these blocks, as opposed to setting this for the entire `Pipeline` in the `PipelineContext` object. This is useful for the cases where multiple `LLMBlocks` exist in the same `Pipeline` where each one uses a different model.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Multiple LLM clients in a single Pipeline
2+
3+
For advanced use-cases, PipelineContext accepts a `clients` dictionary of string to OpenAI client mappings. The special string of "default" sets the OpenAI client used for LLMBlocks by default, but individual LLMBlocks can override the client used by the `client` parameter in their yaml config.
4+
5+
See `pipeline.yaml` in this directory for an example of a Pipeline that uses different clients per `LLMBlock`.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
system: You are a helpful AI assistant.
2+
3+
introduction: |
4+
Repeat the document below back to me verbatim.
5+
6+
principles: |
7+
Do not change anything.
8+
9+
examples: ""
10+
11+
generation: |
12+
Document:
13+
{{document}}
14+
15+
start_tags: [""]
16+
end_tags: [""]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
version: "1.0"
2+
blocks:
3+
# This uses the default client, since we don't specify one
4+
- name: default_client
5+
type: LLMBlock
6+
config:
7+
model_family: mixtral
8+
model_id: Mixtral-8x7B-Instruct-v0.1
9+
config_path: llm_config.yaml
10+
output_cols:
11+
- column_one
12+
13+
# We can also explicitly specify the default client
14+
- name: also_default_client
15+
type: LLMBlock
16+
config:
17+
client: default
18+
model_family: mixtral
19+
model_id: Mixtral-8x7B-Instruct-v0.1
20+
config_path: llm_config.yaml
21+
output_cols:
22+
- column_two
23+
24+
# This uses the "server_a" client explicitly
25+
- name: server_a_client
26+
type: LLMBlock
27+
config:
28+
client: server_a
29+
model_family: granite
30+
model_id: granite-7b-lab
31+
config_path: llm_config.yaml
32+
output_cols:
33+
- column_three
34+
35+
# This uses the "server_b" client explicitly
36+
- name: server_b_client
37+
type: LLMBlock
38+
config:
39+
client: server_b
40+
model_family: granite
41+
model_id: granite-7b-lab
42+
config_path: llm_config.yaml
43+
output_cols:
44+
- column_four

src/instructlab/sdg/blocks/llmblock.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ def template_from_struct_and_config(struct, config):
6363
return PromptRegistry.template_from_string(struct.format(**filtered_config))
6464

6565

66+
def _resolve_client(client_name, context, block):
67+
client = context.clients.get(client_name, None)
68+
if not client:
69+
raise BlockConfigParserError(
70+
f"{type(block).__name__} {block.block_name} requests a client named {client_name} but no client of that name was found in the PipelineContext clients"
71+
)
72+
return client
73+
74+
6675
def _resolve_model_id(model_id, ctx_model_id, block):
6776
# If a model id was passed in the PipelineContext, use that
6877
if ctx_model_id:
@@ -105,6 +114,7 @@ def __init__(
105114
model_id=None,
106115
model_family=None,
107116
model_prompt=None,
117+
client="default",
108118
gen_kwargs={},
109119
parser_kwargs={},
110120
batch_kwargs={},
@@ -117,6 +127,7 @@ def __init__(
117127
self.prompt_template = template_from_struct_and_config(
118128
self.prompt_struct, self.block_config
119129
)
130+
self.client = _resolve_client(client, self.ctx, self)
120131
self.model_id = _resolve_model_id(model_id, self.ctx.model_id, self)
121132
self.model_family = models.get_model_family(
122133
_resolve_model_family(model_family, self.ctx.model_family),
@@ -146,7 +157,7 @@ def __init__(
146157
# Whether the LLM server supports a list of input prompts
147158
# and supports the n parameter to generate n outputs per input
148159
self.server_supports_batched = server_supports_batched(
149-
self.ctx.client, self.model_id
160+
self.client, self.model_id
150161
)
151162

152163
def _parse(self, generated_string) -> dict:
@@ -236,9 +247,7 @@ def _generate(self, samples) -> list:
236247
logger.debug(f"STARTING GENERATION FOR LLMBlock USING PROMPTS: {prompts}")
237248
logger.debug(f"Generation arguments: {self.gen_kwargs}")
238249
if self.server_supports_batched:
239-
response = self.ctx.client.completions.create(
240-
prompt=prompts, **self.gen_kwargs
241-
)
250+
response = self.client.completions.create(prompt=prompts, **self.gen_kwargs)
242251
return [choice.text.strip() for choice in response.choices]
243252

244253
results = []
@@ -248,7 +257,7 @@ def _generate(self, samples) -> list:
248257
for prompt in prompts:
249258
logger.debug(f"CREATING COMPLETION FOR PROMPT: {prompt}")
250259
for _ in range(self.gen_kwargs.get("n", 1)):
251-
response = self.ctx.client.completions.create(
260+
response = self.client.completions.create(
252261
prompt=prompt, **self.gen_kwargs
253262
)
254263
results.append(response.choices[0].text.strip())
@@ -514,9 +523,11 @@ def __init__(
514523
input_col,
515524
output_col,
516525
model_id=None,
526+
client="default",
517527
gen_kwargs={},
518528
) -> None:
519529
super().__init__(ctx, pipe, block_name)
530+
self.client = _resolve_client(client, self.ctx, self)
520531
self.model_id = _resolve_model_id(model_id, self.ctx.model_id, self)
521532
self.input_col = input_col
522533
self.output_col = output_col
@@ -553,7 +564,7 @@ def _generate(self, samples) -> list:
553564
n = self.gen_kwargs.get("n", 1)
554565
for message in messages:
555566
logger.debug(f"CREATING CHAT COMPLETION FOR MESSAGE: {message}")
556-
responses = self.ctx.client.chat.completions.create(
567+
responses = self.client.chat.completions.create(
557568
messages=message, **self.gen_kwargs
558569
)
559570
if n > 1:

src/instructlab/sdg/pipeline.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
6060
# on individual datasets
6161
DEFAULT_DATASET_NUM_PROCS = 8
6262

63-
client: OpenAI
63+
# The key of our default client
64+
DEFAULT_CLIENT_KEY = "default"
65+
66+
client: Optional[OpenAI] = None
6467
model_family: Optional[str] = None
6568
model_id: Optional[str] = None
6669
num_instructions_to_generate: Optional[int] = None
@@ -70,6 +73,9 @@ class PipelineContext: # pylint: disable=too-many-instance-attributes
7073
max_num_tokens: Optional[int] = llmblock.DEFAULT_MAX_NUM_TOKENS
7174
batch_size: int = DEFAULT_BATCH_SIZE
7275
batch_num_workers: Optional[int] = None
76+
clients: Optional[Dict[str, OpenAI]] = None
77+
78+
_clients = None
7379

7480
@property
7581
def batching_enabled(self) -> bool:
@@ -78,6 +84,33 @@ def batching_enabled(self) -> bool:
7884
"""
7985
return self.batch_size > 0 and self.batch_num_workers != 1
8086

87+
@property # type: ignore
88+
def client(self):
89+
return self.clients.get(self.DEFAULT_CLIENT_KEY, None)
90+
91+
@client.setter
92+
def client(self, value):
93+
if isinstance(value, property):
94+
# No default value
95+
value = None
96+
self.clients[self.DEFAULT_CLIENT_KEY] = value
97+
98+
@property # type: ignore
99+
def clients(self):
100+
if self._clients is None:
101+
self._clients = {}
102+
return self._clients
103+
104+
@clients.setter
105+
def clients(self, value):
106+
if isinstance(value, property):
107+
# Empty hash default value
108+
value = {}
109+
if value:
110+
# Only set _clients if passed in a value, so we don't
111+
# override it with the default of None from the @dataclass
112+
self._clients = value
113+
81114

82115
# This is part of the public API.
83116
class PipelineBlockError(Exception):

tests/functional/test_examples.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# Standard
44
from pathlib import Path
5+
from unittest.mock import MagicMock, patch
56
import shlex
67
import shutil
78
import subprocess
@@ -11,6 +12,7 @@
1112
from docling.document_converter import DocumentConverter
1213

1314
# First Party
15+
from instructlab.sdg.pipeline import Pipeline, PipelineContext, _lookup_block_type
1416
from instructlab.sdg.utils.json import jlload
1517

1618

@@ -74,3 +76,29 @@ def test_example_iterblock(tmp_path: Path, examples_path: Path):
7476
output = jlload(output_jsonl)
7577
assert len(output) == 5
7678
assert output[4]["baz"] == "bar"
79+
80+
81+
def test_example_multiple_llm_clients(examples_path: Path):
82+
pipeline_path = examples_path.joinpath("multiple_llm_clients", "pipeline.yaml")
83+
default_client = MagicMock()
84+
server_a_client = MagicMock()
85+
server_b_client = MagicMock()
86+
context = PipelineContext(
87+
clients={
88+
"default": default_client,
89+
"server_a": server_a_client,
90+
"server_b": server_b_client,
91+
}
92+
)
93+
pipeline = Pipeline.from_file(context, pipeline_path)
94+
blocks = []
95+
for block_prop in pipeline.chained_blocks:
96+
block_name = block_prop["name"]
97+
block_type = _lookup_block_type(block_prop["type"])
98+
block_config = block_prop["config"]
99+
block = block_type(pipeline.ctx, pipeline, block_name, **block_config)
100+
blocks.append(block)
101+
assert blocks[0].client == default_client
102+
assert blocks[1].client == default_client
103+
assert blocks[2].client == server_a_client
104+
assert blocks[3].client == server_b_client

tests/test_default_pipeline_configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# Standard
44
from importlib import resources
5-
from unittest.mock import patch
5+
from unittest.mock import MagicMock, patch
66
import unittest
77

88
# Third Party
@@ -53,7 +53,7 @@ def setUp(self):
5353

5454
def test_pipeline_from_config(self):
5555
ctx = PipelineContext(
56-
client=None,
56+
client=MagicMock(),
5757
model_family="mixtral",
5858
model_id="model",
5959
num_instructions_to_generate=1,

0 commit comments

Comments
 (0)