Skip to content

Commit

Permalink
Fix OpenAPI generation to have text/event-stream for streamable methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinb committed Nov 14, 2024
1 parent acbecbf commit bba6edd
Show file tree
Hide file tree
Showing 4 changed files with 595 additions and 597 deletions.
16 changes: 0 additions & 16 deletions docs/openapi_generator/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,6 @@
from llama_stack.distribution.stack import LlamaStack


# TODO: this should be fixed in the generator itself so it reads appropriate annotations
STREAMING_ENDPOINTS = [
"/agents/turn/create",
"/inference/chat_completion",
]


def patch_sse_stream_responses(spec: Specification):
for path, path_item in spec.document.paths.items():
if path in STREAMING_ENDPOINTS:
content = path_item.post.responses["200"].content.pop("application/json")
path_item.post.responses["200"].content["text/event-stream"] = content


def main(output_dir: str):
output_dir = Path(output_dir)
if not output_dir.exists():
Expand All @@ -74,8 +60,6 @@ def main(output_dir: str):
),
)

patch_sse_stream_responses(spec)

with open(output_dir / "llama-stack-spec.yaml", "w", encoding="utf-8") as fp:
yaml.dump(spec.get_json(), fp, allow_unicode=True)

Expand Down
14 changes: 14 additions & 0 deletions docs/openapi_generator/pyopenapi/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import collections
import hashlib
import ipaddress
import typing
Expand Down Expand Up @@ -176,9 +177,20 @@ def build_content(
) -> Dict[str, MediaType]:
"Creates the content subtree for a request or response."

def has_iterator_type(t):
if typing.get_origin(t) is typing.Union:
return any(has_iterator_type(a) for a in typing.get_args(t))
else:
# TODO: needs a proper fix where we let all types correctly flow upwards
# and then test against AsyncIterator
return "StreamChunk" in str(t)

if is_generic_list(payload_type):
media_type = "application/jsonl"
item_type = unwrap_generic_list(payload_type)
elif has_iterator_type(payload_type):
item_type = payload_type
media_type = "text/event-stream"
else:
media_type = "application/json"
item_type = payload_type
Expand Down Expand Up @@ -671,6 +683,8 @@ def generate(self) -> Document:
for extra_tag_group in extra_tag_groups.values():
tags.extend(extra_tag_group)

tags = sorted(tags, key=lambda t: t.name)

tag_groups = []
if operation_tags:
tag_groups.append(
Expand Down
Loading

0 comments on commit bba6edd

Please sign in to comment.