From be423e09ed9bcaed59c1ddfb3750e4e4a51bfae0 Mon Sep 17 00:00:00 2001 From: Alessandro Pireno Date: Wed, 5 Mar 2025 19:32:53 -0500 Subject: [PATCH 1/9] this is creating new examples project surreal_rag to replace surreal_openai --- surrealdb-rag/.gitattributes | 2 + surrealdb-rag/.gitignore | 159 +++++++ surrealdb-rag/LICENSE | 21 + surrealdb-rag/Makefile | 23 + surrealdb-rag/README.md | 134 ++++++ surrealdb-rag/__init__.py | 0 surrealdb-rag/pyproject.toml | 34 ++ surrealdb-rag/requrements.txt | 15 + surrealdb-rag/schema/function_ddl.surql | 443 ++++++++++++++++++ surrealdb-rag/schema/table_ddl.surql | 104 ++++ surrealdb-rag/src/surrealdb_rag/__init__.py | 0 surrealdb-rag/src/surrealdb_rag/app.py | 279 +++++++++++ surrealdb-rag/src/surrealdb_rag/constants.py | 224 +++++++++ .../src/surrealdb_rag/create_database.py | 64 +++ .../src/surrealdb_rag/download_data.py | 84 ++++ .../src/surrealdb_rag/download_glove.py | 38 ++ surrealdb-rag/src/surrealdb_rag/embeddings.py | 42 ++ .../surrealdb_rag/insert_embedding_model.py | 84 ++++ .../src/surrealdb_rag/insert_wiki.py | 105 +++++ .../src/surrealdb_rag/llm_handler.py | 216 +++++++++ surrealdb-rag/src/surrealdb_rag/loggers.py | 27 ++ .../src/surrealdb_rag/train_fastText.py | 73 +++ surrealdb-rag/static/style.css | 180 +++++++ surrealdb-rag/static/surrealdb-icon.svg | 18 + surrealdb-rag/templates/chats.html | 12 + surrealdb-rag/templates/create_chat.html | 9 + surrealdb-rag/templates/index.html | 100 ++++ surrealdb-rag/templates/load_chat.html | 18 + .../templates/send_system_message.html | 30 ++ .../templates/send_user_message.html | 9 + 30 files changed, 2547 insertions(+) create mode 100644 surrealdb-rag/.gitattributes create mode 100644 surrealdb-rag/.gitignore create mode 100644 surrealdb-rag/LICENSE create mode 100644 surrealdb-rag/Makefile create mode 100644 surrealdb-rag/README.md create mode 100644 surrealdb-rag/__init__.py create mode 100644 surrealdb-rag/pyproject.toml create mode 100644 surrealdb-rag/requrements.txt create mode 100644 surrealdb-rag/schema/function_ddl.surql create mode 100644 surrealdb-rag/schema/table_ddl.surql create mode 100644 surrealdb-rag/src/surrealdb_rag/__init__.py create mode 100644 surrealdb-rag/src/surrealdb_rag/app.py create mode 100644 surrealdb-rag/src/surrealdb_rag/constants.py create mode 100644 surrealdb-rag/src/surrealdb_rag/create_database.py create mode 100644 surrealdb-rag/src/surrealdb_rag/download_data.py create mode 100644 surrealdb-rag/src/surrealdb_rag/download_glove.py create mode 100644 surrealdb-rag/src/surrealdb_rag/embeddings.py create mode 100644 surrealdb-rag/src/surrealdb_rag/insert_embedding_model.py create mode 100644 surrealdb-rag/src/surrealdb_rag/insert_wiki.py create mode 100644 surrealdb-rag/src/surrealdb_rag/llm_handler.py create mode 100644 surrealdb-rag/src/surrealdb_rag/loggers.py create mode 100644 surrealdb-rag/src/surrealdb_rag/train_fastText.py create mode 100644 surrealdb-rag/static/style.css create mode 100644 surrealdb-rag/static/surrealdb-icon.svg create mode 100644 surrealdb-rag/templates/chats.html create mode 100644 surrealdb-rag/templates/create_chat.html create mode 100644 surrealdb-rag/templates/index.html create mode 100644 surrealdb-rag/templates/load_chat.html create mode 100644 surrealdb-rag/templates/send_system_message.html create mode 100644 surrealdb-rag/templates/send_user_message.html diff --git a/surrealdb-rag/.gitattributes b/surrealdb-rag/.gitattributes new file mode 100644 index 0000000..dfe0770 --- /dev/null +++ b/surrealdb-rag/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/surrealdb-rag/.gitignore b/surrealdb-rag/.gitignore new file mode 100644 index 0000000..6038254 --- /dev/null +++ b/surrealdb-rag/.gitignore @@ -0,0 +1,159 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# SurrealDB RAG +data/* +!data/.gitkeep +history.txt +.python-version + diff --git a/surrealdb-rag/LICENSE b/surrealdb-rag/LICENSE new file mode 100644 index 0000000..f9b42f0 --- /dev/null +++ b/surrealdb-rag/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Cellan Hall + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/surrealdb-rag/Makefile b/surrealdb-rag/Makefile new file mode 100644 index 0000000..e4b01ff --- /dev/null +++ b/surrealdb-rag/Makefile @@ -0,0 +1,23 @@ +.DEFAULT_GOAL := help + + +.PHONY: help +# See for explanation +help: + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + + +.PHONY: server-start +server-start: ## Start FastAPI server + uvicorn src.surrealdb_rag.app:app --reload + +.PHONY: pycache-remove +pycache-remove: + find . | grep -E "(__pycache__|\.pyc|\.pyo$$)" | xargs rm -rf + +.PHONY: dsstore-remove +dsstore-remove: + find . | grep -E ".DS_Store" | xargs rm -rf + +.PHONY: cleanup +cleanup: pycache-remove dsstore-remove \ No newline at end of file diff --git a/surrealdb-rag/README.md b/surrealdb-rag/README.md new file mode 100644 index 0000000..ac04c21 --- /dev/null +++ b/surrealdb-rag/README.md @@ -0,0 +1,134 @@ +# SurrealDB x OpenAI: A Chat Playground + +Hey there! You've just stumbled upon a cool little project that's all about mashing up the brilliance of OpenAI with the database wizardry of SurrealDB. Think of this as a sandbox where we're diving deep into the realms of Retrival Augmented Generation (RAG) by mixing up a large dataset of Wikipedia articles with some fancy vector storage and search capabilities. + +https://github.com/Ce11an/surrealdb-openai/assets/60790416/a4e6a967-321a-4ca0-8f64-1687512aab38 + +## So, What's the Big Idea? + +We're on a mission to explore the frontiers of what's possible when you pair up SurrealDB with OpenAI. We're talking about importing a whopping 25k Wikipedia articles, complete with their vectors (thanks to OpenAI's smarts), and then whipping up a RAG question-answering system that's as cool as it sounds. + +We've got a cozy little FastAPI server acting as our backstage crew, Jinja2 spinning up the templates, and htmx making our frontend chat application as lively as a chat at your favorite coffee shop. + +## Gear Up + +Before diving in, here's what we're playing with: + +- A shiny Apple M2 Pro running MacOS Sonoma 14.4 +- SurrealDB 1.3.0, cozied up on disk +- Python 3.11, because we like to keep things fresh + +Hit a snag? Just holler, and we'll sort it out together. + +## Getting the Party Started + +First off, make sure SurrealDB is ready to rock on your machine (check out [how to get it up and running](https://surrealdb.com/install)). For Python 3.11, [pyenv](https://github.com/pyenv/pyenv) is your best buddy. + +Grab this repo with: + +```bash +git clone https://github.com/Ce11an/surrealdb-openai.git +``` + +You're gonna need an OpenAI API key for this shindig. Not sure where to snag one? Peek at the [OpenAI Developer Quickstart](https://platform.openai.com/docs/quickstart). Now, because SurrealDB and environment variables are currently in a complicated relationship, we've got a nifty workaround in [chats.surql](https://github.com/Ce11an/surrealdb-openai/blob/main/schema/chats.surql) for you to slip your OpenAI API key into: + +```sql +DEFINE FUNCTION IF NOT EXISTS fn::get_openai_token() { + RETURN "Bearer " +}; +``` + +*Heads up:* This is all for kicks and not meant for the production grind. Keep your OpenAI API key under wraps! + +### Setting Up SurrealDB + +With your setup ready, hit up some `make` commands to get SurrealDB into gear: + +Fire up SurrealDB for some on-disk action: + +```bash +make surreal-start +``` + +To lay down the database blueprint with table and function definitions: + +```bash +make surreal-init +``` + +Need a clean slate? Here's how to clear your database: + +```bash +make surreal-remove +``` + +### Python Time + +Jump into the Python virtual environment: + +```bash +source venv/bin/activate +``` + +Get all the project goodies installed: + +```bash +pip install -e . +``` + +### Grabbing the Dataset + +We're going for the Simple English Wikipedia dataset by OpenAI (it's a biggie — ~700MB zipped, sprawling into a 1.7GB CSV file) that includes those nifty vector embeddings. Ready to download it? + +```bash +get-data +``` + +### Populating SurrealDB + +Time to move that dataset into SurrealDB: + +```bash +surreal-insert +``` + +### Let's Do Some RAG! + +Dive into SurrealDB with SurrealQL: + +```bash +make surreal-sql +``` + +And here's a taste of what you can do with a RAG operation: + +```sql +RETURN fn::surreal_rag("gpt-3.5-turbo", "Who is the greatest basketball player of all time?", 0.85, 0.5); +``` + +### Let's chat? + +To start chatting with the RAG: + +``` +make server-start +``` + +## Extra Bits + +Beyond the RAG adventure, feel free to query, explore, and play with the data in any way you fancy. And if you're looking to amp up your game, tools like LangChain are there to spice things up. + +## Features! More features! + +- [ ] Handle them darn errors! Reply with a system message that informs the user there has been an oopsie. +- [ ] Add user chat history as context. +- [ ] There are way too many steps to get started - docker-compose? +- [ ] Perform RAG to generate SurrealQL QA - this I will need help with. +- [ ] Ummm where are the tests? You, the user, are the test! (seriously, I need to add some...). + +## Coffee, Anyone? + +If this little project made your day or saved you a coffee break's worth of time, consider fueling my caffeine love: + +Buy Me A Coffee + diff --git a/surrealdb-rag/__init__.py b/surrealdb-rag/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/surrealdb-rag/pyproject.toml b/surrealdb-rag/pyproject.toml new file mode 100644 index 0000000..5a48e1b --- /dev/null +++ b/surrealdb-rag/pyproject.toml @@ -0,0 +1,34 @@ +[build-system] +requires = ["setuptools >= 61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "surrealdb_rag" +version = "0" +description = "Example of RAG using SurrealDB" +readme = "README.md" +license = {file = "LICENSE"} +requires-python = ">=3.11" +dependencies = [ + "pandas", + "wget", + "pandas-stubs", + "surrealdb", + "tqdm", + "fastapi", + "uvicorn", + "jinja2", + "python-multipart", + "python-dotenv", + "ollama", + "google.generativeai", + "openai", + "fasttext" +] + +[project.scripts] +surreal-create-db = "surrealdb_rag.create_database:surreal_create_database" +surreal-insert-wiki = "surrealdb_rag.insert_wiki:surreal_wiki_insert" +download-data = "surrealdb_rag.download:download_data" +download-glove = "surrealdb_rag.download:download_glove_model" +surreal-insert-glove = "surrealdb_rag.insert_glove:surreal_glove_insert" \ No newline at end of file diff --git a/surrealdb-rag/requrements.txt b/surrealdb-rag/requrements.txt new file mode 100644 index 0000000..003579f --- /dev/null +++ b/surrealdb-rag/requrements.txt @@ -0,0 +1,15 @@ +pandas +surrealdb +wget +pandas-stubs` +surrealdb +tqdm +fastapi +uvicorn +jinja2 +python-multipart +python-dotenv +ollama +google.generativeai +openai +fasttext \ No newline at end of file diff --git a/surrealdb-rag/schema/function_ddl.surql b/surrealdb-rag/schema/function_ddl.surql new file mode 100644 index 0000000..e1614e2 --- /dev/null +++ b/surrealdb-rag/schema/function_ddl.surql @@ -0,0 +1,443 @@ +/* +This file defines the SurrealQL for the chat functionality of this project. and functions that span either embedding model +*/ + + + +DEFINE FUNCTION OVERWRITE fn::retrieve_vector_size_for_model($model:string) +{ + RETURN (SELECT VALUE array::len(embedding) FROM embedding_model:[ + $model, + NONE + ]..[ + $model, + .. + ] LIMIT 1)[0] +}; + +DEFINE FUNCTION OVERWRITE fn::sentence_to_vector($sentence: string,$model: string) { + + #Pull the first row to determine the size of the vector (they should all be the same) + LET $vector_size = fn::retrieve_vector_size_for_model($model); + + #select the vectors from the embedding table that match the words + LET $vectors = fn::retrieve_vectors_for_sentence($sentence,$model); + + #remove any non-matches + LET $vectors = array::filter($vectors, |$v| { RETURN $v != NONE; }); + + #transpose the vectors to be able to average them + LET $transposed = array::transpose($vectors); + + #sum up the individual floats in the arrays + LET $sum_vector = $transposed.map(|$sub_array| math::sum($sub_array)); + + # calculate the mean of each vector by dividing by the total number of + # vectors in each of the floats + LET $mean_vector = vector::scale($sum_vector, 1.0 / array::len($vectors)); + + #if the array size is correct return it, otherwise return array of zeros + RETURN + IF array::len($mean_vector) == $vector_size {$mean_vector} + ELSE {None} + ; +}; + + + +/* Search for documents using embeddings. + +Args: + input_vector: embedding to search for within the embedding field + threshold: min threshold above and beyond the N returned + model: the name of the embedding model to use... "GLOVE", "CUST_FASTTEXT" or "OPENAI" +Returns: + array: Array of embeddings. +*/ + +DEFINE FUNCTION OVERWRITE fn::search_for_documents($input_vector: array, $threshold: float, $model: string) { + + LET $first_pass = + (IF $model = "GLOVE" THEN + ( + SELECT + id, + vector::similarity::cosine(content_glove_vector, $input_vector) AS similarity_score + FROM embedded_wiki + WHERE content_glove_vector <|5,40|> $input_vector + ); + ELSE IF $model = "CUST_FASTTEXT" THEN + ( + SELECT + id, + vector::similarity::cosine(content_fasttext_vector, $input_vector) AS similarity_score + FROM embedded_wiki + WHERE content_fasttext_vector <|5,40|> $input_vector + ); + ELSE IF $model = "OPENAI" THEN + SELECT id FROM ( + SELECT + id, + vector::similarity::cosine(content_openai_vector, $input_vector) AS similarity_score + FROM embedded_wiki + WHERE content_openai_vector <|5,40|> $input_vector + ) ; + END); + + + RETURN SELECT similarity_score as score,id as doc FROM $first_pass WHERE similarity_score > $threshold; +} + +/* Get prompt for RAG. + +Args: + context: Context to add to the prompt. + +Returns: + string: Prompt with context. +*/ +DEFINE FUNCTION OVERWRITE fn::get_prompt_with_context($documents: array>) { + + LET $prompt = "You are an AI assistant answering questions about anything from Simple English Wikipedia the context will provide you with the most relevant data from Simple English Wikipedia including the page title, url, and page content. + + If referencing the text/context refer to it as Simple English Wikipedia. + + Please provide your response in Markdown converted to HTML format. Include appropriate headings and lists where relevant. + + At the end of the response, add link a HTML link and replace the title and url with the associated title and url of the more relevant page from the context. + + The maximum number of links you can include is 1, do not provide any other references or annotations. + + Only reply with the context provided. If the context is an empty string, reply with 'I am sorry, I do not know the answer.'. + + Do not use any prior knowledge that you have been trained on. + + + $context + "; + LET $context = (SELECT VALUE "\n ------------- \n URL: " + url + "\nTitle: " + title + "\n Content:\n" + text as content + FROM $documents).join("\n"); + RETURN string::replace($prompt, '$context', $context); +}; + + +/* Create a message. + +Args: + chat_id: Record ID from the `chat` table that the message was sent in. + role: Role that sent the message. Allowed values are `user` or `system`. + content: Sent message content. + +Returns: + oject: Content and timestamp. +*/ + +DEFINE FUNCTION OVERWRITE fn::create_message( + $chat_id: string, + $role: string, + $content: string, + $documents: option }>>, + $embedding_model: option, + $llm_model: option, +) { + # Create a message record and get the resulting ID. + LET $message_id = + SELECT VALUE + id + FROM ONLY + CREATE ONLY message + SET role = $role, + content = $content; + + # Create a relation between the chat record and the message record and get the resulting timestamp. + LET $chat = type::record($chat_id); + LET $timestamp = + SELECT VALUE + timestamp + FROM ONLY + RELATE ONLY $chat->sent->$message_id CONTENT { + referenced_documents: $documents, + embedding_model: $embedding_model, + llm_model: $llm_model + }; + + RETURN { + content: $content, + timestamp: $timestamp + }; +}; + +/* Create a user message. + +Args: + chat_id: Record ID from the `chat` table that the message was sent in. + content: Sent message content. + +Returns: + object: Content and timestamp. +*/ +DEFINE FUNCTION OVERWRITE fn::create_user_message($chat_id: string, $content: string, $embedding_model: string,$openai_token: option) { + + LET $threshold = 0.7; + + LET $vector = IF $embedded_model == "OPENAI" THEN + fn::openai_embeddings_complete("text-embedding-ada-002", $content, $openai_token) + ELSE + fn::sentence_to_vector($content,$embedding_model) + END; + LET $documents = fn::search_for_documents($vector, $threshold ,$embedding_model); + RETURN fn::create_message($chat_id, "user", $content, $documents,$embedding_model,None); +}; + +/* Create get the last user message and the reference docs for generating a prompt. + +Args: + chat_id: Record ID from the `chat` table that the message was sent in\ + +Returns: + object: Content, referenced documents [{score,embedded_wiki}], timestamp. +*/ + +DEFINE FUNCTION OVERWRITE fn::get_last_user_message_input_and_prompt($chat_id: string) { + LET $message = + SELECT content,fn::get_prompt_with_context(docs) as prompt_text FROM ( + SELECT + out.content AS content, + referenced_documents.doc as docs, + timestamp + FROM ONLY type::record($chat_id)->sent + WHERE out.role = "user" + ORDER BY timestamp DESC + LIMIT 1 + FETCH out); + + RETURN $message[0]; +}; +/* Generate get the user's message in a chat for generating a tile. + +Args: + chat_id: Record ID from the `chat` table to generate a title for. + +Returns: + string: first chat content. +*/ + +DEFINE FUNCTION OVERWRITE fn::get_first_message($chat_id: string) { + # Get the `content` of the user's initial message. + RETURN ( + SELECT + out.content AS content, + timestamp + FROM ONLY type::record($chat_id)->sent + ORDER BY timestamp + LIMIT 1 + FETCH out + ).content; + +}; + + +/* Create a new chat. + +Returns: + object: Object containing `id` and `title`. +*/ +DEFINE FUNCTION OVERWRITE fn::create_chat() { + RETURN CREATE ONLY chat + RETURN id, title; +}; + +/* Load a chat. + +Args: + chat_id: Record ID from the `chat` table to load. + +Returns: + array[objects]: Array of messages containing `role` and `content`. +*/ +DEFINE FUNCTION OVERWRITE fn::load_chat($chat_id: string) { + RETURN + SELECT + out.role AS role, + out.content AS content, + timestamp + FROM type::record($chat_id)->sent + ORDER BY timestamp + FETCH out; +}; + +/* Load all chats + +Returns: + array[objects]: array of chats records containing `id`, `title`, and `created_at`. +*/ +DEFINE FUNCTION OVERWRITE fn::load_all_chats() { + RETURN + SELECT + id, title, created_at + FROM chat + ORDER BY created_at DESC; +}; + +/* Get chat title + +Args: Record ID of the chat to get the title for. + +Returns: + string: Chat title. +*/ +DEFINE FUNCTION OVERWRITE fn::get_chat_title($chat_id: string) { + RETURN SELECT VALUE title FROM ONLY type::record($chat_id); +}; + +/* delete a chat and sent messages. + +Args: Record ID of the chat to get the title for. + +Returns: + string: chat id that was delete. +*/ + +DEFINE FUNCTION OVERWRITE fn::delete_chat($chat_id:string){ + $chat = type::record($chat_id); + DELETE message WHERE id IN (SELECT ->sent->message FROM $chat); + DELETE sent WHERE in = $chat; + DELETE $chat; + RETURN $chat; +}; + + +/* OpenAI embeddings complete. +Args: + embeddings_model: Embedding model from OpenAI. + input: User input. + openai_token: the token used to authorize calling the API + +Returns: + array: Array of embeddings. +*/ +DEFINE FUNCTION OVERWRITE fn::openai_embeddings_complete($embedding_model: string, $input: string, $openai_token:string) { + RETURN http::post( + "https://api.openai.com/v1/embeddings", + { + "model": $embedding_model, + "input": $input + }, + { + "Authorization": "Bearer " + $openai_token + } + )["data"][0]["embedding"] +}; + + +/* OpenAI chat complete. + +Args: + llm: Large Language Model to use for generation. + input: Initial user input. + prompt_with_context: Prompt with context for the system. + +Returns: + string: Response from LLM. +*/ +DEFINE FUNCTION OVERWRITE fn::openai_chat_complete($llm: string, $input: string, $prompt_with_context: string, $temperature: float, $openai_token:string) { + LET $response = http::post( + "https://api.openai.com/v1/chat/completions", + { + "model": $llm, + "messages": [ + { + "role": "system", + "content": $prompt_with_context + }, + { + "role": "user", "content": $input + }, + ], + "temperature": $temperature + }, + { + "Authorization": $openai_token + } + )["choices"][0]["message"]["content"]; + + # Sometimes there are double quotes + RETURN string::replace($response, '"', ''); +}; + + +/* Gemini format for their endpoint has the model name and key in the query + +Args: + llm: Large Language Model to use for generation. + google_token: the API token for gemini +Returns: + string: path to query for LLM. +*/ +DEFINE FUNCTION OVERWRITE fn::get_gemini_api_url($llm: string,$google_token:string){ + return string::concat("https://generativelanguage.googleapis.com/v1beta/models/",$llm,":generateContent?key=",$google_token); + +}; + + +/* Gemini chat complete. + +Args: + llm: Large Language Model to use for generation. + input: Initial user input. + prompt_with_context: Prompt with context for the system. + google_token: the API token for gemini + +Returns: + string: Response from LLM. +*/ +DEFINE FUNCTION OVERWRITE fn::gemini_chat_complete($llm: string, $prompt_with_context: string, $input: string,$google_token:string) { + + LET $body = { + "contents": [{ + "parts":[{"text": $prompt_with_context},{"text": $input}] + }], + "safetySettings": [] + }; + RETURN http::post( + fn::get_gemini_api_url($llm,$google_token), + $body + ); +}; + + +#these funtions calulates the mean vector for the tokens in a sentence using the glove Model +DEFINE FUNCTION OVERWRITE fn::retrieve_vectors_for_sentence($sentence:string,$model:string) +{ + LET $sentence = $sentence.lowercase(). + replace('.',' .'). + replace(',',' ,'). + replace('?',' ?'). + replace('!',' !'). + replace(';',' ;'). + replace(':',' :'). + replace('(',' ('). + replace(')',' )'). + replace('[',' ['). + replace(']',' ]'). + replace('{',' {'). + replace('}',' }'). + replace('"',' "'). + replace("'"," '"). + replace('`',' `'). + replace('/',' /'). + replace('\\',' \\'). + replace('<',' <'). + replace('>',' >'). + replace('—',' —'). + replace('–',' –'); + LET $words = $sentence.words(); + LET $words = array::filter($words, |$word: any| $word != ''); + + #select the vectors from the embedding table that match the words + + RETURN (SELECT VALUE embedding_model:[ + $model,$this].embedding FROM $words); + + +}; \ No newline at end of file diff --git a/surrealdb-rag/schema/table_ddl.surql b/surrealdb-rag/schema/table_ddl.surql new file mode 100644 index 0000000..673bbfc --- /dev/null +++ b/surrealdb-rag/schema/table_ddl.surql @@ -0,0 +1,104 @@ +/* +This file defines the SurrealQL for the chat functionality of this project. and functions that span either embedding model +*/ + + +# Define the `chat` table. +DEFINE TABLE IF NOT EXISTS chat SCHEMAFULL; + +DEFINE FIELD IF NOT EXISTS title ON TABLE chat TYPE string + DEFAULT "Untitled chat"; + +# Field is populated on creation and is readonly. +DEFINE FIELD IF NOT EXISTS created_at ON TABLE chat TYPE datetime + VALUE time::now() READONLY; + +# Field automatically updates when a field is edited. +DEFINE FIELD IF NOT EXISTS updated_at ON TABLE chat TYPE datetime + VALUE time::now(); + +# Define the message table. +DEFINE TABLE IF NOT EXISTS message SCHEMAFULL; + +/* Field can only be populated with `user` or `system`. + +There are CSS and HTML that relies on these values. +*/ +DEFINE FIELD IF NOT EXISTS role ON message TYPE string + ASSERT $input IN ["user", "system"]; + +DEFINE FIELD IF NOT EXISTS content ON message TYPE string; + +# Field is populated on creation and is readonly. +DEFINE FIELD IF NOT EXISTS created_at ON TABLE message TYPE datetime + VALUE time::now() READONLY; + +# Field automatically updates when a field is edited. +DEFINE FIELD IF NOT EXISTS updated_at ON TABLE message TYPE datetime + VALUE time::now(); + +# Define the `sent` edge table. +DEFINE TABLE sent TYPE RELATION IN chat OUT message ENFORCED; +DEFINE FIELD IF NOT EXISTS timestamp ON TABLE sent TYPE datetime + VALUE time::now(); +DEFINE FIELD IF NOT EXISTS referenced_documents ON TABLE sent TYPE option }>>; +DEFINE FIELD IF NOT EXISTS llm_model ON TABLE sent TYPE option; +DEFINE FIELD IF NOT EXISTS embedding_model ON TABLE sent TYPE option; + +# A message can only be sent in one chat +DEFINE INDEX IF NOT EXISTS unique_sent_message_in_chat + ON TABLE sent + COLUMNS in, out UNIQUE; + + +/* +This file defines the SurrealQL DDL for the glove model embedding functionality of this project. +*/ + + +# Define the `embedded_wiki` table. + +DEFINE TABLE IF NOT EXISTS embedded_wiki SCHEMAFULL; + +DEFINE FIELD IF NOT EXISTS url ON TABLE embedded_wiki TYPE string + # Field must be a URL. + ASSERT string::is::url($value); + +DEFINE FIELD IF NOT EXISTS title ON TABLE embedded_wiki TYPE string + # Field must be non-empty + ASSERT string::len($value) > 0; + +DEFINE FIELD IF NOT EXISTS text ON TABLE embedded_wiki TYPE string + # Field must be non-empty + ASSERT string::len($value) > 0; + +DEFINE FIELD IF NOT EXISTS content_glove_vector ON TABLE embedded_wiki TYPE option> + # Field must have length 300 to use embedding model: glove 300d + ASSERT array::len($value) = 300; + +DEFINE INDEX IF NOT EXISTS embedded_wiki_content_glove_vector_index ON embedded_wiki + FIELDS content_glove_vector + HNSW DIMENSION 300 M 32 EFC 300; + +DEFINE FIELD IF NOT EXISTS content_openai_vector ON TABLE embedded_wiki TYPE option> + # Field must have length 1536 to use embedding model: text-embedding-ada-002 + ASSERT array::len($value) = 1536; + +DEFINE INDEX IF NOT EXISTS embedded_wiki_content_openai_vector_index ON embedded_wiki + FIELDS content_openai_vector + HNSW DIMENSION 1536 M 32 EFC 300; + +DEFINE FIELD IF NOT EXISTS content_fasttext_vector ON TABLE embedded_wiki TYPE option> + # Field must have length 1536 to use embedding model: text-embedding-ada-002 + ASSERT array::len($value) = 100; + +DEFINE INDEX IF NOT EXISTS embedded_wiki_content_fasttext_vector_index ON embedded_wiki + FIELDS content_fasttext_vector + HNSW DIMENSION 100 M 32 EFC 300; + + +# this is a table to store a glove word model in database +DEFINE TABLE IF NOT EXISTS embedding_model TYPE NORMAL SCHEMAFULL; +DEFINE FIELD IF NOT EXISTS word ON embedding_model TYPE string; +DEFINE FIELD IF NOT EXISTS model ON embedding_model TYPE string; +DEFINE FIELD IF NOT EXISTS embedding ON embedding_model TYPE array; diff --git a/surrealdb-rag/src/surrealdb_rag/__init__.py b/surrealdb-rag/src/surrealdb_rag/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/surrealdb-rag/src/surrealdb_rag/app.py b/surrealdb-rag/src/surrealdb_rag/app.py new file mode 100644 index 0000000..3479138 --- /dev/null +++ b/surrealdb-rag/src/surrealdb_rag/app.py @@ -0,0 +1,279 @@ +"""Backend for SurrealDB chat interface.""" + +import contextlib +import datetime +from collections.abc import AsyncGenerator +import fastapi +from surrealdb import AsyncSurreal,RecordID +from fastapi import responses, staticfiles, templating +from surrealdb_rag.llm_handler import LLMModelHander,ModelListHandler + +import uvicorn + +from surrealdb_rag.constants import DatabaseParams, ModelParams, ArgsLoader, SurrealParams +db_params = DatabaseParams() +model_params = ModelParams() +args_loader = ArgsLoader("LLM Model Handler",db_params,model_params) +args_loader.LoadArgs() + + + + +def extract_id(surrealdb_id: RecordID) -> str: + """Extract numeric ID from SurrealDB record ID. + + SurrealDB record ID comes in the form of `:`. + CSS classes cannot be named with a `:` so for CSS we extract the ID. + + Args: + surrealdb_id: SurrealDB record ID. + + Returns: + ID. + """ + if RecordID == type(surrealdb_id): + #return surrealdb_id.id + return surrealdb_id.id.replace(":","-") + else: + return surrealdb_id.replace(":","-") + + + +def convert_timestamp_to_date(timestamp: str) -> str: + """Convert a SurrealDB `datetime` to a readable string. + + The result will be of the format: `April 05 2024, 15:30`. + + Args: + timestamp: SurrealDB `datetime` value. + + Returns: + Date as a string. + """ + # parsed_timestamp = datetime.datetime.fromisoformat(timestamp.rstrip("Z")) + # return parsed_timestamp.strftime("%B %d %Y, %H:%M") + return timestamp + +templates = templating.Jinja2Templates(directory="templates") +templates.env.filters["extract_id"] = extract_id +templates.env.filters["convert_timestamp_to_date"] = convert_timestamp_to_date +life_span = {} + + +@contextlib.asynccontextmanager +async def lifespan(_: fastapi.FastAPI) -> AsyncGenerator: + """FastAPI lifespan to create and destroy objects.""" + + connection = AsyncSurreal(db_params.DB_PARAMS.url) + await connection.signin({"username": db_params.DB_PARAMS.username, "password": db_params.DB_PARAMS.password}) + await connection.use(db_params.DB_PARAMS.namespace, db_params.DB_PARAMS.database) + life_span["surrealdb"] = connection + + + model_list = ModelListHandler(model_params,life_span["surrealdb"]) + + life_span["llm_models"] = await model_list.available_llm_models() + life_span["embed_models"] = await model_list.available_embed_models() + + yield + life_span.clear() + + +app = fastapi.FastAPI(lifespan=lifespan) +app.mount("/static", staticfiles.StaticFiles(directory="static"), name="static") + + +@app.get("/", response_class=responses.HTMLResponse) +async def index(request: fastapi.Request) -> responses.HTMLResponse: + + + return templates.TemplateResponse("index.html", { + "request": request, + "available_llm_models": life_span["llm_models"], + "available_embed_models": life_span["embed_models"], + "default_llm_model": life_span["llm_models"][next(iter(life_span["llm_models"]))], + "default_embed_model":life_span["embed_models"][next(iter(life_span["embed_models"]))] + }) + +@app.get("/get_llm_model_details") +async def get_llm_model_details(llm_model: str = fastapi.Query(...)): + model_data = life_span["llm_models"].get(llm_model) + if model_data: + s = f"Model Version: {model_data['model_version']}, Host: {model_data['host']}" + else: + s = "Model details not found." + return fastapi.Response(s, media_type="text/html") #Return response object + +@app.get("/get_embed_model_details") +async def get_embed_model_details(embed_model: str = fastapi.Query(...)): + model_data = life_span["embed_models"].get(embed_model) + if model_data: + s = f"Version: {model_data['dimensions']}, Host: {model_data['host']}" + else: + s = "Model details not found." + return fastapi.Response(s, media_type="text/html") #Return response object + + + +@app.post("/chats", response_class=responses.HTMLResponse) +async def create_chat(request: fastapi.Request) -> responses.HTMLResponse: + """Create a chat.""" + chat_record = await life_span["surrealdb"].query( + """RETURN fn::create_chat();""" + ) + return templates.TemplateResponse( + "create_chat.html", + { + "request": request, + "chat_id": chat_record["id"], + "chat_title": chat_record["title"], + }, + ) + + +@app.delete("/chats/{chat_id}/delete", response_class=responses.HTMLResponse) +async def delete_chat( + request: fastapi.Request, chat_id: str +) -> responses.HTMLResponse: + + SurrealParams.ParseResponseForErrors( await life_span["surrealdb"].query_raw( + """RETURN fn::delete_chat($chat_id)""",params = {"chat_id":chat_id} + )) + return fastapi.Response(status_code=fastapi.status.HTTP_204_NO_CONTENT) + +@app.get("/chats/{chat_id}", response_class=responses.HTMLResponse) +async def load_chat( + request: fastapi.Request, chat_id: str +) -> responses.HTMLResponse: + """Load a chat.""" + message_records = await life_span["surrealdb"].query( + """RETURN fn::load_chat($chat_id)""",params = {"chat_id":chat_id} + ) + return templates.TemplateResponse( + "load_chat.html", + { + "request": request, + "messages": message_records, + "chat_id": chat_id, + }, + ) + + +@app.get("/chats", response_class=responses.HTMLResponse) +async def load_all_chats(request: fastapi.Request) -> responses.HTMLResponse: + """Load all chats.""" + chat_records = await life_span["surrealdb"].query( + """RETURN fn::load_all_chats();""" + ) + return templates.TemplateResponse( + "chats.html", {"request": request, "chats": chat_records} + ) + + +@app.post( + "/chats/{chat_id}/send-user-message", response_class=responses.HTMLResponse +) +async def send_user_message( + request: fastapi.Request, + chat_id: str, + content: str = fastapi.Form(...), + embed_model: str = fastapi.Form(...) +) -> responses.HTMLResponse: + """Send user message.""" + if embed_model == "OPENAI": + message = SurrealParams.ParseResponseForErrors( await life_span["surrealdb"].query_raw( + """RETURN fn::create_user_message($chat_id, $content,$embedding_model,$openaitoken);""",params = {"chat_id":chat_id,"content":content,"embedding_model":embed_model,"openaitoken":model_params.openai_token} + )) + else: + message = SurrealParams.ParseResponseForErrors( await life_span["surrealdb"].query_raw( + """RETURN fn::create_user_message($chat_id, $content,$embedding_model);""",params = {"chat_id":chat_id,"content":content,"embedding_model":embed_model} + )) + + + return templates.TemplateResponse( + "send_user_message.html", + { + "request": request, + "chat_id": chat_id, + "content": message["result"][0]["result"]["content"], + "timestamp": message["result"][0]["result"]["timestamp"] + }, + ) + + +@app.post( + "/chats/{chat_id}/send-system-message", + response_class=responses.HTMLResponse, +) +async def send_system_message( + request: fastapi.Request, chat_id: str, + llm_model: str = fastapi.Form(...) +) -> responses.HTMLResponse: + """Send system message.""" + + + + + message = SurrealParams.ParseResponseForErrors( await life_span["surrealdb"].query_raw( + """RETURN fn::get_last_user_message_input_and_prompt($chat_id);""",params = {"chat_id":chat_id} + )) + result = message["result"][0]["result"] + prompt_text = result["prompt_text"] + content = result["content"] + #call the LLM + model_data = life_span["llm_models"].get(llm_model) + if not model_data: + raise SystemError(f"Error in outcome: Invalid model {llm_model}") + + llm_handler = LLMModelHander(model_data,model_params,life_span["surrealdb"]) + + llm_response = llm_handler.get_chat_response(prompt_text,content) + + #save the response in the DB + message = SurrealParams.ParseResponseForErrors(await life_span["surrealdb"].query_raw( + """RETURN fn::create_message($chat_id, "system", $llm_response);""",params = {"chat_id":chat_id,"llm_response":llm_response} + )) + + title = await life_span["surrealdb"].query( + """RETURN fn::get_chat_title($chat_id);""",params = {"chat_id":chat_id} + ) + new_title = "" + if title == "Untitled chat": + first_message_text = await life_span["surrealdb"].query( + "RETURN fn::get_first_message($chat_id);",params={"chat_id":chat_id} + ) + system_prompt = "You are a conversation title generator for a ChatGPT type app. Respond only with a simple title using the user input." + new_title = llm_handler.get_chat_response(system_prompt,first_message_text) + #update chat title in database + SurrealParams.ParseResponseForErrors(await life_span["surrealdb"].query_raw( + """UPDATE type::record($chat_id) SET title=$title;""",params = {"chat_id":chat_id,"title":new_title} + )) + + + + result = message["result"][0]["result"] + + + + return templates.TemplateResponse( + "send_system_message.html", + { + "request": request, + "content": result["content"], + "timestamp": result["timestamp"], + "new_title": new_title.strip(), + "chat_id": chat_id, + }, + ) + + + +def run_app(): + + uvicorn.run( + "__main__:app", reload=True + ) + + +if __name__ == "__main__": + run_app() \ No newline at end of file diff --git a/surrealdb-rag/src/surrealdb_rag/constants.py b/surrealdb-rag/src/surrealdb_rag/constants.py new file mode 100644 index 0000000..097a956 --- /dev/null +++ b/surrealdb-rag/src/surrealdb_rag/constants.py @@ -0,0 +1,224 @@ +import argparse +import os + + +WIKI_URL = "https://cdn.openai.com/API/examples/data/vector_database_wikipedia_articles_embedded.zip" +WIKI_ZIP_PATH = "data/vector_database_wikipedia_articles_embedded.zip" +WIKI_PATH = "data/vector_database_wikipedia_articles_embedded.csv" + +GLOVE_URL = "https://nlp.stanford.edu/data/glove.6B.zip" +GLOVE_ZIP_PATH = "data/glove.6B.zip" +GLOVE_PATH = "data/glove.6B.300d.txt" + +CUSTOM_FS_PATH = "data/custom_fast_text.txt" + + +class SurrealParams(): + def __init__(self = None, url = None,username = None, password = None, namespace = None, database = None): + self.username = username + self.password = password + self.namespace = namespace + self.database = database + self.url = url + + @staticmethod + def ParseResponseForErrors(outcome): + + if outcome: + if "result" in outcome: + for item in outcome["result"]: + if item["status"]=="ERR": + raise SystemError("Error in results: {0}".format(item["result"])) + + if "error" in outcome: + raise SystemError("Error in outcome: {0}".format(outcome["error"])) + + return outcome + else: + return None + + +class ModelParams(): + + LLM_MODELS = { + "GEMINI-SURREAL": {"model_version":"gemini-2.0-flash","host":"SQL","platform":"GOOGLE","temperature":None}, + "GEMINI": {"model_version":"gemini-2.0-flash","host":"API","platform":"GOOGLE","temperature":None}, + "DEEPSEEK": {"model_version":"deepseek-r1:1.5b","host":"OLLAMA","platform":"local","temperature":None}, + "OPENAI-SURREAL": {"model_version":"gpt-3.5-turbo","host":"API","platform":"OPENAI","temperature":0.5}, + "OPENAI": {"model_version":"gpt-3.5-turbo","host":"API","platform":"OPENAI","temperature":0.5} + } + + EMBED_MODELS = { + "CUST_FASTTEXT": {"dimensions":100,"host":"SQL"}, + "GLOVE": {"dimensions":300,"host":"SQL"}, + "OPENAI": {"dimensions":1536,"host":"API"} + } + def __init__(self): + self.openai_token_env_var = "OPENAI_API_KEY" + self.openai_token = None + self.gemini_token_env_var = "GOOGLE_GENAI_API_KEY" + self.gemini_token = None + # self.embedding_model_env_var = "SURREAL_RAG_EMBEDDING_MODEL" + # self.embedding_model = None + # self.llm_model_env_var = "SURREAL_RAG_LLM_MODEL" + # self.llm_model = None + # self.version = None + # self.host = None + # self.temperature = 0.5 + + def AddArgs(self, parser:argparse.ArgumentParser): + parser.add_argument("-oenv","--openai_token_env", help="Your env variable for LLM openai_token (Default: {0} for ollama hosted ignore)".format(self.openai_token_env_var)) + parser.add_argument("-genv","--gemini_token_env", help="Your env variable for LLM gemini_token (Default: {0} for ollama hosted ignore)".format(self.gemini_token_env_var)) + + #parser.add_argument("-emenv","--embedding_model_env_var", help="Your env variable for embedding model value can be 'OPENAI' or 'GLOVE' (Default: {0})".format(self.embedding_model_env_var)) + #parser.add_argument("-em","--embedding_model", help="Embedding model value can be 'OPENAI' or 'GLOVE', if none it will use env var (Default: {0})".format("")) + # parser.add_argument("-llmenv","--llm_model_env_var", help="Your env variable for LLM model value can be 'OPENAI','DEEPSEEK' or 'GEMINI' (Default: {0})".format(self.llm_model_env_var)) + # parser.add_argument("-llm","--llm_model", help="LLM model value can be 'OPENAI'.'DEEPSEEK' or 'GEMINI', if none it will use env var (Default: {0})".format("")) + + def SetArgs(self,args:argparse.Namespace): + if args.openai_token_env: + self.openai_token_env_var = args.openai_token_env + + self.openai_token = os.getenv(self.openai_token_env_var) + + if args.gemini_token_env: + self.gemini_token_env_var = args.gemini_token_env + self.gemini_token = os.getenv(self.gemini_token_env_var) + + + # if args.embedding_model_env_var: + # self.embedding_model_env_var = args.embedding_model_env_var + # if args.llm_model_env_var: + # self.llm_model_env_var = args.llm_model_env_var + + + # if args.embedding_model: + # self.embedding_model = args.embedding_model + # else: + # self.embedding_model = os.getenv(self.embedding_model_env_var) + + # if self.embedding_model not in ["OPENAI","GLOVE"]: + # raise ValueError("Embedding model must be 'OPENAI' or 'GLOVE'") + + # if args.llm_model: + # self.llm_model = args.llm_model + # else: + # self.llm_model = os.getenv(self.llm_model_env_var) + + # self.version = self.LLM_MODELS[self.llm_model]["model_version"] + # self.host = self.LLM_MODELS[self.llm_model]["host"] + + + + +class DatabaseParams(): + def __init__(self): + #export SURREAL_CLOUD_TEST_USER=xxx + #export SURREAL_CLOUD_TEST_PASS=xxx + self.DB_USER_ENV_VAR = "SURREAL_RAG_USER" + self.DB_PASS_ENV_VAR = "SURREAL_RAG_PASS" + self.DB_URL_ENV_VAR = "SURREAL_RAG_DB_URL" + self.DB_NS_ENV_VAR = "SURREAL_RAG_DB_NS" + self.DB_DB_ENV_VAR = "SURREAL_RAG_DB_DB" + + + #The path to your SurrealDB instance + #The the SurrealDB namespace and database to upload the model to + self.DB_PARAMS = SurrealParams() + + #For use in authenticating your database in database.py + #These are just the pointers to the env variables + #Don't put the actual passwords here + def AddArgs(self, parser:argparse.ArgumentParser): + + parser.add_argument("-urlenv","--url_env", help="Your env variable for Path to your SurrealDB instance (Default: {0})".format(self.DB_URL_ENV_VAR)) + parser.add_argument("-nsenv","--namespace_env", help="Your env variable for SurrealDB namespace to create and install the data (Default: {0})".format(self.DB_NS_ENV_VAR)) + parser.add_argument("-dbenv","--database_env", help="Your env variable for SurrealDB database to create and install the data (Default: {0})".format(self.DB_DB_ENV_VAR)) + parser.add_argument("-uenv","--user_env", help="Your env variable for db username (Default: {0})".format(self.DB_USER_ENV_VAR)) + parser.add_argument("-penv","--pass_env", help="Your env variable for db password (Default: {0})".format(self.DB_PASS_ENV_VAR)) + + + parser.add_argument("-url","--url", help="Your Path to your SurrealDB instance (Default: {0})".format("")) + parser.add_argument("-ns","--namespace", help="Your SurrealDB namespace to create and install the data (Default: {0})".format("")) + parser.add_argument("-db","--database", help="Your SurrealDB database to create and install the data (Default: {0})".format("")) + parser.add_argument("-u","--username", help="Your db username (Default: {0})".format("")) + parser.add_argument("-p","--password", help="Your db password (Default: {0})".format("")) + + def SetArgs(self,args:argparse.Namespace): + if args.url_env: + self.DB_URL_ENV_VAR = args.url_env + if args.namespace_env: + self.DB_NS_ENV_VAR = args.namespace_env + if args.database_env: + self.DB_DB_ENV_VAR = args.database_env + if args.user_env: + self.DB_USER_ENV_VAR = args.user_env + if args.pass_env: + self.DB_PASS_ENV_VAR = args.pass_env + + if args.url: + self.DB_PARAMS.url = args.url + else: + self.DB_PARAMS.url = os.getenv(self.DB_URL_ENV_VAR) + + if args.namespace: + self.DB_PARAMS.namespace = args.namespace + else: + self.DB_PARAMS.namespace = os.getenv(self.DB_NS_ENV_VAR) + + if args.database: + self.DB_PARAMS.database = args.database + else: + self.DB_PARAMS.database = os.getenv(self.DB_DB_ENV_VAR) + + if args.username: + self.DB_PARAMS.username = args.username + else: + self.DB_PARAMS.username = os.getenv(self.DB_USER_ENV_VAR) + + if args.password: + self.DB_PARAMS.password = args.password + else: + self.DB_PARAMS.password = os.getenv(self.DB_PASS_ENV_VAR) + + + + + + +class ArgsLoader(): + + def __init__(self,description, + db_params: DatabaseParams,model_params: ModelParams): + self.parser = argparse.ArgumentParser(description=description) + self.db_params = db_params + self.model_params = model_params + self.model_params.AddArgs(self.parser) + self.db_params.AddArgs(self.parser) + + + + def LoadArgs(self): + self.args = self.parser.parse_args() + self.db_params.SetArgs(self.args) + self.model_params.SetArgs(self.args) + + def string_to_print(self): + ret_val = self.parser.description + ret_val += f"/n{self.db_params.DB_PARAMS.__dict__}" + ret_val += f"/n{self.model_params.__dict__}" + return ret_val + + def print(self): + print(self.string_to_print()) + + + + + + + + + + + \ No newline at end of file diff --git a/surrealdb-rag/src/surrealdb_rag/create_database.py b/surrealdb-rag/src/surrealdb_rag/create_database.py new file mode 100644 index 0000000..801ee2b --- /dev/null +++ b/surrealdb-rag/src/surrealdb_rag/create_database.py @@ -0,0 +1,64 @@ +"""Insert Wikipedia data into SurrealDB.""" + + +from surrealdb import Surreal + +from surrealdb_rag import loggers + + +from surrealdb_rag.constants import DatabaseParams, ModelParams, ArgsLoader, SurrealParams + +db_params = DatabaseParams() +model_params = ModelParams() +args_loader = ArgsLoader("Input Embeddings Model",db_params,model_params) +args_loader.LoadArgs() + +def surreal_create_database() -> None: + """Create SurrealDB database for Wikipedia embeddings.""" + logger = loggers.setup_logger("SurrealCreateDatabase") + + logger.info(args_loader.string_to_print()) + with Surreal(db_params.DB_PARAMS.url) as connection: + logger.info("Connected to SurrealDB") + connection.signin({"username": db_params.DB_PARAMS.username, "password": db_params.DB_PARAMS.password}) + logger.info("Creating database") + query= f""" + + DEFINE NAMESPACE IF NOT EXISTS {db_params.DB_PARAMS.namespace}; + USE NAMESPACE {db_params.DB_PARAMS.namespace}; + REMOVE DATABASE IF EXISTS {db_params.DB_PARAMS.database}; + DEFINE DATABASE {db_params.DB_PARAMS.database}; + USE DATABASE {db_params.DB_PARAMS.database}; + """ + logger.info(query) + SurrealParams.ParseResponseForErrors(connection.query_raw( + query + )) + logger.info("Database created successfully") + connection.use(db_params.DB_PARAMS.namespace, db_params.DB_PARAMS.database) + + logger.info("Executing common DDL") + with open("./schema/table_ddl.surql") as f: + surlql_to_execute = f.read() + SurrealParams.ParseResponseForErrors( connection.query_raw(surlql_to_execute)) + + with open("./schema/function_ddl.surql") as f: + surlql_to_execute = f.read() + SurrealParams.ParseResponseForErrors( connection.query_raw(surlql_to_execute)) + + # match model_params.EMBEDDING_MODEL: + # case "OPENAI": + # logger.info("Creating DDL for open ai model") + # with open("./schema/openai_embedding_ddl.surql") as f: + # surlql_to_execute = f.read() + # SurrealParams.ParseResponseForErrors( connection.query_raw(surlql_to_execute)) + # case "GLOVE": + # logger.info("Creating DDL for glove model") + # with open("./schema/glove_embedding_ddl.surql") as f: + # surlql_to_execute = f.read() + # SurrealParams.ParseResponseForErrors( connection.query_raw(surlql_to_execute)) + # case _: + # raise ValueError("Embedding model must be 'OPENAI' or 'GLOVE'") + +if __name__ == "__main__": + surreal_create_database() \ No newline at end of file diff --git a/surrealdb-rag/src/surrealdb_rag/download_data.py b/surrealdb-rag/src/surrealdb_rag/download_data.py new file mode 100644 index 0000000..8c403ea --- /dev/null +++ b/surrealdb-rag/src/surrealdb_rag/download_data.py @@ -0,0 +1,84 @@ +"""Download OpenAI Wikipedia data.""" + +import zipfile + +import wget +import os + +from surrealdb_rag import loggers + +import surrealdb_rag.constants as constants + +from surrealdb_rag.embeddings import WordEmbeddingModel + +import pandas as pd +import tqdm + +def download_data() -> None: + """Extract `vector_database_wikipedia_articles_embedded.csv` to `/data`.""" + logger = loggers.setup_logger("DownloadData") + + logger.info("Downloading Wikipedia") + # if not os.path.exists("data"): + # os.makedirs("data") + # wget.download( + # url=constants.WIKI_URL, + # out=constants.WIKI_ZIP_PATH, + # ) + + logger.info("Extracting to data directory") + with zipfile.ZipFile( + constants.WIKI_ZIP_PATH, "r" + ) as zip_ref: + zip_ref.extractall("data") + + if not os.path.exists(constants.WIKI_PATH): + raise FileNotFoundError(f"File not found: {constants.WIKI_PATH}") + + logger.info("Loading Glove embedding model") + + + + try: + gloveEmbeddingModel = WordEmbeddingModel(constants.GLOVE_PATH) + except Exception as e: + logger.error(f"Error opening embedding model. please check the model file was downloaded using download_glove_model {e}") + + try: + fastTextEmbeddingModel = WordEmbeddingModel(constants.CUSTOM_FS_PATH) + except Exception as e: + logger.error(f"Error opening embedding model. train the model using train_fastText {e}") + + usecols=[ + "url", + "title", + "text", + "content_vector" + ] + + + logger.info("Loading Wiki data to data frame") + wiki_records_df = pd.read_csv(constants.WIKI_PATH,usecols=usecols) + + logger.info("Processing glove embeddings") + wiki_records_df['content_glove_vector'] = [gloveEmbeddingModel.sentence_to_vec(x) for x in tqdm.tqdm(wiki_records_df["text"], desc="Processing content glove embeddings")] + + + logger.info("Processing fast text embeddings") + wiki_records_df['content_fasttext_vector'] = [fastTextEmbeddingModel.sentence_to_vec(x) for x in tqdm.tqdm(wiki_records_df["text"], desc="Processing content fast text embeddings")] + + + + logger.info(f"Backing up file {constants.WIKI_PATH + ".bak"}") + + os.rename(constants.WIKI_PATH, constants.WIKI_PATH + ".bak") + logger.info(f"Saving file {constants.WIKI_PATH}") + wiki_records_df.to_csv(constants.WIKI_PATH, index=False) # index=False prevents writing the index + + + + + logger.info("Extracted file successfully. Please check the data directory") + +if __name__ == "__main__": + download_data() \ No newline at end of file diff --git a/surrealdb-rag/src/surrealdb_rag/download_glove.py b/surrealdb-rag/src/surrealdb_rag/download_glove.py new file mode 100644 index 0000000..1f477d2 --- /dev/null +++ b/surrealdb-rag/src/surrealdb_rag/download_glove.py @@ -0,0 +1,38 @@ +"""Download OpenAI Wikipedia data.""" + +import zipfile + +import wget +import os + +from surrealdb_rag import loggers + +import surrealdb_rag.constants as constants + + +def download_glove_model() -> None: + """Extract `glove.6B.txt` to `/data`.""" + logger = loggers.setup_logger("DownloadGloveModel") + + logger.info("Downloading Wikipedia") + if not os.path.exists("data"): + os.makedirs("data") + wget.download( + url=constants.GLOVE_URL, + out=constants.GLOVE_ZIP_PATH, + ) + + logger.info("Extracting to data directory") + with zipfile.ZipFile( + constants.GLOVE_ZIP_PATH, "r" + ) as zip_ref: + zip_ref.extractall("data") + + if not os.path.exists(constants.GLOVE_PATH): + raise FileNotFoundError(f"File not found: {constants.GLOVE_PATH}") + + logger.info("Extracted file successfully. Please check the data directory") + +if __name__ == "__main__": + download_glove_model() + diff --git a/surrealdb-rag/src/surrealdb_rag/embeddings.py b/surrealdb-rag/src/surrealdb_rag/embeddings.py new file mode 100644 index 0000000..eacc1f3 --- /dev/null +++ b/surrealdb-rag/src/surrealdb_rag/embeddings.py @@ -0,0 +1,42 @@ +import numpy as np +import re + +PUNCTUATION_TO_SEPARATE = [ + ".", ",", "?", "!", ";", ":", "(", ")", "[", "]", "{", "}", "\"", "'", "`", "/", "\\", "<", ">", "—", "–" + ] + +class WordEmbeddingModel: + + + def __init__(self,model_path): + self.dictionary = {} + self.vector_size = 0 + self.model_path = model_path + + with open(self.model_path, 'r', encoding='utf-8') as f: + for line in f: + values = line.split() + word = values[0] + vector = np.asarray(values[1:], "float32") + self.dictionary[word] = vector + if self.vector_size==0: + self.vector_size = len(vector) + + + def separate_punctuation(sentence): + """Separates specified punctuation characters with a space before them.""" + punctuation_regex = re.compile(r"([{}])".format(re.escape("".join(PUNCTUATION_TO_SEPARATE)))) + return punctuation_regex.sub(r" \1", sentence) + + + #This method will generate an embedding for a piece of text + def sentence_to_vec(self,sentence): + + words = WordEmbeddingModel.separate_punctuation(sentence).lower().split() + + vectors = [self.dictionary[w] for w in words if w in self.dictionary] + + if vectors: + return np.mean(vectors, axis=0).tolist() + else: + return np.zeros(self.vector_size).tolist() diff --git a/surrealdb-rag/src/surrealdb_rag/insert_embedding_model.py b/surrealdb-rag/src/surrealdb_rag/insert_embedding_model.py new file mode 100644 index 0000000..40798f2 --- /dev/null +++ b/surrealdb-rag/src/surrealdb_rag/insert_embedding_model.py @@ -0,0 +1,84 @@ +"""Insert Wikipedia data into SurrealDB.""" + + +import pandas as pd +from surrealdb import Surreal +import tqdm + +from surrealdb_rag import loggers + + +from surrealdb_rag.constants import DatabaseParams, ModelParams, ArgsLoader, SurrealParams +from surrealdb_rag.embeddings import WordEmbeddingModel + +import surrealdb_rag.constants as constants + +db_params = DatabaseParams() +model_params = ModelParams() +args_loader = ArgsLoader("Input Glove embeddings model",db_params,model_params) +args_loader.LoadArgs() + +INSERT_GLOVE_EMBEDDINGS = """ + FOR $row IN $embeddings { + CREATE embedding_model:[$model,$row.word] CONTENT { + word : $row.word, + model : $model, + embedding: $row.embedding + } RETURN NONE; + }; +""" + +DELETE_GLOVE_EMBEDDINGS = "DELETE embedding_model WHERE model = $model;" + +CHUNK_SIZE = 1000 + +def surreal_model_insert(model_name,model_path,logger): + + logger.info(f"Reading {model_name} model") + embeddingModel = WordEmbeddingModel(model_path) + embeddings_df = pd.DataFrame({'word': embeddingModel.dictionary.keys(), 'embedding': embeddingModel.dictionary.values()}) + total_rows = len(embeddings_df) + total_chunks = (total_rows + CHUNK_SIZE - 1) // CHUNK_SIZE # ceiling division + with Surreal(db_params.DB_PARAMS.url) as connection: + connection.signin({"username": db_params.DB_PARAMS.username, "password": db_params.DB_PARAMS.password}) + connection.use(db_params.DB_PARAMS.namespace, db_params.DB_PARAMS.database) + logger.info("Connected to SurrealDB") + logger.info("Inserting rows into SurrealDB") + + #remove any data from the table + SurrealParams.ParseResponseForErrors(connection.query_raw(DELETE_GLOVE_EMBEDDINGS)) + with tqdm.tqdm(total=total_chunks, desc="Inserting") as pbar: + + for i in range(0, total_rows, CHUNK_SIZE): + chunk = embeddings_df.iloc[i:i + CHUNK_SIZE] + + formatted_rows = [ + { + "word":str(row["word"]), + "embedding":row["embedding"].tolist() + } + for _, row in chunk.iterrows() + ] + + + SurrealParams.ParseResponseForErrors(connection.query_raw( + INSERT_GLOVE_EMBEDDINGS, params={"embeddings": formatted_rows,"model":model_name} + )) + pbar.update(1) + + + +def surreal_embeddings_insert() -> None: + + """Main entrypoint to insert glove embedding model into SurrealDB.""" + logger = loggers.setup_logger("SurrealEmbeddingsInsert") + + + logger.info(args_loader.string_to_print()) + surreal_model_insert("GLOVE",constants.GLOVE_PATH,logger) + surreal_model_insert("CUST_FASTTEXT",constants.CUSTOM_FS_PATH,logger) + + + +if __name__ == "__main__": + surreal_embeddings_insert() \ No newline at end of file diff --git a/surrealdb-rag/src/surrealdb_rag/insert_wiki.py b/surrealdb-rag/src/surrealdb_rag/insert_wiki.py new file mode 100644 index 0000000..516ed7f --- /dev/null +++ b/surrealdb-rag/src/surrealdb_rag/insert_wiki.py @@ -0,0 +1,105 @@ +"""Insert Wikipedia data into SurrealDB.""" + +import ast + +import pandas as pd +from surrealdb import Surreal +import tqdm + +from surrealdb_rag import loggers +import surrealdb_rag.constants as constants + + +from surrealdb_rag.constants import DatabaseParams, ModelParams, ArgsLoader, SurrealParams + +db_params = DatabaseParams() +model_params = ModelParams() +args_loader = ArgsLoader("Input wiki data",db_params,model_params) +args_loader.LoadArgs() + + + + +INSERT_WIKI_RECORDS = """ + FOR $row IN $records { + CREATE type::thing("embedded_wiki",$row.url) CONTENT { + url : $row.url, + title: $row.title, + text: $row.text, + content_glove_vector: $row.content_glove_vector, + content_openai_vector: $row.content_openai_vector, + content_fasttext_vector: $row.content_fasttext_vector + } RETURN NONE; + }; +""" + +DELETE_WIKI_RECORDS = "DELETE embedded_wiki;" + + +CHUNK_SIZE = 50 + + + +def surreal_wiki_insert() -> None: + """Main entrypoint to insert Wikipedia embeddings into SurrealDB.""" + logger = loggers.setup_logger("SurrealWikiInsert") + + logger.info(args_loader.string_to_print()) + + + logger.info(f"Loading file {constants.WIKI_PATH}") + + usecols=[ + "url", + "title", + "text", + "content_vector", + "content_glove_vector", + "content_fasttext_vector" + ] + + wiki_records_df = pd.read_csv(constants.WIKI_PATH,usecols=usecols) + + + total_rows = len(wiki_records_df) + total_chunks = total_rows // CHUNK_SIZE + ( + 1 if total_rows % CHUNK_SIZE else 0 + ) + with Surreal(db_params.DB_PARAMS.url) as connection: + connection.signin({"username": db_params.DB_PARAMS.username, "password": db_params.DB_PARAMS.password}) + connection.use(db_params.DB_PARAMS.namespace, db_params.DB_PARAMS.database) + logger.info("Connected to SurrealDB") + + logger.info("Deleting any existing wiki rows from SurrealDB") + #remove any data from the table + SurrealParams.ParseResponseForErrors(connection.query_raw(DELETE_WIKI_RECORDS)) + + logger.info("Inserting rows into SurrealDB") + with tqdm.tqdm(total=total_chunks, desc="Inserting") as pbar: + for i in range(0, total_rows, CHUNK_SIZE): + chunk = wiki_records_df.iloc[i:i + CHUNK_SIZE] + formatted_rows = [ + { + "url":str(row["url"]), + "title":str(row["title"]), + "text":str(row["text"]), + "content_openai_vector":ast.literal_eval(row["content_vector"]), + "content_glove_vector":ast.literal_eval(row["content_glove_vector"]), + "content_fasttext_vector":ast.literal_eval(row["content_fasttext_vector"]) + } + for _, row in chunk.iterrows() + ] + try: + SurrealParams.ParseResponseForErrors(connection.query_raw( + INSERT_WIKI_RECORDS, params={"records": formatted_rows} + )) + except Exception as e: + print (formatted_rows) + return + + pbar.update(1) + + + +if __name__ == "__main__": + surreal_wiki_insert() \ No newline at end of file diff --git a/surrealdb-rag/src/surrealdb_rag/llm_handler.py b/surrealdb-rag/src/surrealdb_rag/llm_handler.py new file mode 100644 index 0000000..83c4fe5 --- /dev/null +++ b/surrealdb-rag/src/surrealdb_rag/llm_handler.py @@ -0,0 +1,216 @@ +import openai +import ollama +from ollama import generate,GenerateResponse + + +import google.generativeai as genai + + +from surrealdb_rag.constants import DatabaseParams, ModelParams, ArgsLoader +from surrealdb import AsyncSurreal + + + # LLM_MODELS = { + # "GEMINI-SURREAL": {"model_version":"gemini-2.0-flash","host":"SQL","platform":"GOOGLE","temperature":None}, + # "GEMINI": {"model_version":"gemini-2.0-flash","host":"API","platform":"GOOGLE","temperature":None}, + # "DEEPSEEK": {"model_version":"deepseek-r1:1.5b","host":"OLLAMA","platform":"local","temperature":None}, + # "OPENAI-SURREAL": {"model_version":"gpt-3.5-turbo","host":"API","platform":"OPENAI","temperature":0.5}, + # "OPENAI": {"model_version":"gpt-3.5-turbo","host":"API","platform":"OPENAI","temperature":0.5} + # # } + + # EMBED_MODELS = { + # "CUST_FASTTEXT": {"dimensions":100,"host":"SQL"}, + # "GLOVE": {"dimensions":300,"host":"SQL"}, + # "OPENAI": {"dimensions":1536,"host":"API"} + # } + +class ModelListHandler(): + + def __init__(self, model_params, connection): + self.LLM_MODELS = {} + self.EMBED_MODELS = {} + self.model_params = model_params + self.connection = connection + + async def populate_models(self): + self.LLM_MODELS = {} + self.EMBED_MODELS = {} + + check_for_vectors = await self.connection.query( + """SELECT + content_openai_vector!=None AS has_openai_vectors, + content_glove_vector!=None AS has_glove_vectors, + content_fasttext_vector!=None AS has_fasttext_vectors + FROM embedded_wiki LIMIT 1;""") + check_for_vectors = check_for_vectors[0] + #you need an api key for gemini + if self.model_params.gemini_token: + self.LLM_MODELS["GEMINI"] = ModelParams.LLM_MODELS["GEMINI"] + self.LLM_MODELS["GEMINI-SURREAL"] = ModelParams.LLM_MODELS["GEMINI-SURREAL"] + + #you need an api key for openai + if self.model_params.openai_token: + self.LLM_MODELS["OPENAI"] = ModelParams.LLM_MODELS["OPENAI"] + self.LLM_MODELS["OPENAI-SURREAL"] = ModelParams.LLM_MODELS["OPENAI-SURREAL"] + #you need the vector field populated for openai + if check_for_vectors["has_openai_vectors"] == True: + self.EMBED_MODELS["OPENAI"] = ModelParams.EMBED_MODELS["OPENAI"] + + #you need the vector field populated for glove + if check_for_vectors["has_glove_vectors"] == True: + self.EMBED_MODELS["GLOVE"] = ModelParams.EMBED_MODELS["GLOVE"] + + #you need the vector field populated for fasttext + if check_for_vectors["has_fasttext_vectors"] == True: + self.EMBED_MODELS["CUST_FASTTEXT"] = ModelParams.EMBED_MODELS["CUST_FASTTEXT"] + + response: ollama.ListResponse = ollama.list() + + + + for model in response.models: + self.LLM_MODELS[model.model] = {"model_version":model.model,"host":"OLLAMA","platform":"local","temperature":None} + + # print('Name:', model.model) + # print(' Size (MB):', f'{(model.size.real / 1024 / 1024):.2f}') + # if model.details: + # print(' Format:', model.details.format) + # print(' Family:', model.details.family) + # print(' Parameter Size:', model.details.parameter_size) + # print(' Quantization Level:', model.details.quantization_level) + # print('\n') + + + async def available_llm_models(self): + if self.LLM_MODELS != {}: + return self.LLM_MODELS + else: + await self.populate_models() + return self.LLM_MODELS + + async def available_embed_models(self): + if self.EMBED_MODELS != {}: + return self.EMBED_MODELS + else: + await self.populate_models() + return self.EMBED_MODELS + + + + + +class LLMModelHander(): + + + + def __init__(self,model_data:str,model_params:ModelParams,connection:AsyncSurreal): + + self.model_data = model_data + self.model_params = model_params + self.connection = connection + + + + + def get_chat_response(self,prompt_with_context:str,input:str): + + match self.model_data["host"]: + case "SQL": + return self.get_chat_response_from_surreal(prompt_with_context,input) + case "API": + return self.get_chat_response_from_api(prompt_with_context,input) + case "OLLAMA": + return self.get_chat_response_from_ollama(prompt_with_context,input) + case _: + raise SystemError(f"Invalid host method {self.model_data["host"]}") + + + + def get_chat_response_from_api(self,prompt_with_context:str,input:str): + match self.model_data["platform"]: + case "OPENAI": + return self.get_chat_openai_response_from_api(prompt_with_context,input) + case "GOOGLE": + return self.get_chat_gemini_response_from_api(prompt_with_context,input) + case _: + raise SystemError(f"Error in outcome: Invalid model for API execution {self.model_data["platform"]}") + + def get_chat_response_from_surreal(self,prompt_with_context:str,input:str): + match self.model_data["platform"]: + case "OPENAI": + return self.get_chat_openai_response_from_surreal(prompt_with_context,input) + case "GOOGLE": + return self.get_chat_gemini_response_from_surreal(prompt_with_context,input) + case _: + raise SystemError(f"Error in outcome: Invalid model for SQL execution {self.model_data["platform"]}") + + def get_chat_openai_response_from_surreal(self,prompt_with_context:str,input:str): + return "get_chat_openai_response_from_surreal" + + def get_chat_gemini_response_from_surreal(self,prompt_with_context:str,input:str): + return "get_chat_gemini_response_from_surreal" + + + def get_chat_openai_response_from_api(self,prompt_with_context:str,input:str): + + messages = [ + { + "role": "system", + "content": prompt_with_context + }, + { + "role": "user", + "content": input + } + ] + openai.api_key = self.model_params.openai_token + if openai.api_key is None: + raise ValueError("OPENAI_API_KEY environment variable not set.") + try: + response = openai.chat.completions.create( + model=self.model_data["model_version"], + messages=messages, + temperature=self.model_data["temperature"] + ) + return response.choices[0].message.content + except openai.error.OpenAIError as e: + print(f"An error occurred: {e}") + return None + except Exception as e: + print(f"An unexpected error occurred: {e}") + return None + + + def get_chat_gemini_response_from_api(self,prompt_with_context:str,input:str): + + messages = [ + { + "text": prompt_with_context + }, + { + "text": input + } + ] + genai.configure(api_key=self.model_params.gemini_token) + model = genai.GenerativeModel(self.model_data["model_version"]) + response = model.generate_content(messages) + return response.text + + + def get_chat_response_from_ollama(self,prompt_with_context:str,input:str): + + messages = [ + { + "role": "system", + "content": prompt_with_context + }, + { + "role": "user", + "content": input + } + ] + response: GenerateResponse = generate(model=self.model_data["model_version"], prompt=str(messages)) + #parsed_response = parse_deepseek_response(response.response) + #return {"response":response, "think": parsed_response["think"],"content":parsed_response["content"]} + return response.response + diff --git a/surrealdb-rag/src/surrealdb_rag/loggers.py b/surrealdb-rag/src/surrealdb_rag/loggers.py new file mode 100644 index 0000000..27d7d5d --- /dev/null +++ b/surrealdb-rag/src/surrealdb_rag/loggers.py @@ -0,0 +1,27 @@ +"""Module to configure logging.""" + +import logging + + +def setup_logger(name: str) -> logging.Logger: + """Configure and return a logger with the given name. + + Args: + name: Name of the logger. + + Returns: + Configured Python logger. + """ + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + ch.setFormatter(formatter) + logger.addHandler(ch) + return logger diff --git a/surrealdb-rag/src/surrealdb_rag/train_fastText.py b/surrealdb-rag/src/surrealdb_rag/train_fastText.py new file mode 100644 index 0000000..6ed2078 --- /dev/null +++ b/surrealdb-rag/src/surrealdb_rag/train_fastText.py @@ -0,0 +1,73 @@ +"""Download OpenAI Wikipedia data.""" + + +import fasttext + +from surrealdb_rag import loggers +import surrealdb_rag.constants as constants +import pandas as pd +import fasttext +import re +import os + +# Preprocess the text (example - adjust as needed) +def preprocess_text(text): + token = str(text).lower() + token = re.sub(r'[^\w\s]', '', token) # Remove punctuation + token = re.sub(r'\s+', ' ', token) # Normalize whitespace (replace multiple spaces, tabs, newlines with a single space) + token = token.strip() + return token + + +def train_fastText() -> None: + logger = loggers.setup_logger("Train FastText Embedding Model") + + usecols=[ + "title", + "text" + ] + + + logger.info(f"Loading Wiki data {constants.WIKI_PATH} to data frame") + wiki_records_df = pd.read_csv(constants.WIKI_PATH,usecols=usecols) + + # Combine relevant columns + wiki_records_df['combined_text'] = 'title:' + wiki_records_df['title'] + '\ntext:\n' + wiki_records_df['text'] + all_text = wiki_records_df['combined_text'].apply(preprocess_text) + + logger.info(all_text.head()) + logger.info(all_text.describe()) + logger.info(len(all_text)) + + traning_data_file = constants.CUSTOM_FS_PATH + "_train.txt" + model_bin_file = constants.CUSTOM_FS_PATH + ".bin" + model_txt_file = constants.CUSTOM_FS_PATH + # Save the combined text to a file + with open(traning_data_file, "w") as f: + for text in all_text: + f.write(text + "\n") + + # # Train the FastText model + model = fasttext.train_unsupervised(traning_data_file, model='skipgram') + model.save_model(model_bin_file) + model_dim = model.get_dimension() + + print("Model dimension:", model_dim) + + with open(model_txt_file, "w") as f: + words = model.words + for word in words: + #ensure its not an empty string + word = preprocess_text(word) # Clean the token + if word: + vector = model.get_word_vector(word) + if(len(vector) == model_dim): + vector_str = " ".join([str(v) for v in vector]) # More robust conversion to string + f.write(f"{word} {vector_str}\n") + os.remove(traning_data_file) + + + + +if __name__ == "__main__": + train_fastText() \ No newline at end of file diff --git a/surrealdb-rag/static/style.css b/surrealdb-rag/static/style.css new file mode 100644 index 0000000..42ea592 --- /dev/null +++ b/surrealdb-rag/static/style.css @@ -0,0 +1,180 @@ +html, +body { + width: 100%; + height: 100%; + margin: 0; + padding: 0; +} + +body { + font-family: sans-serif; + display: flex; +} + +nav { + width: calc(20%); + overflow-y: auto; + padding: 20px; + box-sizing: border-box; + background: #000000; + color: #f5f5f7; + height: 100%; + margin-bottom: 10px; + display: flex; + flex-direction: column; +} + +#chats { + flex: 1; /* Allow chats to grow and take available space */ + overflow-y: auto; /* Enable scrolling within chats */ +} +.model_selector { + + font-size: 12px; + margin-bottom: 10px; + +} +.model_selector select { + + background: #000000; + color: #fff; + +} + +button { + background: #000000; + color: #fff; + border: 2px solid #000; + border-radius: 8px; + font-size: 12px; + cursor: pointer; + transition: background-color 0.3s ease; + box-sizing: border-box; +} + +button.chat { + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + text-align: left; + width: calc(100%); + padding: 10px 20px; +} + +button.delete { + text-align: center; + padding: 4px 8px; +} + +#chats > div { /* Targets direct children divs of #chats */ + display: flex; + align-items: center; +} + +nav button:last-child { + margin-bottom: 0; +} + +nav button:hover { + background: #ff00a0; +} + +button img.surrealdb-logo { + height: 40px; + padding-right: 10px; + vertical-align: middle; +} + +button img { + -webkit-user-drag: none; +} + +main { + display: flex; + flex-direction: column; + width: 100%; + height: 100%; + padding: 20px; + box-sizing: border-box; + background: #151517; +} + +main form { + display: flex; + height: 50px; +} + +main form input { + border-radius: 10px; + border: 2px solid #444748; + padding: 10px; + box-sizing: border-box; + font-family: inherit; + margin-right: 10px; + flex: 1; + background: #151517; + color: #f5f5f7; +} + +main form button { + width: 100px; + padding: 10px; + box-sizing: border-box; + border-radius: 10px; + background: #ff00a0; + color: #fff; + border: none; + text-align: center; + display: block; + height: 100%; +} + +a:-webkit-any-link { + color: #ff00a0; + cursor: pointer; + text-decoration: underline; +} + +.messages { + overflow-y: auto; + flex: 1; + display: flex; + flex-direction: column; + padding: 10px; +} + +.message { + border-radius: 15px; + padding: 15px; + background: #151517; + margin-bottom: 15px; + color: #f5f5f7; + display: flex; + flex-direction: column; +} + +.message .message-header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 10px; +} + +.message .messenger-name { + font-weight: bold; + font-size: 1rem; + color: #ff00a0; +} + +.message .message-time { + font-size: 0.85rem; + color: #f5f5f7; +} + +.message p.message-content { + margin: 0; +} + +.system.message .messenger-name { + color: #9600ff; +} diff --git a/surrealdb-rag/static/surrealdb-icon.svg b/surrealdb-rag/static/surrealdb-icon.svg new file mode 100644 index 0000000..e4072a7 --- /dev/null +++ b/surrealdb-rag/static/surrealdb-icon.svg @@ -0,0 +1,18 @@ + + + + + + + + + + diff --git a/surrealdb-rag/templates/chats.html b/surrealdb-rag/templates/chats.html new file mode 100644 index 0000000..29a10c6 --- /dev/null +++ b/surrealdb-rag/templates/chats.html @@ -0,0 +1,12 @@ +{% for chat in chats %} + +
+ + +
+{% endfor %} diff --git a/surrealdb-rag/templates/create_chat.html b/surrealdb-rag/templates/create_chat.html new file mode 100644 index 0000000..39e1d68 --- /dev/null +++ b/surrealdb-rag/templates/create_chat.html @@ -0,0 +1,9 @@ + +
+ + +
\ No newline at end of file diff --git a/surrealdb-rag/templates/index.html b/surrealdb-rag/templates/index.html new file mode 100644 index 0000000..27af8fb --- /dev/null +++ b/surrealdb-rag/templates/index.html @@ -0,0 +1,100 @@ + + + + + + + + ChatSurrealDB + + + + + + + + +
+ +
+ + + diff --git a/surrealdb-rag/templates/load_chat.html b/surrealdb-rag/templates/load_chat.html new file mode 100644 index 0000000..9f66fe7 --- /dev/null +++ b/surrealdb-rag/templates/load_chat.html @@ -0,0 +1,18 @@ +
+ {% for message in messages %} +
+
+ {{ message.role | capitalize }} + {{ message.timestamp | convert_timestamp_to_date }} +
+

