Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Does not support the recognition of SQL query statements with parentheses. #687

Open
kellan04 opened this issue Oct 25, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@kellan04
Copy link

kellan04 commented Oct 25, 2024

When parentheses appear at the beginning and end of the SQL statement generated by the large model, it is impossible to extract the complete SQL using regular expressions.

Scenario:
There are two data tables, and the query is intended to perform a cross-table query. An example of the generated SQL is: (SELECT xxx) UNIT (SELECT xxx);

The source code does not support recognition:
src/vanna/base/base.py
image

My modified version:

def extract_sql(self, llm_response: str) -> str:
        """
        Example:
        ```python
        vn.extract_sql("Here's the SQL query in a code block: ```sql\nSELECT * FROM customers\n```")
        ```

        Extracts the SQL query from the LLM response. This is useful in case the LLM response contains other information besides the SQL query.
        Override this function if your LLM responses need custom extraction logic.

        Args:
            llm_response (str): The LLM response.

        Returns:
            str: The extracted SQL query.
        """

        # If the llm_response contains a CTE (with clause), extract the last sql between WITH and ;
        sqls = re.findall(r"\bWITH\b .*?;", llm_response, re.DOTALL)
        if sqls:
            sql = sqls[-1]
            self.log(title="Extracted SQL", message=f"{sql}")
            return sql

        # If the llm_response is not markdown formatted, extract last sql by finding select and ; in the response
        pattern = r"(\(?\s*SELECT\s+.*?\s*\)*?;)"   ### 匹配带有或不带括号的 SELECT 语句
        # pattern = r"SELECT.*?;"
        sqls = re.findall(pattern, llm_response, re.DOTALL)
        if sqls:
            sql = sqls[-1]
            self.log(title="Extracted SQL", message=f"{sql}")
            return sql

        # If the llm_response contains a markdown code block, with or without the sql tag, extract the last sql from it
        sqls = re.findall(r"```sql\n(.*)```", llm_response, re.DOTALL)
        if sqls:
            sql = sqls[-1]
            self.log(title="Extracted SQL", message=f"{sql}")
            return sql

        sqls = re.findall(r"```(.*)```", llm_response, re.DOTALL)
        if sqls:
            sql = sqls[-1]
            self.log(title="Extracted SQL", message=f"{sql}")
            return sql

        return llm_response
@kellan04 kellan04 added the bug Something isn't working label Oct 25, 2024
@kellan04 kellan04 reopened this Oct 25, 2024
@kellan04 kellan04 changed the title 不支持带括号的sql "Does not support the recognition of SQL query statements with parentheses." Oct 25, 2024
@kellan04 kellan04 changed the title "Does not support the recognition of SQL query statements with parentheses." [BUG] Does not support the recognition of SQL query statements with parentheses. Oct 25, 2024
@svetozar02
Copy link

I'm running into the same issue. The leading ( character is getting removed during SQL extraction.

LLM Response: 
        (
            SELECT ****
            LIMIT 1)
        UNION ALL 
        (
            SELECT ****
            LIMIT 1
        );
Extracted SQL: SELECT ****
            LIMIT 1)
        UNION ALL 
        (
            SELECT ****
            LIMIT 1
        );
An error occurred while executing SQL: syntax error at or near ")"

@kellan04
Copy link
Author

kellan04 commented Oct 30, 2024

I'm running into the same issue. The leading ( character is getting removed during SQL extraction.

LLM Response: 
        (
            SELECT ****
            LIMIT 1)
        UNION ALL 
        (
            SELECT ****
            LIMIT 1
        );
Extracted SQL: SELECT ****
            LIMIT 1)
        UNION ALL 
        (
            SELECT ****
            LIMIT 1
        );
An error occurred while executing SQL: syntax error at or near ")"

yes, that is why I use the new regex.

@svetozar02
Copy link

@kellan04 yup, your regex worked and I'm using it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants