|
9 | 9 | from datetime import datetime |
10 | 10 | from pydub import AudioSegment |
11 | 11 | from pydantic import BaseModel |
12 | | -from typing import Dict, List, Any, Optional, Callable |
| 12 | +from typing import ( |
| 13 | + Dict, |
| 14 | + List, |
| 15 | + Any, |
| 16 | + Optional, |
| 17 | + Callable, |
| 18 | + Type, |
| 19 | + get_args, |
| 20 | + get_origin, |
| 21 | + Union, |
| 22 | +) |
| 23 | +from enum import Enum |
| 24 | +from pydantic import BaseModel |
| 25 | +import json |
13 | 26 |
|
14 | 27 |
|
15 | 28 | class ChatCompletions(BaseModel): |
@@ -62,9 +75,9 @@ def __init__( |
62 | 75 | "Authorization": f"{api_key}", |
63 | 76 | "Content-Type": "application/json", |
64 | 77 | } |
65 | | - |
66 | 78 | if self.base_uri[-1] == "/": |
67 | 79 | self.base_uri = self.base_uri[:-1] |
| 80 | + self.failures = 0 |
68 | 81 |
|
69 | 82 | def handle_error(self, error) -> str: |
70 | 83 | print(f"Error: {error}") |
@@ -1660,3 +1673,86 @@ def plan_task( |
1660 | 1673 | return response.json()["response"] |
1661 | 1674 | except Exception as e: |
1662 | 1675 | return self.handle_error(e) |
| 1676 | + |
| 1677 | + def convert_to_model( |
| 1678 | + self, |
| 1679 | + input_string: str, |
| 1680 | + model: Type[BaseModel], |
| 1681 | + agent_name: str = "gpt4free", |
| 1682 | + max_failures: int = 3, |
| 1683 | + response_type: str = None, |
| 1684 | + ): |
| 1685 | + input_string = str(input_string) |
| 1686 | + fields = model.__annotations__ |
| 1687 | + field_descriptions = [] |
| 1688 | + for field, field_type in fields.items(): |
| 1689 | + description = f"{field}: {field_type}" |
| 1690 | + if get_origin(field_type) == Union: |
| 1691 | + field_type = get_args(field_type)[0] |
| 1692 | + if isinstance(field_type, type) and issubclass(field_type, Enum): |
| 1693 | + enum_values = ", ".join([f"{e.name} = {e.value}" for e in field_type]) |
| 1694 | + description += f" (Enum values: {enum_values})" |
| 1695 | + field_descriptions.append(description) |
| 1696 | + schema = "\n".join(field_descriptions) |
| 1697 | + response = self.prompt_agent( |
| 1698 | + agent_name=agent_name, |
| 1699 | + prompt_name="Convert to Pydantic Model", |
| 1700 | + prompt_args={ |
| 1701 | + "schema": schema, |
| 1702 | + "user_input": input_string, |
| 1703 | + }, |
| 1704 | + ) |
| 1705 | + if "```json" in response: |
| 1706 | + response = response.split("```json")[1].split("```")[0].strip() |
| 1707 | + elif "```" in response: |
| 1708 | + response = response.split("```")[1].strip() |
| 1709 | + try: |
| 1710 | + response = json.loads(response) |
| 1711 | + if response_type == "json": |
| 1712 | + return response |
| 1713 | + else: |
| 1714 | + return model(**response) |
| 1715 | + except Exception as e: |
| 1716 | + self.failures += 1 |
| 1717 | + if self.failures > max_failures: |
| 1718 | + print( |
| 1719 | + f"Error: {e} . Failed to convert the response to the model after 3 attempts. Response: {response}" |
| 1720 | + ) |
| 1721 | + return ( |
| 1722 | + response |
| 1723 | + if response |
| 1724 | + else "Failed to convert the response to the model." |
| 1725 | + ) |
| 1726 | + else: |
| 1727 | + self.failures = 1 |
| 1728 | + print( |
| 1729 | + f"Error: {e} . Failed to convert the response to the model, trying again. {self.failures}/3 failures. Response: {response}" |
| 1730 | + ) |
| 1731 | + return self.convert_to_model( |
| 1732 | + input_string=input_string, |
| 1733 | + model=model, |
| 1734 | + agent_name=agent_name, |
| 1735 | + max_failures=max_failures, |
| 1736 | + failures=self.failures, |
| 1737 | + ) |
| 1738 | + |
| 1739 | + def convert_list_of_dicts( |
| 1740 | + self, |
| 1741 | + data: List[dict], |
| 1742 | + model: Type[BaseModel], |
| 1743 | + agent_name: str = "gpt4free", |
| 1744 | + ): |
| 1745 | + converted_data = self.convert_to_model( |
| 1746 | + input_string=json.dumps(data[0], indent=4), |
| 1747 | + model=model, |
| 1748 | + agent_name=agent_name, |
| 1749 | + ) |
| 1750 | + mapped_list = [] |
| 1751 | + for info in data: |
| 1752 | + new_data = {} |
| 1753 | + for key, value in converted_data.items(): |
| 1754 | + item = [k for k, v in data[0].items() if v == value] |
| 1755 | + if item: |
| 1756 | + new_data[key] = info[item[0]] |
| 1757 | + mapped_list.append(new_data) |
| 1758 | + return mapped_list |
0 commit comments