{{ message.content | safe }}

+
+ {% endfor %} +
+ +
+ + + +
diff --git a/surrealdb-rag/templates/send_system_message.html b/surrealdb-rag/templates/send_system_message.html new file mode 100644 index 0000000..32e2f77 --- /dev/null +++ b/surrealdb-rag/templates/send_system_message.html @@ -0,0 +1,30 @@ +
+
+ System + {{ timestamp | convert_timestamp_to_date }} +
+

{{ content | safe }}

+
+ +
+ +{% if new_title %} + + + +{% endif %} \ No newline at end of file diff --git a/surrealdb-rag/templates/send_user_message.html b/surrealdb-rag/templates/send_user_message.html new file mode 100644 index 0000000..fd9188f --- /dev/null +++ b/surrealdb-rag/templates/send_user_message.html @@ -0,0 +1,9 @@ +
+
+ User + {{ timestamp | convert_timestamp_to_date }} +
+

{{ content | safe }}

+
+ From 883cc3c5a2c447ac84ad9aec36cfaed4fa70cebe Mon Sep 17 00:00:00 2001 From: Alessandro Pireno Date: Wed, 5 Mar 2025 22:33:46 -0500 Subject: [PATCH 2/9] added message detail and dynamic model loading --- surrealdb-rag/schema/function_ddl.surql | 1 + surrealdb-rag/src/surrealdb_rag/app.py | 22 +++++++- surrealdb-rag/src/surrealdb_rag/constants.py | 16 +++--- .../src/surrealdb_rag/llm_handler.py | 51 +++++++++++++------ surrealdb-rag/static/style.css | 48 ++++++++++++++++- surrealdb-rag/templates/index.html | 31 +++++++++++ surrealdb-rag/templates/load_chat.html | 8 ++- .../templates/load_message_detail.html | 25 +++++++++ 8 files changed, 174 insertions(+), 28 deletions(-) create mode 100644 surrealdb-rag/templates/load_message_detail.html diff --git a/surrealdb-rag/schema/function_ddl.surql b/surrealdb-rag/schema/function_ddl.surql index e1614e2..3353b59 100644 --- a/surrealdb-rag/schema/function_ddl.surql +++ b/surrealdb-rag/schema/function_ddl.surql @@ -258,6 +258,7 @@ Returns: DEFINE FUNCTION OVERWRITE fn::load_chat($chat_id: string) { RETURN SELECT + out.id AS id, out.role AS role, out.content AS content, timestamp diff --git a/surrealdb-rag/src/surrealdb_rag/app.py b/surrealdb-rag/src/surrealdb_rag/app.py index 3479138..0861e39 100644 --- a/surrealdb-rag/src/surrealdb_rag/app.py +++ b/surrealdb-rag/src/surrealdb_rag/app.py @@ -99,7 +99,7 @@ async def index(request: fastapi.Request) -> responses.HTMLResponse: async def get_llm_model_details(llm_model: str = fastapi.Query(...)): model_data = life_span["llm_models"].get(llm_model) if model_data: - s = f"Model Version: {model_data['model_version']}, Host: {model_data['host']}" + s = f"Version: {model_data['model_version']}, Host: {model_data['host']}" else: s = "Model details not found." return fastapi.Response(s, media_type="text/html") #Return response object @@ -108,7 +108,7 @@ async def get_llm_model_details(llm_model: str = fastapi.Query(...)): async def get_embed_model_details(embed_model: str = fastapi.Query(...)): model_data = life_span["embed_models"].get(embed_model) if model_data: - s = f"Version: {model_data['dimensions']}, Host: {model_data['host']}" + s = f"Dimensions: {model_data['dimensions']}, Host: {model_data['host']}" else: s = "Model details not found." return fastapi.Response(s, media_type="text/html") #Return response object @@ -157,6 +157,24 @@ async def load_chat( "chat_id": chat_id, }, ) +@app.get("/messages/{message_id}", response_class=responses.HTMLResponse) +async def load_chat( + request: fastapi.Request, message_id: str +) -> responses.HTMLResponse: + """Load a chat.""" + message = await life_span["surrealdb"].query( + """RETURN fn::load_message_detail($message_id)""",params = {"message_id":message_id} + ) + return templates.TemplateResponse( + "load_message_detail.html", + { + "request": request, + "message": message, + "message_id": message_id, + }, + ) + + @app.get("/chats", response_class=responses.HTMLResponse) diff --git a/surrealdb-rag/src/surrealdb_rag/constants.py b/surrealdb-rag/src/surrealdb_rag/constants.py index 097a956..aacf59f 100644 --- a/surrealdb-rag/src/surrealdb_rag/constants.py +++ b/surrealdb-rag/src/surrealdb_rag/constants.py @@ -40,13 +40,15 @@ def ParseResponseForErrors(outcome): class ModelParams(): - LLM_MODELS = { - "GEMINI-SURREAL": {"model_version":"gemini-2.0-flash","host":"SQL","platform":"GOOGLE","temperature":None}, - "GEMINI": {"model_version":"gemini-2.0-flash","host":"API","platform":"GOOGLE","temperature":None}, - "DEEPSEEK": {"model_version":"deepseek-r1:1.5b","host":"OLLAMA","platform":"local","temperature":None}, - "OPENAI-SURREAL": {"model_version":"gpt-3.5-turbo","host":"API","platform":"OPENAI","temperature":0.5}, - "OPENAI": {"model_version":"gpt-3.5-turbo","host":"API","platform":"OPENAI","temperature":0.5} - } + # GEMINI_MODELS = ["gemini-2.0-flash-lite","gemini-2.0-flash","gemini-1.5-flash","gemini-1.5-flash-8b","gemini-1.5-pro"] + # # OPENAI_MODELS = ["gemini-2.0-flash-lite","gemini-2.0-flash","gemini-1.5-flash","gemini-1.5-flash-8b","gemini-1.5-pro"] + # # LLM_MODELS = { + # # "GEMINI-SURREAL": {"model_version":"gemini-2.0-flash","host":"SQL","platform":"GOOGLE","temperature":None}, + # # "GEMINI": {"model_version":"gemini-2.0-flash","host":"API","platform":"GOOGLE","temperature":None}, + # # "DEEPSEEK": {"model_version":"deepseek-r1:1.5b","host":"OLLAMA","platform":"local","temperature":None}, + # # "OPENAI-SURREAL": {"model_version":"gpt-3.5-turbo","host":"API","platform":"OPENAI","temperature":0.5}, + # # "OPENAI": {"model_version":"gpt-3.5-turbo","host":"API","platform":"OPENAI","temperature":0.5} + # # } EMBED_MODELS = { "CUST_FASTTEXT": {"dimensions":100,"host":"SQL"}, diff --git a/surrealdb-rag/src/surrealdb_rag/llm_handler.py b/surrealdb-rag/src/surrealdb_rag/llm_handler.py index 83c4fe5..3c1d5ac 100644 --- a/surrealdb-rag/src/surrealdb_rag/llm_handler.py +++ b/surrealdb-rag/src/surrealdb_rag/llm_handler.py @@ -43,33 +43,52 @@ async def populate_models(self): content_fasttext_vector!=None AS has_fasttext_vectors FROM embedded_wiki LIMIT 1;""") check_for_vectors = check_for_vectors[0] + #you need the vector field populated for fasttext + if check_for_vectors["has_fasttext_vectors"] == True: + self.EMBED_MODELS["CUST_FASTTEXT"] = ModelParams.EMBED_MODELS["CUST_FASTTEXT"] + + #you need the vector field populated for glove + if check_for_vectors["has_glove_vectors"] == True: + self.EMBED_MODELS["GLOVE"] = ModelParams.EMBED_MODELS["GLOVE"] + #you need an api key for gemini if self.model_params.gemini_token: - self.LLM_MODELS["GEMINI"] = ModelParams.LLM_MODELS["GEMINI"] - self.LLM_MODELS["GEMINI-SURREAL"] = ModelParams.LLM_MODELS["GEMINI-SURREAL"] - + genai.configure(api_key=self.model_params.gemini_token) + + for model in genai.list_models(): + #print(model) + if ( model.supported_generation_methods in + [ + ['generateContent', 'countTokens'] , + ['generateContent', 'countTokens', 'createCachedContent'] + ] + and "gemini" in model.name + and (model.display_name == model.description + or "stable" in model.description.lower()) ): + self.LLM_MODELS["GOOGLE - " + model.display_name] = {"model_version":model.name,"host":"API","platform":"GOOGLE","temperature":None} + self.LLM_MODELS["GOOGLE - " + model.display_name + " (surreal)"] = {"model_version":model.name, "host":"SQL","platform":"GOOGLE","temperature":None} + #you need an api key for openai if self.model_params.openai_token: - self.LLM_MODELS["OPENAI"] = ModelParams.LLM_MODELS["OPENAI"] - self.LLM_MODELS["OPENAI-SURREAL"] = ModelParams.LLM_MODELS["OPENAI-SURREAL"] + openai.api_key = self.model_params.openai_token + models = openai.models.list() + for model in models.data: + if(model.owned_by == "openai" and "gpt" in model.id): + #print(model) + self.LLM_MODELS["OPENAI - " + model.id] = {"model_version":model.id,"host":"API","platform":"OPENAI","temperature":0.5} + self.LLM_MODELS["OPENAI - " + model.id + " (surreal)"] = {"model_version":model.id,"host":"SQL","platform":"OPENAI","temperature":0.5} + + + # self.LLM_MODELS["OPENAI"] = ModelParams.LLM_MODELS["OPENAI"] + # self.LLM_MODELS["OPENAI-SURREAL"] = ModelParams.LLM_MODELS["OPENAI-SURREAL"] #you need the vector field populated for openai if check_for_vectors["has_openai_vectors"] == True: self.EMBED_MODELS["OPENAI"] = ModelParams.EMBED_MODELS["OPENAI"] - #you need the vector field populated for glove - if check_for_vectors["has_glove_vectors"] == True: - self.EMBED_MODELS["GLOVE"] = ModelParams.EMBED_MODELS["GLOVE"] - - #you need the vector field populated for fasttext - if check_for_vectors["has_fasttext_vectors"] == True: - self.EMBED_MODELS["CUST_FASTTEXT"] = ModelParams.EMBED_MODELS["CUST_FASTTEXT"] - response: ollama.ListResponse = ollama.list() - - for model in response.models: - self.LLM_MODELS[model.model] = {"model_version":model.model,"host":"OLLAMA","platform":"local","temperature":None} + self.LLM_MODELS["OLLAMA " + model.model] = {"model_version":model.model,"host":"OLLAMA","platform":"local","temperature":None} # print('Name:', model.model) # print(' Size (MB):', f'{(model.size.real / 1024 / 1024):.2f}') diff --git a/surrealdb-rag/static/style.css b/surrealdb-rag/static/style.css index 42ea592..ea12eb2 100644 --- a/surrealdb-rag/static/style.css +++ b/surrealdb-rag/static/style.css @@ -23,7 +23,45 @@ nav { display: flex; flex-direction: column; } +/* Modal Styles */ +.modal { + display: none; /* Hidden by default */ + position: fixed; /* Stay in place */ + z-index: 1; /* Sit on top */ + left: 0; + top: 0; + width: 100%; /* Full width */ + height: 100%; /* Full height */ + overflow: auto; /* Enable scroll if needed */ + background-color: rgb(0,0,0); /* Fallback color */ + background-color: rgba(0,0,0,0.4); /* Black w/ opacity */ +} + +.modal-content { + background-color: #fefefe; + margin: 15% auto; /* 15% from the top and centered */ + padding: 20px; + border: 1px solid #888; + width: 80%; /* Could be more or less, depending on screen size */ + position: relative; /* For close button positioning */ +} +.close { + color: #aaa; + float: right; + font-size: 28px; + font-weight: bold; + position: absolute; + top: 10px; + right: 15px; +} + +.close:hover, +.close:focus { + color: black; + text-decoration: none; + cursor: pointer; +} #chats { flex: 1; /* Allow chats to grow and take available space */ overflow-y: auto; /* Enable scrolling within chats */ @@ -40,14 +78,16 @@ nav { color: #fff; } - +.chat-id{ + color: #fff; +} button { background: #000000; color: #fff; + cursor: pointer; border: 2px solid #000; border-radius: 8px; font-size: 12px; - cursor: pointer; transition: background-color 0.3s ease; box-sizing: border-box; } @@ -169,6 +209,10 @@ a:-webkit-any-link { .message .message-time { font-size: 0.85rem; color: #f5f5f7; + display: flex; +} +.message .message-time button { + padding: 4px 8px; } .message p.message-content { diff --git a/surrealdb-rag/templates/index.html b/surrealdb-rag/templates/index.html index 27af8fb..4ded1d0 100644 --- a/surrealdb-rag/templates/index.html +++ b/surrealdb-rag/templates/index.html @@ -43,6 +43,31 @@ } }); + document.addEventListener("DOMContentLoaded", function() { + const modal = document.getElementById("myModal"); + const closeBtn = document.querySelector(".close"); + + // Event delegation for dynamically created buttons + document.body.addEventListener("click", function(event) { + if (event.target.classList.contains("message")) { + modal.style.display = "block"; + } + }); + + if (modal && closeBtn) { + closeBtn.onclick = function() { + modal.style.display = "none"; + }; + + window.onclick = function(event) { + if (event.target === modal) { + modal.style.display = "none"; + } + }; + } else { + console.error("Modal or close button not found."); + } + }); @@ -95,6 +120,12 @@
+ diff --git a/surrealdb-rag/templates/load_chat.html b/surrealdb-rag/templates/load_chat.html index 9f66fe7..9fe35cb 100644 --- a/surrealdb-rag/templates/load_chat.html +++ b/surrealdb-rag/templates/load_chat.html @@ -1,9 +1,15 @@ +
Chat ID: {{chat_id}}
{% for message in messages %}
{{ message.role | capitalize }} - {{ message.timestamp | convert_timestamp_to_date }} + + {{ message.timestamp | convert_timestamp_to_date }} + +

{{ message.content | safe }}

diff --git a/surrealdb-rag/templates/load_message_detail.html b/surrealdb-rag/templates/load_message_detail.html new file mode 100644 index 0000000..a02e3b2 --- /dev/null +++ b/surrealdb-rag/templates/load_message_detail.html @@ -0,0 +1,25 @@ + +
+message detail: +
ID:{{message.id}}
+ +
role: {{message.role}}
+
created_at: {{message.created_at}}
+
updated_at: {{message.updated_at}}
+
embedding_model: {{message.sent[0].embedding_model}}
+
llm_model: {{message.sent[0].llm_model}}
+
timestamp: {{message.sent[0].timestamp}}
+ +{% if message.sent[0].referenced_documents %} +
+ Referenced Documents: + {%for doc in message.sent[0].referenced_documents %} +
+ + Score: {{doc.score}} + Doc: {{doc.doc}} +
+ {% endfor %} +
+{% endif %} +
\ No newline at end of file From 73c5a73670aff6eae30630de478a273b5f8db074 Mon Sep 17 00:00:00 2001 From: Alessandro Pireno Date: Sun, 9 Mar 2025 22:51:14 -0400 Subject: [PATCH 3/9] refactors abound --- surrealdb-rag/Makefile | 6 +- surrealdb-rag/pyproject.toml | 1 - surrealdb-rag/schema/function_ddl.surql | 106 ++++--- surrealdb-rag/schema/table_ddl.surql | 55 +++- surrealdb-rag/src/surrealdb_rag/app.py | 162 ++++++++--- surrealdb-rag/src/surrealdb_rag/constants.py | 52 +++- .../src/surrealdb_rag/create_database.py | 5 +- ...download_data.py => download_wiki_data.py} | 4 +- .../surrealdb_rag/insert_embedding_model.py | 88 +++++- .../src/surrealdb_rag/insert_wiki.py | 112 +++++++- .../src/surrealdb_rag/llm_handler.py | 272 ++++++++++++------ surrealdb-rag/static/style.css | 2 +- surrealdb-rag/templates/chat.html | 16 ++ surrealdb-rag/templates/chats.html | 3 +- surrealdb-rag/templates/create_chat.html | 3 +- surrealdb-rag/templates/document.html | 13 + surrealdb-rag/templates/index.html | 162 ++++++++++- surrealdb-rag/templates/load_chat.html | 24 -- surrealdb-rag/templates/message.html | 61 ++++ ...essage_detail.html => message_detail.html} | 14 +- .../templates/send_system_message.html | 30 -- .../templates/send_user_message.html | 9 - 22 files changed, 895 insertions(+), 305 deletions(-) rename surrealdb-rag/src/surrealdb_rag/{download_data.py => download_wiki_data.py} (96%) create mode 100644 surrealdb-rag/templates/chat.html create mode 100644 surrealdb-rag/templates/document.html delete mode 100644 surrealdb-rag/templates/load_chat.html create mode 100644 surrealdb-rag/templates/message.html rename surrealdb-rag/templates/{load_message_detail.html => message_detail.html} (57%) delete mode 100644 surrealdb-rag/templates/send_system_message.html delete mode 100644 surrealdb-rag/templates/send_user_message.html diff --git a/surrealdb-rag/Makefile b/surrealdb-rag/Makefile index e4b01ff..7d34226 100644 --- a/surrealdb-rag/Makefile +++ b/surrealdb-rag/Makefile @@ -20,4 +20,8 @@ dsstore-remove: find . | grep -E ".DS_Store" | xargs rm -rf .PHONY: cleanup -cleanup: pycache-remove dsstore-remove \ No newline at end of file +cleanup: pycache-remove dsstore-remove + +.PHONY: surreal-insert-glove +surreal-insert-glove: python src/surrealdb_rag/insert_embedding_model.py -emtr GLOVE -emv 300d -emp data/glove.6B.300d.txt -des "Standard pretrained GLoVE model from https://nlp.stanford.edu/projects/glove/ 300 dimensions version" + diff --git a/surrealdb-rag/pyproject.toml b/surrealdb-rag/pyproject.toml index 5a48e1b..cf0e57b 100644 --- a/surrealdb-rag/pyproject.toml +++ b/surrealdb-rag/pyproject.toml @@ -31,4 +31,3 @@ surreal-create-db = "surrealdb_rag.create_database:surreal_create_database" surreal-insert-wiki = "surrealdb_rag.insert_wiki:surreal_wiki_insert" download-data = "surrealdb_rag.download:download_data" download-glove = "surrealdb_rag.download:download_glove_model" -surreal-insert-glove = "surrealdb_rag.insert_glove:surreal_glove_insert" \ No newline at end of file diff --git a/surrealdb-rag/schema/function_ddl.surql b/surrealdb-rag/schema/function_ddl.surql index 3353b59..3764a83 100644 --- a/surrealdb-rag/schema/function_ddl.surql +++ b/surrealdb-rag/schema/function_ddl.surql @@ -4,21 +4,10 @@ This file defines the SurrealQL for the chat functionality of this project. and -DEFINE FUNCTION OVERWRITE fn::retrieve_vector_size_for_model($model:string) -{ - RETURN (SELECT VALUE array::len(embedding) FROM embedding_model:[ - $model, - NONE - ]..[ - $model, - .. - ] LIMIT 1)[0] -}; - -DEFINE FUNCTION OVERWRITE fn::sentence_to_vector($sentence: string,$model: string) { +DEFINE FUNCTION OVERWRITE fn::sentence_to_vector($sentence: string,$model: Record) { #Pull the first row to determine the size of the vector (they should all be the same) - LET $vector_size = fn::retrieve_vector_size_for_model($model); + LET $vector_size = $model.dimensions; #select the vectors from the embedding table that match the words LET $vectors = fn::retrieve_vectors_for_sentence($sentence,$model); @@ -38,7 +27,7 @@ DEFINE FUNCTION OVERWRITE fn::sentence_to_vector($sentence: string,$model: strin #if the array size is correct return it, otherwise return array of zeros RETURN - IF array::len($mean_vector) == $vector_size {$mean_vector} + IF array::len($mean_vector) == $vector_size ELSE {None} ; }; @@ -50,43 +39,45 @@ DEFINE FUNCTION OVERWRITE fn::sentence_to_vector($sentence: string,$model: strin Args: input_vector: embedding to search for within the embedding field threshold: min threshold above and beyond the N returned - model: the name of the embedding model to use... "GLOVE", "CUST_FASTTEXT" or "OPENAI" + model: the name of the embedding model to use... "GLOVE", "OPENAI" or "FASTTEXT" Returns: array: Array of embeddings. */ -DEFINE FUNCTION OVERWRITE fn::search_for_documents($input_vector: array, $threshold: float, $model: string) { - +DEFINE FUNCTION OVERWRITE fn::search_for_documents($corpus_table: string, $input_vector: array, $threshold: float, $model: Record) { + LET $first_pass = - (IF $model = "GLOVE" THEN + (IF $model.model_trainer = "GLOVE" THEN ( SELECT id, vector::similarity::cosine(content_glove_vector, $input_vector) AS similarity_score - FROM embedded_wiki + FROM type::table($corpus_table) WHERE content_glove_vector <|5,40|> $input_vector ); - ELSE IF $model = "CUST_FASTTEXT" THEN + ELSE IF $model.model_trainer = "FASTTEXT" THEN ( SELECT id, vector::similarity::cosine(content_fasttext_vector, $input_vector) AS similarity_score - FROM embedded_wiki + FROM type::table($corpus_table) WHERE content_fasttext_vector <|5,40|> $input_vector ); - ELSE IF $model = "OPENAI" THEN + ELSE IF $model.model_trainer = "OPENAI" THEN SELECT id FROM ( SELECT id, vector::similarity::cosine(content_openai_vector, $input_vector) AS similarity_score - FROM embedded_wiki + FROM type::table($corpus_table) WHERE content_openai_vector <|5,40|> $input_vector ) ; END); RETURN SELECT similarity_score as score,id as doc FROM $first_pass WHERE similarity_score > $threshold; -} +}; + + /* Get prompt for RAG. @@ -96,7 +87,7 @@ Args: Returns: string: Prompt with context. */ -DEFINE FUNCTION OVERWRITE fn::get_prompt_with_context($documents: array>) { +DEFINE FUNCTION OVERWRITE fn::get_prompt_with_context($documents: array) { LET $prompt = "You are an AI assistant answering questions about anything from Simple English Wikipedia the context will provide you with the most relevant data from Simple English Wikipedia including the page title, url, and page content. @@ -136,9 +127,11 @@ DEFINE FUNCTION OVERWRITE fn::create_message( $chat_id: string, $role: string, $content: string, - $documents: option }>>, - $embedding_model: option, + $documents: option>, + $embedding_model: option>, $llm_model: option, + $prompt_text: option + ) { # Create a message record and get the resulting ID. LET $message_id = @@ -158,44 +151,69 @@ DEFINE FUNCTION OVERWRITE fn::create_message( RELATE ONLY $chat->sent->$message_id CONTENT { referenced_documents: $documents, embedding_model: $embedding_model, - llm_model: $llm_model + llm_model: $llm_model, + prompt_text: $prompt_text }; - RETURN { - content: $content, - timestamp: $timestamp - }; + + RETURN fn::load_message_detail($message_id); + }; + + /* Create a user message. Args: chat_id: Record ID from the `chat` table that the message was sent in. content: Sent message content. + embedding_model: the embed model used to find docs + openai_token: token if using openai embeddings Returns: object: Content and timestamp. */ -DEFINE FUNCTION OVERWRITE fn::create_user_message($chat_id: string, $content: string, $embedding_model: string,$openai_token: option) { +DEFINE FUNCTION OVERWRITE fn::create_user_message($chat_id: string, $corpus_table: string, $content: string, $embedding_model: option>,$openai_token: option) { LET $threshold = 0.7; LET $vector = IF $embedded_model == "OPENAI" THEN - fn::openai_embeddings_complete("text-embedding-ada-002", $content, $openai_token) + fn::openai_embeddings_complete($embedding_model.version, $content, $openai_token) ELSE fn::sentence_to_vector($content,$embedding_model) END; - LET $documents = fn::search_for_documents($vector, $threshold ,$embedding_model); - RETURN fn::create_message($chat_id, "user", $content, $documents,$embedding_model,None); + LET $documents = fn::search_for_documents($corpus_table,$vector, $threshold ,$embedding_model); + + RETURN fn::create_message($chat_id, "user", $content, $documents,$embedding_model); }; + + +/* Create a system message. + +Args: + chat_id: Record ID from the `chat` table that the message was sent in. + content: Sent message content. + llm_model: the llm model used to generate the content + +Returns: + object: Content and timestamp. +*/ +DEFINE FUNCTION OVERWRITE fn::create_system_message($chat_id: string, $content: string, $llm_model: string,$prompt_text:string) { + RETURN fn::create_message($chat_id, "system", $content, None,None,$llm_model,$prompt_text); +}; + + + + + /* Create get the last user message and the reference docs for generating a prompt. Args: chat_id: Record ID from the `chat` table that the message was sent in\ Returns: - object: Content, referenced documents [{score,embedded_wiki}], timestamp. + object: Content, referenced documents [{score,documents}], timestamp. */ DEFINE FUNCTION OVERWRITE fn::get_last_user_message_input_and_prompt($chat_id: string) { @@ -407,8 +425,20 @@ DEFINE FUNCTION OVERWRITE fn::gemini_chat_complete($llm: string, $prompt_with_co }; + + + +DEFINE FUNCTION OVERWRITE fn::load_document_detail($corpus_table:string,$document_id: string) { + RETURN SELECT * FROM type::thing($corpus_table,$document_id); +}; + + +DEFINE FUNCTION OVERWRITE fn::load_message_detail($message_id: string) { + RETURN (SELECT *,<-sent.{referenced_documents,embedding_model,llm_model,timestamp,prompt_text} AS sent FROM type::record($message_id))[0]; +}; + #these funtions calulates the mean vector for the tokens in a sentence using the glove Model -DEFINE FUNCTION OVERWRITE fn::retrieve_vectors_for_sentence($sentence:string,$model:string) +DEFINE FUNCTION OVERWRITE fn::retrieve_vectors_for_sentence($sentence:string,$model:Record) { LET $sentence = $sentence.lowercase(). replace('.',' .'). diff --git a/surrealdb-rag/schema/table_ddl.surql b/surrealdb-rag/schema/table_ddl.surql index 673bbfc..573fd58 100644 --- a/surrealdb-rag/schema/table_ddl.surql +++ b/surrealdb-rag/schema/table_ddl.surql @@ -41,9 +41,12 @@ DEFINE FIELD IF NOT EXISTS updated_at ON TABLE message TYPE datetime DEFINE TABLE sent TYPE RELATION IN chat OUT message ENFORCED; DEFINE FIELD IF NOT EXISTS timestamp ON TABLE sent TYPE datetime VALUE time::now(); -DEFINE FIELD IF NOT EXISTS referenced_documents ON TABLE sent TYPE option }>>; +DEFINE FIELD IF NOT EXISTS referenced_documents ON TABLE sent TYPE option>; DEFINE FIELD IF NOT EXISTS llm_model ON TABLE sent TYPE option; -DEFINE FIELD IF NOT EXISTS embedding_model ON TABLE sent TYPE option; +DEFINE FIELD IF NOT EXISTS embedding_model ON TABLE sent TYPE option>; +DEFINE FIELD IF NOT EXISTS prompt_text ON TABLE sent TYPE option; + + # A message can only be sent in one chat DEFINE INDEX IF NOT EXISTS unique_sent_message_in_chat @@ -55,44 +58,68 @@ DEFINE INDEX IF NOT EXISTS unique_sent_message_in_chat This file defines the SurrealQL DDL for the glove model embedding functionality of this project. */ +DEFINE TABLE IF NOT EXISTS corpus_table SCHEMAFULL; +DEFINE FIELD IF NOT EXISTS table_name ON TABLE corpus_table TYPE string; +DEFINE FIELD IF NOT EXISTS display_name ON TABLE corpus_table TYPE string; +DEFINE FIELD IF NOT EXISTS embed_models ON TABLE corpus_table TYPE Array; + + +DEFINE TABLE IF NOT EXISTS embedding_model_definition SCHEMAFULL; +DEFINE FIELD IF NOT EXISTS model_trainer ON TABLE embedding_model_definition TYPE string; +DEFINE FIELD IF NOT EXISTS host ON TABLE embedding_model_definition TYPE string; +DEFINE FIELD IF NOT EXISTS dimensions ON TABLE embedding_model_definition TYPE int; +DEFINE FIELD IF NOT EXISTS corpus ON TABLE embedding_model_definition TYPE string; +DEFINE FIELD IF NOT EXISTS description ON TABLE embedding_model_definition TYPE string; + +# UPSERT embed_model:['FASTTEXT','fasttext wiki'] CONTENT {model_trainer:'FASTTEXT',host:'SQL',dimensions:100,version:'fasttext wiki', corpus:"trained on openai wiki sample data"}; +# UPSERT embedding_model_definition:['GLOVE','glove 300d'] CONTENT {model_trainer:'GLOVE',host:'SQL',dimensions:300,version:'glove 300d', corpus:"generic pretrained"}; +UPSERT embedding_model_definition:['OPENAI','text-embedding-ada-002'] CONTENT { + model_trainer:'OPENAI',host:'API',dimensions:1536, version:"text-embedding-ada-002", corpus:"generic pretrained" , description:'The standard OPENAI embedding model' + }; + +DEFINE TABLE IF NOT EXISTS corpus_table_model SCHEMAFULL; +DEFINE FIELD IF NOT EXISTS corpus_table ON TABLE corpus_table_model TYPE Record; +DEFINE FIELD IF NOT EXISTS model ON TABLE corpus_table_model TYPE Record; +DEFINE FIELD IF NOT EXISTS field_name ON TABLE corpus_table_model TYPE string; + -# Define the `embedded_wiki` table. +# Define the `{corpus_table}` table. -DEFINE TABLE IF NOT EXISTS embedded_wiki SCHEMAFULL; +DEFINE TABLE IF NOT EXISTS {corpus_table} SCHEMAFULL; -DEFINE FIELD IF NOT EXISTS url ON TABLE embedded_wiki TYPE string +DEFINE FIELD IF NOT EXISTS url ON TABLE {corpus_table} TYPE string # Field must be a URL. ASSERT string::is::url($value); -DEFINE FIELD IF NOT EXISTS title ON TABLE embedded_wiki TYPE string +DEFINE FIELD IF NOT EXISTS title ON TABLE {corpus_table} TYPE string # Field must be non-empty ASSERT string::len($value) > 0; -DEFINE FIELD IF NOT EXISTS text ON TABLE embedded_wiki TYPE string +DEFINE FIELD IF NOT EXISTS text ON TABLE {corpus_table} TYPE string # Field must be non-empty ASSERT string::len($value) > 0; -DEFINE FIELD IF NOT EXISTS content_glove_vector ON TABLE embedded_wiki TYPE option> +DEFINE FIELD IF NOT EXISTS content_glove_vector ON TABLE {corpus_table} TYPE option> # Field must have length 300 to use embedding model: glove 300d ASSERT array::len($value) = 300; -DEFINE INDEX IF NOT EXISTS embedded_wiki_content_glove_vector_index ON embedded_wiki +DEFINE INDEX IF NOT EXISTS {corpus_table} ON {corpus_table} FIELDS content_glove_vector HNSW DIMENSION 300 M 32 EFC 300; -DEFINE FIELD IF NOT EXISTS content_openai_vector ON TABLE embedded_wiki TYPE option> +DEFINE FIELD IF NOT EXISTS content_openai_vector ON TABLE {corpus_table} TYPE option> # Field must have length 1536 to use embedding model: text-embedding-ada-002 ASSERT array::len($value) = 1536; -DEFINE INDEX IF NOT EXISTS embedded_wiki_content_openai_vector_index ON embedded_wiki +DEFINE INDEX IF NOT EXISTS {corpus_table}_content_openai_vector_index ON {corpus_table} FIELDS content_openai_vector HNSW DIMENSION 1536 M 32 EFC 300; -DEFINE FIELD IF NOT EXISTS content_fasttext_vector ON TABLE embedded_wiki TYPE option> +DEFINE FIELD IF NOT EXISTS content_fasttext_vector ON TABLE {corpus_table} TYPE option> # Field must have length 1536 to use embedding model: text-embedding-ada-002 ASSERT array::len($value) = 100; -DEFINE INDEX IF NOT EXISTS embedded_wiki_content_fasttext_vector_index ON embedded_wiki +DEFINE INDEX IF NOT EXISTS {corpus_table}_content_fasttext_vector_index ON {corpus_table} FIELDS content_fasttext_vector HNSW DIMENSION 100 M 32 EFC 300; @@ -100,5 +127,5 @@ DEFINE INDEX IF NOT EXISTS embedded_wiki_content_fasttext_vector_index ON embedd # this is a table to store a glove word model in database DEFINE TABLE IF NOT EXISTS embedding_model TYPE NORMAL SCHEMAFULL; DEFINE FIELD IF NOT EXISTS word ON embedding_model TYPE string; -DEFINE FIELD IF NOT EXISTS model ON embedding_model TYPE string; +DEFINE FIELD IF NOT EXISTS model ON embedding_model TYPE Record; DEFINE FIELD IF NOT EXISTS embedding ON embedding_model TYPE array; diff --git a/surrealdb-rag/src/surrealdb_rag/app.py b/surrealdb-rag/src/surrealdb_rag/app.py index 0861e39..32b67f9 100644 --- a/surrealdb-rag/src/surrealdb_rag/app.py +++ b/surrealdb-rag/src/surrealdb_rag/app.py @@ -7,9 +7,12 @@ from surrealdb import AsyncSurreal,RecordID from fastapi import responses, staticfiles, templating from surrealdb_rag.llm_handler import LLMModelHander,ModelListHandler +from urllib.parse import quote +import json import uvicorn - +import ast +from urllib.parse import urlencode from surrealdb_rag.constants import DatabaseParams, ModelParams, ArgsLoader, SurrealParams db_params = DatabaseParams() model_params = ModelParams() @@ -18,6 +21,23 @@ +def format_url_id(surrealdb_id: RecordID) -> str: + + if RecordID == type(surrealdb_id): + str_to_format = surrealdb_id.id + else: + str_to_format = surrealdb_id + return quote(str_to_format).replace("/","|") + + +def unformat_url_id(surrealdb_id: str) -> str: + return surrealdb_id.replace("|","/") + if RecordID == type(surrealdb_id): + str_to_format = surrealdb_id.id + else: + str_to_format = surrealdb_id + return quote(str_to_format) + def extract_id(surrealdb_id: RecordID) -> str: """Extract numeric ID from SurrealDB record ID. @@ -56,6 +76,7 @@ def convert_timestamp_to_date(timestamp: str) -> str: templates = templating.Jinja2Templates(directory="templates") templates.env.filters["extract_id"] = extract_id +templates.env.filters["format_url_id"] = format_url_id templates.env.filters["convert_timestamp_to_date"] = convert_timestamp_to_date life_span = {} @@ -71,9 +92,11 @@ async def lifespan(_: fastapi.FastAPI) -> AsyncGenerator: model_list = ModelListHandler(model_params,life_span["surrealdb"]) - + + life_span["llm_models"] = await model_list.available_llm_models() - life_span["embed_models"] = await model_list.available_embed_models() + life_span["corpus_tables"] = await model_list.available_corpus_tables() + yield life_span.clear() @@ -87,30 +110,61 @@ async def lifespan(_: fastapi.FastAPI) -> AsyncGenerator: async def index(request: fastapi.Request) -> responses.HTMLResponse: + available_llm_models_json = json.dumps(life_span["llm_models"]) + available_corpus_tables_json = json.dumps(life_span["corpus_tables"]) + + default_llm_model = life_span["llm_models"][next(iter(life_span["llm_models"]))] + default_corpus_table = life_span["corpus_tables"][next(iter(life_span["corpus_tables"]))] + default_embed_model = default_corpus_table["embed_models"][0] + return templates.TemplateResponse("index.html", { "request": request, - "available_llm_models": life_span["llm_models"], - "available_embed_models": life_span["embed_models"], - "default_llm_model": life_span["llm_models"][next(iter(life_span["llm_models"]))], - "default_embed_model":life_span["embed_models"][next(iter(life_span["embed_models"]))] + "available_llm_models": available_llm_models_json, + "available_corpus_tables": available_corpus_tables_json, + "default_llm_model": default_llm_model, + "default_corpus_table": default_corpus_table, + "default_embed_model":default_embed_model }) +@app.get("/get_corpus_table_details") +async def get_corpus_table_details(corpus_table: str = fastapi.Query(...)): + corpus_table_detail = life_span["corpus_tables"].get(corpus_table) + if corpus_table_detail: + s = f"Table: {corpus_table_detail['table_name']}" + else: + s = "Corpus table details not found." + return fastapi.Response(s, media_type="text/html") #Return response object + + + @app.get("/get_llm_model_details") async def get_llm_model_details(llm_model: str = fastapi.Query(...)): model_data = life_span["llm_models"].get(llm_model) if model_data: - s = f"Version: {model_data['model_version']}, Host: {model_data['host']}" + s = f" Platform: {model_data['platform']}, Host: {model_data['host']}
Version: {model_data['model_version']}" else: s = "Model details not found." return fastapi.Response(s, media_type="text/html") #Return response object @app.get("/get_embed_model_details") -async def get_embed_model_details(embed_model: str = fastapi.Query(...)): - model_data = life_span["embed_models"].get(embed_model) - if model_data: - s = f"Dimensions: {model_data['dimensions']}, Host: {model_data['host']}" +async def get_embed_model_details(corpus_table: str = fastapi.Query(...),embed_model: str = fastapi.Query(...)): + embed_models = life_span["corpus_tables"][corpus_table]["embed_models"] + embed_model_detail = None + embed_model = ast.literal_eval(embed_model) + for model in embed_models: + #arrays are passed as csv from the html + if embed_model == model["model"]: + embed_model_detail = model + break + if embed_model_detail==None : + raise Exception(f"Invalid embedd model {embed_model}") else: - s = "Model details not found." + s = f""" + Dimensions: {embed_model_detail['dimensions']}, Host: {embed_model_detail['host']}
+ Corpus: {embed_model_detail['corpus']}
+ Description: {embed_model_detail['description']} + """ + return fastapi.Response(s, media_type="text/html") #Return response object @@ -149,16 +203,21 @@ async def load_chat( message_records = await life_span["surrealdb"].query( """RETURN fn::load_chat($chat_id)""",params = {"chat_id":chat_id} ) + + title = await life_span["surrealdb"].query( + """RETURN fn::get_chat_title($chat_id);""",params = {"chat_id":chat_id} + ) + return templates.TemplateResponse( - "load_chat.html", + "chat.html", { "request": request, "messages": message_records, - "chat_id": chat_id, + "chat": {"id":chat_id,"title": title } }, ) @app.get("/messages/{message_id}", response_class=responses.HTMLResponse) -async def load_chat( +async def load_message( request: fastapi.Request, message_id: str ) -> responses.HTMLResponse: """Load a chat.""" @@ -166,11 +225,31 @@ async def load_chat( """RETURN fn::load_message_detail($message_id)""",params = {"message_id":message_id} ) return templates.TemplateResponse( - "load_message_detail.html", + "message_detail.html", { "request": request, "message": message, - "message_id": message_id, + "message_id": message_id + }, + ) + + +@app.get("/documents/{document_id}", response_class=responses.HTMLResponse) +async def load_document( + request: fastapi.Request, document_id: str, + corpus_table: str = fastapi.Query(...) +) -> responses.HTMLResponse: + """Load a chat.""" + document_id = unformat_url_id(document_id) + document = await life_span["surrealdb"].query( + """RETURN fn::load_document_detail($corpus_table,$document_id)""",params = {"corpus_table":corpus_table,"document_id":document_id} + ) + return templates.TemplateResponse( + "document.html", + { + "request": request, + "document": document[0], + "document_id": document_id }, ) @@ -195,26 +274,30 @@ async def send_user_message( request: fastapi.Request, chat_id: str, content: str = fastapi.Form(...), - embed_model: str = fastapi.Form(...) + embed_model: str = fastapi.Form(...), + corpus_table: str = fastapi.Form(...) ) -> responses.HTMLResponse: """Send user message.""" - if embed_model == "OPENAI": - message = SurrealParams.ParseResponseForErrors( await life_span["surrealdb"].query_raw( - """RETURN fn::create_user_message($chat_id, $content,$embedding_model,$openaitoken);""",params = {"chat_id":chat_id,"content":content,"embedding_model":embed_model,"openaitoken":model_params.openai_token} + + embed_model = ast.literal_eval(embed_model) + # need to fix for model_trainer + if embed_model[0] == "OPENAI": + outcome = SurrealParams.ParseResponseForErrors( await life_span["surrealdb"].query_raw( + """RETURN fn::create_user_message($chat_id,$corpus_table, $content,type::thing('embedding_model_definition',$embedding_model),$openaitoken);""",params = {"chat_id":chat_id,"corpus_table":corpus_table,"content":content,"embedding_model":embed_model,"openaitoken":model_params.openai_token} )) else: - message = SurrealParams.ParseResponseForErrors( await life_span["surrealdb"].query_raw( - """RETURN fn::create_user_message($chat_id, $content,$embedding_model);""",params = {"chat_id":chat_id,"content":content,"embedding_model":embed_model} + outcome = SurrealParams.ParseResponseForErrors( await life_span["surrealdb"].query_raw( + """RETURN fn::create_user_message($chat_id,$corpus_table, $content,type::thing('embedding_model_definition',$embedding_model));""",params = {"chat_id":chat_id,"corpus_table":corpus_table,"content":content,"embedding_model":embed_model} )) - + message = outcome["result"][0]["result"] return templates.TemplateResponse( - "send_user_message.html", + "message.html", { "request": request, "chat_id": chat_id, - "content": message["result"][0]["result"]["content"], - "timestamp": message["result"][0]["result"]["timestamp"] + "new_message": True, + "message" : message }, ) @@ -232,10 +315,10 @@ async def send_system_message( - message = SurrealParams.ParseResponseForErrors( await life_span["surrealdb"].query_raw( + outcome = SurrealParams.ParseResponseForErrors( await life_span["surrealdb"].query_raw( """RETURN fn::get_last_user_message_input_and_prompt($chat_id);""",params = {"chat_id":chat_id} )) - result = message["result"][0]["result"] + result = outcome["result"][0]["result"] prompt_text = result["prompt_text"] content = result["content"] #call the LLM @@ -248,10 +331,10 @@ async def send_system_message( llm_response = llm_handler.get_chat_response(prompt_text,content) #save the response in the DB - message = SurrealParams.ParseResponseForErrors(await life_span["surrealdb"].query_raw( - """RETURN fn::create_message($chat_id, "system", $llm_response);""",params = {"chat_id":chat_id,"llm_response":llm_response} + outcome = SurrealParams.ParseResponseForErrors(await life_span["surrealdb"].query_raw( + """RETURN fn::create_system_message($chat_id,$llm_response,$llm_model,$prompt_text);""",params = {"chat_id":chat_id,"llm_response":llm_response,"llm_model":llm_model,"prompt_text":prompt_text} )) - + title = await life_span["surrealdb"].query( """RETURN fn::get_chat_title($chat_id);""",params = {"chat_id":chat_id} ) @@ -261,7 +344,7 @@ async def send_system_message( "RETURN fn::get_first_message($chat_id);",params={"chat_id":chat_id} ) system_prompt = "You are a conversation title generator for a ChatGPT type app. Respond only with a simple title using the user input." - new_title = llm_handler.get_chat_response(system_prompt,first_message_text) + new_title = llm_handler.get_short_plain_text_response(system_prompt,first_message_text) #update chat title in database SurrealParams.ParseResponseForErrors(await life_span["surrealdb"].query_raw( """UPDATE type::record($chat_id) SET title=$title;""",params = {"chat_id":chat_id,"title":new_title} @@ -269,18 +352,17 @@ async def send_system_message( - result = message["result"][0]["result"] - - + message = outcome["result"][0]["result"] + return templates.TemplateResponse( - "send_system_message.html", + "message.html", { "request": request, - "content": result["content"], - "timestamp": result["timestamp"], "new_title": new_title.strip(), "chat_id": chat_id, + "new_message": True, + "message": message }, ) diff --git a/surrealdb-rag/src/surrealdb_rag/constants.py b/surrealdb-rag/src/surrealdb_rag/constants.py index aacf59f..5ac7e9a 100644 --- a/surrealdb-rag/src/surrealdb_rag/constants.py +++ b/surrealdb-rag/src/surrealdb_rag/constants.py @@ -10,7 +10,7 @@ GLOVE_ZIP_PATH = "data/glove.6B.zip" GLOVE_PATH = "data/glove.6B.300d.txt" -CUSTOM_FS_PATH = "data/custom_fast_text.txt" +FS_WIKI_PATH = "data/custom_fast_wiki_text.txt" class SurrealParams(): @@ -50,11 +50,11 @@ class ModelParams(): # # "OPENAI": {"model_version":"gpt-3.5-turbo","host":"API","platform":"OPENAI","temperature":0.5} # # } - EMBED_MODELS = { - "CUST_FASTTEXT": {"dimensions":100,"host":"SQL"}, - "GLOVE": {"dimensions":300,"host":"SQL"}, - "OPENAI": {"dimensions":1536,"host":"API"} - } + # EMBED_MODELS = { + # "FASTTEXT": {"dimensions":100,"host":"SQL"}, + # "GLOVE": {"dimensions":300,"host":"SQL"}, + # "OPENAI": {"dimensions":1536,"host":"API"} + # } def __init__(self): self.openai_token_env_var = "OPENAI_API_KEY" self.openai_token = None @@ -197,19 +197,55 @@ def __init__(self,description, self.model_params = model_params self.model_params.AddArgs(self.parser) self.db_params.AddArgs(self.parser) + self.AdditionalArgs = {} + + def AddArg(self,name:str,flag:str,action:str,help:str,default:str): + self.parser.add_argument(f"-{flag}",f"--{action}", help=help.format(default)) + self.AdditionalArgs[name] = {"flag":flag,"action":action,"value":default} + def LoadArgs(self): self.args = self.parser.parse_args() self.db_params.SetArgs(self.args) self.model_params.SetArgs(self.args) + for key in self.AdditionalArgs.keys(): + if getattr(self.args, self.AdditionalArgs[key]["action"]): + self.AdditionalArgs[key]["value"] = getattr(self.args, self.AdditionalArgs[key]["action"]) def string_to_print(self): ret_val = self.parser.description - ret_val += f"/n{self.db_params.DB_PARAMS.__dict__}" - ret_val += f"/n{self.model_params.__dict__}" + + + ret_val += "\n\nDB Params:" + ret_val += ArgsLoader.dict_to_str(self.db_params.DB_PARAMS.__dict__) + ret_val += "\n\nModel Params:" + ret_val += ArgsLoader.dict_to_str(self.model_params.__dict__) + ret_val += "\n\nAdditional Params:" + ret_val += ArgsLoader.additional_args_dict_to_str(self.AdditionalArgs) return ret_val + + ret_val += f"\n{self.db_params.DB_PARAMS.__dict__}" + ret_val += f"\n{self.model_params.__dict__}" + for key in self.AdditionalArgs.keys(): + ret_val += f"\n{key} : {self.AdditionalArgs[key]["value"]}" + return ret_val + + def additional_args_dict_to_str(the_dict:dict): + ret_val = "" + for key in the_dict.keys(): + ret_val += f"\n{key} : {the_dict[key]["value"]}" + return ret_val + + + def dict_to_str(the_dict:dict): + ret_val = "" + for key in the_dict.keys(): + ret_val += f"\n{key} : {the_dict[key]}" + return ret_val + + def print(self): print(self.string_to_print()) diff --git a/surrealdb-rag/src/surrealdb_rag/create_database.py b/surrealdb-rag/src/surrealdb_rag/create_database.py index 801ee2b..6914ab7 100644 --- a/surrealdb-rag/src/surrealdb_rag/create_database.py +++ b/surrealdb-rag/src/surrealdb_rag/create_database.py @@ -37,10 +37,7 @@ def surreal_create_database() -> None: logger.info("Database created successfully") connection.use(db_params.DB_PARAMS.namespace, db_params.DB_PARAMS.database) - logger.info("Executing common DDL") - with open("./schema/table_ddl.surql") as f: - surlql_to_execute = f.read() - SurrealParams.ParseResponseForErrors( connection.query_raw(surlql_to_execute)) + logger.info("Executing common function DDL") with open("./schema/function_ddl.surql") as f: surlql_to_execute = f.read() diff --git a/surrealdb-rag/src/surrealdb_rag/download_data.py b/surrealdb-rag/src/surrealdb_rag/download_wiki_data.py similarity index 96% rename from surrealdb-rag/src/surrealdb_rag/download_data.py rename to surrealdb-rag/src/surrealdb_rag/download_wiki_data.py index 8c403ea..727fc0d 100644 --- a/surrealdb-rag/src/surrealdb_rag/download_data.py +++ b/surrealdb-rag/src/surrealdb_rag/download_wiki_data.py @@ -14,6 +14,8 @@ import pandas as pd import tqdm + + def download_data() -> None: """Extract `vector_database_wikipedia_articles_embedded.csv` to `/data`.""" logger = loggers.setup_logger("DownloadData") @@ -45,7 +47,7 @@ def download_data() -> None: logger.error(f"Error opening embedding model. please check the model file was downloaded using download_glove_model {e}") try: - fastTextEmbeddingModel = WordEmbeddingModel(constants.CUSTOM_FS_PATH) + fastTextEmbeddingModel = WordEmbeddingModel(constants.FS_WIKI_PATH) except Exception as e: logger.error(f"Error opening embedding model. train the model using train_fastText {e}") diff --git a/surrealdb-rag/src/surrealdb_rag/insert_embedding_model.py b/surrealdb-rag/src/surrealdb_rag/insert_embedding_model.py index 40798f2..960b1e3 100644 --- a/surrealdb-rag/src/surrealdb_rag/insert_embedding_model.py +++ b/surrealdb-rag/src/surrealdb_rag/insert_embedding_model.py @@ -16,9 +16,9 @@ db_params = DatabaseParams() model_params = ModelParams() args_loader = ArgsLoader("Input Glove embeddings model",db_params,model_params) -args_loader.LoadArgs() -INSERT_GLOVE_EMBEDDINGS = """ +INSERT_EMBEDDINGS = """ + LET $model = type::thing('embedding_model_definition',[$model_trainer,$model_version]); FOR $row IN $embeddings { CREATE embedding_model:[$model,$row.word] CONTENT { word : $row.word, @@ -28,13 +28,30 @@ }; """ -DELETE_GLOVE_EMBEDDINGS = "DELETE embedding_model WHERE model = $model;" +DELETE_EMBEDDINGS = """ +LET $model = type::thing('embedding_model_definition',[$model_trainer,$model_version]); +DELETE embedding_model WHERE model = $model; +""" + + +UPDATE_EMBEDDING_MODEL_DEF = """ +LET $model = type::thing('embedding_model_definition',[$model_trainer,$model_version]); +UPSERT embedding_model_definition:[$model_trainer,$model_version] CONTENT { + model_trainer:$model_trainer, + host:'SQL', + dimensions:$dimensions, + version:$model_version, + corpus:$corpus, + description:$description +}; +""" + CHUNK_SIZE = 1000 -def surreal_model_insert(model_name,model_path,logger): +def surreal_model_insert(model_trainer,model_version,model_path,description,corpus,logger): - logger.info(f"Reading {model_name} model") + logger.info(f"Reading {model_trainer} {model_version} model") embeddingModel = WordEmbeddingModel(model_path) embeddings_df = pd.DataFrame({'word': embeddingModel.dictionary.keys(), 'embedding': embeddingModel.dictionary.values()}) total_rows = len(embeddings_df) @@ -43,10 +60,12 @@ def surreal_model_insert(model_name,model_path,logger): connection.signin({"username": db_params.DB_PARAMS.username, "password": db_params.DB_PARAMS.password}) connection.use(db_params.DB_PARAMS.namespace, db_params.DB_PARAMS.database) logger.info("Connected to SurrealDB") + logger.info(f"Deleting any rows from {model_trainer} {model_version}") + + SurrealParams.ParseResponseForErrors(connection.query_raw(DELETE_EMBEDDINGS,params={"model_trainer":model_trainer,"model_version":model_version})) logger.info("Inserting rows into SurrealDB") #remove any data from the table - SurrealParams.ParseResponseForErrors(connection.query_raw(DELETE_GLOVE_EMBEDDINGS)) with tqdm.tqdm(total=total_chunks, desc="Inserting") as pbar: for i in range(0, total_rows, CHUNK_SIZE): @@ -62,21 +81,68 @@ def surreal_model_insert(model_name,model_path,logger): SurrealParams.ParseResponseForErrors(connection.query_raw( - INSERT_GLOVE_EMBEDDINGS, params={"embeddings": formatted_rows,"model":model_name} + INSERT_EMBEDDINGS, params={"embeddings": formatted_rows,"model_trainer":model_trainer,"model_version":model_version} )) pbar.update(1) - + SurrealParams.ParseResponseForErrors(connection.query_raw(UPDATE_EMBEDDING_MODEL_DEF, + params={ + "model_trainer":model_trainer, + "model_version":model_version, + "dimensions":embeddingModel.vector_size, + "version":model_version, + "description":description, + "corpus":corpus + })) + + def surreal_embeddings_insert() -> None: """Main entrypoint to insert glove embedding model into SurrealDB.""" logger = loggers.setup_logger("SurrealEmbeddingsInsert") - + args_loader.AddArg( + "model_trainer","emtr","model_trainer","The name of the training algorithm: 'GLOVE' or 'FASTTEXT' (Default{0})",None + ) + args_loader.AddArg( + "model_version","emv","model_version","The name of the version of the model: eg '300d' or 'custom wiki' (Default{0})",None + ) + args_loader.AddArg( + "model_path","emp","model_path","The path to the txt file with the words and vectors 'data/glove.6B.300d.txt' or 'data/custom_wiki_fast_text.txt' (Default{0})",None + ) + args_loader.AddArg( + "description","des","description","a description of the embedding model. Include source and other notes (Default{0})",None + ) + args_loader.AddArg( + "corpus","cor","corpus","a description of the embedding model training data. (Default{0})",None + ) + + args_loader.LoadArgs() + + + + model_trainer = args_loader.AdditionalArgs["model_trainer"]["value"] + if not model_trainer or not model_trainer in ['GLOVE','FASTTEXT']: + raise Exception("You must supply a model trainer with -emtr and it must be 'GLOVE' or 'FASTTEXT' ") + model_version = args_loader.AdditionalArgs["model_version"]["value"] + if not model_version: + raise Exception("You must supply a model version") + model_path = args_loader.AdditionalArgs["model_path"]["value"] + if not model_path: + raise Exception("You must supply a model path") + description = args_loader.AdditionalArgs["description"]["value"] + if not description: + raise Exception("You must supply a model description") + corpus = args_loader.AdditionalArgs["corpus"]["value"] + if not description: + raise Exception("You must supply a model corpus") + + logger.info(args_loader.string_to_print()) - surreal_model_insert("GLOVE",constants.GLOVE_PATH,logger) - surreal_model_insert("CUST_FASTTEXT",constants.CUSTOM_FS_PATH,logger) + + + surreal_model_insert(model_trainer,model_version,model_path,description,corpus,logger) diff --git a/surrealdb-rag/src/surrealdb_rag/insert_wiki.py b/surrealdb-rag/src/surrealdb_rag/insert_wiki.py index 516ed7f..40ba7ac 100644 --- a/surrealdb-rag/src/surrealdb_rag/insert_wiki.py +++ b/surrealdb-rag/src/surrealdb_rag/insert_wiki.py @@ -15,25 +15,60 @@ db_params = DatabaseParams() model_params = ModelParams() args_loader = ArgsLoader("Input wiki data",db_params,model_params) -args_loader.LoadArgs() +TABLE_NAME = "embedded_wiki" +DISPLAY_NAME = "Wikipedia" +GET_EMBED_MODEL_DESCRIPTIONS = """ + SELECT * FROM embedding_model_definition; +""" + + +EMBED_MODEL_DEFINITIONS = { + "GLOVE":{"field_name":"content_glove_vector","model_definition":[ + 'GLOVE', + '6b 300d' + ]}, + "OPENAI":{"field_name":"content_openai_vector","model_definition":[ + 'OPENAI', + 'text-embedding-ada-002' + ]}, + "FASTTEXT":{"field_name":"content_fasttext_vector","model_definition":[ + 'FASTTEXT', + 'wiki' + ]}, +} + + + + +UPDATE_CORPUS_TABLE_INFO = f""" + DELETE FROM corpus_table_model WHERE corpus_table = corpus_table:{TABLE_NAME}; + FOR $model IN $embed_models {{ + LET $model_definition = type::thing("embedding_model_definition",$model.model_id); + UPSERT corpus_table_model:[corpus_table:{TABLE_NAME},$model_definition] SET model = $model_definition,field_name = $model.field_name, corpus_table=corpus_table:{TABLE_NAME}; + }}; + UPSERT corpus_table:{TABLE_NAME} SET table_name = '{TABLE_NAME}', display_name = '{DISPLAY_NAME}', + embed_models = (SELECT value id FROM corpus_table_model WHERE corpus_table = corpus_table:{TABLE_NAME}) RETURN NONE; + +""" + -INSERT_WIKI_RECORDS = """ - FOR $row IN $records { - CREATE type::thing("embedded_wiki",$row.url) CONTENT { - url : $row.url, - title: $row.title, - text: $row.text, - content_glove_vector: $row.content_glove_vector, - content_openai_vector: $row.content_openai_vector, - content_fasttext_vector: $row.content_fasttext_vector - } RETURN NONE; - }; +INSERT_WIKI_RECORDS = f""" + FOR $row IN $records {{ + CREATE type::thing("{TABLE_NAME}",$row.url) CONTENT {{ + url : $row.url, + title: $row.title, + text: $row.text, + content_glove_vector: $row.content_glove_vector, + content_openai_vector: $row.content_openai_vector, + content_fasttext_vector: $row.content_fasttext_vector + }} RETURN NONE; + }}; """ -DELETE_WIKI_RECORDS = "DELETE embedded_wiki;" +DELETE_WIKI_RECORDS = f"DELETE {TABLE_NAME};" CHUNK_SIZE = 50 @@ -41,12 +76,34 @@ def surreal_wiki_insert() -> None: + args_loader.AddArg( + "embed_models", + "ems", + "embed_models", + "The Embed models you'd like to calculate sepearated by , can be GLOVE,FASTTEXT,OPENAI eg -ems OPENAI,GLOVE will calculate both OPENAI and GLOVE. (default{})", + "GLOVE,FASTTEXT,OPENAI" + ) + args_loader.LoadArgs() + """Main entrypoint to insert Wikipedia embeddings into SurrealDB.""" logger = loggers.setup_logger("SurrealWikiInsert") logger.info(args_loader.string_to_print()) + embed_models_str = args_loader.AdditionalArgs["embed_models"]["value"] + embed_models = embed_models_str.split(",") + + if len(embed_models)<1: + raise Exception("You must specify at least one valid model of GLOVE,FASTTEXT,OPENAI with the -ems flag") + + + + for embed_model in embed_models: + if embed_model not in EMBED_MODEL_DEFINITIONS: + raise Exception(f"{embed_model} is invalid, You must specify at least one valid model of GLOVE,FASTTEXT,OPENAI with the -ems flag") + + logger.info(f"Loading file {constants.WIKI_PATH}") usecols=[ @@ -66,10 +123,32 @@ def surreal_wiki_insert() -> None: 1 if total_rows % CHUNK_SIZE else 0 ) with Surreal(db_params.DB_PARAMS.url) as connection: + + connection.signin({"username": db_params.DB_PARAMS.username, "password": db_params.DB_PARAMS.password}) connection.use(db_params.DB_PARAMS.namespace, db_params.DB_PARAMS.database) logger.info("Connected to SurrealDB") + + + embed_model_mappings = [] + for embed_model in embed_models: + if embed_model in EMBED_MODEL_DEFINITIONS: + field_name = EMBED_MODEL_DEFINITIONS[embed_model]["field_name"] + model_definition = EMBED_MODEL_DEFINITIONS[embed_model]["model_definition"] + embed_model_mappings.append({"model_id": model_definition, "field_name": field_name}) + + + # logger.info(f"Updating corpus table info for {TABLE_NAME}") + # SurrealParams.ParseResponseForErrors( connection.query_raw(UPDATE_CORPUS_TABLE_INFO,params={"embed_models":embed_model_mappings})) + # return + + with open("./schema/table_ddl.surql") as f: + surlql_to_execute = f.read() + surlql_to_execute = surlql_to_execute.format(corpus_table = "embedded_wiki") + SurrealParams.ParseResponseForErrors( connection.query_raw(surlql_to_execute)) + + logger.info("Deleting any existing wiki rows from SurrealDB") #remove any data from the table SurrealParams.ParseResponseForErrors(connection.query_raw(DELETE_WIKI_RECORDS)) @@ -99,6 +178,13 @@ def surreal_wiki_insert() -> None: pbar.update(1) + + + logger.info(f"Updating corpus table info for {TABLE_NAME}") + SurrealParams.ParseResponseForErrors( connection.query_raw(UPDATE_CORPUS_TABLE_INFO,params={"embed_models":embed_model_mappings})) + + + if __name__ == "__main__": diff --git a/surrealdb-rag/src/surrealdb_rag/llm_handler.py b/surrealdb-rag/src/surrealdb_rag/llm_handler.py index 3c1d5ac..71a7a58 100644 --- a/surrealdb-rag/src/surrealdb_rag/llm_handler.py +++ b/surrealdb-rag/src/surrealdb_rag/llm_handler.py @@ -8,114 +8,177 @@ from surrealdb_rag.constants import DatabaseParams, ModelParams, ArgsLoader from surrealdb import AsyncSurreal - - - # LLM_MODELS = { - # "GEMINI-SURREAL": {"model_version":"gemini-2.0-flash","host":"SQL","platform":"GOOGLE","temperature":None}, - # "GEMINI": {"model_version":"gemini-2.0-flash","host":"API","platform":"GOOGLE","temperature":None}, - # "DEEPSEEK": {"model_version":"deepseek-r1:1.5b","host":"OLLAMA","platform":"local","temperature":None}, - # "OPENAI-SURREAL": {"model_version":"gpt-3.5-turbo","host":"API","platform":"OPENAI","temperature":0.5}, - # "OPENAI": {"model_version":"gpt-3.5-turbo","host":"API","platform":"OPENAI","temperature":0.5} - # # } - - # EMBED_MODELS = { - # "CUST_FASTTEXT": {"dimensions":100,"host":"SQL"}, - # "GLOVE": {"dimensions":300,"host":"SQL"}, - # "OPENAI": {"dimensions":1536,"host":"API"} - # } +import re class ModelListHandler(): def __init__(self, model_params, connection): self.LLM_MODELS = {} - self.EMBED_MODELS = {} + self.CORPUS_TABLES = {} self.model_params = model_params self.connection = connection - async def populate_models(self): - self.LLM_MODELS = {} - self.EMBED_MODELS = {} - - check_for_vectors = await self.connection.query( - """SELECT - content_openai_vector!=None AS has_openai_vectors, - content_glove_vector!=None AS has_glove_vectors, - content_fasttext_vector!=None AS has_fasttext_vectors - FROM embedded_wiki LIMIT 1;""") - check_for_vectors = check_for_vectors[0] - #you need the vector field populated for fasttext - if check_for_vectors["has_fasttext_vectors"] == True: - self.EMBED_MODELS["CUST_FASTTEXT"] = ModelParams.EMBED_MODELS["CUST_FASTTEXT"] - - #you need the vector field populated for glove - if check_for_vectors["has_glove_vectors"] == True: - self.EMBED_MODELS["GLOVE"] = ModelParams.EMBED_MODELS["GLOVE"] - - #you need an api key for gemini - if self.model_params.gemini_token: - genai.configure(api_key=self.model_params.gemini_token) - - for model in genai.list_models(): - #print(model) - if ( model.supported_generation_methods in - [ - ['generateContent', 'countTokens'] , - ['generateContent', 'countTokens', 'createCachedContent'] - ] - and "gemini" in model.name - and (model.display_name == model.description - or "stable" in model.description.lower()) ): - self.LLM_MODELS["GOOGLE - " + model.display_name] = {"model_version":model.name,"host":"API","platform":"GOOGLE","temperature":None} - self.LLM_MODELS["GOOGLE - " + model.display_name + " (surreal)"] = {"model_version":model.name, "host":"SQL","platform":"GOOGLE","temperature":None} - - #you need an api key for openai - if self.model_params.openai_token: - openai.api_key = self.model_params.openai_token - models = openai.models.list() - for model in models.data: - if(model.owned_by == "openai" and "gpt" in model.id): - #print(model) - self.LLM_MODELS["OPENAI - " + model.id] = {"model_version":model.id,"host":"API","platform":"OPENAI","temperature":0.5} - self.LLM_MODELS["OPENAI - " + model.id + " (surreal)"] = {"model_version":model.id,"host":"SQL","platform":"OPENAI","temperature":0.5} - - - # self.LLM_MODELS["OPENAI"] = ModelParams.LLM_MODELS["OPENAI"] - # self.LLM_MODELS["OPENAI-SURREAL"] = ModelParams.LLM_MODELS["OPENAI-SURREAL"] - #you need the vector field populated for openai - if check_for_vectors["has_openai_vectors"] == True: - self.EMBED_MODELS["OPENAI"] = ModelParams.EMBED_MODELS["OPENAI"] - - response: ollama.ListResponse = ollama.list() - - for model in response.models: - self.LLM_MODELS["OLLAMA " + model.model] = {"model_version":model.model,"host":"OLLAMA","platform":"local","temperature":None} - - # print('Name:', model.model) - # print(' Size (MB):', f'{(model.size.real / 1024 / 1024):.2f}') - # if model.details: - # print(' Format:', model.details.format) - # print(' Family:', model.details.family) - # print(' Parameter Size:', model.details.parameter_size) - # print(' Quantization Level:', model.details.quantization_level) - # print('\n') - - async def available_llm_models(self): if self.LLM_MODELS != {}: return self.LLM_MODELS else: - await self.populate_models() + self.LLM_MODELS = {} + + + #you need an api key for gemini + if self.model_params.gemini_token: + genai.configure(api_key=self.model_params.gemini_token) + + for model in genai.list_models(): + #print(model) + if ( model.supported_generation_methods in + [ + ['generateContent', 'countTokens'] , + ['generateContent', 'countTokens', 'createCachedContent'] + ] + and "gemini" in model.name + and (model.display_name == model.description + or "stable" in model.description.lower()) ): + self.LLM_MODELS["GOOGLE - " + model.display_name] = {"model_version":model.name,"host":"API","platform":"GOOGLE","temperature":0} + self.LLM_MODELS["GOOGLE - " + model.display_name + " (surreal)"] = {"model_version":model.name, "host":"SQL","platform":"GOOGLE","temperature":0} + + #you need an api key for openai + if self.model_params.openai_token: + openai.api_key = self.model_params.openai_token + models = openai.models.list() + for model in models.data: + if(model.owned_by == "openai" and "gpt" in model.id): + #print(model) + self.LLM_MODELS["OPENAI - " + model.id] = {"model_version":model.id,"host":"API","platform":"OPENAI","temperature":0.5} + self.LLM_MODELS["OPENAI - " + model.id + " (surreal)"] = {"model_version":model.id,"host":"SQL","platform":"OPENAI","temperature":0.5} + + + response: ollama.ListResponse = ollama.list() + + for model in response.models: + self.LLM_MODELS["OLLAMA " + model.model] = {"model_version":model.model,"host":"OLLAMA","platform":"OLLAMA","temperature":0} + + + return self.LLM_MODELS - async def available_embed_models(self): - if self.EMBED_MODELS != {}: - return self.EMBED_MODELS + async def available_corpus_tables(self): + if self.CORPUS_TABLES != {}: + return self.CORPUS_TABLES else: - await self.populate_models() - return self.EMBED_MODELS - + self.CORPUS_TABLES = {} + corpus_tables = await self.connection.query(""" + SELECT display_name,table_name,embed_models FROM corpus_table FETCH embed_models,embed_models.model; + """) + + #you need an api key for openai so remove openai from list if api is absent + + for corpus_table in corpus_tables: + # example record + # { + # display_name: 'Wikipedia', + # embed_models: [ + # { + # corpus_table: corpus_table:embedded_wiki, + # field_name: 'content_fasttext_vector', + # id: corpus_table_model:[ + # corpus_table:embedded_wiki, + # embedding_model_definition:[ + # 'FASTTEXT', + # 'wiki' + # ] + # ], + # model: { + # corpus: 'https://cdn.openai.com/API/examples/data/vector_database_wikipedia_articles_embedded.zip', + # description: 'Custom trained model using fasttext based on OPENAI wiki example download', + # dimensions: 100, + # host: 'SQL', + # id: embedding_model_definition:[ + # 'FASTTEXT', + # 'wiki' + # ], + # model_trainer: 'FASTTEXT', + # version: 'wiki' + # } + # }, + # { + # corpus_table: corpus_table:embedded_wiki, + # field_name: 'content_glove_vector', + # id: corpus_table_model:[ + # corpus_table:embedded_wiki, + # embedding_model_definition:[ + # 'GLOVE', + # '6b 300d' + # ] + # ], + # model: { + # corpus: 'Wikipedia 2014 + Gigaword 5', + # description: 'Standard pretrained GLoVE model from https://nlp.stanford.edu/projects/glove/ 300 dimensions version', + # dimensions: 300, + # host: 'SQL', + # id: embedding_model_definition:[ + # 'GLOVE', + # '6b 300d' + # ], + # model_trainer: 'GLOVE', + # version: '6b 300d' + # } + # }, + # { + # corpus_table: corpus_table:embedded_wiki, + # field_name: 'content_openai_vector', + # id: corpus_table_model:[ + # corpus_table:embedded_wiki, + # embedding_model_definition:[ + # 'OPENAI', + # 'text-embedding-ada-002' + # ] + # ], + # model: { + # corpus: 'generic pretrained', + # description: 'The standard OPENAI embedding model', + # dimensions: 1536, + # host: 'API', + # id: embedding_model_definition:[ + # 'OPENAI', + # 'text-embedding-ada-002' + # ], + # model_trainer: 'OPENAI', + # version: 'text-embedding-ada-002' + # } + # } + # ], + # table_name: 'embedded_wiki' + # } + # create an dict item for table_name + table_name = corpus_table["table_name"] + self.CORPUS_TABLES[table_name] = {} + self.CORPUS_TABLES[table_name]["display_name"] = corpus_table["display_name"] + self.CORPUS_TABLES[table_name]["table_name"] = corpus_table["table_name"] + self.CORPUS_TABLES[table_name]["embed_models"] = [] + for model in corpus_table["embed_models"]: + model_def = model["model"] + model_def_id = model_def["id"].id + if model_def_id[0] != "OPENAI" or ( + model_def_id[0] == "OPENAI" and self.model_params.openai_token): + + self.CORPUS_TABLES[table_name]["embed_models"].append( + {"model":model_def_id, + "field_name":model["field_name"], + "corpus":model_def["corpus"], + "description":model_def["description"], + "host":model_def["host"], + "model_trainer":model_def["model_trainer"], + "version":model_def["version"], + "dimensions":model_def["dimensions"], + } + ) + return self.CORPUS_TABLES + + + class LLMModelHander(): @@ -129,7 +192,34 @@ def __init__(self,model_data:str,model_params:ModelParams,connection:AsyncSurrea self.connection = connection + def extract_plain_text(text): + """ + Extracts plain text from a string by removing content within tags. + + Args: + text (str): The input string containing tags. + + Returns: + str: The plain text with tags and their contents removed. + """ + # Use a regular expression to find and remove content within tags + clean_text = remove_think_tags(clean_text) + clean_text = re.sub(r'<[^>]*>', '', clean_text) + return clean_text + def remove_think_tags(text): + """ + Removes tags and their content from the given text, leaving only the text after the closing tag. + + Args: + text (str): The input string. + + Returns: + str: The string with tags and their content removed. + """ + return re.sub(r'.*?\n*', '', text, flags=re.DOTALL | re.IGNORECASE).strip() + def get_short_plain_text_response(self,prompt_with_context:str,input:str): + return LLMModelHander.extract_plain_text(self.get_chat_response(prompt_with_context,input)) def get_chat_response(self,prompt_with_context:str,input:str): diff --git a/surrealdb-rag/static/style.css b/surrealdb-rag/static/style.css index ea12eb2..4aa4240 100644 --- a/surrealdb-rag/static/style.css +++ b/surrealdb-rag/static/style.css @@ -78,7 +78,7 @@ nav { color: #fff; } -.chat-id{ +.chat-header{ color: #fff; } button { diff --git a/surrealdb-rag/templates/chat.html b/surrealdb-rag/templates/chat.html new file mode 100644 index 0000000..22c275c --- /dev/null +++ b/surrealdb-rag/templates/chat.html @@ -0,0 +1,16 @@ +
+ {{chat.title}} (ID: {{chat.id}})
+
+ {% for message in messages %} + + {% include 'message.html' %} + + {% endfor %} +
+ +
+ + + +
diff --git a/surrealdb-rag/templates/chats.html b/surrealdb-rag/templates/chats.html index 29a10c6..9ce7f96 100644 --- a/surrealdb-rag/templates/chats.html +++ b/surrealdb-rag/templates/chats.html @@ -1,7 +1,8 @@ {% for chat in chats %}
-
+
Title:{{document.title}}
+ +
Text:
+
{{document.text}}
+
OPENAI vector
+
{{document.content_openai_vector}}
+
GLOVE vector
+
{{document.content_glove_vector}}
+
Custom FastText vector
+
{{document.content_fasttext_vector}}
diff --git a/surrealdb-rag/templates/index.html b/surrealdb-rag/templates/index.html index 4ded1d0..e620eeb 100644 --- a/surrealdb-rag/templates/index.html +++ b/surrealdb-rag/templates/index.html @@ -68,7 +68,120 @@ console.error("Modal or close button not found."); } }); + const available_llm_models_string = `{{ available_llm_models | safe }}`; // @ts-ignore - Ignore templating syntax + let available_llm_models = {}; + try { + available_llm_models = JSON.parse(available_llm_models_string); + } catch (e) { + console.error("Error parsing available_llm_models:", e); + } + + const available_corpus_tables_string = `{{ available_corpus_tables | safe }}`; // @ts-ignore - Ignore templating syntax + let available_corpus_tables = {}; + try { + available_corpus_tables = JSON.parse(available_corpus_tables_string); + } catch (e) { + console.error("Error parsing available_corpus_tables:", e); + } + + + + function updateLlmModelSelect(){ + const platformSelect = document.getElementById("platformSelect"); + const hostSelect = document.getElementById("hostSelect"); + const llmModelSelect = document.getElementById("llmModelSelect"); + + const selectedPlatform = platformSelect.value; + if(selectedPlatform=="OLLAMA"){ + hostSelect.value = "OLLAMA"; + } else { + if (hostSelect.value == "OLLAMA"){ + hostSelect.value = "API"; + } + } + const selectedHost = hostSelect.value; + + // Clear existing options + llmModelSelect.innerHTML = ""; + + + + // Create a default "Select a Model" option + // const defaultOption = document.createElement("option"); + // defaultOption.text = "Select a Model"; + // defaultOption.value = ""; // You can set this to an empty string or a specific value + // llmModelSelect.add(defaultOption); + + // Populate with matching models + for (const modelName in available_llm_models) { + const model = available_llm_models[modelName]; + if (model.platform === selectedPlatform && model.host === selectedHost) { + const option = document.createElement("option"); + option.text = model.model_version; + option.value = modelName; // Or you can set this to a different identifier if needed + llmModelSelect.add(option); + } + } + llmModelSelect.dispatchEvent(new Event('change')); + } + + function arrayToCsvString(arr) { + const quotedElements = arr.map(element => `"${element}"`); + return `[${quotedElements.join(',')}]`; + } + + function updateEmbedModelSelect(){ + const embedModelSelect = document.getElementById("embedModelSelect"); + const corpusTableSelect = document.getElementById("corpusTableSelect"); + const selectedCorpusTable = corpusTableSelect.value; + const corpusTable = available_corpus_tables[selectedCorpusTable]; + embedModelSelect.innerHTML = ""; + for (const embedModel in corpusTable.embed_models) { + const modelName = corpusTable.embed_models[embedModel].model.join(' - '); + const modelId = arrayToCsvString(corpusTable.embed_models[embedModel].model); + const option = document.createElement("option"); + console.log("updateEmbedModelSelect" + modelName ); + option.text = modelName; + option.value = modelId; + embedModelSelect.add(option); + } + + + } + function populateCorpusTableSelect(){ + const corpusTableSelect = document.getElementById("corpusTableSelect"); + + console.log("populateCorpusTableSelect") + console.log(available_corpus_tables) + + // Clear existing options + corpusTableSelect.innerHTML = ""; + for (const corpusTableID in available_corpus_tables) { + const option = document.createElement("option"); + option.text = available_corpus_tables[corpusTableID].display_name; + option.value = corpusTableID; // Or you can set this to a different identifier if needed + corpusTableSelect.add(option); + } + corpusTableSelect.add(document.createElement("option")); + + + corpusTableSelect.dispatchEvent(new Event('change')); + } + + document.addEventListener("DOMContentLoaded", function() { + updateLlmModelSelect(); + }); + + + document.addEventListener("DOMContentLoaded", function() { + populateCorpusTableSelect(); + }); + + document.addEventListener("DOMContentLoaded", function() { + updateEmbedModelSelect(); + }); + @@ -80,38 +193,57 @@
- - +
+ Data Set: + +
+ {% if default_corpus_table %} + Table: {{default_corpus_table['table_name']}} + {% endif %} + +
+

