Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions libs/oci/langchain_oci/chat_models/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,19 @@ def process_stream_tool_calls(
if tool_id:
tool_call_ids.add(tool_id)

args = tool_call["function"].get("arguments")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will take more than this to fix this problem.
When langgraph tries to consume streaming chunks and tries to create a tool call, it will fail if the parsed string is not a json and it will create an invalid tool call much before the control comes to our code.
https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/messages/ai.py#L508-L522

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One option is to extend AIMessageChunk and override init_tool_calls to do the double parsing ourselves. Make sure you use the new class in this file instead of AIMessageChunk

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class OCIAIMessageChunk(AIMessageChunk):
    @model_validator(mode="after")
    def init_tool_calls(self) -> Self:
        """Initialize tool calls from tool call chunks.

        Returns:
            The values with tool calls initialized.

        Raises:
            ValueError: If the tool call chunks are malformed.
        """
        if not self.tool_call_chunks:
            if self.tool_calls:
                self.tool_call_chunks = [
                    create_tool_call_chunk(
                        name=tc["name"],
                        args=json.dumps(tc["args"]),
                        id=tc["id"],
                        index=None,
                    )
                    for tc in self.tool_calls
                ]
            if self.invalid_tool_calls:
                tool_call_chunks = self.tool_call_chunks
                tool_call_chunks.extend(
                    [
                        create_tool_call_chunk(
                            name=tc["name"], args=tc["args"], id=tc["id"], index=None
                        )
                        for tc in self.invalid_tool_calls
                    ]
                )
                self.tool_call_chunks = tool_call_chunks

            return self
        tool_calls = []
        invalid_tool_calls = []


        def add_chunk_to_invalid_tool_calls(chunk: ToolCallChunk) -> None:
            invalid_tool_calls.append(
                create_invalid_tool_call(
                    name=chunk["name"],
                    args=chunk["args"],
                    id=chunk["id"],
                    error=None,
                )
            )

        for chunk in self.tool_call_chunks:
            try:
                parsed_args = parse_partial_json(chunk["args"]) if chunk["args"] else {}
                if isinstance(parsed_args, str):
                    parsed_args = parse_partial_json(parsed_args)
                if isinstance(parsed_args, dict):
                    tool_calls.append(
                        create_tool_call(
                            name=chunk["name"] or "",
                            args=parsed_args,
                            id=chunk["id"],
                        )
                    )
                else:
                    add_chunk_to_invalid_tool_calls(chunk)
            except Exception:
                add_chunk_to_invalid_tool_calls(chunk)
        self.tool_calls = tool_calls
        self.invalid_tool_calls = invalid_tool_calls
        return self

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my experience, the OCI genai endpoint either returns a JSON or a double escaped JSON. The PR code handles both situations:

  • Normal JSON ('{"key": "value"}'): First parse succeeds → dict → second parse raises TypeError → keep original
  • Double-escaped JSON ('"{"key": "value"}"'): First parse → string → second parse → dict → convert back to unescaped JSON

Then the result passed to LangChain will always be a valid JSON after LangChain parsed it by:
args_ = parse_partial_json(chunk["args"]) if chunk["args"] else {}

# If args is a double-escaped JSON string, parse it twice to get the original JSON # noqa: E501
try:
parsed_args = json.loads(json.loads(args))
args = json.dumps(parsed_args)
# If args is not a double-escaped JSON string, keep it as is
except (json.JSONDecodeError, TypeError):
pass

tool_call_chunks.append(
tool_call_chunk(
name=tool_call["function"].get("name"),
args=tool_call["function"].get("arguments"),
args=args,
id=tool_id,
index=len(tool_call_ids) - 1, # index tracking
)
Expand Down Expand Up @@ -1027,10 +1036,19 @@ def process_stream_tool_calls(
if tool_id:
tool_call_ids.add(tool_id)

args = tool_call["function"].get("arguments")
# If args is a double-escaped JSON string, parse it twice to get the original JSON # noqa: E501
try:
parsed_args = json.loads(json.loads(args))
args = json.dumps(parsed_args)
# If args is not a double-escaped JSON string, keep it as is
except (json.JSONDecodeError, TypeError):
pass

tool_call_chunks.append(
tool_call_chunk(
name=tool_call["function"].get("name"),
args=tool_call["function"].get("arguments"),
args=args,
id=tool_id,
index=len(tool_call_ids) - 1, # index tracking
)
Expand Down
Loading