diff --git a/images/db_insert.png b/images/db_insert.png
new file mode 100644
index 000000000..7f8f80891
Binary files /dev/null and b/images/db_insert.png differ
diff --git a/notebooks/llms/langchain/readthedocs_rag_zilliz.ipynb b/notebooks/llms/langchain/readthedocs_rag_zilliz.ipynb
index 3f4c209a0..6325037e1 100755
--- a/notebooks/llms/langchain/readthedocs_rag_zilliz.ipynb
+++ b/notebooks/llms/langchain/readthedocs_rag_zilliz.ipynb
@@ -13,16 +13,16 @@
"id": "f6ffd11a",
"metadata": {},
"source": [
- "In this notebook, we are going to use Milvus documentation pages to create a chatbot about our product.\n",
+ "In this notebook, we are going to use Milvus documentation pages to create a chatbot about our product. The chatbot is going to follow RAG steps to retrieve chunks of data using Semantic Vector Search, then the Question + Context will be fed as a Prompt to a LLM to generate an answer.\n",
"\n",
- "A chatbot is going to follow RAG steps to retrieve chunks of data using Semantic Vector Search, then the Question + Context will be fed as a Prompt to a LLM to generate an answer.\n",
+ "Many RAG demos use OpenAI for the Embedding Model and ChatGPT for the Generative AI model. **In this notebook, we will demo a fully open source RAG stack.**\n",
+ "\n",
+ "Using open-source Q&A with retrieval saves money since we make free calls to our own data almost all the time - retrieval, evaluation, and development iterations. We only make a paid call to OpenAI once for the final chat generation step. \n",
"\n",
"
\n",
"
\n",
"
\n",
"\n",
- "Many RAG demos use OpenAI for the Embedding Model and ChatGPT for the Generative AI model. In this notebook, we will demo a fully open source RAG stack - open source embedding model available on HuggingFace, Milvus, and an open source LLM.\n",
- "\n",
"Let's get started!"
]
},
@@ -46,7 +46,13 @@
"id": "e059b674",
"metadata": {},
"source": [
- "## Download Milvus documentation to a local directory."
+ "## Download Milvus documentation to a local directory.\n",
+ "\n",
+ "The data we’ll use is our own product documentation web pages. ReadTheDocs is an open-source free software documentation hosting platform, where documentation is written with the Sphinx document generator.\n",
+ "\n",
+ "The code block below downloads the web pages into a local directory called `rtdocs`. \n",
+ "\n",
+ "I've already uploaded the `rtdocs` data folder to github, so you should see it if you cloned my repo."
]
},
{
@@ -56,7 +62,7 @@
"metadata": {},
"outputs": [],
"source": [
- "# # Uncomment to download readthedocs page locally.\n",
+ "# # Uncomment to download readthedocs pages locally.\n",
"\n",
"# DOCS_PAGE=\"https://pymilvus.readthedocs.io/en/latest/\"\n",
"# !echo $DOCS_PAGE\n",
@@ -79,15 +85,10 @@
"metadata": {},
"source": [
"Code in this notebook uses fully-managed Milvus on [Ziliz Cloud free trial](https://cloud.zilliz.com/login). Choose the default \"Starter\" option when you provision > Create collection > Give it a name > Create cluster and collection.\n",
- "- pip install pymilvus\n",
"\n",
- "💡 **For production purposes**, use a local Milvus docker, Milvus clusters, or fully-managed Milvus on Zilliz Cloud.\n",
- "- [Local Milvus docker](https://milvus.io/docs/install_standalone-docker.md) requires local docker installed and running.\n",
- "- [Milvus clusters](https://milvus.io/docs/install_cluster-milvusoperator.md) requires a K8s cluster up and running.\n",
- "- [Milvus client](https://milvus.io/docs/using_milvusclient.md) with [Milvus lite](https://milvus.io/docs/milvus_lite.md), which runs a local server. ⛔️ Milvus lite is only meant for demos and local testing.\n",
+ "💡 Note: To keep your tokens private, best practice is to use an **env variable**. See [how to save api key in env variable](https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety).
\n",
"\n",
- "💡 Note: To keep your tokens private, best practice is to use an env variable.\n",
- "In Jupyter, need .env file (in same dir as notebooks) containing lines like this:\n",
+ "In Jupyter, you also need a .env file (in same dir as notebooks) containing lines like this:\n",
"- VARIABLE_NAME=value\n"
]
},
@@ -106,6 +107,7 @@
}
],
"source": [
+ "# !pip install pymilvus #python sdk for milvus\n",
"from pymilvus import connections, utility\n",
"\n",
"import os\n",
@@ -113,8 +115,9 @@
"load_dotenv()\n",
"TOKEN = os.getenv(\"ZILLIZ_API_KEY\")\n",
"\n",
- "# Connect to Zilliz cloud.\n",
- "CLUSTER_ENDPOINT=\"https://in03-e3348b7ab973336.api.gcp-us-west1.zillizcloud.com:443\"\n",
+ "# Connect to Zilliz cloud using enpoint URI and API key TOKEN.\n",
+ "# TODO change this before checking into github.\n",
+ "CLUSTER_ENDPOINT=\"https://in03-xxxx.api.gcp-us-west1.zillizcloud.com:443\"\n",
"connections.connect(\n",
" alias='default',\n",
" # Public endpoint obtained from Zilliz Cloud\n",
@@ -133,7 +136,7 @@
"metadata": {},
"source": [
"## Load the Embedding Model checkpoint and use it to create vector embeddings\n",
- "**Embedding model:** We will use the open-source [sentence transformers](https://www.sbert.net/docs/pretrained_models.html) available on HuggingFace to encode the documentation text. We will download the model from HuggingFace and run it locally. We'll save the model's generated embeedings to a pandas dataframe and then into the milvus database.\n",
+ "**Embedding model:** We will use the open-source [sentence transformers](https://www.sbert.net/docs/pretrained_models.html) available on HuggingFace to encode the documentation text. We will download the model from HuggingFace and run it locally. \n",
"\n",
"Two model parameters of note below:\n",
"1. EMBEDDING_LENGTH refers to the dimensionality or length of the embedding vector. In this case, the embeddings generated for EACH token in the input text will have the SAME length = 768. This size of embedding is often associated with BERT-based models, where the embeddings are used for downstream tasks such as classification, question answering, or text generation.
\n",
@@ -199,6 +202,9 @@
"You can think of a collection in Milvus like a \"table\" in SQL databases. The **collection** will contain the \n",
"- **Schema** (or no-schema Milvus Client). \n",
"💡 You'll need the vector `EMBEDDING_LENGTH` parameter from your embedding model.\n",
+ "Typical values are:\n",
+ " - 768 for sbert embedding models\n",
+ " - 1536 for ada-002 OpenAI embedding models\n",
"- **Vector index** for efficient vector search\n",
"- **Vector distance metric** for measuring nearest neighbor vectors\n",
"- **Consistency level**\n",
@@ -206,7 +212,6 @@
"\n",
"Some supported [data types](https://milvus.io/docs/schema.md) for Milvus schemas are:\n",
"- INT64 - primary key\n",
- "- VARCHAR - raw texts\n",
"- FLOAT_VECTOR - embedings = list of `numpy.ndarray` of `numpy.float32` numbers"
]
},
@@ -220,7 +225,7 @@
"output_type": "stream",
"text": [
"Embedding length: 768\n",
- "Created collection: MIlvusDocs\n",
+ "Successfully created collection: `MilvusDocs`\n",
"Schema: {'auto_id': True, 'description': 'The schema for docs pages', 'fields': [{'name': 'pk', 'description': '', 'type': , 'is_primary': True, 'auto_id': True}, {'name': 'vector', 'description': '', 'type': , 'params': {'dim': 768}}], 'enable_dynamic_field': True}\n"
]
}
@@ -231,26 +236,32 @@
" CollectionSchema, Collection)\n",
"\n",
"# 1. Name your collection.\n",
- "COLLECTION_NAME = \"MIlvusDocs\"\n",
+ "COLLECTION_NAME = \"MilvusDocs\"\n",
"\n",
"# 2. Use embedding length from the embedding model.\n",
"print(f\"Embedding length: {EMBEDDING_LENGTH}\")\n",
"\n",
- "# 3. Define minimum required fields.\n",
+ "# 3. Define a minimum expandable schema.\n",
"fields = [\n",
- " FieldSchema(name=\"pk\", dtype=DataType.INT64, is_primary=True, auto_id=True),\n",
- " FieldSchema(name=\"vector\", dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_LENGTH),\n",
+ " FieldSchema(\"pk\", DataType.INT64, is_primary=True, auto_id=True),\n",
+ " FieldSchema(\"vector\", DataType.FLOAT_VECTOR, dim=EMBEDDING_LENGTH),\n",
"]\n",
- "\n",
- "# 4. Create schema with dynamic field enabled.\n",
"schema = CollectionSchema(\n",
"\t\tfields,\n",
"\t\tdescription=\"The schema for docs pages\",\n",
"\t\tenable_dynamic_field=True\n",
")\n",
+ "\n",
+ "# 4. Check if collection already exists, if so drop it.\n",
+ "has = utility.has_collection(COLLECTION_NAME)\n",
+ "if has:\n",
+ " drop_result = utility.drop_collection(COLLECTION_NAME)\n",
+ " print(f\"Successfully dropped collection: `{COLLECTION_NAME}`\")\n",
+ "\n",
+ "# 5. Create the collection.\n",
"mc = Collection(COLLECTION_NAME, schema, consistency_level=\"Eventually\")\n",
"\n",
- "print(f\"Created collection: {COLLECTION_NAME}\")\n",
+ "print(f\"Successfully created collection: `{COLLECTION_NAME}`\")\n",
"print(f\"Schema: {mc.schema}\")"
]
},
@@ -260,7 +271,9 @@
"source": [
"## Add a Vector Index\n",
"\n",
- "The vector index determines the vector **search algorithm** used to find the closest vectors in your data to the query a user submits. Most vector indexes use different sets of parameters depending on whether the database is:\n",
+ "The vector index determines the vector **search algorithm** used to find the closest vectors in your data to the query a user submits. \n",
+ "\n",
+ "Most vector indexes use different sets of parameters depending on whether the database is:\n",
"- **inserting vectors** (creation mode) - vs - \n",
"- **searching vectors** (search mode) \n",
"\n",
@@ -348,12 +361,16 @@
"## Chunking\n",
"\n",
"Before embedding, it is necessary to decide your chunk strategy, chunk size, and chunk overlap. In this demo, I will use:\n",
- "- **Strategy** = Use markdown header hierarchies. Split markdown sections if too long.\n",
+ "- **Strategy** = Use markdown header hierarchies. Keep markdown sections together unless they are too long.\n",
"- **Chunk size** = Use the embedding model's parameter `MAX_SEQ_LENGTH`\n",
"- **Overlap** = Rule-of-thumb 10-15%\n",
"- **Function** = \n",
" - Langchain's `HTMLHeaderTextSplitter` to split markdown sections.\n",
- " - Langchain's `RecursiveCharacterTextSplitter` to split up long reviews recursively.\n"
+ " - Langchain's `RecursiveCharacterTextSplitter` to split up long reviews recursively.\n",
+ "\n",
+ "\n",
+ "Notice below, each chunk is grounded with the document source page.
\n",
+ "In addition, header titles are kept together with the chunk of markdown text."
]
},
{
@@ -365,13 +382,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "chunking time: 0.01805710792541504\n",
+ "chunking time: 0.023195981979370117\n",
"docs: 15, split into: 15\n",
"split into chunks: 159, type: list of \n",
"\n",
"Looking at a sample chunk...\n",
- "{'h1': 'Installation', 'h2': 'Installing via pip', 'source': 'rtdocs/pymilvus.readthedocs.io/en/latest/install.html'}\n",
- "demonstrate how to install and using PyMilvus in a virtual environment. See virtualenv for more info\n"
+ "Installation¶ Installing via pip¶ PyMilvus is in the Python Package Index. PyMilvus only support pyt\n",
+ "{'h1': 'Installation', 'h2': 'Installing via pip', 'source': 'rtdocs/pymilvus.readthedocs.io/en/latest/install.html'}\n"
]
}
],
@@ -382,7 +399,6 @@
"headers_to_split_on = [\n",
" (\"h1\", \"Header 1\"),\n",
" (\"h2\", \"Header 2\"),\n",
- " (\"h3\", \"Header 3\"),\n",
"]\n",
"# Create an instance of the HTMLHeaderTextSplitter\n",
"html_splitter = HTMLHeaderTextSplitter(headers_to_split_on=headers_to_split_on)\n",
@@ -430,13 +446,13 @@
"print(f\"docs: {len(docs)}, split into: {len(html_header_splits)}\")\n",
"print(f\"split into chunks: {len(chunks)}, type: list of {type(chunks[0])}\") \n",
"\n",
- "# Inspect chunks.\n",
+ "# Inspect a chunk.\n",
"print()\n",
"print(\"Looking at a sample chunk...\")\n",
- "print(chunks[1].metadata)\n",
- "print(chunks[1].page_content[:100])\n",
+ "print(chunks[0].page_content[:100])\n",
+ "print(chunks[0].metadata)\n",
"\n",
- "# TODO - remove this before saving in github.\n",
+ "# # TODO - remove this before saving in github.\n",
"# # Print the child splits with their associated header metadata\n",
"# print()\n",
"# for child in chunks:\n",
@@ -478,9 +494,22 @@
"source": [
"## Insert data into Milvus\n",
"\n",
- "Milvus and Milvus Lite support loading pandas dataframes directly.\n",
+ "For each original text chunk, we'll write the quadruplet (`vector, text, source, h1, h2`) into the database.\n",
"\n",
- "Milvus Client, however, requires conerting pandas df into a list of dictionaries first.\n"
+ "\n",
+ "
\n",
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Milvus and Milvus Lite support loading data from:\n",
+ "- pandas dataframes \n",
+ "- list of dictionaries\n",
+ "\n",
+ "Below, we use the embedding model provided by HuggingFace, download its checkpoint, and run it locally as the encoder. "
]
},
{
@@ -489,9 +518,11 @@
"metadata": {},
"outputs": [],
"source": [
- "# Convert chunks and embeddings to a list of dictionaries.\n",
+ "# Convert chunks to a list of dictionaries.\n",
"chunk_list = []\n",
"for chunk in chunks:\n",
+ "\n",
+ " # Generate embeddings using encoder from HuggingFace.\n",
" embeddings = torch.tensor(encoder.encode([chunk.page_content]))\n",
" embeddings = F.normalize(embeddings, p=2, dim=1)\n",
" converted_values = list(map(np.float32, embeddings))[0]\n",
@@ -501,9 +532,10 @@
" h2 = chunk.metadata['h2'][:50]\n",
" except:\n",
" h2 = \"\"\n",
+ " # Assemble embedding vector, original text chunk, metadata.\n",
" chunk_dict = {\n",
" 'vector': converted_values,\n",
- " 'chunk': chunk.page_content,\n",
+ " 'text': chunk.page_content,\n",
" 'source': chunk.metadata['source'],\n",
" 'h1': chunk.metadata['h1'][:50],\n",
" 'h2': h2,\n",
@@ -526,14 +558,13 @@
"output_type": "stream",
"text": [
"Start inserting entities\n",
- "Milvus insert time for 159 vectors: 1.0154786109924316 seconds\n",
- "(insert count: 159, delete count: 0, upsert count: 0, timestamp: 445785288603074562, success count: 159, err count: 0)\n",
- "[{\"name\":\"_default\",\"collection_name\":\"MIlvusDocs\",\"description\":\"\"}]\n"
+ "Milvus insert time for 159 vectors: 0.7005021572113037 seconds\n",
+ "[{\"name\":\"_default\",\"collection_name\":\"MilvusDocs\",\"description\":\"\"}]\n"
]
}
],
"source": [
- "# Insert a batch of data into the Milvus collection.\n",
+ "# Insert data into the Milvus collection.\n",
"\n",
"print(\"Start inserting entities\")\n",
"start_time = time.time()\n",
@@ -546,7 +577,6 @@
"mc.flush() \n",
"\n",
"# Inspect results.\n",
- "print(insert_result)\n",
"print(mc.partitions) # list[Partition] objects\n"
]
},
@@ -558,7 +588,6 @@
"## Run a Semantic Search\n",
"\n",
"Now we can search all the documentation embeddings to find the `TOP_K` documentation chunks with the closest embeddings to a user's query.\n",
- "- In this example, we'll ask about AUTOINDEX.\n",
"\n",
"💡 The same model should always be used for consistency for all the embeddings."
]
@@ -576,7 +605,7 @@
"\n",
"Next, you can ask a question about your custom data!\n",
"\n",
- "💡 In LLM lingo:\n",
+ "💡 In LLM vocabulary:\n",
"> **Query** is the generic term for user questions. \n",
"A query is a list of multiple individual questions, up to maybe 1000 different questions!\n",
"\n",
@@ -600,11 +629,11 @@
],
"source": [
"# Define a sample question about your data.\n",
- "question = \"what is the default distance metric used in AUTOINDEX?\"\n",
- "query = [question]\n",
+ "QUESTION = \"what is the default distance metric used in AUTOINDEX?\"\n",
+ "QUERY = [QUESTION]\n",
"\n",
"# Inspect the length of the query.\n",
- "QUERY_LENGTH = len(query[0])\n",
+ "QUERY_LENGTH = len(QUERY[0])\n",
"print(f\"query length: {QUERY_LENGTH}\")"
]
},
@@ -634,20 +663,20 @@
"output_type": "stream",
"text": [
"Loaded milvus collection into memory.\n",
- "Milvus search time: 0.06506514549255371 sec\n",
+ "Milvus search time: 0.0449519157409668 sec\n",
"type: , count: 5\n"
]
}
],
"source": [
- "# RETRIEVAL USING MILVUS.\n",
+ "# RETRIEVAL USING MILVUS API.\n",
"\n",
"# Before conducting a search based on a query, you need to load the data into memory.\n",
"mc.load()\n",
"print(\"Loaded milvus collection into memory.\")\n",
"\n",
- "# Embed the question using the same embedding model.\n",
- "embedded_question = torch.tensor(encoder.encode([question]))\n",
+ "# Embed the question using the same encoder.\n",
+ "embedded_question = torch.tensor(encoder.encode(QUERY))\n",
"# Normalize embeddings to unit length.\n",
"embedded_question = F.normalize(embedded_question, p=2, dim=1)\n",
"# Convert the embeddings to list of list of np.float32.\n",
@@ -663,9 +692,9 @@
" anns_field=\"vector\", \n",
" # No params for AUTOINDEX\n",
" param={},\n",
- " # Access dynamic fields in the boolean expression.\n",
+ " # Milvus can utilize metadata to enhance the search experience in boolean expressions.\n",
" # expr=\"\",\n",
- " output_fields=[\"h1\", \"h2\", \"chunk\", \"source\"], \n",
+ " output_fields=[\"h1\", \"h2\", \"text\", \"source\"], \n",
" limit=TOP_K,\n",
" consistency_level=\"Eventually\"\n",
" )\n",
@@ -695,23 +724,39 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "2267\n"
+ "0th query result\n",
+ "id: 445766022949312658, distance: 0.708217978477478, entity: {'text': \"index_file_size (int) – Segment size. See Storage Concepts. metric_type (MetricType) – Distance Metrics type. Valued form MetricType. See Distance Metrics. A demo is as follow: param={'collection_name': 'name', 'dimension': 16, 'index_file_size': 1024 # Optional, default 1024, 'metric_type': MetricType.L2 # Optional, default MetricType.L2 } timeout (float) – An optional duration of time in seconds to allow for the RPC. When timeout is set to None, client waits until server responses or error occurs.\", 'source': 'https://pymilvus.readthedocs.io/en/latest/api.html', 'h1': 'API reference', 'h2': 'Client'}\n",
+ "id: 445766022949312769, distance: 0.690363883972168, entity: {'text': 'A metric for binary vectors, only support :attr:`~milvus.IndexType.FLAT` index. #: See `Substructure `_. SUPERSTRUCTURE = 7 def __repr__(self): return \"\".format(self.__class__.__name__, self._name_) def __str__(self): return self._name_', 'source': 'https://pymilvus.readthedocs.io/en/latest/_modules/milvus/client/types.html', 'h1': 'Source code for milvus.client.types from enum impo', 'h2': ''}\n",
+ "id: 445766022949312619, distance: 0.688593864440918, entity: {'text': 'you can refer to Storage Concepts for more information about segments and index_file_size. metric_type:Milvus compute distance between two vectors, you can refer to Distance Metrics for more information. Now we can create a collection: >>> collection_name = \\'demo_film_tutorial\\' >>> collection_param = { ... \"collection_name\": collection_name, ... \"dimension\": 8, ... \"index_file_size\": 2048, ... \"metric_type\": MetricType.L2 ... } >>> client.create_collection(collection_param) Status(code=0, message=\\'Create', 'source': 'https://pymilvus.readthedocs.io/en/latest/tutorial.html', 'h1': 'Tutorial', 'h2': 'This is a basic introduction to Milvus by PyMilvus'}\n",
+ "id: 445766022949312621, distance: 0.6716917157173157, entity: {'text': \"metric_type=) The attributes of collection can be extracted from info. >>> info.collection_name 'demo_film_tutorial' >>> info.dimension 8 >>> info.index_file_size 2048 >>> info.metric_type This tutorial is a basic intro tutorial, building index won’t be covered by this tutorial. If you want to go further into Milvus with indexes, it’s recommended to check our index examples. If you’re already known about indexes from index examples, and you want a full lists of params supported by PyMilvus, you check out\", 'source': 'https://pymilvus.readthedocs.io/en/latest/tutorial.html', 'h1': 'Tutorial', 'h2': 'This is a basic introduction to Milvus by PyMilvus'}\n",
+ "id: 445766022949312711, distance: 0.6690606474876404, entity: {'text': \"-- Segment size. See `Storage Concepts `_. * *metric_type* (``MetricType``) -- Distance Metrics type. Valued form :class:`~milvus.MetricType`. See `Distance Metrics `_. A demo is as follow: .. code-block:: python param={'collection_name': 'name', 'dimension': 16, 'index_file_size': 1024 # Optional, default 1024, 'metric_type': MetricType.L2 # Optional, default MetricType.L2 } :param timeout: An optional duration of time in seconds to allow for the RPC. When timeout is set to None, client waits until\", 'source': 'https://pymilvus.readthedocs.io/en/latest/_modules/milvus/client/stub.html', 'h1': 'Source code for milvus.client.stub # -*- coding: U', 'h2': ''}\n",
+ "505\n"
]
}
],
"source": [
- "# # TODO - remove this before saving in github.\n",
- "# for n, hits in enumerate(results):\n",
- "# print(f\"{n}th query result\")\n",
- "# for hit in hits:\n",
- "# print(hit)\n",
+ "# TODO - remove this before saving in github.\n",
+ "for n, hits in enumerate(results):\n",
+ " print(f\"{n}th query result\")\n",
+ " for hit in hits:\n",
+ " print(hit)\n",
"\n",
"# Assemble the context as a stuffed string.\n",
"context = \"\"\n",
+ "i = 0\n",
"for r in results[0]:\n",
- " text = r.entity.chunk\n",
- " context += f\"{text} \"\n",
- "print(len(context))"
+ " text = r.entity.text\n",
+ " if i == 0: # only first result\n",
+ " context += f\"{text} \"\n",
+ " i += 1\n",
+ "print(len(context))\n",
+ "\n",
+ "# Also save the context metadata to retrieve along with the answer.\n",
+ "context_metadata = {\n",
+ " \"h1\": results[0][0].entity.h1,\n",
+ " \"h2\": results[0][0].entity.h2,\n",
+ " \"source\": results[0][0].entity.source,\n",
+ "}"
]
},
{
@@ -721,12 +766,33 @@
"source": [
"## Use an LLM to Generate a chat response to the user's question using the Retrieved Context.\n",
"\n",
- "Below, we're using an open, very tiny generative AI model, or LLM. Many demos use OpenAI as the LLM choice instead."
+ "Below, we'll use an open, very tiny generative AI model, or LLM, available on HuggingFace. Many demos use OpenAI as the LLM choice instead."
]
},
{
"cell_type": "code",
"execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def assemble_grounding_sources(answer, context_metadata):\n",
+ " \"\"\"Assemble the answer and grounding sources into a string\"\"\"\n",
+ " grounded_answer = f\"Answer: {answer}\\n\"\n",
+ " grounded_answer += \"Grounding sources and citations:\\n\"\n",
+ " try:\n",
+ " grounded_answer += f\"'h1': {context_metadata['h1']}, 'h2':{context_metadata['h2']}\\n\"\n",
+ " except:\n",
+ " pass\n",
+ " try:\n",
+ " grounded_answer += f\"'source': {context_metadata['source']}\"\n",
+ " except:\n",
+ " pass\n",
+ " return grounded_answer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
"id": "3e7fa0b6",
"metadata": {},
"outputs": [
@@ -745,21 +811,22 @@
"from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline\n",
"\n",
"# Load the Hugging Face auto-regressive LLM checkpoint.\n",
- "llm = \"deepset/tinyroberta-squad2\"\n",
- "tokenizer = AutoTokenizer.from_pretrained(llm)\n",
+ "tiny_llm = \"deepset/tinyroberta-squad2\"\n",
+ "tokenizer = AutoTokenizer.from_pretrained(tiny_llm)\n",
"\n",
"# context cannot be empty so just put random text in it.\n",
"QA_input = {\n",
- " 'question': question,\n",
+ " 'question': QUESTION,\n",
" 'context': 'The quick brown fox jumped over the lazy dog'\n",
"}\n",
"\n",
"nlp = pipeline('question-answering', \n",
- " model=llm, \n",
+ " model=tiny_llm, \n",
" tokenizer=tokenizer)\n",
- "\n",
"result = nlp(QA_input)\n",
- "print(f\"Question: {question}\")\n",
+ "\n",
+ "# Print the question and answer.\n",
+ "print(f\"Question: {QUESTION}\")\n",
"print(f\"Answer: {result['answer']}\")\n",
"\n",
"# The baseline LLM chat is not very helpful."
@@ -767,7 +834,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 17,
"id": "a68e87b1",
"metadata": {},
"outputs": [
@@ -776,31 +843,135 @@
"output_type": "stream",
"text": [
"Question: what is the default distance metric used in AUTOINDEX?\n",
- "Answer: MetricType.L2\n"
+ "Answer: MetricType.L2\n",
+ "Grounding sources and citations:\n",
+ "'h1': API reference, 'h2':Client\n",
+ "'source': https://pymilvus.readthedocs.io/en/latest/api.html\n"
]
}
],
"source": [
"# NOW ASK THE SAME LLM THE SAME QUESTION USING THE RETRIEVED CONTEXT.\n",
"QA_input = {\n",
- " 'question': question,\n",
+ " 'question': QUESTION,\n",
" 'context': context,\n",
"}\n",
"\n",
"nlp = pipeline('question-answering', \n",
- " model=llm, \n",
+ " model=tiny_llm, \n",
" tokenizer=tokenizer)\n",
- "\n",
"result = nlp(QA_input)\n",
- "print(f\"Question: {question}\")\n",
- "print(f\"Answer: {result['answer']}\")\n",
+ "\n",
+ "# Print the question and answer along with grounding sources and citations.\n",
+ "answer = assemble_grounding_sources(result['answer'], context_metadata)\n",
+ "print(f\"Question: {QUESTION}\")\n",
+ "print(answer)\n",
"\n",
"# That answer looks a little better!"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Use OpenAI to generate a more human-like chat response to the user's question \n",
+ "\n",
+ "We've practiced retrieval for free on our own data using open-source LLMs.
\n",
+ "\n",
+ "Now let's make a call to the paid OpenAI GPT.\n",
+ "\n",
+ "💡 Note: We’re using a temperature of 0.0 to enable reproducible experiments. For use cases that need to always be factually grounded, use very low temperature values while more creative tasks can benefit from higher temperatures."
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import openai\n",
+ "from dotenv import load_dotenv, find_dotenv\n",
+ "\n",
+ "# See how to save api key in env variable.\n",
+ "# https://help.openai.com/en/articles/5112595-best-practices-for-api-key-safety\n",
+ "_ = load_dotenv(find_dotenv())\n",
+ "openai.api_key = os.environ['OPENAI_API_KEY']\n",
+ "\n",
+ "# Define the generation llm model to use.\n",
+ "LLM_NAME = \"gpt-3.5-turbo-1106\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def prepare_response(response):\n",
+ " return response[\"choices\"][-1][\"message\"][\"content\"]\n",
+ "\n",
+ "def generate_response(\n",
+ " llm, temperature=0.0, \n",
+ " grounding_sources=None,\n",
+ " system_content=\"\", assistant_content=\"\", user_content=\"\"):\n",
+ " \"\"\"Generate response from an LLM.\"\"\"\n",
+ "\n",
+ " try:\n",
+ " response = openai.ChatCompletion.create(\n",
+ " model=llm,\n",
+ " temperature=temperature,\n",
+ " api_key=openai.api_key,\n",
+ " messages=[\n",
+ " {\"role\": \"system\", \"content\": system_content},\n",
+ " {\"role\": \"assistant\", \"content\": assistant_content},\n",
+ " {\"role\": \"user\", \"content\": user_content},\n",
+ " ],\n",
+ " )\n",
+ " answer = prepare_response(response=response)\n",
+ " \n",
+ " # Add the grounding sources and citations.\n",
+ " answer = assemble_grounding_sources(answer, grounding_sources)\n",
+ " return answer\n",
+ "\n",
+ " except Exception as e:\n",
+ " print(f\"Exception: {e}\")\n",
+ " return \"\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Question: what is the default distance metric used in AUTOINDEX?\n",
+ "Answer: The default distance metric used in AUTOINDEX is L2.\n",
+ "Grounding sources and citations:\n",
+ "'h1': API reference, 'h2':Client\n",
+ "'source': https://pymilvus.readthedocs.io/en/latest/api.html\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Generate response\n",
+ "response = generate_response(\n",
+ " llm=LLM_NAME,\n",
+ " temperature=0.0,\n",
+ " grounding_sources=context_metadata,\n",
+ " system_content=\"Answer the question using the context provided. Be succinct.\",\n",
+ " user_content=f\"question: {QUESTION}, context: {context}\")\n",
+ "\n",
+ "# Print the question and answer along with grounding sources and citations.\n",
+ "print(f\"Question: {QUESTION}\")\n",
+ "print(response)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
"id": "d0e81e68",
"metadata": {},
"outputs": [],
@@ -811,7 +982,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 22,
"id": "c777937e",
"metadata": {},
"outputs": [
@@ -819,8 +990,6 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "The watermark extension is already loaded. To reload it, use:\n",
- " %reload_ext watermark\n",
"Author: Christy Bergman\n",
"\n",
"Python implementation: CPython\n",
@@ -832,6 +1001,7 @@
"sentence_transformers: 2.2.2\n",
"pymilvus : 2.3.3\n",
"langchain : 0.0.322\n",
+ "openai : 0.28.0\n",
"\n",
"conda environment: py310\n",
"\n"
@@ -843,7 +1013,7 @@
"# !pip install watermark\n",
"\n",
"%load_ext watermark\n",
- "%watermark -a 'Christy Bergman' -v -p torch,transformers,sentence_transformers,pymilvus,langchain --conda"
+ "%watermark -a 'Christy Bergman' -v -p torch,transformers,sentence_transformers,pymilvus,langchain,openai --conda"
]
}
],
diff --git a/notebooks/text/imdb_search_milvus_client.ipynb b/notebooks/text/imdb_search_milvus_client.ipynb
index 1e371d108..b9a0df3dd 100755
--- a/notebooks/text/imdb_search_milvus_client.ipynb
+++ b/notebooks/text/imdb_search_milvus_client.ipynb
@@ -55,7 +55,7 @@
"💡 **For production purposes**, use a local Milvus docker, Milvus clusters, or fully-managed Milvus on Zilliz Cloud.\n",
"- [Local Milvus docker](https://milvus.io/docs/install_standalone-docker.md) requires local docker installed and running.\n",
"- [Milvus clusters](https://milvus.io/docs/install_cluster-milvusoperator.md) requires a K8s cluster up and running.\n",
- "- [Ziliz Cloud free trial](https://cloud.zilliz.com/login) choose a \"free\" option when you provision.\n"
+ "- [Ziliz Cloud free trial](https://cloud.zilliz.com/login) choose a \"Default\" option when you provision.\n"
]
},
{
@@ -76,8 +76,7 @@
"source": [
"from milvus import default_server\n",
"from pymilvus import (\n",
- " connections, utility, \n",
- " MilvusClient,\n",
+ " connections, utility\n",
")\n",
"\n",
"# Cleanup previous data and stop server in case it is still running.\n",
@@ -945,7 +944,7 @@
"start_time = time.time()\n",
"insert_result = mc.insert(\n",
" COLLECTION_NAME,\n",
- " data=dict_list, \n",
+ " data=dict_list,\n",
" progress_bar=True)\n",
"end_time = time.time()\n",
"print(f\"Milvus insert time for {batch.shape[0]} vectors: {end_time - start_time} seconds\")\n",