LLM:
+ +
- {% if default_llm_model %} - Version: {{default_llm_model['model_version']}}, Host: {{default_llm_model['host']}} - {% endif %} +

Embedding:

{% if default_embed_model %} - Dimensions: {{default_embed_model['dimensions']}}, Host: {{default_embed_model['host']}} + Dimensions: {{default_embed_model['dimensions']}}, Host: {{default_embed_model['host']}}
+ Corpus: {{default_embed_model['corpus']}}
+ Description: {{default_embed_model['description']}} + {% endif %}

@@ -123,7 +255,9 @@ diff --git a/surrealdb-rag/templates/load_chat.html b/surrealdb-rag/templates/load_chat.html deleted file mode 100644 index 9fe35cb..0000000 --- a/surrealdb-rag/templates/load_chat.html +++ /dev/null @@ -1,24 +0,0 @@ -
Chat ID: {{chat_id}}
-
- {% for message in messages %} -
-
- {{ message.role | capitalize }} - - {{ message.timestamp | convert_timestamp_to_date }} - - -
-

{{ message.content | safe }}

-
- {% endfor %} -
- -
- - - -
diff --git a/surrealdb-rag/templates/message.html b/surrealdb-rag/templates/message.html new file mode 100644 index 0000000..c25dd1c --- /dev/null +++ b/surrealdb-rag/templates/message.html @@ -0,0 +1,61 @@ + + +
+{% else %} + > +{% endif %} +
+ {{ message.role | capitalize }} + + + {{ message.timestamp | convert_timestamp_to_date }} + + + +
+ + + {% if new_message %} + + {% if message.role=="user" %} + + {% endif %} + + + + {% if new_title %} + +
+ + + {% endif %} + {% endif %} +

