Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 64 additions & 38 deletions libs/oci/langchain_oci/chat_models/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,22 +100,37 @@ def remove_signature_from_tool_description(name: str, description: str) -> str:
@staticmethod
def convert_oci_tool_call_to_langchain(tool_call: Any) -> ToolCall:
"""Convert an OCI tool call to a LangChain ToolCall."""
parsed = json.loads(tool_call.arguments)

# If the parsed result is a string, it means the JSON was escaped, so parse again # noqa: E501
if isinstance(parsed, str):
try:
parsed = json.loads(parsed)
except json.JSONDecodeError:
# If it's not valid JSON, keep it as a string
pass
# Check if this is a Generic/Meta format (has arguments as JSON string)
# or Cohere format (has parameters as dict)
attribute_map = getattr(tool_call, "attribute_map", None) or {}

if "arguments" in attribute_map and tool_call.arguments is not None:
# Generic/Meta format: parse JSON arguments
parsed = json.loads(tool_call.arguments)

# If the parsed result is a string, it means JSON was escaped
if isinstance(parsed, str):
try:
parsed = json.loads(parsed)
except json.JSONDecodeError:
# If it's not valid JSON, keep it as a string
pass
args = parsed
else:
# Cohere format: parameters is already a dict
args = tool_call.parameters

# Get tool call ID (generate one if not present)
tool_id = (
tool_call.id
if "id" in attribute_map
else uuid.uuid4().hex[:]
)

return ToolCall(
name=tool_call.name,
args=parsed
if "arguments" in tool_call.attribute_map
else tool_call.parameters,
id=tool_call.id if "id" in tool_call.attribute_map else uuid.uuid4().hex[:],
args=args,
id=tool_id,
)


Expand Down Expand Up @@ -263,19 +278,19 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]:
}

# Include token usage if available
if (
hasattr(response.data.chat_response, "usage")
and response.data.chat_response.usage
):
generation_info["total_tokens"] = (
response.data.chat_response.usage.total_tokens
)
try:
if (
hasattr(response.data.chat_response, "usage")
and response.data.chat_response.usage
):
generation_info["total_tokens"] = (
response.data.chat_response.usage.total_tokens
)
except (KeyError, AttributeError):
pass

# Include tool calls if available
if self.chat_tool_calls(response):
generation_info["tool_calls"] = self.format_response_tool_calls(
self.chat_tool_calls(response)
)
# Note: tool_calls are now handled in _generate() to avoid redundant conversions
# The formatted tool calls will be added there if present
return generation_info

def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]:
Expand Down Expand Up @@ -643,18 +658,19 @@ def chat_generation_info(self, response: Any) -> Dict[str, Any]:
}

# Include token usage if available
if (
hasattr(response.data.chat_response, "usage")
and response.data.chat_response.usage
):
generation_info["total_tokens"] = (
response.data.chat_response.usage.total_tokens
)
try:
if (
hasattr(response.data.chat_response, "usage")
and response.data.chat_response.usage
):
generation_info["total_tokens"] = (
response.data.chat_response.usage.total_tokens
)
except (KeyError, AttributeError):
pass

if self.chat_tool_calls(response):
generation_info["tool_calls"] = self.format_response_tool_calls(
self.chat_tool_calls(response)
)
# Note: tool_calls are now handled in _generate() to avoid redundant conversions
# The formatted tool calls will be added there if present
return generation_info

def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]:
Expand Down Expand Up @@ -1400,6 +1416,9 @@ def _generate(
if stop is not None:
content = enforce_stop_tokens(content, stop)

# Fetch raw tool calls once to avoid redundant calls
raw_tool_calls = self._provider.chat_tool_calls(response)

generation_info = self._provider.chat_generation_info(response)

llm_output = {
Expand All @@ -1408,12 +1427,19 @@ def _generate(
"request_id": response.request_id,
"content-length": response.headers["content-length"],
}

# Convert tool calls once for LangChain format
tool_calls = []
if "tool_calls" in generation_info:
if raw_tool_calls:
tool_calls = [
OCIUtils.convert_oci_tool_call_to_langchain(tool_call)
for tool_call in self._provider.chat_tool_calls(response)
for tool_call in raw_tool_calls
]
# Add formatted version to generation_info if not already present
# This avoids redundant formatting in chat_generation_info()
if "tool_calls" not in generation_info:
formatted = self._provider.format_response_tool_calls(raw_tool_calls)
generation_info["tool_calls"] = formatted
message = AIMessage(
content=content or "",
additional_kwargs=generation_info,
Expand Down
Loading
Loading