Skip to content

Commit

Permalink
visualization file formatted
Browse files Browse the repository at this point in the history
  • Loading branch information
ericljx2020-gmail committed Oct 18, 2024
1 parent 6bc094a commit 5f9b743
Showing 1 changed file with 31 additions and 20 deletions.
51 changes: 31 additions & 20 deletions bootcamp/tutorials/quickstart/rag_visualization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
" with open(file_path, \"r\") as file:\n",
" file_text = file.read()\n",
"\n",
" text_lines += file_text.split(\"# \")\n"
" text_lines += file_text.split(\"# \")"
]
},
{
Expand All @@ -171,7 +171,7 @@
"source": [
"from openai import OpenAI\n",
"\n",
"openai_client = OpenAI()\n"
"openai_client = OpenAI()"
]
},
{
Expand Down Expand Up @@ -265,7 +265,7 @@
"outputs": [],
"source": [
"if milvus_client.has_collection(collection_name):\n",
" milvus_client.drop_collection(collection_name)\n"
" milvus_client.drop_collection(collection_name)"
]
},
{
Expand All @@ -288,7 +288,7 @@
" dimension=embedding_dim,\n",
" metric_type=\"IP\", # Inner product distance\n",
" consistency_level=\"Strong\", # Strong consistency level\n",
")\n"
")"
]
},
{
Expand Down Expand Up @@ -333,7 +333,7 @@
"for i, line in enumerate(tqdm(text_lines, desc=\"Creating embeddings\")):\n",
" data.append({\"id\": i, \"vector\": emb_text(line), \"text\": line})\n",
"\n",
"milvus_client.insert(collection_name=collection_name, data=data)\n"
"milvus_client.insert(collection_name=collection_name, data=data)"
]
},
{
Expand Down Expand Up @@ -382,7 +382,7 @@
" limit=10, # Return top 3 results\n",
" search_params={\"metric_type\": \"IP\", \"params\": {}}, # Inner product distance\n",
" output_fields=[\"text\"], # Return the text field\n",
")\n"
")"
]
},
{
Expand Down Expand Up @@ -452,7 +452,7 @@
"retrieved_lines_with_distances = [\n",
" (res[\"entity\"][\"text\"], res[\"distance\"]) for res in search_res[0]\n",
"]\n",
"print(json.dumps(retrieved_lines_with_distances, indent=4))\n"
"print(json.dumps(retrieved_lines_with_distances, indent=4))"
]
},
{
Expand Down Expand Up @@ -582,16 +582,16 @@
"import numpy as np\n",
"from sklearn.manifold import TSNE\n",
"\n",
"data.append({'id': len(data), 'vector': emb_text(question), 'text': question})\n",
"data.append({\"id\": len(data), \"vector\": emb_text(question), \"text\": question})\n",
"embeddings = []\n",
"for gp in data:\n",
" embeddings.append(gp['vector'])\n",
" embeddings.append(gp[\"vector\"])\n",
"\n",
"X = np.array(embeddings, dtype=np.float32)\n",
"tsne = TSNE(random_state=0, max_iter=1000)\n",
"tsne_results = tsne.fit_transform(X)\n",
"\n",
"df_tsne = pd.DataFrame(tsne_results, columns=['TSNE1', 'TSNE2'])\n",
"df_tsne = pd.DataFrame(tsne_results, columns=[\"TSNE1\", \"TSNE2\"])\n",
"df_tsne"
]
},
Expand Down Expand Up @@ -623,7 +623,7 @@
"import seaborn as sns\n",
"\n",
"# Extract similar ids from search results\n",
"similar_ids = [gp['id'] for gp in search_res[0]]\n",
"similar_ids = [gp[\"id\"] for gp in search_res[0]]\n",
"\n",
"df_norm = df_tsne[:-1]\n",
"\n",
Expand All @@ -636,23 +636,34 @@
"fig, ax = plt.subplots(figsize=(8, 6)) # Set figsize\n",
"\n",
"# Set the style of the plot\n",
"sns.set_style('darkgrid', {\"grid.color\": \".6\", \"grid.linestyle\": \":\"})\n",
"sns.set_style(\"darkgrid\", {\"grid.color\": \".6\", \"grid.linestyle\": \":\"})\n",
"\n",
"# Plot all points in blue\n",
"sns.scatterplot(data=df_tsne, x='TSNE1', y='TSNE2', color='blue', label='All knowledge', ax=ax)\n",
"sns.scatterplot(\n",
" data=df_tsne, x=\"TSNE1\", y=\"TSNE2\", color=\"blue\", label=\"All knowledge\", ax=ax\n",
")\n",
"\n",
"# Overlay similar points in red\n",
"sns.scatterplot(data=similar_points, x='TSNE1', y='TSNE2', color='red', label='Similar knowledge', ax=ax)\n",
"\n",
"sns.scatterplot(data=df_query, x='TSNE1', y='TSNE2', color='green', label='Query', ax=ax)\n",
"sns.scatterplot(\n",
" data=similar_points,\n",
" x=\"TSNE1\",\n",
" y=\"TSNE2\",\n",
" color=\"red\",\n",
" label=\"Similar knowledge\",\n",
" ax=ax,\n",
")\n",
"\n",
"sns.scatterplot(\n",
" data=df_query, x=\"TSNE1\", y=\"TSNE2\", color=\"green\", label=\"Query\", ax=ax\n",
")\n",
"\n",
"# Set plot titles and labels\n",
"plt.title('Scatter plot of knowledge using t-SNE')\n",
"plt.xlabel('TSNE1')\n",
"plt.ylabel('TSNE2')\n",
"plt.title(\"Scatter plot of knowledge using t-SNE\")\n",
"plt.xlabel(\"TSNE1\")\n",
"plt.ylabel(\"TSNE2\")\n",
"\n",
"# Set axis to be equal\n",
"plt.axis('equal')\n",
"plt.axis(\"equal\")\n",
"\n",
"# Display the legend\n",
"plt.legend()\n",
Expand Down

0 comments on commit 5f9b743

Please sign in to comment.