{{ message.content | safe }}

+
diff --git a/surrealdb-rag/templates/load_message_detail.html b/surrealdb-rag/templates/message_detail.html similarity index 57% rename from surrealdb-rag/templates/load_message_detail.html rename to surrealdb-rag/templates/message_detail.html index a02e3b2..23da9bd 100644 --- a/surrealdb-rag/templates/load_message_detail.html +++ b/surrealdb-rag/templates/message_detail.html @@ -6,6 +6,7 @@
role: {{message.role}}
created_at: {{message.created_at}}
updated_at: {{message.updated_at}}
+
content: {{message.content}}
embedding_model: {{message.sent[0].embedding_model}}
llm_model: {{message.sent[0].llm_model}}
timestamp: {{message.sent[0].timestamp}}
@@ -15,11 +16,18 @@ Referenced Documents: {%for doc in message.sent[0].referenced_documents %}
- Score: {{doc.score}} - Doc: {{doc.doc}} + Doc: + {{ doc.doc }} +
{% endfor %}
{% endif %} -
\ No newline at end of file +
prompt: {{message.sent[0].prompt_text}}
+ + \ No newline at end of file diff --git a/surrealdb-rag/templates/send_system_message.html b/surrealdb-rag/templates/send_system_message.html deleted file mode 100644 index 32e2f77..0000000 --- a/surrealdb-rag/templates/send_system_message.html +++ /dev/null @@ -1,30 +0,0 @@ -
-
- System - {{ timestamp | convert_timestamp_to_date }} -
-

{{ content | safe }}

-
- -
- -{% if new_title %} - - - -{% endif %} \ No newline at end of file diff --git a/surrealdb-rag/templates/send_user_message.html b/surrealdb-rag/templates/send_user_message.html deleted file mode 100644 index fd9188f..0000000 --- a/surrealdb-rag/templates/send_user_message.html +++ /dev/null @@ -1,9 +0,0 @@ -
-
- User - {{ timestamp | convert_timestamp_to_date }} -
-

{{ content | safe }}

-
- From f87ecb96fba1aa17eeec3752d55dae8cc219a42e Mon Sep 17 00:00:00 2001 From: Alessandro Pireno Date: Mon, 10 Mar 2025 16:24:15 -0400 Subject: [PATCH 4/9] made prompt dynamic, got sql http calls working and ready for multiple corpuses --- surrealdb-rag/schema/function_ddl.surql | 295 ++++++++------ surrealdb-rag/schema/table_ddl.surql | 2 - surrealdb-rag/src/surrealdb_rag/app.py | 75 ++-- surrealdb-rag/src/surrealdb_rag/constants.py | 164 ++++---- .../src/surrealdb_rag/create_database.py | 28 +- .../src/surrealdb_rag/download_glove.py | 2 +- .../src/surrealdb_rag/download_wiki_data.py | 10 +- .../surrealdb_rag/insert_embedding_model.py | 38 +- .../src/surrealdb_rag/insert_wiki.py | 59 ++- .../src/surrealdb_rag/llm_handler.py | 360 +++++++++++++----- ...ain_fastText.py => train_wiki_fasttext.py} | 12 +- surrealdb-rag/static/style.css | 3 +- surrealdb-rag/templates/chat.html | 2 +- surrealdb-rag/templates/document.html | 2 +- surrealdb-rag/templates/index.html | 62 ++- surrealdb-rag/templates/message.html | 5 +- surrealdb-rag/templates/message_detail.html | 17 +- 17 files changed, 763 insertions(+), 373 deletions(-) rename surrealdb-rag/src/surrealdb_rag/{train_fastText.py => train_wiki_fasttext.py} (88%) diff --git a/surrealdb-rag/schema/function_ddl.surql b/surrealdb-rag/schema/function_ddl.surql index 3764a83..8ea8e94 100644 --- a/surrealdb-rag/schema/function_ddl.surql +++ b/surrealdb-rag/schema/function_ddl.surql @@ -4,6 +4,16 @@ This file defines the SurrealQL for the chat functionality of this project. and +/* +Calculates the mean vector for a sentence using a specified embedding model. + +Args: + sentence (string): The input sentence. + model (Record): The embedding model definition. + +Returns: + array|None: The mean vector of the sentence, or None if an error occurs. +*/ DEFINE FUNCTION OVERWRITE fn::sentence_to_vector($sentence: string,$model: Record) { #Pull the first row to determine the size of the vector (they should all be the same) @@ -33,18 +43,20 @@ DEFINE FUNCTION OVERWRITE fn::sentence_to_vector($sentence: string,$model: Recor }; - -/* Search for documents using embeddings. +/* +Searches for documents using embeddings. Args: - input_vector: embedding to search for within the embedding field - threshold: min threshold above and beyond the N returned - model: the name of the embedding model to use... "GLOVE", "OPENAI" or "FASTTEXT" + corpus_table (string): The name of the corpus table to search. + input_vector (array): The embedding vector to search for. + threshold (float): The minimum similarity threshold. + model (Record): The embedding model definition. + Returns: - array: Array of embeddings. + array<{score: float, doc: any}>: An array of documents with their similarity scores. */ - DEFINE FUNCTION OVERWRITE fn::search_for_documents($corpus_table: string, $input_vector: array, $threshold: float, $model: Record) { + LET $first_pass = (IF $model.model_trainer = "GLOVE" THEN @@ -64,7 +76,7 @@ DEFINE FUNCTION OVERWRITE fn::search_for_documents($corpus_table: string, $input WHERE content_fasttext_vector <|5,40|> $input_vector ); ELSE IF $model.model_trainer = "OPENAI" THEN - SELECT id FROM ( + ( SELECT id, vector::similarity::cosine(content_openai_vector, $input_vector) AS similarity_score @@ -79,50 +91,92 @@ DEFINE FUNCTION OVERWRITE fn::search_for_documents($corpus_table: string, $input -/* Get prompt for RAG. +/* +Generates a prompt with context and chat history. Args: - context: Context to add to the prompt. + prompt (string): The prompt template. + documents (array): The documents to include in the context. + chat_history (Record): The chat history to include. Returns: - string: Prompt with context. + string: The generated prompt. */ -DEFINE FUNCTION OVERWRITE fn::get_prompt_with_context($documents: array) { +DEFINE FUNCTION OVERWRITE fn::get_prompt_with_context($prompt:string, $documents: array, $chat_history: Record) { - LET $prompt = "You are an AI assistant answering questions about anything from Simple English Wikipedia the context will provide you with the most relevant data from Simple English Wikipedia including the page title, url, and page content. + LET $context = (SELECT VALUE "\n ------------- \n URL: " + url + "\nTitle: " + title + "\n Content:\n" + text as content + FROM $documents).join("\n"); + LET $chat_history = (SELECT VALUE "\n ------------- \n Role: " + role + "\n Content:\n" + content as content + FROM $chat_history).join("\n"); - If referencing the text/context refer to it as Simple English Wikipedia. + RETURN string::replace(string::replace($prompt, '$context', $context),'$chat_history', $chat_history); +}; - Please provide your response in Markdown converted to HTML format. Include appropriate headings and lists where relevant. +/* +Retrieves the message history for a chat. - At the end of the response, add link a HTML link and replace the title and url with the associated title and url of the more relevant page from the context. +Args: + chat_id (string): The ID of the chat. + message_memory_length (int): The number of messages to retrieve. - The maximum number of links you can include is 1, do not provide any other references or annotations. +Returns: + array>: An array of message records. +*/ +DEFINE FUNCTION OVERWRITE fn::get_message_history($chat_id: string,$message_memory_length: int) { + LET $full_history = (SELECT VALUE out FROM (SELECT out,updated_at FROM type::record($chat_id)->sent ORDER BY updated_at DESC)); + RETURN array::slice($full_history, 1, $message_memory_length - 1); + +}; - Only reply with the context provided. If the context is an empty string, reply with 'I am sorry, I do not know the answer.'. - Do not use any prior knowledge that you have been trained on. +/* +Retrieves the last user message, its input, and the generated prompt. - - $context - "; - LET $context = (SELECT VALUE "\n ------------- \n URL: " + url + "\nTitle: " + title + "\n Content:\n" + text as content - FROM $documents).join("\n"); - RETURN string::replace($prompt, '$context', $context); -}; +Args: + chat_id (string): The ID of the chat. + prompt (string): The prompt template. + message_memory_length (int): The number of messages to include in the history. +Returns: + object: An object containing the content, prompt, and timestamp of the last user message. +*/ +DEFINE FUNCTION OVERWRITE fn::get_last_user_message_input_and_prompt($chat_id: string,$prompt:string,$message_memory_length: int) { -/* Create a message. + LET $message = + SELECT content,fn::get_prompt_with_context($prompt,docs,chat_history) as prompt_text + FROM + ( + SELECT + out.id, + out.content AS content, + referenced_documents.doc as docs, + fn::get_message_history($chat_id,$message_memory_length) as chat_history, + timestamp + FROM ONLY type::record($chat_id)->sent + WHERE out.role = "user" + ORDER BY timestamp DESC + LIMIT 1 + FETCH out + ); + + RETURN $message[0]; +}; + +/* +Creates a new message in a chat. Args: - chat_id: Record ID from the `chat` table that the message was sent in. - role: Role that sent the message. Allowed values are `user` or `system`. - content: Sent message content. + chat_id (string): The ID of the chat. + role (string): The role of the message sender (e.g., "user", "system"). + content (string): The message content. + documents (option>): Referenced documents. + embedding_model (option>): The embedding model used. + llm_model (option): The LLM model used. + prompt_text (option): The prompt text used. Returns: - oject: Content and timestamp. + object: The created message details. */ - DEFINE FUNCTION OVERWRITE fn::create_message( $chat_id: string, $role: string, @@ -161,43 +215,46 @@ DEFINE FUNCTION OVERWRITE fn::create_message( }; - -/* Create a user message. +/* +Creates a new user message in a chat. Args: - chat_id: Record ID from the `chat` table that the message was sent in. - content: Sent message content. - embedding_model: the embed model used to find docs - openai_token: token if using openai embeddings + chat_id (string): The ID of the chat. + corpus_table (string): The corpus table the document resides in. + content (string): The message content. + embedding_model (option>): The embedding model used. + openai_token (option): OpenAI API token (if applicable). Returns: - object: Content and timestamp. + object: The created user message details. */ DEFINE FUNCTION OVERWRITE fn::create_user_message($chat_id: string, $corpus_table: string, $content: string, $embedding_model: option>,$openai_token: option) { LET $threshold = 0.7; - LET $vector = IF $embedded_model == "OPENAI" THEN + LET $vector = IF $embedding_model.model_trainer == "OPENAI" THEN fn::openai_embeddings_complete($embedding_model.version, $content, $openai_token) ELSE fn::sentence_to_vector($content,$embedding_model) END; + LET $documents = fn::search_for_documents($corpus_table,$vector, $threshold ,$embedding_model); RETURN fn::create_message($chat_id, "user", $content, $documents,$embedding_model); }; - -/* Create a system message. +/* +Creates a new system message in a chat. Args: - chat_id: Record ID from the `chat` table that the message was sent in. - content: Sent message content. - llm_model: the llm model used to generate the content + chat_id (string): The ID of the chat. + content (string): The message content. + llm_model (string): The LLM model used. + prompt_text (string): The prompt used to generate the system message. Returns: - object: Content and timestamp. + object: The created system message details. */ DEFINE FUNCTION OVERWRITE fn::create_system_message($chat_id: string, $content: string, $llm_model: string,$prompt_text:string) { RETURN fn::create_message($chat_id, "system", $content, None,None,$llm_model,$prompt_text); @@ -207,39 +264,15 @@ DEFINE FUNCTION OVERWRITE fn::create_system_message($chat_id: string, $content: -/* Create get the last user message and the reference docs for generating a prompt. - -Args: - chat_id: Record ID from the `chat` table that the message was sent in\ - -Returns: - object: Content, referenced documents [{score,documents}], timestamp. -*/ - -DEFINE FUNCTION OVERWRITE fn::get_last_user_message_input_and_prompt($chat_id: string) { - LET $message = - SELECT content,fn::get_prompt_with_context(docs) as prompt_text FROM ( - SELECT - out.content AS content, - referenced_documents.doc as docs, - timestamp - FROM ONLY type::record($chat_id)->sent - WHERE out.role = "user" - ORDER BY timestamp DESC - LIMIT 1 - FETCH out); - - RETURN $message[0]; -}; -/* Generate get the user's message in a chat for generating a tile. +/* +Retrieves the first message in a chat. Args: - chat_id: Record ID from the `chat` table to generate a title for. + chat_id (string): The ID of the chat. Returns: - string: first chat content. + string|null: The content of the first message, or null if the chat is empty. */ - DEFINE FUNCTION OVERWRITE fn::get_first_message($chat_id: string) { # Get the `content` of the user's initial message. RETURN ( @@ -255,23 +288,26 @@ DEFINE FUNCTION OVERWRITE fn::get_first_message($chat_id: string) { }; -/* Create a new chat. +/* +Creates a new chat. Returns: - object: Object containing `id` and `title`. + object: The created chat object with `id` and `title`. */ DEFINE FUNCTION OVERWRITE fn::create_chat() { RETURN CREATE ONLY chat RETURN id, title; }; -/* Load a chat. + +/* +Loads the messages in a chat. Args: - chat_id: Record ID from the `chat` table to load. + chat_id (string): The ID of the chat. Returns: - array[objects]: Array of messages containing `role` and `content`. + array: An array of message objects. */ DEFINE FUNCTION OVERWRITE fn::load_chat($chat_id: string) { RETURN @@ -285,10 +321,11 @@ DEFINE FUNCTION OVERWRITE fn::load_chat($chat_id: string) { FETCH out; }; -/* Load all chats +/* +Loads all chats. Returns: - array[objects]: array of chats records containing `id`, `title`, and `created_at`. + array: An array of chat objects. */ DEFINE FUNCTION OVERWRITE fn::load_all_chats() { RETURN @@ -298,25 +335,28 @@ DEFINE FUNCTION OVERWRITE fn::load_all_chats() { ORDER BY created_at DESC; }; -/* Get chat title +/* +Retrieves the title of a chat. -Args: Record ID of the chat to get the title for. +Args: + chat_id (string): The ID of the chat. Returns: - string: Chat title. + string: The chat title. */ DEFINE FUNCTION OVERWRITE fn::get_chat_title($chat_id: string) { RETURN SELECT VALUE title FROM ONLY type::record($chat_id); }; -/* delete a chat and sent messages. +/* +Deletes a chat and its messages. -Args: Record ID of the chat to get the title for. +Args: + chat_id (string): The ID of the chat to delete. Returns: - string: chat id that was delete. + string: The ID of the deleted chat. */ - DEFINE FUNCTION OVERWRITE fn::delete_chat($chat_id:string){ $chat = type::record($chat_id); DELETE message WHERE id IN (SELECT ->sent->message FROM $chat); @@ -325,15 +365,16 @@ DEFINE FUNCTION OVERWRITE fn::delete_chat($chat_id:string){ RETURN $chat; }; +/* +Generates embeddings using the OpenAI API. -/* OpenAI embeddings complete. Args: - embeddings_model: Embedding model from OpenAI. - input: User input. - openai_token: the token used to authorize calling the API + embedding_model (string): The OpenAI embedding model to use. + input (string): The input text. + openai_token (string): The OpenAI API token. Returns: - array: Array of embeddings. + array: The generated embeddings. */ DEFINE FUNCTION OVERWRITE fn::openai_embeddings_complete($embedding_model: string, $input: string, $openai_token:string) { RETURN http::post( @@ -348,18 +389,20 @@ DEFINE FUNCTION OVERWRITE fn::openai_embeddings_complete($embedding_model: strin )["data"][0]["embedding"] }; - -/* OpenAI chat complete. +/* +Generates a chat completion using the OpenAI API. Args: - llm: Large Language Model to use for generation. - input: Initial user input. - prompt_with_context: Prompt with context for the system. + llm (string): The OpenAI LLM to use. + prompt_with_context (string): The prompt with context. + input (string): The user input. + temperature (float): The temperature for generation. + openai_token (string): The OpenAI API token. Returns: - string: Response from LLM. + string: The generated chat response. */ -DEFINE FUNCTION OVERWRITE fn::openai_chat_complete($llm: string, $input: string, $prompt_with_context: string, $temperature: float, $openai_token:string) { +DEFINE FUNCTION OVERWRITE fn::openai_chat_complete($llm: string, $prompt_with_context: string, $input: string, $temperature: float, $openai_token:string) { LET $response = http::post( "https://api.openai.com/v1/chat/completions", { @@ -376,7 +419,7 @@ DEFINE FUNCTION OVERWRITE fn::openai_chat_complete($llm: string, $input: string, "temperature": $temperature }, { - "Authorization": $openai_token + "Authorization": "Bearer " + $openai_token } )["choices"][0]["message"]["content"]; @@ -384,31 +427,31 @@ DEFINE FUNCTION OVERWRITE fn::openai_chat_complete($llm: string, $input: string, RETURN string::replace($response, '"', ''); }; - -/* Gemini format for their endpoint has the model name and key in the query +/* +Generates the Gemini API URL. Args: - llm: Large Language Model to use for generation. - google_token: the API token for gemini + llm (string): The Gemini LLM to use. + google_token (string): The Google API token. + Returns: - string: path to query for LLM. + string: The Gemini API URL. */ DEFINE FUNCTION OVERWRITE fn::get_gemini_api_url($llm: string,$google_token:string){ return string::concat("https://generativelanguage.googleapis.com/v1beta/models/",$llm,":generateContent?key=",$google_token); }; - /* Gemini chat complete. Args: - llm: Large Language Model to use for generation. - input: Initial user input. - prompt_with_context: Prompt with context for the system. - google_token: the API token for gemini + llm (string): Large Language Model to use for generation. + input (string): Initial user input. + prompt_with_context (string): Prompt with context for the system. + google_token (string): The API token for Gemini. Returns: - string: Response from LLM. + object: Response from LLM. */ DEFINE FUNCTION OVERWRITE fn::gemini_chat_complete($llm: string, $prompt_with_context: string, $input: string,$google_token:string) { @@ -427,17 +470,43 @@ DEFINE FUNCTION OVERWRITE fn::gemini_chat_complete($llm: string, $prompt_with_co +/* +Loads the details of a document. + +Args: + corpus_table (string): The name of the corpus table. + document_id (string): The ID of the document. +Returns: + object: The document details. +*/ DEFINE FUNCTION OVERWRITE fn::load_document_detail($corpus_table:string,$document_id: string) { RETURN SELECT * FROM type::thing($corpus_table,$document_id); }; +/* +Loads the details of a message, including related sent data. + +Args: + message_id (string): The ID of the message. + +Returns: + object: The message details, including related sent data. +*/ DEFINE FUNCTION OVERWRITE fn::load_message_detail($message_id: string) { RETURN (SELECT *,<-sent.{referenced_documents,embedding_model,llm_model,timestamp,prompt_text} AS sent FROM type::record($message_id))[0]; }; +/* +Retrieves the vectors for each word in a sentence using the specified embedding model. + +Args: + sentence (string): The input sentence. + model (Record): The embedding model definition. -#these funtions calulates the mean vector for the tokens in a sentence using the glove Model +Returns: + array>: An array of embedding vectors, one for each word. +*/ DEFINE FUNCTION OVERWRITE fn::retrieve_vectors_for_sentence($sentence:string,$model:Record) { LET $sentence = $sentence.lowercase(). diff --git a/surrealdb-rag/schema/table_ddl.surql b/surrealdb-rag/schema/table_ddl.surql index 573fd58..285c147 100644 --- a/surrealdb-rag/schema/table_ddl.surql +++ b/surrealdb-rag/schema/table_ddl.surql @@ -21,8 +21,6 @@ DEFINE FIELD IF NOT EXISTS updated_at ON TABLE chat TYPE datetime DEFINE TABLE IF NOT EXISTS message SCHEMAFULL; /* Field can only be populated with `user` or `system`. - -There are CSS and HTML that relies on these values. */ DEFINE FIELD IF NOT EXISTS role ON message TYPE string ASSERT $input IN ["user", "system"]; diff --git a/surrealdb-rag/src/surrealdb_rag/app.py b/surrealdb-rag/src/surrealdb_rag/app.py index 32b67f9..d2f6b75 100644 --- a/surrealdb-rag/src/surrealdb_rag/app.py +++ b/surrealdb-rag/src/surrealdb_rag/app.py @@ -14,13 +14,25 @@ import ast from urllib.parse import urlencode from surrealdb_rag.constants import DatabaseParams, ModelParams, ArgsLoader, SurrealParams + + +# Load configuration parameters db_params = DatabaseParams() model_params = ModelParams() args_loader = ArgsLoader("LLM Model Handler",db_params,model_params) args_loader.LoadArgs() +"""Format a SurrealDB RecordID for use in a URL. + + Replaces '/' with '|' and URL-encodes the ID. + + Args: + surrealdb_id: SurrealDB RecordID. + Returns: + Formatted string for use in a URL. + """ def format_url_id(surrealdb_id: RecordID) -> str: if RecordID == type(surrealdb_id): @@ -29,18 +41,20 @@ def format_url_id(surrealdb_id: RecordID) -> str: str_to_format = surrealdb_id return quote(str_to_format).replace("/","|") +"""Unformat a URL-encoded SurrealDB RecordID. + + Replaces '|' with '/'. + Args: + surrealdb_id: URL-encoded SurrealDB RecordID. + + Returns: + Unformatted string. + """ def unformat_url_id(surrealdb_id: str) -> str: return surrealdb_id.replace("|","/") - if RecordID == type(surrealdb_id): - str_to_format = surrealdb_id.id - else: - str_to_format = surrealdb_id - return quote(str_to_format) - -def extract_id(surrealdb_id: RecordID) -> str: - """Extract numeric ID from SurrealDB record ID. +"""Extract numeric ID from SurrealDB record ID. SurrealDB record ID comes in the form of `:`. CSS classes cannot be named with a `:` so for CSS we extract the ID. @@ -49,8 +63,10 @@ def extract_id(surrealdb_id: RecordID) -> str: surrealdb_id: SurrealDB record ID. Returns: - ID. + ID with ':' replaced by '-'. """ +def extract_id(surrealdb_id: RecordID) -> str: + if RecordID == type(surrealdb_id): #return surrealdb_id.id return surrealdb_id.id.replace(":","-") @@ -58,11 +74,7 @@ def extract_id(surrealdb_id: RecordID) -> str: return surrealdb_id.replace(":","-") - -def convert_timestamp_to_date(timestamp: str) -> str: - """Convert a SurrealDB `datetime` to a readable string. - - The result will be of the format: `April 05 2024, 15:30`. +"""Convert a SurrealDB `datetime` to a readable string. Args: timestamp: SurrealDB `datetime` value. @@ -70,6 +82,8 @@ def convert_timestamp_to_date(timestamp: str) -> str: Returns: Date as a string. """ +def convert_timestamp_to_date(timestamp: str) -> str: + # parsed_timestamp = datetime.datetime.fromisoformat(timestamp.rstrip("Z")) # return parsed_timestamp.strftime("%B %d %Y, %H:%M") return timestamp @@ -83,8 +97,10 @@ def convert_timestamp_to_date(timestamp: str) -> str: @contextlib.asynccontextmanager async def lifespan(_: fastapi.FastAPI) -> AsyncGenerator: - """FastAPI lifespan to create and destroy objects.""" + """FastAPI lifespan to create and destroy objects. + Initializes and closes the SurrealDB connection and loads LLM and corpus data. + """ connection = AsyncSurreal(db_params.DB_PARAMS.url) await connection.signin({"username": db_params.DB_PARAMS.username, "password": db_params.DB_PARAMS.password}) await connection.use(db_params.DB_PARAMS.namespace, db_params.DB_PARAMS.database) @@ -102,6 +118,7 @@ async def lifespan(_: fastapi.FastAPI) -> AsyncGenerator: life_span.clear() +# Initialize FastAPI application app = fastapi.FastAPI(lifespan=lifespan) app.mount("/static", staticfiles.StaticFiles(directory="static"), name="static") @@ -109,6 +126,7 @@ async def lifespan(_: fastapi.FastAPI) -> AsyncGenerator: @app.get("/", response_class=responses.HTMLResponse) async def index(request: fastapi.Request) -> responses.HTMLResponse: + """Render the main chat interface.""" available_llm_models_json = json.dumps(life_span["llm_models"]) available_corpus_tables_json = json.dumps(life_span["corpus_tables"]) @@ -123,22 +141,25 @@ async def index(request: fastapi.Request) -> responses.HTMLResponse: "available_corpus_tables": available_corpus_tables_json, "default_llm_model": default_llm_model, "default_corpus_table": default_corpus_table, - "default_embed_model":default_embed_model + "default_embed_model":default_embed_model, + "default_prompt_text":LLMModelHander.DEFAULT_PROMPT_TEXT }) @app.get("/get_corpus_table_details") async def get_corpus_table_details(corpus_table: str = fastapi.Query(...)): + """Retrieve and return details of a corpus table.""" corpus_table_detail = life_span["corpus_tables"].get(corpus_table) if corpus_table_detail: s = f"Table: {corpus_table_detail['table_name']}" else: s = "Corpus table details not found." - return fastapi.Response(s, media_type="text/html") #Return response object + return fastapi.Response(s, media_type="text/html") @app.get("/get_llm_model_details") async def get_llm_model_details(llm_model: str = fastapi.Query(...)): + """Retrieve and return details of an LLM model.""" model_data = life_span["llm_models"].get(llm_model) if model_data: s = f" Platform: {model_data['platform']}, Host: {model_data['host']}
Version: {model_data['model_version']}" @@ -148,6 +169,7 @@ async def get_llm_model_details(llm_model: str = fastapi.Query(...)): @app.get("/get_embed_model_details") async def get_embed_model_details(corpus_table: str = fastapi.Query(...),embed_model: str = fastapi.Query(...)): + """Retrieve and return details of an embedding model.""" embed_models = life_span["corpus_tables"][corpus_table]["embed_models"] embed_model_detail = None embed_model = ast.literal_eval(embed_model) @@ -171,7 +193,7 @@ async def get_embed_model_details(corpus_table: str = fastapi.Query(...),embed_m @app.post("/chats", response_class=responses.HTMLResponse) async def create_chat(request: fastapi.Request) -> responses.HTMLResponse: - """Create a chat.""" + """Create a new chat.""" chat_record = await life_span["surrealdb"].query( """RETURN fn::create_chat();""" ) @@ -190,6 +212,7 @@ async def delete_chat( request: fastapi.Request, chat_id: str ) -> responses.HTMLResponse: + """Delete a chat and its messages.""" SurrealParams.ParseResponseForErrors( await life_span["surrealdb"].query_raw( """RETURN fn::delete_chat($chat_id)""",params = {"chat_id":chat_id} )) @@ -216,11 +239,13 @@ async def load_chat( "chat": {"id":chat_id,"title": title } }, ) + + @app.get("/messages/{message_id}", response_class=responses.HTMLResponse) async def load_message( request: fastapi.Request, message_id: str ) -> responses.HTMLResponse: - """Load a chat.""" + """Load a message.""" message = await life_span["surrealdb"].query( """RETURN fn::load_message_detail($message_id)""",params = {"message_id":message_id} ) @@ -308,15 +333,19 @@ async def send_user_message( ) async def send_system_message( request: fastapi.Request, chat_id: str, - llm_model: str = fastapi.Form(...) + llm_model: str = fastapi.Form(...), + prompt_template: str = fastapi.Form(...) ) -> responses.HTMLResponse: """Send system message.""" + # outcome = SurrealParams.ParseResponseForErrors( await life_span["surrealdb"].query_raw( + # """RETURN fn::get_last_user_message_input_and_prompt($chat_id,$prompt,$message_memory_length);""",params = {"chat_id":chat_id,"prompt":LLMModelHander.DEFAULT_PROMPT_TEXT,"message_memory_length":5} + # )) outcome = SurrealParams.ParseResponseForErrors( await life_span["surrealdb"].query_raw( - """RETURN fn::get_last_user_message_input_and_prompt($chat_id);""",params = {"chat_id":chat_id} + """RETURN fn::get_last_user_message_input_and_prompt($chat_id,$prompt,$message_memory_length);""",params = {"chat_id":chat_id,"prompt":prompt_template,"message_memory_length":5} )) result = outcome["result"][0]["result"] prompt_text = result["prompt_text"] @@ -328,7 +357,7 @@ async def send_system_message( llm_handler = LLMModelHander(model_data,model_params,life_span["surrealdb"]) - llm_response = llm_handler.get_chat_response(prompt_text,content) + llm_response = await llm_handler.get_chat_response(prompt_text,content) #save the response in the DB outcome = SurrealParams.ParseResponseForErrors(await life_span["surrealdb"].query_raw( @@ -344,7 +373,7 @@ async def send_system_message( "RETURN fn::get_first_message($chat_id);",params={"chat_id":chat_id} ) system_prompt = "You are a conversation title generator for a ChatGPT type app. Respond only with a simple title using the user input." - new_title = llm_handler.get_short_plain_text_response(system_prompt,first_message_text) + new_title = await llm_handler.get_short_plain_text_response(system_prompt,first_message_text) #update chat title in database SurrealParams.ParseResponseForErrors(await life_span["surrealdb"].query_raw( """UPDATE type::record($chat_id) SET title=$title;""",params = {"chat_id":chat_id,"title":new_title} diff --git a/surrealdb-rag/src/surrealdb_rag/constants.py b/surrealdb-rag/src/surrealdb_rag/constants.py index 5ac7e9a..91b5327 100644 --- a/surrealdb-rag/src/surrealdb_rag/constants.py +++ b/surrealdb-rag/src/surrealdb_rag/constants.py @@ -14,13 +14,29 @@ class SurrealParams(): + """ + Class to hold SurrealDB connection parameters. + """ def __init__(self = None, url = None,username = None, password = None, namespace = None, database = None): + self.username = username self.password = password self.namespace = namespace self.database = database self.url = url + """ + Parses the SurrealDB response for errors. + + Args: + outcome (dict): The SurrealDB response. + + Returns: + dict: The parsed response, or None if the outcome is None. + + Raises: + SystemError: If an error is found in the response. + """ @staticmethod def ParseResponseForErrors(outcome): @@ -40,43 +56,32 @@ def ParseResponseForErrors(outcome): class ModelParams(): - # GEMINI_MODELS = ["gemini-2.0-flash-lite","gemini-2.0-flash","gemini-1.5-flash","gemini-1.5-flash-8b","gemini-1.5-pro"] - # # OPENAI_MODELS = ["gemini-2.0-flash-lite","gemini-2.0-flash","gemini-1.5-flash","gemini-1.5-flash-8b","gemini-1.5-pro"] - # # LLM_MODELS = { - # # "GEMINI-SURREAL": {"model_version":"gemini-2.0-flash","host":"SQL","platform":"GOOGLE","temperature":None}, - # # "GEMINI": {"model_version":"gemini-2.0-flash","host":"API","platform":"GOOGLE","temperature":None}, - # # "DEEPSEEK": {"model_version":"deepseek-r1:1.5b","host":"OLLAMA","platform":"local","temperature":None}, - # # "OPENAI-SURREAL": {"model_version":"gpt-3.5-turbo","host":"API","platform":"OPENAI","temperature":0.5}, - # # "OPENAI": {"model_version":"gpt-3.5-turbo","host":"API","platform":"OPENAI","temperature":0.5} - # # } - - # EMBED_MODELS = { - # "FASTTEXT": {"dimensions":100,"host":"SQL"}, - # "GLOVE": {"dimensions":300,"host":"SQL"}, - # "OPENAI": {"dimensions":1536,"host":"API"} - # } def __init__(self): + + #These are just the pointers to the env variables + #Don't put the actual tokens here + self.openai_token_env_var = "OPENAI_API_KEY" self.openai_token = None self.gemini_token_env_var = "GOOGLE_GENAI_API_KEY" self.gemini_token = None - # self.embedding_model_env_var = "SURREAL_RAG_EMBEDDING_MODEL" - # self.embedding_model = None - # self.llm_model_env_var = "SURREAL_RAG_LLM_MODEL" - # self.llm_model = None - # self.version = None - # self.host = None - # self.temperature = 0.5 + + """ + Adds command-line arguments for model parameters. + Args: + parser (argparse.ArgumentParser): The argument parser. + """ def AddArgs(self, parser:argparse.ArgumentParser): parser.add_argument("-oenv","--openai_token_env", help="Your env variable for LLM openai_token (Default: {0} for ollama hosted ignore)".format(self.openai_token_env_var)) parser.add_argument("-genv","--gemini_token_env", help="Your env variable for LLM gemini_token (Default: {0} for ollama hosted ignore)".format(self.gemini_token_env_var)) - #parser.add_argument("-emenv","--embedding_model_env_var", help="Your env variable for embedding model value can be 'OPENAI' or 'GLOVE' (Default: {0})".format(self.embedding_model_env_var)) - #parser.add_argument("-em","--embedding_model", help="Embedding model value can be 'OPENAI' or 'GLOVE', if none it will use env var (Default: {0})".format("")) - # parser.add_argument("-llmenv","--llm_model_env_var", help="Your env variable for LLM model value can be 'OPENAI','DEEPSEEK' or 'GEMINI' (Default: {0})".format(self.llm_model_env_var)) - # parser.add_argument("-llm","--llm_model", help="LLM model value can be 'OPENAI'.'DEEPSEEK' or 'GEMINI', if none it will use env var (Default: {0})".format("")) - + """ + Sets model parameters from command-line arguments and environment variables. + + Args: + args (argparse.Namespace): The parsed command-line arguments. + """ def SetArgs(self,args:argparse.Namespace): if args.openai_token_env: self.openai_token_env_var = args.openai_token_env @@ -88,35 +93,19 @@ def SetArgs(self,args:argparse.Namespace): self.gemini_token = os.getenv(self.gemini_token_env_var) - # if args.embedding_model_env_var: - # self.embedding_model_env_var = args.embedding_model_env_var - # if args.llm_model_env_var: - # self.llm_model_env_var = args.llm_model_env_var - - - # if args.embedding_model: - # self.embedding_model = args.embedding_model - # else: - # self.embedding_model = os.getenv(self.embedding_model_env_var) - - # if self.embedding_model not in ["OPENAI","GLOVE"]: - # raise ValueError("Embedding model must be 'OPENAI' or 'GLOVE'") - # if args.llm_model: - # self.llm_model = args.llm_model - # else: - # self.llm_model = os.getenv(self.llm_model_env_var) - - # self.version = self.LLM_MODELS[self.llm_model]["model_version"] - # self.host = self.LLM_MODELS[self.llm_model]["host"] - class DatabaseParams(): def __init__(self): - #export SURREAL_CLOUD_TEST_USER=xxx - #export SURREAL_CLOUD_TEST_PASS=xxx + + #The path to your SurrealDB instance + #The the SurrealDB namespace and database + #For use in authenticating your database + #These are just the pointers to the env variables + #Don't put the actual passwords here + self.DB_USER_ENV_VAR = "SURREAL_RAG_USER" self.DB_PASS_ENV_VAR = "SURREAL_RAG_PASS" self.DB_URL_ENV_VAR = "SURREAL_RAG_DB_URL" @@ -124,13 +113,16 @@ def __init__(self): self.DB_DB_ENV_VAR = "SURREAL_RAG_DB_DB" - #The path to your SurrealDB instance - #The the SurrealDB namespace and database to upload the model to self.DB_PARAMS = SurrealParams() - #For use in authenticating your database in database.py - #These are just the pointers to the env variables - #Don't put the actual passwords here + + + """ + Sets model parameters from command-line arguments and environment variables. + + Args: + args (argparse.Namespace): The parsed command-line arguments. + """ def AddArgs(self, parser:argparse.ArgumentParser): parser.add_argument("-urlenv","--url_env", help="Your env variable for Path to your SurrealDB instance (Default: {0})".format(self.DB_URL_ENV_VAR)) @@ -187,9 +179,18 @@ def SetArgs(self,args:argparse.Namespace): - +""" +Class to load and manage command-line arguments using argparse. +""" class ArgsLoader(): - + """ + Initializes the ArgsLoader with a description and parameter objects. + + Args: + description (str): Description of the program for the help message. + db_params (DatabaseParams): Instance of DatabaseParams to manage database arguments. + model_params (ModelParams): Instance of ModelParams to manage model arguments. + """ def __init__(self,description, db_params: DatabaseParams,model_params: ModelParams): self.parser = argparse.ArgumentParser(description=description) @@ -198,14 +199,25 @@ def __init__(self,description, self.model_params.AddArgs(self.parser) self.db_params.AddArgs(self.parser) self.AdditionalArgs = {} - + + """ + Adds a custom argument to the parser. + + Args: + name (str): Name of the argument. + flag (str): Short flag for the argument (e.g., 'f'). + action (str): Long flag for the argument (e.g., 'file'). + help (str): Help message for the argument. + default (str): Default value for the argument. + """ def AddArg(self,name:str,flag:str,action:str,help:str,default:str): self.parser.add_argument(f"-{flag}",f"--{action}", help=help.format(default)) self.AdditionalArgs[name] = {"flag":flag,"action":action,"value":default} - - + """ + Parses the command-line arguments and sets the parameter objects. + """ def LoadArgs(self): self.args = self.parser.parse_args() self.db_params.SetArgs(self.args) @@ -213,7 +225,12 @@ def LoadArgs(self): for key in self.AdditionalArgs.keys(): if getattr(self.args, self.AdditionalArgs[key]["action"]): self.AdditionalArgs[key]["value"] = getattr(self.args, self.AdditionalArgs[key]["action"]) + """ + Generates a formatted string containing all parsed arguments. + Returns: + str: Formatted string of arguments. + """ def string_to_print(self): ret_val = self.parser.description @@ -226,19 +243,30 @@ def string_to_print(self): ret_val += ArgsLoader.additional_args_dict_to_str(self.AdditionalArgs) return ret_val - ret_val += f"\n{self.db_params.DB_PARAMS.__dict__}" - ret_val += f"\n{self.model_params.__dict__}" - for key in self.AdditionalArgs.keys(): - ret_val += f"\n{key} : {self.AdditionalArgs[key]["value"]}" - return ret_val - + """ + Converts a dictionary of additional arguments to a formatted string. + + Args: + the_dict (dict): Dictionary of additional arguments. + + Returns: + str: Formatted string. + """ def additional_args_dict_to_str(the_dict:dict): ret_val = "" for key in the_dict.keys(): ret_val += f"\n{key} : {the_dict[key]["value"]}" return ret_val + """ + Converts a dictionary to a formatted string. + Args: + the_dict (dict): Dictionary to convert. + + Returns: + str: Formatted string. + """ def dict_to_str(the_dict:dict): ret_val = "" for key in the_dict.keys(): @@ -246,7 +274,9 @@ def dict_to_str(the_dict:dict): return ret_val - + """ + Prints the formatted string of arguments. + """ def print(self): print(self.string_to_print()) diff --git a/surrealdb-rag/src/surrealdb_rag/create_database.py b/surrealdb-rag/src/surrealdb_rag/create_database.py index 6914ab7..30f58a7 100644 --- a/surrealdb-rag/src/surrealdb_rag/create_database.py +++ b/surrealdb-rag/src/surrealdb_rag/create_database.py @@ -1,5 +1,3 @@ -"""Insert Wikipedia data into SurrealDB.""" - from surrealdb import Surreal @@ -8,20 +6,27 @@ from surrealdb_rag.constants import DatabaseParams, ModelParams, ArgsLoader, SurrealParams +# Initialize parameter objects and argument loader db_params = DatabaseParams() model_params = ModelParams() args_loader = ArgsLoader("Input Embeddings Model",db_params,model_params) -args_loader.LoadArgs() + +""" +Creates a SurrealDB database and namespace, and executes schema DDL. +""" def surreal_create_database() -> None: """Create SurrealDB database for Wikipedia embeddings.""" logger = loggers.setup_logger("SurrealCreateDatabase") + args_loader.LoadArgs() # Parse command-line arguments + logger.info(args_loader.string_to_print()) with Surreal(db_params.DB_PARAMS.url) as connection: logger.info("Connected to SurrealDB") connection.signin({"username": db_params.DB_PARAMS.username, "password": db_params.DB_PARAMS.password}) logger.info("Creating database") + # SurrealQL query to create namespace and database query= f""" DEFINE NAMESPACE IF NOT EXISTS {db_params.DB_PARAMS.namespace}; @@ -31,6 +36,8 @@ def surreal_create_database() -> None: USE DATABASE {db_params.DB_PARAMS.database}; """ logger.info(query) + + # Execute the query and check for errors SurrealParams.ParseResponseForErrors(connection.query_raw( query )) @@ -39,23 +46,12 @@ def surreal_create_database() -> None: logger.info("Executing common function DDL") + # Read the schema DDL that holds the SurQL functions from file with open("./schema/function_ddl.surql") as f: surlql_to_execute = f.read() SurrealParams.ParseResponseForErrors( connection.query_raw(surlql_to_execute)) - # match model_params.EMBEDDING_MODEL: - # case "OPENAI": - # logger.info("Creating DDL for open ai model") - # with open("./schema/openai_embedding_ddl.surql") as f: - # surlql_to_execute = f.read() - # SurrealParams.ParseResponseForErrors( connection.query_raw(surlql_to_execute)) - # case "GLOVE": - # logger.info("Creating DDL for glove model") - # with open("./schema/glove_embedding_ddl.surql") as f: - # surlql_to_execute = f.read() - # SurrealParams.ParseResponseForErrors( connection.query_raw(surlql_to_execute)) - # case _: - # raise ValueError("Embedding model must be 'OPENAI' or 'GLOVE'") + if __name__ == "__main__": surreal_create_database() \ No newline at end of file diff --git a/surrealdb-rag/src/surrealdb_rag/download_glove.py b/surrealdb-rag/src/surrealdb_rag/download_glove.py index 1f477d2..c46cd34 100644 --- a/surrealdb-rag/src/surrealdb_rag/download_glove.py +++ b/surrealdb-rag/src/surrealdb_rag/download_glove.py @@ -1,4 +1,4 @@ -"""Download OpenAI Wikipedia data.""" +"""Download GLoVe pre trained model.""" import zipfile diff --git a/surrealdb-rag/src/surrealdb_rag/download_wiki_data.py b/surrealdb-rag/src/surrealdb_rag/download_wiki_data.py index 727fc0d..4e51757 100644 --- a/surrealdb-rag/src/surrealdb_rag/download_wiki_data.py +++ b/surrealdb-rag/src/surrealdb_rag/download_wiki_data.py @@ -21,13 +21,7 @@ def download_data() -> None: logger = loggers.setup_logger("DownloadData") logger.info("Downloading Wikipedia") - # if not os.path.exists("data"): - # os.makedirs("data") - # wget.download( - # url=constants.WIKI_URL, - # out=constants.WIKI_ZIP_PATH, - # ) - + logger.info("Extracting to data directory") with zipfile.ZipFile( constants.WIKI_ZIP_PATH, "r" @@ -39,8 +33,6 @@ def download_data() -> None: logger.info("Loading Glove embedding model") - - try: gloveEmbeddingModel = WordEmbeddingModel(constants.GLOVE_PATH) except Exception as e: diff --git a/surrealdb-rag/src/surrealdb_rag/insert_embedding_model.py b/surrealdb-rag/src/surrealdb_rag/insert_embedding_model.py index 960b1e3..e14c2bd 100644 --- a/surrealdb-rag/src/surrealdb_rag/insert_embedding_model.py +++ b/surrealdb-rag/src/surrealdb_rag/insert_embedding_model.py @@ -13,10 +13,12 @@ import surrealdb_rag.constants as constants +# Initialize database and model parameters, and argument loader db_params = DatabaseParams() model_params = ModelParams() args_loader = ArgsLoader("Input Glove embeddings model",db_params,model_params) +# SurrealQL queries for database operations INSERT_EMBEDDINGS = """ LET $model = type::thing('embedding_model_definition',[$model_trainer,$model_version]); FOR $row IN $embeddings { @@ -47,30 +49,46 @@ """ -CHUNK_SIZE = 1000 +CHUNK_SIZE = 1000 # Size of chunks for batch insertion +""" +Inserts word embeddings from a model file into SurrealDB. + +Args: + model_trainer (str): Name of the training algorithm (e.g., 'GLOVE', 'FASTTEXT'). + model_version (str): Version of the model (e.g., '300d', 'custom wiki'). + model_path (str): Path to the model file. + description (str): Description of the embedding model. + corpus (str): Description of the training data. + logger (logging.Logger): Logger instance. +""" def surreal_model_insert(model_trainer,model_version,model_path,description,corpus,logger): logger.info(f"Reading {model_trainer} {model_version} model") - embeddingModel = WordEmbeddingModel(model_path) + # Load the embedding model + embeddingModel = WordEmbeddingModel(model_path) + # Create DataFrame from model data. embeddings_df = pd.DataFrame({'word': embeddingModel.dictionary.keys(), 'embedding': embeddingModel.dictionary.values()}) + # Calculate number of chunks for batch processing. total_rows = len(embeddings_df) - total_chunks = (total_rows + CHUNK_SIZE - 1) // CHUNK_SIZE # ceiling division + total_chunks = (total_rows + CHUNK_SIZE - 1) // CHUNK_SIZE # Calculate number of chunks for batch processing with Surreal(db_params.DB_PARAMS.url) as connection: connection.signin({"username": db_params.DB_PARAMS.username, "password": db_params.DB_PARAMS.password}) connection.use(db_params.DB_PARAMS.namespace, db_params.DB_PARAMS.database) logger.info("Connected to SurrealDB") logger.info(f"Deleting any rows from {model_trainer} {model_version}") + # Delete existing embeddings for this particular model SurrealParams.ParseResponseForErrors(connection.query_raw(DELETE_EMBEDDINGS,params={"model_trainer":model_trainer,"model_version":model_version})) logger.info("Inserting rows into SurrealDB") - #remove any data from the table with tqdm.tqdm(total=total_chunks, desc="Inserting") as pbar: - + # Iterate through chunks of data. for i in range(0, total_rows, CHUNK_SIZE): + # Get a chunk of data. chunk = embeddings_df.iloc[i:i + CHUNK_SIZE] + # create an array of dicts to bulk load into surreal formatted_rows = [ { "word":str(row["word"]), @@ -79,12 +97,14 @@ def surreal_model_insert(model_trainer,model_version,model_path,description,corp for _, row in chunk.iterrows() ] - + # Insert the chunk. SurrealParams.ParseResponseForErrors(connection.query_raw( INSERT_EMBEDDINGS, params={"embeddings": formatted_rows,"model_trainer":model_trainer,"model_version":model_version} )) + # Update progress bar. pbar.update(1) + # Update the model definition. SurrealParams.ParseResponseForErrors(connection.query_raw(UPDATE_EMBEDDING_MODEL_DEF, params={ "model_trainer":model_trainer, @@ -97,11 +117,13 @@ def surreal_model_insert(model_trainer,model_version,model_path,description,corp +"""Main entrypoint to insert glove embedding model into SurrealDB.""" def surreal_embeddings_insert() -> None: """Main entrypoint to insert glove embedding model into SurrealDB.""" logger = loggers.setup_logger("SurrealEmbeddingsInsert") + # Add command-line arguments specific to embedding insertion. args_loader.AddArg( "model_trainer","emtr","model_trainer","The name of the training algorithm: 'GLOVE' or 'FASTTEXT' (Default{0})",None ) @@ -118,9 +140,11 @@ def surreal_embeddings_insert() -> None: "corpus","cor","corpus","a description of the embedding model training data. (Default{0})",None ) + # Parse command-line arguments. args_loader.LoadArgs() + # Retrieve argument values. model_trainer = args_loader.AdditionalArgs["model_trainer"]["value"] if not model_trainer or not model_trainer in ['GLOVE','FASTTEXT']: @@ -139,9 +163,11 @@ def surreal_embeddings_insert() -> None: raise Exception("You must supply a model corpus") + # Log the parsed arguments. logger.info(args_loader.string_to_print()) + # Insert the embedding model. surreal_model_insert(model_trainer,model_version,model_path,description,corpus,logger) diff --git a/surrealdb-rag/src/surrealdb_rag/insert_wiki.py b/surrealdb-rag/src/surrealdb_rag/insert_wiki.py index 40ba7ac..7eef07d 100644 --- a/surrealdb-rag/src/surrealdb_rag/insert_wiki.py +++ b/surrealdb-rag/src/surrealdb_rag/insert_wiki.py @@ -12,19 +12,22 @@ from surrealdb_rag.constants import DatabaseParams, ModelParams, ArgsLoader, SurrealParams +# Initialize database and model parameters, and argument loader db_params = DatabaseParams() model_params = ModelParams() args_loader = ArgsLoader("Input wiki data",db_params,model_params) +# Define table name and display name TABLE_NAME = "embedded_wiki" DISPLAY_NAME = "Wikipedia" - +# SurrealQL query to get embedding model definitions GET_EMBED_MODEL_DESCRIPTIONS = """ SELECT * FROM embedding_model_definition; """ +# Mapping of embedding models to their field names and definitions EMBED_MODEL_DEFINITIONS = { "GLOVE":{"field_name":"content_glove_vector","model_definition":[ 'GLOVE', @@ -42,7 +45,8 @@ - + +# SurrealQL query to update corpus table information UPDATE_CORPUS_TABLE_INFO = f""" DELETE FROM corpus_table_model WHERE corpus_table = corpus_table:{TABLE_NAME}; FOR $model IN $embed_models {{ @@ -55,6 +59,7 @@ """ +# SurrealQL query to insert Wikipedia records INSERT_WIKI_RECORDS = f""" FOR $row IN $records {{ CREATE type::thing("{TABLE_NAME}",$row.url) CONTENT {{ @@ -68,14 +73,18 @@ }}; """ +# SurrealQL query to delete all Wikipedia records DELETE_WIKI_RECORDS = f"DELETE {TABLE_NAME};" +# Chunk size for batch processing CHUNK_SIZE = 50 +"""Main entrypoint to insert Wikipedia embeddings into SurrealDB.""" def surreal_wiki_insert() -> None: + # Add command-line argument for embedding models args_loader.AddArg( "embed_models", "ems", @@ -83,29 +92,46 @@ def surreal_wiki_insert() -> None: "The Embed models you'd like to calculate sepearated by , can be GLOVE,FASTTEXT,OPENAI eg -ems OPENAI,GLOVE will calculate both OPENAI and GLOVE. (default{})", "GLOVE,FASTTEXT,OPENAI" ) + # Parse command-line arguments args_loader.LoadArgs() - """Main entrypoint to insert Wikipedia embeddings into SurrealDB.""" logger = loggers.setup_logger("SurrealWikiInsert") logger.info(args_loader.string_to_print()) + # Retrieve embedding models from arguments embed_models_str = args_loader.AdditionalArgs["embed_models"]["value"] embed_models = embed_models_str.split(",") + # Validate embedding models if len(embed_models)<1: raise Exception("You must specify at least one valid model of GLOVE,FASTTEXT,OPENAI with the -ems flag") - - + using_openai = False + using_fasttext = False + using_glove = False for embed_model in embed_models: + if embed_model=="OPENAI": + using_openai = True + if embed_model=="FASTTEXT": + using_fasttext = True + if embed_model=="GLOVE": + using_glove = True if embed_model not in EMBED_MODEL_DEFINITIONS: raise Exception(f"{embed_model} is invalid, You must specify at least one valid model of GLOVE,FASTTEXT,OPENAI with the -ems flag") + # Create embedding model mappings + embed_model_mappings = [] + for embed_model in embed_models: + if embed_model in EMBED_MODEL_DEFINITIONS: + field_name = EMBED_MODEL_DEFINITIONS[embed_model]["field_name"] + model_definition = EMBED_MODEL_DEFINITIONS[embed_model]["model_definition"] + embed_model_mappings.append({"model_id": model_definition, "field_name": field_name}) logger.info(f"Loading file {constants.WIKI_PATH}") + # Define columns to read from CSV usecols=[ "url", "title", @@ -115,9 +141,11 @@ def surreal_wiki_insert() -> None: "content_fasttext_vector" ] + # Read Wikipedia records from CSV wiki_records_df = pd.read_csv(constants.WIKI_PATH,usecols=usecols) + # Calculate total rows and chunks total_rows = len(wiki_records_df) total_chunks = total_rows // CHUNK_SIZE + ( 1 if total_rows % CHUNK_SIZE else 0 @@ -131,18 +159,9 @@ def surreal_wiki_insert() -> None: - embed_model_mappings = [] - for embed_model in embed_models: - if embed_model in EMBED_MODEL_DEFINITIONS: - field_name = EMBED_MODEL_DEFINITIONS[embed_model]["field_name"] - model_definition = EMBED_MODEL_DEFINITIONS[embed_model]["model_definition"] - embed_model_mappings.append({"model_id": model_definition, "field_name": field_name}) - # logger.info(f"Updating corpus table info for {TABLE_NAME}") - # SurrealParams.ParseResponseForErrors( connection.query_raw(UPDATE_CORPUS_TABLE_INFO,params={"embed_models":embed_model_mappings})) - # return - + # Read and execute table DDL that creates the tables and indexes if missing with open("./schema/table_ddl.surql") as f: surlql_to_execute = f.read() surlql_to_execute = surlql_to_execute.format(corpus_table = "embedded_wiki") @@ -150,10 +169,11 @@ def surreal_wiki_insert() -> None: logger.info("Deleting any existing wiki rows from SurrealDB") - #remove any data from the table + # Delete existing wiki records SurrealParams.ParseResponseForErrors(connection.query_raw(DELETE_WIKI_RECORDS)) logger.info("Inserting rows into SurrealDB") + # Iterate through chunks and insert records with tqdm.tqdm(total=total_chunks, desc="Inserting") as pbar: for i in range(0, total_rows, CHUNK_SIZE): chunk = wiki_records_df.iloc[i:i + CHUNK_SIZE] @@ -162,9 +182,9 @@ def surreal_wiki_insert() -> None: "url":str(row["url"]), "title":str(row["title"]), "text":str(row["text"]), - "content_openai_vector":ast.literal_eval(row["content_vector"]), - "content_glove_vector":ast.literal_eval(row["content_glove_vector"]), - "content_fasttext_vector":ast.literal_eval(row["content_fasttext_vector"]) + "content_openai_vector":ast.literal_eval(row["content_vector"]) if using_openai else None, + "content_glove_vector":ast.literal_eval(row["content_glove_vector"]) if using_glove else None, + "content_fasttext_vector":ast.literal_eval(row["content_fasttext_vector"]) if using_fasttext else None } for _, row in chunk.iterrows() ] @@ -180,6 +200,7 @@ def surreal_wiki_insert() -> None: + # Update corpus table information logger.info(f"Updating corpus table info for {TABLE_NAME}") SurrealParams.ParseResponseForErrors( connection.query_raw(UPDATE_CORPUS_TABLE_INFO,params={"embed_models":embed_model_mappings})) diff --git a/surrealdb-rag/src/surrealdb_rag/llm_handler.py b/surrealdb-rag/src/surrealdb_rag/llm_handler.py index 71a7a58..9a43f69 100644 --- a/surrealdb-rag/src/surrealdb_rag/llm_handler.py +++ b/surrealdb-rag/src/surrealdb_rag/llm_handler.py @@ -10,14 +10,31 @@ from surrealdb import AsyncSurreal import re + +""" +Handles the retrieval of available LLM models and corpus tables. +""" class ModelListHandler(): + + """ + Initializes the ModelListHandler. + Args: + model_params (ModelParams): Model parameters. + connection (AsyncSurreal): Asynchronous SurrealDB connection. + """ def __init__(self, model_params, connection): self.LLM_MODELS = {} self.CORPUS_TABLES = {} self.model_params = model_params self.connection = connection + + """ + Retrieves a dictionary of available LLM models. + Returns: + dict: Dictionary of available LLM models. + """ async def available_llm_models(self): if self.LLM_MODELS != {}: return self.LLM_MODELS @@ -25,7 +42,7 @@ async def available_llm_models(self): self.LLM_MODELS = {} - #you need an api key for gemini + # Configure Gemini API if API key is available if self.model_params.gemini_token: genai.configure(api_key=self.model_params.gemini_token) @@ -42,7 +59,7 @@ async def available_llm_models(self): self.LLM_MODELS["GOOGLE - " + model.display_name] = {"model_version":model.name,"host":"API","platform":"GOOGLE","temperature":0} self.LLM_MODELS["GOOGLE - " + model.display_name + " (surreal)"] = {"model_version":model.name, "host":"SQL","platform":"GOOGLE","temperature":0} - #you need an api key for openai + # Configure OpenAI API if API key is available if self.model_params.openai_token: openai.api_key = self.model_params.openai_token models = openai.models.list() @@ -53,6 +70,7 @@ async def available_llm_models(self): self.LLM_MODELS["OPENAI - " + model.id + " (surreal)"] = {"model_version":model.id,"host":"SQL","platform":"OPENAI","temperature":0.5} + # Retrieve models from Ollama response: ollama.ListResponse = ollama.list() for model in response.models: @@ -61,7 +79,12 @@ async def available_llm_models(self): return self.LLM_MODELS - + """ + Retrieves a dictionary of available corpus tables. + + Returns: + dict: Dictionary of available corpus tables. + """ async def available_corpus_tables(self): if self.CORPUS_TABLES != {}: return self.CORPUS_TABLES @@ -71,85 +94,9 @@ async def available_corpus_tables(self): SELECT display_name,table_name,embed_models FROM corpus_table FETCH embed_models,embed_models.model; """) - #you need an api key for openai so remove openai from list if api is absent - + # Filter out OpenAI models if API key is absent for corpus_table in corpus_tables: - # example record - # { - # display_name: 'Wikipedia', - # embed_models: [ - # { - # corpus_table: corpus_table:embedded_wiki, - # field_name: 'content_fasttext_vector', - # id: corpus_table_model:[ - # corpus_table:embedded_wiki, - # embedding_model_definition:[ - # 'FASTTEXT', - # 'wiki' - # ] - # ], - # model: { - # corpus: 'https://cdn.openai.com/API/examples/data/vector_database_wikipedia_articles_embedded.zip', - # description: 'Custom trained model using fasttext based on OPENAI wiki example download', - # dimensions: 100, - # host: 'SQL', - # id: embedding_model_definition:[ - # 'FASTTEXT', - # 'wiki' - # ], - # model_trainer: 'FASTTEXT', - # version: 'wiki' - # } - # }, - # { - # corpus_table: corpus_table:embedded_wiki, - # field_name: 'content_glove_vector', - # id: corpus_table_model:[ - # corpus_table:embedded_wiki, - # embedding_model_definition:[ - # 'GLOVE', - # '6b 300d' - # ] - # ], - # model: { - # corpus: 'Wikipedia 2014 + Gigaword 5', - # description: 'Standard pretrained GLoVE model from https://nlp.stanford.edu/projects/glove/ 300 dimensions version', - # dimensions: 300, - # host: 'SQL', - # id: embedding_model_definition:[ - # 'GLOVE', - # '6b 300d' - # ], - # model_trainer: 'GLOVE', - # version: '6b 300d' - # } - # }, - # { - # corpus_table: corpus_table:embedded_wiki, - # field_name: 'content_openai_vector', - # id: corpus_table_model:[ - # corpus_table:embedded_wiki, - # embedding_model_definition:[ - # 'OPENAI', - # 'text-embedding-ada-002' - # ] - # ], - # model: { - # corpus: 'generic pretrained', - # description: 'The standard OPENAI embedding model', - # dimensions: 1536, - # host: 'API', - # id: embedding_model_definition:[ - # 'OPENAI', - # 'text-embedding-ada-002' - # ], - # model_trainer: 'OPENAI', - # version: 'text-embedding-ada-002' - # } - # } - # ], - # table_name: 'embedded_wiki' - # } + # create an dict item for table_name table_name = corpus_table["table_name"] self.CORPUS_TABLES[table_name] = {} @@ -180,11 +127,46 @@ async def available_corpus_tables(self): - +""" +Handles interactions with different LLM models. +""" class LLMModelHander(): + DEFAULT_PROMPT_TEXT = """ + You are an AI assistant answering questions about anything from the corpus of knowledge provided in the tags. + + You may also refer to the text in the tags but only for refining your understanding of what is being asked of you. Do not rely on the chat_history for answering the question! + + Please provide your response in Markdown converted to HTML format. Include appropriate headings and lists where relevant. + + At the end of the response, add any links as a HTML link and replace the title and url with the associated title and url of the more relevant page from the context. + + Only reply with the context provided. If the context is an empty string, reply with 'I am sorry, I do not know the answer.'. + + Do not use any prior knowledge that you have been trained on. + + + $context + + + $chat_history + + """ + + GEMINI_CHAT_COMPLETE = """RETURN fn::gemini_chat_complete($llm,$prompt_with_context,$input,$google_token);""" + + OPENAI_CHAT_COMPLETE = """RETURN fn::openai_chat_complete($llm,$prompt_with_context, $input, $temperature, $openai_token);""" + + + """ + Initializes the LLMModelHander. + Args: + model_data (str): Model data. + model_params (ModelParams): Model parameters. + connection (AsyncSurreal): Asynchronous SurrealDB connection. + """ def __init__(self,model_data:str,model_params:ModelParams,connection:AsyncSurreal): self.model_data = model_data @@ -192,8 +174,40 @@ def __init__(self,model_data:str,model_params:ModelParams,connection:AsyncSurrea self.connection = connection - def extract_plain_text(text): + """ + Parses a string containing tags and extracts the content. + + Args: + input_string: The string to parse. + + Returns: + A dictionary with two keys: + - "think": The content between the tags. + - "content": The rest of the string. + """ + def parse_deepseek_response(input_string): """ + Parses a string containing tags and extracts the content. + + Args: + input_string: The string to parse. + + Returns: + A dictionary with two keys: + - "think": The content between the tags. + - "content": The rest of the string. + """ + + # Use regular expression to find the content within the tags + think_match = re.search(r"(.*?)", input_string, re.DOTALL) + think_content = think_match.group(1).strip() if think_match else "" + + # Remove the section from the original string + content = re.sub(r".*?\s*", "", input_string, flags=re.DOTALL).strip() + + return {"think": think_content, "content": content} + + """ Extracts plain text from a string by removing content within tags. Args: @@ -202,12 +216,14 @@ def extract_plain_text(text): Returns: str: The plain text with tags and their contents removed. """ + def extract_plain_text(text): + # Use a regular expression to find and remove content within tags - clean_text = remove_think_tags(clean_text) + clean_text = LLMModelHander.remove_think_tags(text) clean_text = re.sub(r'<[^>]*>', '', clean_text) return clean_text - def remove_think_tags(text): - """ + + """ Removes tags and their content from the given text, leaving only the text after the closing tag. Args: @@ -216,52 +232,170 @@ def remove_think_tags(text): Returns: str: The string with tags and their content removed. """ - return re.sub(r'.*?\n*', '', text, flags=re.DOTALL | re.IGNORECASE).strip() + def remove_think_tags(text): + # Uses a regular expression to find and remove tags and their content. + + return re.sub(r".*?\s*", "", text, flags=re.DOTALL).strip() + + + """ + Gets a short plain text response from the chat model. + + Args: + prompt_with_context (str): The prompt with context. + input (str): The user input. + + Returns: + str: A short plain text response. + """ + async def get_short_plain_text_response(self,prompt_with_context:str,input:str): + # Limits the response to 255 characters. + n = 255 + # Extracts plain text from the chat response. + text = LLMModelHander.extract_plain_text(await self.get_chat_response(prompt_with_context,input)) + if len(text) > n: + # Returns the first 255 characters if the response is longer. + return text[:n] + else: + # Returns the full response if it's shorter than 255 characters. + return text + + """ + Gets a chat response based on the model's host. - def get_short_plain_text_response(self,prompt_with_context:str,input:str): - return LLMModelHander.extract_plain_text(self.get_chat_response(prompt_with_context,input)) + Args: + prompt_with_context (str): The prompt with context. + input (str): The user input. - def get_chat_response(self,prompt_with_context:str,input:str): + Returns: + str: The chat response. + """ + async def get_chat_response(self,prompt_with_context:str,input:str): + # Matches the model's host and calls the appropriate method. match self.model_data["host"]: case "SQL": - return self.get_chat_response_from_surreal(prompt_with_context,input) + # Gets response from SurrealDB. + return await self.get_chat_response_from_surreal(prompt_with_context,input) case "API": + # Gets response from API. return self.get_chat_response_from_api(prompt_with_context,input) case "OLLAMA": + # Gets response from Ollama. return self.get_chat_response_from_ollama(prompt_with_context,input) case _: raise SystemError(f"Invalid host method {self.model_data["host"]}") + """ + Gets a chat response from an API based on the model's platform. - + Args: + prompt_with_context (str): The prompt with context. + input (str): The user input. + + Returns: + str: The chat response from the API. + """ def get_chat_response_from_api(self,prompt_with_context:str,input:str): + # Matches the model's platform and calls the appropriate method. match self.model_data["platform"]: case "OPENAI": + # Gets response from OpenAI API. return self.get_chat_openai_response_from_api(prompt_with_context,input) case "GOOGLE": + # Gets response from Google Gemini API. return self.get_chat_gemini_response_from_api(prompt_with_context,input) case _: raise SystemError(f"Error in outcome: Invalid model for API execution {self.model_data["platform"]}") + + + """ + Gets a chat response from SurrealDB based on the model's platform. + + Args: + prompt_with_context (str): The prompt with context. + input (str): The user input. - def get_chat_response_from_surreal(self,prompt_with_context:str,input:str): + Returns: + str: The chat response from SurrealDB. + """ + async def get_chat_response_from_surreal(self,prompt_with_context:str,input:str): + # Matches the model's platform and calls the appropriate method. match self.model_data["platform"]: case "OPENAI": - return self.get_chat_openai_response_from_surreal(prompt_with_context,input) + # Gets response from OpenAI using SurrealDB. + return await self.get_chat_openai_response_from_surreal(prompt_with_context,input) case "GOOGLE": - return self.get_chat_gemini_response_from_surreal(prompt_with_context,input) + # Gets response from Google Gemini using SurrealDB. + return await self.get_chat_gemini_response_from_surreal(prompt_with_context,input) case _: raise SystemError(f"Error in outcome: Invalid model for SQL execution {self.model_data["platform"]}") - def get_chat_openai_response_from_surreal(self,prompt_with_context:str,input:str): - return "get_chat_openai_response_from_surreal" + + """ + Gets a chat response from OpenAI using SurrealDB. + + Args: + prompt_with_context (str): The prompt with context. + input (str): The user input. + + Returns: + str: The OpenAI chat response. + """ + async def get_chat_openai_response_from_surreal(self,prompt_with_context:str,input:str): + + # Executes the OpenAI chat completion query in SurrealDB. + model_version = self.model_data["model_version"] + openai_response = await self.connection.query( + LLMModelHander.OPENAI_CHAT_COMPLETE, params={ + "llm":model_version, + "prompt_with_context":prompt_with_context, + "input":input, + "temperature":self.model_data["temperature"], + "openai_token":self.model_params.openai_token + }) + + # Returns the content of the response. + return openai_response["choices"][0]["message"]["content"] + - def get_chat_gemini_response_from_surreal(self,prompt_with_context:str,input:str): - return "get_chat_gemini_response_from_surreal" + """ + Gets a chat response from Google Gemini using SurrealDB. + + Args: + prompt_with_context (str): The prompt with context. + input (str): The user input. + Returns: + str: The Gemini chat response. + """ + async def get_chat_gemini_response_from_surreal(self,prompt_with_context:str,input:str): + + # Executes the Gemini chat completion query in SurrealDB. + model_version = self.model_data["model_version"].replace("models/","") + gemini_response = await self.connection.query( + LLMModelHander.GEMINI_CHAT_COMPLETE, params={ + "llm":model_version, + "prompt_with_context":prompt_with_context, + "input":input, + "google_token":self.model_params.gemini_token + } ) + # Returns the text content of the response. + return gemini_response["candidates"][0]["content"]["parts"][0]["text"] + + """ + Gets a chat response from the OpenAI API. + + Args: + prompt_with_context (str): The prompt with context. + input (str): The user input. + Returns: + str: The OpenAI chat response. + """ def get_chat_openai_response_from_api(self,prompt_with_context:str,input:str): + # Constructs the messages for the OpenAI API. messages = [ { "role": "system", @@ -272,15 +406,18 @@ def get_chat_openai_response_from_api(self,prompt_with_context:str,input:str): "content": input } ] + # Sets the OpenAI API key. openai.api_key = self.model_params.openai_token if openai.api_key is None: raise ValueError("OPENAI_API_KEY environment variable not set.") try: + # Calls the OpenAI chat completions API. response = openai.chat.completions.create( model=self.model_data["model_version"], messages=messages, temperature=self.model_data["temperature"] ) + # Returns the content of the response. return response.choices[0].message.content except openai.error.OpenAIError as e: print(f"An error occurred: {e}") @@ -289,9 +426,19 @@ def get_chat_openai_response_from_api(self,prompt_with_context:str,input:str): print(f"An unexpected error occurred: {e}") return None + """ + Gets a chat response from the Google Gemini API. + Args: + prompt_with_context (str): The prompt with context. + input (str): The user input. + + Returns: + str: The Gemini chat response. + """ def get_chat_gemini_response_from_api(self,prompt_with_context:str,input:str): + # Constructs the messages for the Gemini API. messages = [ { "text": prompt_with_context @@ -301,13 +448,26 @@ def get_chat_gemini_response_from_api(self,prompt_with_context:str,input:str): } ] genai.configure(api_key=self.model_params.gemini_token) + # Initializes the GenerativeModel with the model version. model = genai.GenerativeModel(self.model_data["model_version"]) + # Generates the content from the model. response = model.generate_content(messages) + # Returns the text content of the response. return response.text + """ + Gets a chat response from Ollama. + + Args: + prompt_with_context (str): The prompt with context. + input (str): The user input. + Returns: + str: The Ollama chat response. + """ def get_chat_response_from_ollama(self,prompt_with_context:str,input:str): + # Constructs the messages for Ollama. messages = [ { "role": "system", @@ -318,7 +478,9 @@ def get_chat_response_from_ollama(self,prompt_with_context:str,input:str): "content": input } ] + # Generates the response from Ollama. response: GenerateResponse = generate(model=self.model_data["model_version"], prompt=str(messages)) + # Optional: Parse DeepSeek response if needed. #parsed_response = parse_deepseek_response(response.response) #return {"response":response, "think": parsed_response["think"],"content":parsed_response["content"]} return response.response diff --git a/surrealdb-rag/src/surrealdb_rag/train_fastText.py b/surrealdb-rag/src/surrealdb_rag/train_wiki_fasttext.py similarity index 88% rename from surrealdb-rag/src/surrealdb_rag/train_fastText.py rename to surrealdb-rag/src/surrealdb_rag/train_wiki_fasttext.py index 6ed2078..72c2f4f 100644 --- a/surrealdb-rag/src/surrealdb_rag/train_fastText.py +++ b/surrealdb-rag/src/surrealdb_rag/train_wiki_fasttext.py @@ -1,4 +1,4 @@ -"""Download OpenAI Wikipedia data.""" +"""Train a fasttext with the wiki data """ import fasttext @@ -19,7 +19,7 @@ def preprocess_text(text): return token -def train_fastText() -> None: +def train_wiki_fastText() -> None: logger = loggers.setup_logger("Train FastText Embedding Model") usecols=[ @@ -39,9 +39,9 @@ def train_fastText() -> None: logger.info(all_text.describe()) logger.info(len(all_text)) - traning_data_file = constants.CUSTOM_FS_PATH + "_train.txt" - model_bin_file = constants.CUSTOM_FS_PATH + ".bin" - model_txt_file = constants.CUSTOM_FS_PATH + traning_data_file = constants.FS_WIKI_PATH + "_train.txt" + model_bin_file = constants.FS_WIKI_PATH + ".bin" + model_txt_file = constants.FS_WIKI_PATH # Save the combined text to a file with open(traning_data_file, "w") as f: for text in all_text: @@ -70,4 +70,4 @@ def train_fastText() -> None: if __name__ == "__main__": - train_fastText() \ No newline at end of file + train_wiki_fastText() \ No newline at end of file diff --git a/surrealdb-rag/static/style.css b/surrealdb-rag/static/style.css index 4aa4240..9d9ed1c 100644 --- a/surrealdb-rag/static/style.css +++ b/surrealdb-rag/static/style.css @@ -56,7 +56,8 @@ nav { right: 15px; } -.close:hover, +.close:hover, +.doc:hover, .close:focus { color: black; text-decoration: none; diff --git a/surrealdb-rag/templates/chat.html b/surrealdb-rag/templates/chat.html index 22c275c..e090c4e 100644 --- a/surrealdb-rag/templates/chat.html +++ b/surrealdb-rag/templates/chat.html @@ -11,6 +11,6 @@
- +
diff --git a/surrealdb-rag/templates/document.html b/surrealdb-rag/templates/document.html index 9d7b16c..e878353 100644 --- a/surrealdb-rag/templates/document.html +++ b/surrealdb-rag/templates/document.html @@ -4,7 +4,7 @@
Title:{{document.title}}
Text:
-
{{document.text}}
+
{{document.text}}
OPENAI vector
{{document.content_openai_vector}}
GLOVE vector
diff --git a/surrealdb-rag/templates/index.html b/surrealdb-rag/templates/index.html index e620eeb..3b73bf3 100644 --- a/surrealdb-rag/templates/index.html +++ b/surrealdb-rag/templates/index.html @@ -43,6 +43,22 @@ } }); + + function clearAllModalHTML(){ + const messageDetail = document.getElementById('message_detail'); + const docDetail = document.getElementById('doc_detail'); + const promptArea = document.getElementById('promptArea'); + if (messageDetail) { + messageDetail.innerHTML = ''; + } + if (docDetail) { + docDetail.innerHTML = ''; + } + if (promptArea) { + promptArea.style.display = 'none'; + } + + } document.addEventListener("DOMContentLoaded", function() { const modal = document.getElementById("myModal"); const closeBtn = document.querySelector(".close"); @@ -57,11 +73,13 @@ if (modal && closeBtn) { closeBtn.onclick = function() { modal.style.display = "none"; + clearAllModalHTML(); }; window.onclick = function(event) { if (event.target === modal) { modal.style.display = "none"; + clearAllModalHTML(); } }; } else { @@ -85,7 +103,19 @@ console.error("Error parsing available_corpus_tables:", e); } + function unescapeHTML(html) { + var temp = document.createElement("div"); + temp.innerHTML = html; + var result = temp.textContent || temp.innerText || ""; + return result; + } + const default_prompt_text = unescapeHTML(`{{default_prompt_text}}`); + + function resetPrompt(){ + const promptTemplateText = document.getElementById("promptTemplateText"); + promptTemplateText.value = default_prompt_text; + } function updateLlmModelSelect(){ const platformSelect = document.getElementById("platformSelect"); @@ -180,8 +210,22 @@ document.addEventListener("DOMContentLoaded", function() { updateEmbedModelSelect(); + + }); + document.addEventListener("DOMContentLoaded", function() { + const togglePrompt = document.getElementById('togglePrompt'); + const togglePromptArea = document.getElementById('togglePromptArea'); + + if (togglePrompt) { + togglePrompt.addEventListener('click', function() { + document.getElementById('myModal').style.display = 'block'; + document.getElementById("promptArea").style.display = "block"; + }); + } else { + console.error("togglePrompt button not found."); + } + }); - @@ -199,8 +243,7 @@
{% if default_corpus_table %} @@ -248,6 +291,10 @@

+
+
+ +
@@ -257,7 +304,14 @@ × - + diff --git a/surrealdb-rag/templates/message.html b/surrealdb-rag/templates/message.html index c25dd1c..f3be5c1 100644 --- a/surrealdb-rag/templates/message.html +++ b/surrealdb-rag/templates/message.html @@ -5,7 +5,7 @@ {% if message.role=="user" %} hx-trigger="load" hx-post="chats/{{ chat_id }}/send-system-message" hx-target=".messages" {% endif %} - hx-swap="beforeend" hx-include="#llmModelSelect, #embedModelSelect, #corpusTableSelect"> + hx-swap="beforeend" hx-include="#promptTemplateText, #llmModelSelect, #embedModelSelect, #corpusTableSelect"> {% else %} > {% endif %} @@ -14,7 +14,8 @@ {{ message.timestamp | convert_timestamp_to_date }} - diff --git a/surrealdb-rag/templates/message_detail.html b/surrealdb-rag/templates/message_detail.html index 23da9bd..4e81136 100644 --- a/surrealdb-rag/templates/message_detail.html +++ b/surrealdb-rag/templates/message_detail.html @@ -6,12 +6,11 @@
role: {{message.role}}
created_at: {{message.created_at}}
updated_at: {{message.updated_at}}
-
content: {{message.content}}
embedding_model: {{message.sent[0].embedding_model}}
llm_model: {{message.sent[0].llm_model}}
timestamp: {{message.sent[0].timestamp}}
- {% if message.sent[0].referenced_documents %} +----------------------
Referenced Documents: {%for doc in message.sent[0].referenced_documents %} @@ -25,8 +24,20 @@ {% endfor %}
{% endif %} -
prompt: {{message.sent[0].prompt_text}}
+---------------------- +
+
content: 
+        {{message.content}}
+ + +{% if message.sent[0].prompt_text %} +---------------------- +
+
prompt: 
+    {{message.sent[0].prompt_text}}
+ +{% endif %} - + + + @@ -235,7 +273,7 @@
Data Set: @@ -289,6 +327,10 @@
+ +
+ +
@@ -299,9 +341,6 @@
+ + diff --git a/surrealdb-rag/templates/query_results.html b/surrealdb-rag/templates/query_results.html new file mode 100644 index 0000000..e488b60 --- /dev/null +++ b/surrealdb-rag/templates/query_results.html @@ -0,0 +1,21 @@ + + + + + + + + + + + {% for row in query_results %} + + + + + {% endfor %} + +
ValueCount
{{ row.val }}{{ row.count }}
+ \ No newline at end of file From 221bf746e0be90364b00b3bec2ba568f86f71ed2 Mon Sep 17 00:00:00 2001 From: Alessandro Pireno Date: Sun, 16 Mar 2025 10:44:34 -0400 Subject: [PATCH 9/9] Updated data extraction method --- surrealdb-rag/.vscode/launch.json | 4 +- surrealdb-rag/notes for readme.txt | 4 +- surrealdb-rag/pyproject.toml | 5 +- surrealdb-rag/requrements.txt | 5 +- .../src/surrealdb_rag/download_edgar_data.py | 69 ++---- .../edgar_build_csv_append_vectors.py | 5 +- .../src/surrealdb_rag/fin_data_extractor.py | 200 ++++++++++++++++++ .../src/surrealdb_rag/insert_edgar.py | 1 + surrealdb-rag/src/surrealdb_rag/scripts.py | 16 +- 9 files changed, 240 insertions(+), 69 deletions(-) create mode 100644 surrealdb-rag/src/surrealdb_rag/fin_data_extractor.py diff --git a/surrealdb-rag/.vscode/launch.json b/surrealdb-rag/.vscode/launch.json index c80ebdc..eb079ae 100644 --- a/surrealdb-rag/.vscode/launch.json +++ b/surrealdb-rag/.vscode/launch.json @@ -22,7 +22,7 @@ "python": "/Users/sandro/git_repos/examples-1/surrealdb-rag/.venv/bin/python", // Explicitly set your venv Python interpreter "args": [ // Array of command-line arguments "-fsv", - "EDGAR 10ks", + "EDGAR Data", "-ems", "GLOVE,FASTTEXT", "-tn","embedded_edgar", @@ -42,7 +42,5 @@ ] } - - ] } \ No newline at end of file diff --git a/surrealdb-rag/notes for readme.txt b/surrealdb-rag/notes for readme.txt index 02d45af..0812901 100644 --- a/surrealdb-rag/notes for readme.txt +++ b/surrealdb-rag/notes for readme.txt @@ -54,8 +54,8 @@ python ./src/surrealdb_rag/download_edgar_data.py python ./src/surrealdb_rag/edgar_train_fasttext.py -python ./src/surrealdb_rag/insert_embedding_model.py -emtr FASTTEXT -emv "EDGAR 10ks" -emp data/custom_fast_edgar_text.txt -des "Model trained on 10-K filings for 30 days prior to March 11 2025" -cor "10k filing data from https://www.sec.gov/edgar/search/" +python ./src/surrealdb_rag/insert_embedding_model.py -emtr FASTTEXT -emv "EDGAR Data" -emp data/custom_fast_edgar_text.txt -des "Model trained on 10-K filings for 30 days prior to March 11 2025" -cor "10k filing data from https://www.sec.gov/edgar/search/" python ./src/surrealdb_rag/edgar_build_csv_append_vectors.py -python ./src/surrealdb_rag/insert_edgar.py -fsv "EDGAR 10ks" -ems GLOVE,FASTTEXT \ No newline at end of file +python ./src/surrealdb_rag/insert_edgar.py -fsv "EDGAR Data" -ems GLOVE,FASTTEXT \ No newline at end of file diff --git a/surrealdb-rag/pyproject.toml b/surrealdb-rag/pyproject.toml index 46be5b5..5ce4821 100644 --- a/surrealdb-rag/pyproject.toml +++ b/surrealdb-rag/pyproject.toml @@ -25,7 +25,10 @@ dependencies = [ "openai", "fasttext", "edgartools", - "bs4" + "bs4", + "spacy", + "transformers", + "torch" ] [project.scripts] diff --git a/surrealdb-rag/requrements.txt b/surrealdb-rag/requrements.txt index 273c8f8..e47b548 100644 --- a/surrealdb-rag/requrements.txt +++ b/surrealdb-rag/requrements.txt @@ -14,4 +14,7 @@ google.generativeai openai fasttext edgartools -bs4 \ No newline at end of file +bs4 +spacy +transformers +torch \ No newline at end of file diff --git a/surrealdb-rag/src/surrealdb_rag/download_edgar_data.py b/surrealdb-rag/src/surrealdb_rag/download_edgar_data.py index 64dabe8..0c96f9f 100644 --- a/surrealdb-rag/src/surrealdb_rag/download_edgar_data.py +++ b/surrealdb-rag/src/surrealdb_rag/download_edgar_data.py @@ -1,8 +1,5 @@ """Download OpenAI Wikipedia data.""" -import zipfile - -import wget import os from surrealdb_rag import loggers @@ -11,18 +8,17 @@ from surrealdb_rag.embeddings import WordEmbeddingModel -import pandas as pd import tqdm import datetime -from surrealdb_rag.constants import DatabaseParams, ModelParams, ArgsLoader, SurrealParams +from surrealdb_rag.constants import DatabaseParams, ModelParams, ArgsLoader import csv -from bs4 import BeautifulSoup -import ast - import edgar +from surrealdb_rag.fin_data_extractor import extract_text_from_edgar_html + + # Initialize database and model parameters, and argument loader db_params = DatabaseParams() model_params = ModelParams() @@ -31,46 +27,6 @@ #https://pypi.org/project/sec-api/ # gen an api key here https://sec-api.io/profile -def extract_text_from_html(html_content): - """ - Extracts text content from an HTML 10-K filing, mimicking browser rendering. - - Args: - html_content (str): The HTML content of the 10-K filing. - - Returns: - str: The extracted text content, with HTML tags removed and basic formatting preserved. - """ - soup = BeautifulSoup(html_content, 'html.parser') # Parse the HTML - - # 1. Extract text from the tag (most content is here) - body_text = soup.body.get_text(separator='\n', strip=True) - # separator='\n' puts a newline between different block-level elements, improving readability - # strip=True removes leading/trailing whitespace - - # 2. (Optional but Recommended) Further refine by focusing on main content areas - # 10-Ks often have header/footer and navigation that you might want to exclude. - # You might need to inspect the HTML of a few 10-Ks to identify consistent - # container divs or sections that hold the main content. - # For example, if you find a main content div with a specific ID or class, you could do: - # main_content_div = soup.find('div', {'id': 'document'}) # Example ID, inspect your HTML - # if main_content_div: - # body_text = main_content_div.get_text(separator='\n', strip=True) - # else: - # body_text = soup.body.get_text(separator='\n', strip=True) # Fallback to body if main content not found - - # 3. (Optional) Handle Tables more explicitly (if needed for your NLP tasks) - # If tables are important and you want to preserve their structure somewhat, you could: - table_texts = [] - for table in soup.find_all('table'): - table_text = "" - for row in table.find_all('tr'): - row_text = " | ".join([cell.get_text(strip=True) for cell in row.find_all(['td', 'th'])]) - table_text += row_text + "\n" - table_texts.append(table_text) - body_text += "\n\n" + "\n\n".join(table_texts) # Add table text back in, separated - - return body_text def file_name_from_url(url:str): return url.replace("https://","").replace("http://","").replace(".","_").replace("/","_") @@ -83,7 +39,7 @@ def process_filing(filing:edgar.Filing,dict_writer:csv.DictWriter): file_path = f"{constants.EDGAR_FOLDER}{file_name_from_url(filing.filing_url)}.txt" if not os.path.exists(file_path): html_file = filing.html() - text_content = extract_text_from_html(html_file) + text_content = extract_text_from_edgar_html(html_file,filing.form) with open(file_path, "w") as f: f.write(text_content) @@ -106,9 +62,14 @@ def process_filing(filing:edgar.Filing,dict_writer:csv.DictWriter): } dict_writer.writerow(row) return row - except: + except Exception as e: + try: + url = filing.filing_url + except Exception as e: + url = "Undeterminable" + row = { - "url":"error", + "url":url, "company_name":"", "cik":"", "form":"", @@ -122,7 +83,8 @@ def process_filing(filing:edgar.Filing,dict_writer:csv.DictWriter): "company.sic":"", "company.website":"", "filing_date":"", - "file_path":""} + "file_path":"", + "error":str(e)} dict_writer.writerow(row) return row @@ -300,7 +262,8 @@ def download_edgar_data() -> None: "company.sic":"", "company.website":"", "filing_date":"", - "file_path":""}.keys() + "file_path":"", + "error":""}.keys() with open(index_file,"w", newline='') as f: dict_writer = csv.DictWriter(f, file_keys) dict_writer.writeheader() diff --git a/surrealdb-rag/src/surrealdb_rag/edgar_build_csv_append_vectors.py b/surrealdb-rag/src/surrealdb_rag/edgar_build_csv_append_vectors.py index 0c74bec..de2335e 100644 --- a/surrealdb-rag/src/surrealdb_rag/edgar_build_csv_append_vectors.py +++ b/surrealdb-rag/src/surrealdb_rag/edgar_build_csv_append_vectors.py @@ -130,7 +130,7 @@ def create_csv_from_folder(logger,file_index_df: pd.DataFrame , output_file_path dict_writer.writeheader() for index, file in tqdm.tqdm(file_index_df.iterrows(), total=len(file_index_df), desc=f"Processing files"): - if file["url"] != "error" and os.path.exists(file["file_path"]): + if file["file_path"] and os.path.exists(file["file_path"]): with open(file["file_path"]) as source: file_contents = source.read() chunks = generate_chunks(file_contents,chunk_size) @@ -169,6 +169,9 @@ def create_csv_from_folder(logger,file_index_df: pd.DataFrame , output_file_path } chunk_number += 1 dict_writer.writerow(row) + else: + logger.error(f"File not found: '{file["file_path"]}'") + logger.info(f"CSV generation complete. Corpus saved to '{output_file_path}'.") diff --git a/surrealdb-rag/src/surrealdb_rag/fin_data_extractor.py b/surrealdb-rag/src/surrealdb_rag/fin_data_extractor.py new file mode 100644 index 0000000..49b6959 --- /dev/null +++ b/surrealdb-rag/src/surrealdb_rag/fin_data_extractor.py @@ -0,0 +1,200 @@ +from bs4 import BeautifulSoup, NavigableString +import re + +def extract_text_from_edgar_html(html_content, form_type): + """ + Extracts text from Edgar HTML, with improved table handling for FastText. + + Args: + html_content (str): HTML content. + form_type (str): Form type (e.g., "10-K"). + + Returns: + str: Extracted text. + """ + soup = BeautifulSoup(html_content, 'lxml') + + for tag in soup.find_all(['script', 'style', 'head', 'meta', 'img']): + tag.decompose() + + relevant_text = [] + if form_type.upper() in ("10-K", "10-Q"): + sections_to_find = [ # ... (Same list as before) ... + "item 1\. business", "item 1a\. risk factors", "item 1b\. unresolved staff comments", + "item 2\. properties", "item 3\. legal proceedings", + "item 4\. mine safety disclosures", + "item 5\. market for registrant’s common equity, related stockholder matters and issuer purchases of equity securities", + "item 6\. \[reserved\]", + "item 7\. management’s discussion and analysis of financial condition and results of operations", + "item 7a\. quantitative and qualitative disclosures about market risk", + "item 8\. financial statements and supplementary data", + "item 9\. changes in and disagreements with accountants on accounting and financial disclosure", + "item 9a\. controls and procedures", + "item 9b\. other information", + "item 9b\. disclosure regarding foreign jurisdictions that prevent inspections", + "item 10\. directors, executive officers and corporate governance", + "item 11\. executive compensation", + "item 12\. security ownership of certain beneficial owners and management and related stockholder matters", + "item 13\. certain relationships and related transactions, and director independence", + "item 14\. principal accountant fees and services", + "part i", + "part ii", + "part iii", + "part iv", ] + + for i in range(len(sections_to_find)): + start_section = sections_to_find[i] + start_tag = find_section(soup, start_section) + + if start_tag: + end_tag_text = sections_to_find[i + 1] if i + 1 < len(sections_to_find) else None + section_text = extract_text_between(start_tag, end_tag_text) # Use updated extract_text_between + relevant_text.append(section_text) + + elif form_type.upper() in ("SC 13D", "SC 13G"): + items_to_find = [ + "item 1\. security and issuer", + "item 2\. identity and background", + "item 3\. source and amount of funds or other consideration", + "item 4\. purpose of transaction", + "item 5\. interest in securities of the issuer", + "item 6\. contracts, arrangements, understandings or relationships with respect to securities of the issuer", + "item 7\. material to be filed as exhibits", + "signature" + ] + for i in range(len(items_to_find)): + start_item = items_to_find[i] + start_tag = find_section(soup, start_item) + if start_tag: + end_tag_text = items_to_find[i+1] if i + 1 < len(items_to_find) else None + item_text = extract_text_between(start_tag, end_tag_text) + relevant_text.append(item_text) + + elif form_type.upper() in ("S-1", "S-4"): + common_headings = [ + "summary", + "risk factors", + "use of proceeds", + "dividend policy", + "capitalization", + "dilution", + "selected financial data", + "management’s discussion and analysis", + "business", + "management", + "certain relationships and related transactions", + "principal stockholders", + "description of securities", + "underwriting", + "legal matters", + "experts", + "where you can find more information", + "incorporation of certain information by reference", + ] + for heading in common_headings: + start_tag = find_section(soup, heading) + if start_tag: + next_heading_index = common_headings.index(heading) + 1 + end_tag_text = common_headings[next_heading_index] if next_heading_index < len(common_headings) else None + section_text = extract_text_between(start_tag, end_tag_text) + relevant_text.append(section_text) + + all_paragraphs = [] + for p_tag in soup.find_all('p'): + paragraph_text = p_tag.get_text(separator=" ", strip=True) + if paragraph_text: + all_paragraphs.append(paragraph_text) + + combined_text = "\n\n".join(relevant_text) + for paragraph in all_paragraphs: + if paragraph not in combined_text: + combined_text += "\n" + paragraph + + combined_text = re.sub(r'\s+', ' ', combined_text) + combined_text = combined_text.strip() + return combined_text + + +def find_section(soup_obj, section_start): + """(Same as before - reused)""" + section_start_lower = section_start.lower() + start_tag = soup_obj.find(string=re.compile(r'^\s*' + re.escape(section_start_lower), re.IGNORECASE)) + if start_tag: + return start_tag + start_tag = soup_obj.find('b', string=re.compile(r'^\s*' + re.escape(section_start_lower), re.IGNORECASE)) + if start_tag: + return start_tag + for span in soup_obj.find_all('span'): + if span.get('style') and 'font-weight' in span.get('style').lower() and 'bold' in span.get('style').lower(): + if re.search(r'^\s*' + re.escape(section_start_lower), span.get_text(), re.IGNORECASE): + return span + start_tag = soup_obj.find('p', string=re.compile(r'^\s*' + re.escape(section_start_lower), re.IGNORECASE)) + if start_tag: + return start_tag + return None + +def extract_text_between(start_tag, end_tag_text=None): + """Extracts text, with improved table handling.""" + if not start_tag: + return "" + + extracted_text = [] + current_tag = start_tag.find_next() + + while current_tag and (end_tag_text is None or not current_tag.find(string=re.compile(end_tag_text, re.IGNORECASE))): + if isinstance(current_tag, NavigableString): + extracted_text.append(current_tag.strip()) + elif current_tag.name == 'table': + # Improved Table Handling + table_sentences = process_table_to_sentences(current_tag) + extracted_text.extend(table_sentences) + current_tag = current_tag.find_next() + + return " ".join(extracted_text) + +def process_table_to_sentences(table_tag): + """ + Converts an HTML table to a list of natural language sentences. + + Args: + table_tag (bs4.element.Tag): The BeautifulSoup tag. + + Returns: + list[str]: A list of sentences representing the table's content. + """ + sentences = [] + rows = table_tag.find_all('tr') + if not rows: + return sentences + + # 1. Try to Extract Headers (if they exist and are reasonably formatted) + header_row = rows[0] + headers = [th.get_text(strip=True) for th in header_row.find_all(['th', 'td'])] + has_headers = len(headers) > 0 + + # 2. Process Data Rows + for row in rows[1 if has_headers else 0:]: # Skip header row if we found headers + cells = [td.get_text(strip=True) for td in row.find_all('td')] + + if not cells: # Skip if empty row + continue + + # --- Simple Strategy (Suitable for many tables) --- + if has_headers and len(headers) == len(cells): + # Create sentences like "Header1: Cell1, Header2: Cell2, ..." + sentence = ", ".join([f"{headers[i]}: {cells[i]}" for i in range(len(cells))]) + sentences.append(sentence) + else: + # If no headers, or header/cell count mismatch, just join cells with spaces. + sentence = " ".join(cells) + sentences.append(sentence) + + # 3. Number and Currency Handling + # - Replace $ signs with "USD " (or your currency of choice). + final_sentences = [] + for sentence in sentences: + sentence = re.sub(r'\$', 'USD ', sentence) # Replace $ with "USD " + sentence = re.sub(r'(\d),(\d)', r'\1\2', sentence) #9,941 --> 9941 remove commas within numbers. + final_sentences.append(sentence) + + return final_sentences \ No newline at end of file diff --git a/surrealdb-rag/src/surrealdb_rag/insert_edgar.py b/surrealdb-rag/src/surrealdb_rag/insert_edgar.py index ce113be..7d88940 100644 --- a/surrealdb-rag/src/surrealdb_rag/insert_edgar.py +++ b/surrealdb-rag/src/surrealdb_rag/insert_edgar.py @@ -211,6 +211,7 @@ def surreal_edgar_insert() -> None: embed_model_mappings.append({"model_id": model_definition, "field_name": field_name}) + logger.info(f"Calculating rows in file {input_file}") # num_rows_csv = sum(1 for row in open(input_file, 'r', encoding='utf-8')) - 1 # Subtract header row num_rows_csv = count_csv_rows_pandas_chunked(input_file) diff --git a/surrealdb-rag/src/surrealdb_rag/scripts.py b/surrealdb-rag/src/surrealdb_rag/scripts.py index 3f57bef..8b102b4 100644 --- a/surrealdb-rag/src/surrealdb_rag/scripts.py +++ b/surrealdb-rag/src/surrealdb_rag/scripts.py @@ -127,9 +127,9 @@ def download_edgar(): # python ./src/surrealdb_rag/edgar_train_fasttext.py def train_edgar(): - run_process(["python", "./src/surrealdb_rag/download_edgar_data.py"]) + run_process(["python", "./src/surrealdb_rag/edgar_train_fasttext.py"]) -# python ./src/surrealdb_rag/insert_embedding_model.py -emtr FASTTEXT -emv "EDGAR 10ks" -emp data/custom_fast_edgar_text.txt -des "Model trained on 10-K filings for 30 days prior to March 11 2025" -cor "10k filing data from https://www.sec.gov/edgar/search/" +# python ./src/surrealdb_rag/insert_embedding_model.py -emtr FASTTEXT -emv "EDGAR Data" -emp data/custom_fast_edgar_text.txt -des "Model trained on 10-K filings for 30 days prior to March 11 2025" -cor "10k filing data from https://www.sec.gov/edgar/search/" def insert_edgar_fs(): # Alias definition IN this file """Runs the embedding model insertion for GLOVE.""" command = [ @@ -138,7 +138,7 @@ def insert_edgar_fs(): # Alias definition IN this file "-emtr", "FASTTEXT", "-emv", - "EDGAR 10ks", + "EDGAR Data", "-emp", "data/custom_fast_edgar_text.txt", "-des", @@ -161,14 +161,14 @@ def add_vectors_to_edgar(): "-edf","10-K,10-Q",]) -# python ./src/surrealdb_rag/insert_edgar.py -fsv "EDGAR 10ks" -ems GLOVE,FASTTEXT +# python ./src/surrealdb_rag/insert_edgar.py -fsv "EDGAR Data" -ems GLOVE,FASTTEXT def insert_edgar(): # Alias definition IN this file """Runs the embedding model insertion for EDGAR data.""" command = [ "python", "./src/surrealdb_rag/insert_edgar.py", # Path to the script "-fsv", # Flag for fast_text_version - "EDGAR 10ks", # Value for fast_text_version (NO quotes needed) + "EDGAR Data", # Value for fast_text_version (NO quotes needed) "-ems", # Flag for embed_models "GLOVE,FASTTEXT", # Value for embed_models (NO quotes needed), "-tn","embedded_edgar", @@ -196,7 +196,7 @@ def insert_ai_industry_edgar(): "python", "./src/surrealdb_rag/insert_edgar.py", # Path to the script "-fsv", # Flag for fast_text_version - "EDGAR 10ks", # Value for fast_text_version (NO quotes needed) + "EDGAR Data", # Value for fast_text_version (NO quotes needed) "-ems", # Flag for embed_models "GLOVE,FASTTEXT", # Value for embed_models (NO quotes needed), "-tn","embedded_edgar_ai", @@ -226,7 +226,7 @@ def insert_large_chunk_edgar(): "python", "./src/surrealdb_rag/insert_edgar.py", # Path to the script "-fsv", # Flag for fast_text_version - "EDGAR 10ks", # Value for fast_text_version (NO quotes needed) + "EDGAR Data", # Value for fast_text_version (NO quotes needed) "-ems", # Flag for embed_models "GLOVE,FASTTEXT", # Value for embed_models (NO quotes needed), "-tn","embedded_edgar_lc", @@ -252,7 +252,7 @@ def add_large_chunk_edgar_data(): -# python ./src/surrealdb_rag/insert_edgar.py -fsv "EDGAR 10ks" -ems GLOVE,FASTTEXT +# python ./src/surrealdb_rag/insert_edgar.py -fsv "EDGAR Data" -ems GLOVE,FASTTEXT def app(): # Alias definition IN this file """Runs UX for the app.""" command = [