From 12770b551bd0cb2f4261dce32d71164aed1a58f3 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 21 Apr 2025 17:49:56 -0700 Subject: [PATCH 01/21] add function to get default search model config async --- src/khoj/database/adapters/__init__.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 248a78e88..25ba85d1d 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -574,6 +574,16 @@ def get_default_search_model() -> SearchModelConfig: return SearchModelConfig.objects.first() +async def aget_default_search_model() -> SearchModelConfig: + default_search_model = await SearchModelConfig.objects.filter(name="default").afirst() + + if default_search_model: + return default_search_model + elif await SearchModelConfig.objects.count() == 0: + await SearchModelConfig.objects.acreate() + return await SearchModelConfig.objects.afirst() + + def get_or_create_search_models(): search_models = SearchModelConfig.objects.all() if search_models.count() == 0: From be42ac1f2f4c42703c52d52ac7f5f9ab5fa0d428 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 21 Apr 2025 17:51:31 -0700 Subject: [PATCH 02/21] add a basic user memory object for storage the user memory is comprised of the user to whom it belongs, the generated embeddings, the raw text field, and the search model used to create the embeddings remove irrelevant datastore object --- src/khoj/database/models/__init__.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index bd49aa8cd..90940d421 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -772,8 +772,12 @@ def __str__(self): return f"{self.slug} - {self.identifier} at {self.created_at}" -class DataStore(DbBaseModel): - key = models.CharField(max_length=200, unique=True) - value = models.JSONField(default=dict) - private = models.BooleanField(default=False) - owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True) +class UserMemory(DbBaseModel): + """ + A class to represent a memory storage model for longer term memories + """ + + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) + embeddings = VectorField(dimensions=None) + raw = models.TextField() + search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True) From dec90b4bcb652b0f9efe6862e28c4a258cc3ac3e Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 21 Apr 2025 17:51:50 -0700 Subject: [PATCH 03/21] add associated migrations for new memory obj --- .../0090_usermemory_delete_datastore.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 src/khoj/database/migrations/0090_usermemory_delete_datastore.py diff --git a/src/khoj/database/migrations/0090_usermemory_delete_datastore.py b/src/khoj/database/migrations/0090_usermemory_delete_datastore.py new file mode 100644 index 000000000..eb93ed1c6 --- /dev/null +++ b/src/khoj/database/migrations/0090_usermemory_delete_datastore.py @@ -0,0 +1,42 @@ +# Generated by Django 5.1.8 on 2025-04-21 23:44 + +import django.db.models.deletion +import pgvector.django +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0089_chatmodel_price_tier_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="UserMemory", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("embeddings", pgvector.django.VectorField()), + ("raw", models.TextField()), + ( + "search_model", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to="database.searchmodelconfig", + ), + ), + ("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + "abstract": False, + }, + ), + migrations.DeleteModel( + name="DataStore", + ), + ] From 0b1851197516cdc057948e791e0b2fd127700876 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 21 Apr 2025 17:52:40 -0700 Subject: [PATCH 04/21] add user memory to the admin page and adapters adapters provide various methods for creating, updating, deleting the user memory objects --- src/khoj/database/adapters/__init__.py | 79 ++++++++++++++++++++++++++ src/khoj/database/admin.py | 2 + 2 files changed, 81 insertions(+) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 25ba85d1d..cdfe89775 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -59,6 +59,7 @@ Subscription, TextToImageModelConfig, UserConversationConfig, + UserMemory, UserRequests, UserTextToImageModelConfig, UserVoiceModelConfig, @@ -1999,3 +2000,81 @@ def delete_automation(user: KhojUser, automation_id: str): automation.remove() return automation_metadata + + +class UserMemoryAdapters: + @staticmethod + @require_valid_user + async def pull_memories(user: KhojUser, window=10, limit=5) -> list[UserMemory]: + """ + Pulls memories from the database for a given user. Medium term memory. + """ + time_frame = datetime.now(timezone.utc) - timedelta(days=window) + memories = UserMemory.objects.filter(user=user, updated_at__gte=time_frame).order_by("-created_at")[:limit] + return memories + + @staticmethod + @require_valid_user + async def save_memory(user: KhojUser, memory: str) -> UserMemory: + """ + Saves a memory to the database for a given user. + """ + embeddings_model = state.embeddings_model + model = await aget_default_search_model() + + embeddings = await sync_to_async(embeddings_model[model.name].embed_query)(memory) + memory_instance = await UserMemory.objects.acreate( + user=user, embeddings=embeddings, raw=memory, search_model=model + ) + + return memory_instance + + @staticmethod + @require_valid_user + async def search_memories(user: KhojUser, query: str) -> list[UserMemory]: + """ + Searches for memories in the database for a given user. Long term memory. + """ + embeddings_model = state.embeddings_model + model = await aget_default_search_model() + + max_distance = model.bi_encoder_confidence_threshold or math.inf + + embedded_query = await sync_to_async(embeddings_model[model.name].embed_query)(query) + + relevant_memories = ( + UserMemory.objects.filter(user=user) + .annotate(distance=CosineDistance("embeddings", embedded_query)) + .order_by("distance") + ) + + relevant_memories = relevant_memories.filter(distance__lte=max_distance) + + return relevant_memories[:10] + + @staticmethod + @require_valid_user + async def delete_memory(user: KhojUser, memory_id: str) -> bool: + """ + Deletes a memory from the database for a given user. + """ + try: + memory = await UserMemory.objects.aget(user=user, id=memory_id) + await memory.adelete() + return True + except UserMemory.DoesNotExist: + return False + + @staticmethod + def convert_memories_to_dict(memories: List[UserMemory]) -> List[dict]: + """ + Converts a list of Memory objects to a list of dictionaries. + """ + return [ + { + "id": memory.id, + "raw": memory.raw, + "updated_at": memory.updated_at, + } + for memory in memories + ] diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 7297ce118..e66ab1144 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -33,6 +33,7 @@ Subscription, TextToImageModelConfig, UserConversationConfig, + UserMemory, UserRequests, UserVoiceModelConfig, VoiceModelOption, @@ -181,6 +182,7 @@ def get_email_login_url(self, request, queryset): admin.site.register(VoiceModelOption, unfold_admin.ModelAdmin) admin.site.register(UserRequests, unfold_admin.ModelAdmin) admin.site.register(RateLimitRecord, unfold_admin.ModelAdmin) +admin.site.register(UserMemory, unfold_admin.ModelAdmin) @admin.register(Agent) From 600052219828c80367f0112c5a5801555f7602e4 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 21 Apr 2025 17:54:35 -0700 Subject: [PATCH 05/21] connect associated logic for getting and updating memories add logic to automatically create new memories and delete existing memories when saving conversation to log. the logic prevents the agent from editing existing memories, as a create + delete can suffice for this case and reduces the complexity. --- src/khoj/processor/conversation/prompts.py | 62 +++++++++++++++++++ src/khoj/processor/conversation/utils.py | 16 ++++- src/khoj/routers/api_chat.py | 13 +++- src/khoj/routers/helpers.py | 72 ++++++++++++++++++++++ 4 files changed, 159 insertions(+), 4 deletions(-) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index b0cec27b3..f6e79e915 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -1385,3 +1385,65 @@ User's Name: {name} """.strip() ) + +extract_facts_from_query = PromptTemplate.from_template( + """ +Given a query, extract the facts *related to the user* from the query. This is in order to construct a robust memory of who the user is, their interests, their life circumstances, events in their life, their personal motivations. + +You will be provided a subset of the existing facts that are already stored for the user, and potentially relevant to the query. You have two possible actions: +1. Create new facts +2. Delete existing facts + +You may use the existing facts to enhance the new facts that you're creating. You may also choose to delete existing facts that are no longer relevant. You cannot update existing facts; you can only create new facts or delete existing ones. + +To create a new fact, add it to the create array. Do not create an ID. If you have nothing to create, leave the create array empty. Use first person perspective when creating new facts. + +To delete a fact, specify the fact's ID in the delete array. If you have nothing to delete, leave the delete array empty. You must delete anything that is no longer relevant or true about the user. + +# Example +Existing Facts: +{{ + "facts": [ + {{ + "id": "abc", + "raw": "I am not interested in sports", + "updated_at": "2023-10-01T12:00:00Z" + }}, + {{ + "id": "def", + "raw": "I am a software engineer" + "updated_at": "2023-10-31T14:00:00Z" + }}, + {{ + "id": "ghi", + "raw": "My mother works at the hospital", + "updated_at": "2023-10-02T17:00:00Z" + }} + ] +}} + +Input Query: I had an amazing day today! I was replicating this core AI paper, but ran into some issues with the training pipeline. In between coding, I took my cat Whiskers out for a walk and played a game of football. My mom called me in between her shift at the hospital (she's a doctor), so we had a nice chat. + +Response: +{{ + "create": [ + "I am interested in AI and machine learning", + "I have a pet cat named Whiskers", + "I enjoy playing football", + "My mother works at the hospital and is a doctor" + ], + "delete": [ + "abc", + "ghi" + ], +}} + +# Input +These are some potentially related facts: +{matched_facts} + +Conversation History: +{chat_history} + +""".strip() +) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b8eea9075..bf3dd859b 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -4,14 +4,12 @@ import math import mimetypes import os -import queue import re import uuid from dataclasses import dataclass from datetime import datetime from enum import Enum from io import BytesIO -from time import perf_counter from typing import Any, Callable, Dict, List, Optional import PIL.Image @@ -24,7 +22,7 @@ from transformers import AutoTokenizer from khoj.database.adapters import ConversationAdapters -from khoj.database.models import ChatModel, ClientApplication, KhojUser +from khoj.database.models import ChatModel, ClientApplication, KhojUser, UserMemory from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.search_filter.base_filter import BaseFilter @@ -232,6 +230,7 @@ async def save_to_conversation_log( client_application: ClientApplication = None, conversation_id: str = None, automation_id: str = None, + matching_memories: List[UserMemory] = [], query_images: List[str] = None, raw_query_files: List[FileAttachment] = [], generated_images: List[str] = [], @@ -240,6 +239,8 @@ async def save_to_conversation_log( train_of_thought: List[Any] = [], tracer: Dict[str, Any] = {}, ): + from khoj.routers.helpers import ai_update_memories + user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") turn_id = tracer.get("mid") or str(uuid.uuid4()) @@ -278,6 +279,15 @@ async def save_to_conversation_log( user_message=q, ) + if not automation_id: + # Don't update memories from automations, as this could get noisy. + await ai_update_memories( + user=user, + conversation_history={"chat": updated_conversation}, + memories=matching_memories, + tracer=tracer, + ) + if is_promptrace_enabled(): merge_message_into_conversation_trace(q, chat_response, tracer) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index dd9512380..ddd0c9655 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -20,6 +20,7 @@ ConversationAdapters, EntryAdapters, PublicConversationAdapters, + UserMemoryAdapters, aget_user_name, ) from khoj.database.models import Agent, KhojUser @@ -89,7 +90,6 @@ trial_rate_limit=20, subscribed_rate_limit=75, slug="command" ) - api_chat = APIRouter() @@ -833,6 +833,16 @@ def collect_telemetry(): user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") meta_log = conversation.conversation_log + # Get most recent memories and long term relevant memories + recent_memories = await UserMemoryAdapters.pull_memories(user) + recent_memories = await sync_to_async(list)(recent_memories) + + long_term_memories = await UserMemoryAdapters.search_memories(user=user, query=q) + long_term_memories = await sync_to_async(list)(long_term_memories) + + # Create a de-duped set of memories + relevant_memories = set(recent_memories + long_term_memories) + researched_results = "" online_results: Dict = dict() code_results: Dict = dict() @@ -1286,6 +1296,7 @@ def collect_telemetry(): user, request.user.client_app, conversation_id, + relevant_memories, location, user_name, researched_results, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index a0baffb9c..b4cdfaaee 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -43,6 +43,7 @@ ConversationAdapters, EntryAdapters, FileObjectAdapters, + UserMemoryAdapters, aget_user_by_email, ais_user_subscribed, create_khoj_token, @@ -50,6 +51,7 @@ get_user_name, get_user_notion_config, get_user_subscription_state, + require_valid_user, run_with_process_lock, ) from khoj.database.models import ( @@ -64,6 +66,7 @@ RateLimitRecord, Subscription, TextToImageModelConfig, + UserMemory, UserRequests, ) from khoj.processor.content.docx.docx_to_entries import DocxToEntries @@ -928,6 +931,73 @@ async def generate_excalidraw_diagram_from_description( return response +class ExtractedFacts(BaseModel): + create: List[str] = Field(..., min_items=0) + delete: List[str] = Field(..., min_items=0) + + +async def extract_facts_from_query( + user: KhojUser, + conversation_history: dict, + existing_facts: List[UserMemory] = None, + tracer: dict = {}, +) -> ExtractedFacts: + """ + Extract facts from the given query + """ + chat_history = construct_chat_history(conversation_history, n=2) + + formatted_memories = UserMemoryAdapters.convert_memories_to_dict(existing_facts) if existing_facts else [] + + extract_facts_prompt = prompts.extract_facts_from_query.format( + chat_history=chat_history, + matched_facts=formatted_memories, + ) + + with timer("Chat actor: Extract facts from query", logger): + response = await send_message_to_model_wrapper(extract_facts_prompt, user=user, tracer=tracer) + response = response.strip() + # JSON parse the list of strings + try: + response = clean_json(response) + response = json.loads(response) + parsed_response = ExtractedFacts(**response) + if not isinstance(parsed_response, ExtractedFacts): + raise ValueError(f"Invalid response for extracting facts: {response}") + return parsed_response + + except Exception: + logger.error(f"Invalid response for extracting facts: {response}") + return ExtractedFacts(create=[], delete=[]) + + +@require_valid_user +async def ai_update_memories( + user: KhojUser, conversation_history: dict, memories: List[UserMemory], tracer: dict = {} +) -> List[UserMemory]: + """ + Updates the memories for a given user, based on their latest input query. + """ + new_data = await extract_facts_from_query( + user=user, conversation_history=conversation_history, existing_facts=memories, tracer=tracer + ) + + if not new_data: + return [] + + # Save the new data to the database + created_memories = new_data.create + deleted_memories = new_data.delete + + for m in created_memories: + logger.info(f"Creating memory: {m}") + await UserMemoryAdapters.save_memory(user, m) + + for m in deleted_memories: + logger.info(f"Deleting memory: {m}") + await UserMemoryAdapters.delete_memory(user, m) + + async def generate_mermaidjs_diagram( q: str, conversation_history: Dict[str, Any], @@ -1418,6 +1488,7 @@ async def agenerate_chat_response( user: KhojUser = None, client_application: ClientApplication = None, conversation_id: str = None, + matching_memories: List[UserMemory] = [], location_data: LocationData = None, user_name: Optional[str] = None, meta_research: str = "", @@ -1452,6 +1523,7 @@ async def agenerate_chat_response( inferred_queries=inferred_queries, client_application=client_application, conversation_id=conversation_id, + matching_memories=matching_memories, query_images=query_images, train_of_thought=train_of_thought, raw_query_files=raw_query_files, From f85b74bf58e66d13e3876a19ddac8bf6c457845c Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 21 Apr 2025 18:51:19 -0700 Subject: [PATCH 06/21] add API endpoints for managing memories add backend APIs exposing basic CRUD capabilities on memories --- src/khoj/configure.py | 2 + src/khoj/routers/api_memories.py | 114 +++++++++++++++++++++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 src/khoj/routers/api_memories.py diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 40d61a888..0c275a851 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -314,6 +314,7 @@ def configure_routes(app): from khoj.routers.api_agents import api_agents from khoj.routers.api_chat import api_chat from khoj.routers.api_content import api_content + from khoj.routers.api_memories import api_memories from khoj.routers.api_model import api_model from khoj.routers.notion import notion_router from khoj.routers.web_client import web_client @@ -322,6 +323,7 @@ def configure_routes(app): app.include_router(api_chat, prefix="/api/chat") app.include_router(api_agents, prefix="/api/agents") app.include_router(api_model, prefix="/api/model") + app.include_router(api_memories, prefix="/api/memories") app.include_router(api_content, prefix="/api/content") app.include_router(notion_router, prefix="/api/notion") app.include_router(web_client) diff --git a/src/khoj/routers/api_memories.py b/src/khoj/routers/api_memories.py new file mode 100644 index 000000000..d75278279 --- /dev/null +++ b/src/khoj/routers/api_memories.py @@ -0,0 +1,114 @@ +import json +import logging +from typing import Optional + +from asgiref.sync import sync_to_async +from fastapi import APIRouter, Request +from fastapi.responses import Response +from pydantic import BaseModel +from starlette.authentication import requires + +from khoj.database.adapters import UserMemoryAdapters +from khoj.database.models import UserMemory + +api_memories = APIRouter() +logger = logging.getLogger(__name__) + + +@api_memories.get("") +@requires(["authenticated"]) +async def get_memories( + request: Request, + client: Optional[str] = None, +): + """Get all memories for the authenticated user""" + user = request.user.object + + memories = UserMemory.objects.filter(user=user) + all_memories = await sync_to_async(list)(memories) + + # Convert memories to a list of dictionaries + formatted_memories = [ + { + "id": memory.id, + "raw": memory.raw, + "created_at": memory.created_at.isoformat(), + } + for memory in all_memories + ] + + return Response(content=json.dumps(formatted_memories), media_type="application/json", status_code=200) + + +@api_memories.delete("/{memory_id}") +@requires(["authenticated"]) +async def delete_memory( + request: Request, + memory_id: int, + client: Optional[str] = None, +): + """Delete a specific memory by ID""" + user = request.user.object + + # Verify memory belongs to user before deleting + memory = await UserMemory.objects.filter(id=memory_id, user=user).afirst() + if not memory: + return Response( + content=json.dumps({"error": "Memory not found"}), media_type="application/json", status_code=404 + ) + + await memory.adelete() + + return Response(status_code=204) + + +class UpdateMemoryBody(BaseModel): + """Request model for updating a memory""" + + raw: str + + +@api_memories.put("/{memory_id}") +@requires(["authenticated"]) +async def update_memory( + request: Request, + body: UpdateMemoryBody, + memory_id: int, + client: Optional[str] = None, +): + """Update a specific memory's content""" + user = request.user.object + + # Get the memory and verify it belongs to the user + memory = await UserMemory.objects.filter(id=memory_id, user=user).afirst() + if not memory: + return Response( + content=json.dumps({"error": "Memory not found"}), media_type="application/json", status_code=404 + ) + + new_content = body.raw + if not new_content: + return Response( + content=json.dumps({"error": "Missing required field 'raw'"}), + media_type="application/json", + status_code=400, + ) + + await memory.adelete() + + # Create a new memory with the updated content + new_memory = await UserMemoryAdapters.save_memory( + user=user, + memory=new_content, + ) + + return Response( + content=json.dumps( + { + "id": new_memory.id, + "raw": new_memory.raw, + } + ), + media_type="application/json", + status_code=200, + ) From ff79f267da490c6bd6058389f031e3550b34d05c Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 21 Apr 2025 18:52:00 -0700 Subject: [PATCH 07/21] add support in the settings page for browsing, viewing memories add a section to the settings page that allows you to browse, view, edit existing memories the backend has about you --- .../app/components/userMemory/userMemory.tsx | 90 +++++++++++++ src/interface/web/app/settings/page.tsx | 122 +++++++++++++++++- 2 files changed, 209 insertions(+), 3 deletions(-) create mode 100644 src/interface/web/app/components/userMemory/userMemory.tsx diff --git a/src/interface/web/app/components/userMemory/userMemory.tsx b/src/interface/web/app/components/userMemory/userMemory.tsx new file mode 100644 index 000000000..db360e7bc --- /dev/null +++ b/src/interface/web/app/components/userMemory/userMemory.tsx @@ -0,0 +1,90 @@ +import { useState } from "react"; +import { Input } from "@/components/ui/input"; +import { Button } from "@/components/ui/button"; +import { Pencil, TrashSimple, FloppyDisk, X } from "@phosphor-icons/react"; +import { useToast } from "@/components/ui/use-toast"; + +export interface UserMemorySchema { + id: number; + raw: string; + created_at: string; +} + +interface UserMemoryProps { + memory: UserMemorySchema; + onDelete: (id: number) => void; + onUpdate: (id: number, raw: string) => void; +} + +export function UserMemory({ memory, onDelete, onUpdate }: UserMemoryProps) { + const [isEditing, setIsEditing] = useState(false); + const [content, setContent] = useState(memory.raw); + const { toast } = useToast(); + + const handleUpdate = () => { + onUpdate(memory.id, content); + setIsEditing(false); + toast({ + title: "Memory Updated", + description: "Your memory has been successfully updated.", + }); + }; + + const handleDelete = () => { + onDelete(memory.id); + toast({ + title: "Memory Deleted", + description: "Your memory has been successfully deleted.", + }); + }; + + return ( +
+ {isEditing ? ( + <> + setContent(e.target.value)} + className="flex-1" + /> + + + + ) : ( + <> + + + + + )} +
+ ); +} diff --git a/src/interface/web/app/settings/page.tsx b/src/interface/web/app/settings/page.tsx index 5983c5913..e14274132 100644 --- a/src/interface/web/app/settings/page.tsx +++ b/src/interface/web/app/settings/page.tsx @@ -15,6 +15,7 @@ import { Button } from "@/components/ui/button"; import { InputOTP, InputOTPGroup, InputOTPSlot } from "@/components/ui/input-otp"; import { Input } from "@/components/ui/input"; import { Card, CardContent, CardFooter, CardHeader } from "@/components/ui/card"; + import { DropdownMenu, DropdownMenuContent, @@ -23,9 +24,25 @@ import { DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; import { - AlertDialog, AlertDialogAction, AlertDialogCancel, - AlertDialogContent, AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, AlertDialogTrigger + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, + AlertDialogTrigger } from "@/components/ui/alert-dialog"; + +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogTrigger +} from "@/components/ui/dialog"; + import { Table, TableBody, TableCell, TableRow } from "@/components/ui/table"; import { @@ -67,6 +84,7 @@ import Loading from "../components/loading/loading"; import IntlTelInput from "intl-tel-input/react"; import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; import { AppSidebar } from "../components/appSidebar/appSidebar"; +import { UserMemory, UserMemorySchema } from "../components/userMemory/userMemory"; import { Separator } from "@/components/ui/separator"; import { KhojLogoType } from "../components/logo/khojLogo"; import { Progress } from "@/components/ui/progress"; @@ -308,6 +326,7 @@ export default function SettingsView() { const [numberValidationState, setNumberValidationState] = useState( PhoneNumberValidationState.Verified, ); + const [memories, setMemories] = useState([]); const [isExporting, setIsExporting] = useState(false); const [exportProgress, setExportProgress] = useState(0); const [exportedConversations, setExportedConversations] = useState(0); @@ -649,6 +668,65 @@ export default function SettingsView() { } }; + const fetchMemories = async () => { + try { + console.log("Fetching memories..."); + const response = await fetch('/api/memories/'); + if (!response.ok) throw new Error('Failed to fetch memories'); + const data = await response.json(); + setMemories(data); + } catch (error) { + console.error('Error fetching memories:', error); + toast({ + title: "Error", + description: "Failed to fetch memories. Please try again.", + variant: "destructive" + }); + } + }; + + const handleDeleteMemory = async (id: number) => { + try { + const response = await fetch(`/api/memories/${id}`, { + method: 'DELETE' + }); + if (!response.ok) throw new Error('Failed to delete memory'); + setMemories(memories.filter(memory => memory.id !== id)); + } catch (error) { + console.error('Error deleting memory:', error); + toast({ + title: "Error", + description: "Failed to delete memory. Please try again.", + variant: "destructive" + }); + } + }; + + const handleUpdateMemory = async (id: number, raw: string) => { + try { + const response = await fetch(`/api/memories/${id}`, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ raw, memory_id: id }), + }); + if (!response.ok) throw new Error('Failed to update memory'); + const updatedMemory: UserMemorySchema = await response.json(); + setMemories(memories.map(memory => + memory.id === id ? updatedMemory : memory + )); + } catch (error) { + console.error('Error updating memory:', error); + toast({ + title: "Error", + description: "Failed to update memory. Please try again.", + variant: "destructive" + }); + } + }; + + const syncContent = async (type: string) => { try { const response = await fetch(`/api/content?t=${type}`, { @@ -1237,7 +1315,45 @@ export default function SettingsView() { - + + + + Memories + + +

+ View and manage your long-term memories +

+
+ + open && fetchMemories()}> + + + + + + Your Memories + +
+ {memories.map((memory) => ( + + ))} + {memories.length === 0 && ( +

No memories found

+ )} +
+
+
+
+
From af0a32cf4fc4caa12b03e5da46385dc88d28dd69 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 21 Apr 2025 18:55:25 -0700 Subject: [PATCH 08/21] fix type of relevant_memories and remove return in ai_update_memories --- src/khoj/routers/api_chat.py | 2 +- src/khoj/routers/helpers.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index ddd0c9655..373ce790d 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1296,7 +1296,7 @@ def collect_telemetry(): user, request.user.client_app, conversation_id, - relevant_memories, + list(relevant_memories), location, user_name, researched_results, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index b4cdfaaee..a8e59e7cd 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -972,9 +972,7 @@ async def extract_facts_from_query( @require_valid_user -async def ai_update_memories( - user: KhojUser, conversation_history: dict, memories: List[UserMemory], tracer: dict = {} -) -> List[UserMemory]: +async def ai_update_memories(user: KhojUser, conversation_history: dict, memories: List[UserMemory], tracer: dict = {}): """ Updates the memories for a given user, based on their latest input query. """ @@ -983,7 +981,7 @@ async def ai_update_memories( ) if not new_data: - return [] + return # Save the new data to the database created_memories = new_data.create From e952165d3e5f01515b916df8900b874e084fc343 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 21 Apr 2025 20:34:17 -0700 Subject: [PATCH 09/21] wire up relevant_memories to helper methods where relevant wire up the relevant_memories data to helper functions as needed. e.g., image generation, online search, extract webpage content, run code, reserach mode --- src/khoj/processor/conversation/openai/gpt.py | 2 +- src/khoj/processor/conversation/utils.py | 15 +++++++- src/khoj/processor/image/generate.py | 4 ++- src/khoj/processor/tools/online_search.py | 34 ++++++++++++++++--- src/khoj/processor/tools/run_code.py | 5 ++- src/khoj/routers/api_chat.py | 11 ++++-- src/khoj/routers/helpers.py | 29 +++++++++++++++- src/khoj/routers/research.py | 15 +++++--- 8 files changed, 99 insertions(+), 16 deletions(-) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index b5fbdcf22..33dffbc87 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -7,7 +7,7 @@ from openai.lib._pydantic import _ensure_strict_json_schema from pydantic import BaseModel -from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.database.models import Agent, ChatModel, KhojUser, UserMemory from khoj.processor.conversation import prompts from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index bf3dd859b..9ad7849b5 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -302,7 +302,11 @@ async def save_to_conversation_log( def construct_structured_message( - message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str = None + message: str, + images: list[str], + model_type: str, + vision_enabled: bool, + attached_file_context: str = None, ): """ Format messages into appropriate multimedia format for supported chat model types @@ -359,6 +363,7 @@ def generate_chatml_messages_with_context( model_type="", context_message="", query_files: str = None, + relevant_memories: List[UserMemory] = None, generated_files: List[FileAttachment] = None, generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = [], @@ -453,6 +458,14 @@ def generate_chatml_messages_with_context( ) ) + if not is_none_or_empty(relevant_memories): + memory_context = "Here are some relevant memories about me stored in the system context:\n\n" + for memory in relevant_memories: + friendly_dt = memory.created_at.strftime("%Y-%m-%d %H:%M:%S") + memory_context += f"- {memory.raw} ({friendly_dt})\n" + memory_context += "\n" + messages.append(ChatMessage(content=memory_context, role="user")) + if not is_none_or_empty(user_message): messages.append( ChatMessage( diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index f1b84431f..0c773381d 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -10,7 +10,7 @@ from google.genai import types as gtypes from khoj.database.adapters import ConversationAdapters -from khoj.database.models import Agent, KhojUser, TextToImageModelConfig +from khoj.database.models import Agent, KhojUser, TextToImageModelConfig, UserMemory from khoj.routers.helpers import ChatEvent, generate_better_image_prompt from khoj.routers.storage import upload_generated_image_to_bucket from khoj.utils import state @@ -31,6 +31,7 @@ async def text_to_image( query_images: Optional[List[str]] = None, agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): status_code = 200 @@ -72,6 +73,7 @@ async def text_to_image( user=user, agent=agent, query_files=query_files, + relevant_memories=relevant_memories, tracer=tracer, ) diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 0564b65c2..d347928bc 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -10,7 +10,13 @@ from markdownify import markdownify from khoj.database.adapters import ConversationAdapters -from khoj.database.models import Agent, KhojUser, ServerChatSettings, WebScraper +from khoj.database.models import ( + Agent, + KhojUser, + ServerChatSettings, + UserMemory, + WebScraper, +) from khoj.processor.conversation import prompts from khoj.routers.helpers import ( ChatEvent, @@ -69,6 +75,7 @@ async def search_online( previous_subqueries: Set = set(), agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): query += " ".join(custom_filters) @@ -85,8 +92,9 @@ async def search_online( user, query_images=query_images, agent=agent, - tracer=tracer, query_files=query_files, + relevant_memories=relevant_memories, + tracer=tracer, ) subqueries = list(new_subqueries - previous_subqueries) response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {} @@ -161,7 +169,13 @@ async def search_online( yield {ChatEvent.STATUS: event} tasks = [ read_webpage_and_extract_content( - data["queries"], link, data.get("content"), user=user, agent=agent, tracer=tracer + data["queries"], + link, + data.get("content"), + user=user, + agent=agent, + relevant_memories=relevant_memories, + tracer=tracer, ) for link, data in webpages.items() ] @@ -367,6 +381,7 @@ async def read_webpages( agent: Agent = None, max_webpages_to_read: int = 1, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): "Infer web pages to read from the query and extract relevant information from them" @@ -380,6 +395,7 @@ async def read_webpages( query_images, agent=agent, query_files=query_files, + relevant_memories=relevant_memories, tracer=tracer, ) @@ -388,7 +404,14 @@ async def read_webpages( webpage_links_str = "\n- " + "\n- ".join(list(urls)) async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} - tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls] + + tasks = [ + read_webpage_and_extract_content( + {query}, url, user=user, agent=agent, relevant_memories=relevant_memories, tracer=tracer + ) + for url in urls + ] + results = await asyncio.gather(*tasks) response: Dict[str, Dict] = defaultdict(dict) @@ -419,6 +442,7 @@ async def read_webpage_and_extract_content( content: str = None, user: KhojUser = None, agent: Agent = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> Tuple[set[str], str, Union[None, str]]: # Select the web scrapers to use for reading the web page @@ -442,7 +466,7 @@ async def read_webpage_and_extract_content( if is_none_or_empty(extracted_info): with timer(f"Extracting relevant information from web page at '{url}' took", logger): extracted_info = await extract_relevant_info( - subqueries, content, user=user, agent=agent, tracer=tracer + subqueries, content, user=user, agent=agent, relevant_memories=relevant_memories, tracer=tracer ) # If we successfully extracted information, break the loop diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index e188da053..57a8af3b4 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -20,7 +20,7 @@ ) from khoj.database.adapters import FileObjectAdapters -from khoj.database.models import Agent, FileObject, KhojUser +from khoj.database.models import Agent, FileObject, KhojUser, UserMemory from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( ChatEvent, @@ -59,6 +59,7 @@ async def run_code( agent: Agent = None, sandbox_url: str = SANDBOX_URL, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): # Generate Code @@ -124,6 +125,7 @@ async def generate_python_code( agent: Agent = None, tracer: dict = {}, query_files: str = None, + relevant_memories: List[UserMemory] = None, ) -> GeneratedCode: location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" @@ -158,6 +160,7 @@ async def generate_python_code( user=user, tracer=tracer, query_files=query_files, + relevant_memories=relevant_memories, ) # Extract python code wrapped in markdown code blocks from the response diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 373ce790d..a76eecaae 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -841,7 +841,7 @@ def collect_telemetry(): long_term_memories = await sync_to_async(list)(long_term_memories) # Create a de-duped set of memories - relevant_memories = set(recent_memories + long_term_memories) + relevant_memories = list(set(recent_memories + long_term_memories)) researched_results = "" online_results: Dict = dict() @@ -868,6 +868,7 @@ def collect_telemetry(): query_images=uploaded_images, agent=agent, query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ) except ValueError as e: @@ -909,6 +910,7 @@ def collect_telemetry(): location=location, file_filters=conversation.file_filters if conversation else [], query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(research_result, InformationCollectionIteration): @@ -1091,6 +1093,7 @@ def collect_telemetry(): query_images=uploaded_images, agent=agent, query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1118,6 +1121,7 @@ def collect_telemetry(): query_images=uploaded_images, agent=agent, query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1159,6 +1163,7 @@ def collect_telemetry(): query_images=uploaded_images, agent=agent, query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1201,6 +1206,7 @@ def collect_telemetry(): query_images=uploaded_images, agent=agent, query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1246,6 +1252,7 @@ def collect_telemetry(): agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1296,7 +1303,7 @@ def collect_telemetry(): user, request.user.client_app, conversation_id, - list(relevant_memories), + relevant_memories, location, user_name, researched_results, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index a8e59e7cd..b38141e8b 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -347,6 +347,7 @@ async def aget_data_sources_and_output_format( query_images: List[str] = None, agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> Dict[str, Any]: """ @@ -410,6 +411,7 @@ class PickTools(BaseModel): response_schema=PickTools, user=user, query_files=query_files, + relevant_memories=relevant_memories, agent_chat_model=agent_chat_model, tracer=tracer, ) @@ -462,6 +464,7 @@ async def infer_webpage_urls( query_images: List[str] = None, agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> List[str]: """ @@ -499,6 +502,7 @@ class WebpageUrls(BaseModel): response_schema=WebpageUrls, user=user, query_files=query_files, + relevant_memories=relevant_memories, agent_chat_model=agent_chat_model, tracer=tracer, ) @@ -526,6 +530,7 @@ async def generate_online_subqueries( query_images: List[str] = None, agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> Set[str]: """ @@ -564,6 +569,7 @@ class OnlineQueries(BaseModel): response_schema=OnlineQueries, user=user, query_files=query_files, + relevant_memories=relevant_memories, agent_chat_model=agent_chat_model, tracer=tracer, ) @@ -641,7 +647,12 @@ async def aschedule_query( async def extract_relevant_info( - qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {} + qs: set[str], + corpus: str, + user: KhojUser = None, + agent: Agent = None, + relevant_memories: List[UserMemory] = None, + tracer: dict = {}, ) -> Union[str, None]: """ Extract relevant information for a given query from the target corpus @@ -667,6 +678,7 @@ async def extract_relevant_info( prompts.system_prompt_extract_relevant_information, user=user, agent_chat_model=agent_chat_model, + relevant_memories=relevant_memories, tracer=tracer, ) return response.strip() @@ -785,6 +797,7 @@ async def generate_excalidraw_diagram( agent: Agent = None, send_status_func: Optional[Callable] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): if send_status_func: @@ -801,6 +814,7 @@ async def generate_excalidraw_diagram( user=user, agent=agent, query_files=query_files, + relevant_memories=relevant_memories, tracer=tracer, ) @@ -836,6 +850,7 @@ async def generate_better_diagram_description( user: KhojUser = None, agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> str: """ @@ -880,6 +895,7 @@ async def generate_better_diagram_description( query_images=query_images, user=user, query_files=query_files, + relevant_memories=relevant_memories, agent_chat_model=agent_chat_model, tracer=tracer, ) @@ -1007,6 +1023,7 @@ async def generate_mermaidjs_diagram( agent: Agent = None, send_status_func: Optional[Callable] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): if send_status_func: @@ -1023,6 +1040,7 @@ async def generate_mermaidjs_diagram( user=user, agent=agent, query_files=query_files, + relevant_memories=relevant_memories, tracer=tracer, ) @@ -1052,6 +1070,7 @@ async def generate_better_mermaidjs_diagram_description( user: KhojUser = None, agent: Agent = None, query_files: str = None, + relevant_memories: List[Dict] = None, tracer: dict = {}, ) -> str: """ @@ -1096,6 +1115,7 @@ async def generate_better_mermaidjs_diagram_description( query_images=query_images, user=user, query_files=query_files, + relevant_memories=relevant_memories, agent_chat_model=agent_chat_model, tracer=tracer, ) @@ -1141,6 +1161,7 @@ async def generate_better_image_prompt( user: KhojUser = None, agent: Agent = None, query_files: str = "", + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> str: """ @@ -1199,6 +1220,7 @@ async def generate_better_image_prompt( query_images=query_images, user=user, query_files=query_files, + relevant_memories=relevant_memories, agent_chat_model=agent_chat_model, tracer=tracer, ) @@ -1219,6 +1241,7 @@ async def send_message_to_model_wrapper( query_images: List[str] = None, context: str = "", query_files: str = None, + relevant_memories: List[UserMemory] = None, agent_chat_model: ChatModel = None, tracer: dict = {}, ): @@ -1260,6 +1283,7 @@ async def send_message_to_model_wrapper( vision_enabled=vision_available, model_type=chat_model.model_type, query_files=query_files, + relevant_memories=relevant_memories, ) return send_message_to_model_offline( @@ -1287,6 +1311,7 @@ async def send_message_to_model_wrapper( query_images=query_images, model_type=chat_model.model_type, query_files=query_files, + relevant_memories=relevant_memories, ) return send_message_to_model( @@ -1313,6 +1338,7 @@ async def send_message_to_model_wrapper( query_images=query_images, model_type=chat_model.model_type, query_files=query_files, + relevant_memories=relevant_memories, ) return anthropic_send_message_to_model( @@ -1338,6 +1364,7 @@ async def send_message_to_model_wrapper( query_images=query_images, model_type=chat_model.model_type, query_files=query_files, + relevant_memories=relevant_memories, ) return gemini_send_message_to_model( diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index fa855b9c9..943a13f0d 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, Field from khoj.database.adapters import AgentAdapters, EntryAdapters -from khoj.database.models import Agent, KhojUser +from khoj.database.models import Agent, KhojUser, UserMemory from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( InformationCollectionIteration, @@ -85,8 +85,9 @@ async def apick_next_tool( previous_iterations: List[InformationCollectionIteration] = [], max_iterations: int = 5, send_status_func: Optional[Callable] = None, - tracer: dict = {}, query_files: str = None, + relevant_memories: List[UserMemory] = [], + tracer: dict = {}, ): """Given a query, determine which of the available tools the agent should use in order to answer appropriately.""" @@ -144,6 +145,7 @@ async def apick_next_tool( user=user, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, agent_chat_model=agent_chat_model, tracer=tracer, ) @@ -203,6 +205,7 @@ async def execute_information_collection( user_name: str = None, location: LocationData = None, file_filters: List[str] = [], + relevant_memories: List[UserMemory] = [], tracer: dict = {}, query_files: str = None, ): @@ -227,8 +230,9 @@ async def execute_information_collection( previous_iterations, MAX_ITERATIONS, send_status_func, - tracer=tracer, query_files=query_files, + relevant_memories=relevant_memories, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -304,6 +308,7 @@ async def execute_information_collection( query_images=query_images, previous_subqueries=previous_subqueries, agent=agent, + relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -328,8 +333,9 @@ async def execute_information_collection( max_webpages_to_read=1, query_images=query_images, agent=agent, - tracer=tracer, query_files=query_files, + relevant_memories=relevant_memories, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -362,6 +368,7 @@ async def execute_information_collection( query_images=query_images, agent=agent, query_files=query_files, + relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: From 8acbd1119c93525fdef710cbb230eb04406e2916 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 21 Apr 2025 20:35:51 -0700 Subject: [PATCH 10/21] fix typing of relevant_memories in generate diagram --- src/khoj/routers/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index b38141e8b..c8683020f 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1070,7 +1070,7 @@ async def generate_better_mermaidjs_diagram_description( user: KhojUser = None, agent: Agent = None, query_files: str = None, - relevant_memories: List[Dict] = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> str: """ From 0b336bb3e5dc77b36d7c3c9e3f795cf64002f39c Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 21 Apr 2025 21:23:16 -0700 Subject: [PATCH 11/21] pass memory context to extract references and final response agent --- .../conversation/anthropic/anthropic_chat.py | 6 ++++- .../conversation/google/gemini_chat.py | 6 ++++- .../conversation/offline/chat_model.py | 6 ++++- src/khoj/processor/conversation/openai/gpt.py | 4 ++++ src/khoj/processor/conversation/utils.py | 9 ++++++-- src/khoj/routers/api.py | 22 ++++++++++++++++++- src/khoj/routers/api_chat.py | 1 + src/khoj/routers/helpers.py | 4 ++++ src/khoj/routers/research.py | 3 ++- 9 files changed, 54 insertions(+), 7 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 5bad38ef5..5e44c219d 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -5,7 +5,7 @@ import pyjson5 from langchain.schema import ChatMessage -from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.database.models import Agent, ChatModel, KhojUser, UserMemory from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.utils import ( anthropic_chat_completion_with_backoff, @@ -41,6 +41,7 @@ def extract_questions_anthropic( vision_enabled: bool = False, personality_context: Optional[str] = None, query_files: str = None, + memory_context: Optional[str] = None, tracer: dict = {}, ): """ @@ -88,6 +89,7 @@ def extract_questions_anthropic( model_type=ChatModel.ModelType.ANTHROPIC, vision_enabled=vision_enabled, attached_file_context=query_files, + relevant_memories_context=memory_context, ) messages = [ChatMessage(content=content, role="user")] @@ -156,6 +158,7 @@ async def converse_anthropic( query_images: Optional[list[str]] = None, vision_available: bool = False, query_files: str = None, + relevant_memories: List[UserMemory] = None, generated_files: List[FileAttachment] = None, program_execution_context: Optional[List[str]] = None, generated_asset_results: Dict[str, Dict] = {}, @@ -226,6 +229,7 @@ async def converse_anthropic( vision_enabled=vision_available, model_type=ChatModel.ModelType.ANTHROPIC, query_files=query_files, + relevant_memories=relevant_memories, generated_files=generated_files, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 3c42ef067..7b52ba0ce 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -6,7 +6,7 @@ from langchain.schema import ChatMessage from pydantic import BaseModel, Field -from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.database.models import Agent, ChatModel, KhojUser, UserMemory from khoj.processor.conversation import prompts from khoj.processor.conversation.google.utils import ( gemini_chat_completion_with_backoff, @@ -42,6 +42,7 @@ def extract_questions_gemini( vision_enabled: bool = False, personality_context: Optional[str] = None, query_files: str = None, + memory_context: Optional[str] = None, tracer: dict = {}, ): """ @@ -89,6 +90,7 @@ def extract_questions_gemini( model_type=ChatModel.ModelType.GOOGLE, vision_enabled=vision_enabled, attached_file_context=query_files, + relevant_memories_context=memory_context, ) messages = [] @@ -180,6 +182,7 @@ async def converse_gemini( query_images: Optional[list[str]] = None, vision_available: bool = False, query_files: str = None, + relevant_memories: List[UserMemory] = None, generated_files: List[FileAttachment] = None, generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = None, @@ -251,6 +254,7 @@ async def converse_gemini( vision_enabled=vision_available, model_type=ChatModel.ModelType.GOOGLE, query_files=query_files, + relevant_memories=relevant_memories, generated_files=generated_files, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index b7f89c8d3..4969c3654 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -10,7 +10,7 @@ from langchain.schema import ChatMessage from llama_cpp import Llama -from khoj.database.models import Agent, ChatModel, KhojUser +from khoj.database.models import Agent, ChatModel, KhojUser, UserMemory from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( @@ -46,6 +46,7 @@ def extract_questions_offline( temperature: float = 0.7, personality_context: Optional[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> List[str]: """ @@ -97,6 +98,7 @@ def extract_questions_offline( max_prompt_size=max_prompt_size, model_type=ChatModel.ModelType.OFFLINE, query_files=query_files, + relevant_memories=relevant_memories, ) state.chat_lock.acquire() @@ -163,6 +165,7 @@ async def converse_offline( user_name: str = None, agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, generated_files: List[FileAttachment] = None, additional_context: List[str] = None, generated_asset_results: Dict[str, Dict] = {}, @@ -240,6 +243,7 @@ async def converse_offline( tokenizer_name=tokenizer_name, model_type=ChatModel.ModelType.OFFLINE, query_files=query_files, + relevant_memories=relevant_memories, generated_files=generated_files, generated_asset_results=generated_asset_results, program_execution_context=additional_context, diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 33dffbc87..663e97259 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -44,6 +44,7 @@ def extract_questions( vision_enabled: bool = False, personality_context: Optional[str] = None, query_files: str = None, + memory_context: str = None, tracer: dict = {}, ): """ @@ -89,6 +90,7 @@ def extract_questions( model_type=ChatModel.ModelType.OPENAI, vision_enabled=vision_enabled, attached_file_context=query_files, + relevant_memories_context=memory_context, ) messages = [] @@ -182,6 +184,7 @@ async def converse_openai( query_images: Optional[list[str]] = None, vision_available: bool = False, query_files: str = None, + relevant_memories: List[UserMemory] = None, generated_files: List[FileAttachment] = None, generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = None, @@ -254,6 +257,7 @@ async def converse_openai( vision_enabled=vision_available, model_type=ChatModel.ModelType.OPENAI, query_files=query_files, + relevant_memories=relevant_memories, generated_files=generated_files, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 9ad7849b5..a1e6d7e69 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -307,6 +307,7 @@ def construct_structured_message( model_type: str, vision_enabled: bool, attached_file_context: str = None, + relevant_memories_context: str = None, ): """ Format messages into appropriate multimedia format for supported chat model types @@ -316,13 +317,17 @@ def construct_structured_message( ChatModel.ModelType.GOOGLE, ChatModel.ModelType.ANTHROPIC, ]: - if not attached_file_context and not (vision_enabled and images): + if not any([images and vision_enabled, attached_file_context, relevant_memories_context]): return message constructed_messages: List[Any] = [{"type": "text", "text": message}] if not is_none_or_empty(attached_file_context): constructed_messages.append({"type": "text", "text": attached_file_context}) + + if not is_none_or_empty(relevant_memories_context): + constructed_messages.append({"type": "text", "text": relevant_memories_context}) + if vision_enabled and images: for image in images: constructed_messages.append({"type": "image_url", "image_url": {"url": image}}) @@ -459,7 +464,7 @@ def generate_chatml_messages_with_context( ) if not is_none_or_empty(relevant_memories): - memory_context = "Here are some relevant memories about me stored in the system context:\n\n" + memory_context = "Here are some relevant memories about me stored in the system context. You can ignore them if they are not relevant to the query:\n\n" for memory in relevant_memories: friendly_dt = memory.created_at.strftime("%Y-%m-%d %H:%M:%S") memory_context += f"- {memory.raw} ({friendly_dt})\n" diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index efba17ecb..3d9f18c0c 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -26,10 +26,17 @@ AutomationAdapters, ConversationAdapters, EntryAdapters, + UserMemoryAdapters, get_default_search_model, get_user_photo, ) -from khoj.database.models import Agent, ChatModel, KhojUser, SpeechToTextModelOptions +from khoj.database.models import ( + Agent, + ChatModel, + KhojUser, + SpeechToTextModelOptions, + UserMemory, +) from khoj.processor.conversation import prompts from khoj.processor.conversation.anthropic.anthropic_chat import ( extract_questions_anthropic, @@ -365,6 +372,7 @@ async def extract_references_and_questions( previous_inferred_queries: Set = set(), agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): # Initialize Variables @@ -413,6 +421,14 @@ async def extract_references_and_questions( personality_context = prompts.personality_context.format(personality=agent.personality) if agent else "" + memory_context = UserMemoryAdapters.convert_memories_to_dict(relevant_memories) if relevant_memories else None + if memory_context: + memory_context = "Here are some relevant memories about me stored in the system context:\n\n" + for memory in relevant_memories: + friendly_dt = memory.created_at.strftime("%Y-%m-%d %H:%M:%S") + memory_context += f"- {memory.raw} ({friendly_dt})\n" + logger.debug(memory_context) + # Infer search queries from user message with timer("Extracting search queries took", logger): # If we've reached here, either the user has enabled offline chat or the openai model is enabled. @@ -439,6 +455,7 @@ async def extract_references_and_questions( max_prompt_size=chat_model.max_prompt_size, personality_context=personality_context, query_files=query_files, + relevant_memories=relevant_memories, tracer=tracer, ) elif chat_model.model_type == ChatModel.ModelType.OPENAI: @@ -457,6 +474,7 @@ async def extract_references_and_questions( vision_enabled=vision_enabled, personality_context=personality_context, query_files=query_files, + memory_context=memory_context, tracer=tracer, ) elif chat_model.model_type == ChatModel.ModelType.ANTHROPIC: @@ -475,6 +493,7 @@ async def extract_references_and_questions( vision_enabled=vision_enabled, personality_context=personality_context, query_files=query_files, + memory_context=memory_context, tracer=tracer, ) elif chat_model.model_type == ChatModel.ModelType.GOOGLE: @@ -494,6 +513,7 @@ async def extract_references_and_questions( vision_enabled=vision_enabled, personality_context=personality_context, query_files=query_files, + memory_context=memory_context, tracer=tracer, ) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index a76eecaae..ac937b209 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1047,6 +1047,7 @@ def collect_telemetry(): query_images=uploaded_images, agent=agent, query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index c8683020f..2e0f91efc 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1592,6 +1592,7 @@ async def agenerate_chat_response( user_name=user_name, agent=agent, query_files=query_files, + relevant_memories=matching_memories, generated_files=raw_generated_files, generated_asset_results=generated_asset_results, tracer=tracer, @@ -1620,6 +1621,7 @@ async def agenerate_chat_response( agent=agent, vision_available=vision_available, query_files=query_files, + relevant_memories=matching_memories, generated_files=raw_generated_files, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, @@ -1649,6 +1651,7 @@ async def agenerate_chat_response( agent=agent, vision_available=vision_available, query_files=query_files, + relevant_memories=matching_memories, generated_files=raw_generated_files, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, @@ -1677,6 +1680,7 @@ async def agenerate_chat_response( query_images=query_images, vision_available=vision_available, query_files=query_files, + relevant_memories=matching_memories, generated_files=raw_generated_files, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 943a13f0d..dcf0347a0 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -266,8 +266,9 @@ async def execute_information_collection( query_images, previous_inferred_queries=previous_inferred_queries, agent=agent, - tracer=tracer, query_files=query_files, + relevant_memories=relevant_memories, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] From 90d193095543f047ed6d9ac15831e86c1f223cb3 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 21 Apr 2025 21:25:29 -0700 Subject: [PATCH 12/21] fix mismatched type of memory_context --- src/khoj/routers/api.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 3d9f18c0c..c284b22cb 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -421,10 +421,11 @@ async def extract_references_and_questions( personality_context = prompts.personality_context.format(personality=agent.personality) if agent else "" - memory_context = UserMemoryAdapters.convert_memories_to_dict(relevant_memories) if relevant_memories else None - if memory_context: + dict_memories = UserMemoryAdapters.convert_memories_to_dict(relevant_memories) if relevant_memories else None + memory_context = None + if dict_memories: memory_context = "Here are some relevant memories about me stored in the system context:\n\n" - for memory in relevant_memories: + for memory in dict_memories: friendly_dt = memory.created_at.strftime("%Y-%m-%d %H:%M:%S") memory_context += f"- {memory.raw} ({friendly_dt})\n" logger.debug(memory_context) From 35f7e5c3be3ee9047262aca08d0903fead05726f Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 21 Apr 2025 21:27:29 -0700 Subject: [PATCH 13/21] remove unnecessary preprocessing or memories todict --- src/khoj/routers/api.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index c284b22cb..965217c28 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -421,12 +421,11 @@ async def extract_references_and_questions( personality_context = prompts.personality_context.format(personality=agent.personality) if agent else "" - dict_memories = UserMemoryAdapters.convert_memories_to_dict(relevant_memories) if relevant_memories else None memory_context = None - if dict_memories: + if relevant_memories: memory_context = "Here are some relevant memories about me stored in the system context:\n\n" - for memory in dict_memories: - friendly_dt = memory.created_at.strftime("%Y-%m-%d %H:%M:%S") + for memory in relevant_memories: + friendly_dt = memory.updated_at.strftime("%Y-%m-%d %H:%M:%S") memory_context += f"- {memory.raw} ({friendly_dt})\n" logger.debug(memory_context) From 908e6edcbcf157d790a01f521bd9dbb1a8525545 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 21 Apr 2025 21:31:42 -0700 Subject: [PATCH 14/21] run final completion func steps asynchronously to avoid blocking the main thread --- src/khoj/processor/conversation/anthropic/anthropic_chat.py | 3 ++- src/khoj/processor/conversation/google/gemini_chat.py | 3 ++- src/khoj/processor/conversation/offline/chat_model.py | 2 +- src/khoj/processor/conversation/openai/gpt.py | 3 ++- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 5e44c219d..254acb683 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -1,3 +1,4 @@ +import asyncio import logging from datetime import datetime, timedelta from typing import AsyncGenerator, Dict, List, Optional @@ -255,4 +256,4 @@ async def converse_anthropic( # Call completion_func once finish streaming and we have the full response if completion_func: - await completion_func(chat_response=full_response) + asyncio.create_task(completion_func(chat_response=full_response)) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 7b52ba0ce..005a10a4a 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -1,3 +1,4 @@ +import asyncio import logging from datetime import datetime, timedelta from typing import AsyncGenerator, Dict, List, Optional @@ -279,4 +280,4 @@ async def converse_gemini( # Call completion_func once finish streaming and we have the full response if completion_func: - await completion_func(chat_response=full_response) + asyncio.create_task(completion_func(chat_response=full_response)) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 4969c3654..b2194c3f4 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -321,7 +321,7 @@ def _sync_llm_thread(): # Call the completion function after streaming is done if completion_func: - await completion_func(chat_response=aggregated_response_container["response"]) + asyncio.create_task(completion_func(chat_response=aggregated_response_container["response"])) def send_message_to_model_offline( diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 663e97259..57fa76381 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -1,3 +1,4 @@ +import asyncio import logging from datetime import datetime, timedelta from typing import AsyncGenerator, Dict, List, Optional @@ -281,7 +282,7 @@ async def converse_openai( # Call completion_func once finish streaming and we have the full response if completion_func: - await completion_func(chat_response=full_response) + asyncio.create_task(completion_func(chat_response=full_response)) def clean_response_schema(schema: BaseModel | dict) -> dict: From 7ada2916303f190bdcceead7466b0927a2f46687 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 27 Aug 2025 17:26:47 -0700 Subject: [PATCH 15/21] Reduce research tools shown user memories to doc search and researcher Other tools do not strictly require it. Reduce memory access to keep agent context in check. Remove unused memories context passed to construct_structured_message func --- src/khoj/processor/conversation/utils.py | 3 --- src/khoj/routers/research.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 2864b90c1..60b7933cf 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -532,7 +532,6 @@ def construct_structured_message( model_type: str = None, vision_enabled: bool = True, attached_file_context: str = None, - relevant_memories_context: str = None, ): """ Format messages into appropriate multimedia format for supported chat model types. @@ -550,8 +549,6 @@ def construct_structured_message( if vision_enabled and images: for image in images: constructed_messages += [{"type": "image_url", "image_url": {"url": image}}] - if not is_none_or_empty(relevant_memories_context): - constructed_messages.append({"type": "text", "text": relevant_memories_context}) return constructed_messages diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 1a0582d16..03302e32c 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -373,7 +373,6 @@ async def research( max_online_searches=max_online_searches, max_webpages_to_read=0, query_images=query_images, - relevant_memories=relevant_memories, previous_subqueries=previous_subqueries, agent=agent, tracer=tracer, @@ -397,7 +396,6 @@ async def research( send_status_func=send_status_func, # max_webpages_to_read=max_webpages_to_read, agent=agent, - relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -433,7 +431,6 @@ async def research( query_images=query_images, agent=agent, query_files=query_files, - relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: From 08b42c597caab07c4b7197ca637567316cfab6c1 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 27 Aug 2025 17:22:11 -0700 Subject: [PATCH 16/21] Scope memories separately for each agent Only default agent has access to memories across agents. Each agent has memory scoped to only its interaction with user. This allows you to keep memories formed from interacting with the health agent vs finance agent vs work agent vs personal agent separated --- src/khoj/database/adapters/__init__.py | 56 ++++++++++++------- .../0095_usermemory_delete_datastore.py | 12 +++- src/khoj/database/models/__init__.py | 3 +- src/khoj/processor/conversation/utils.py | 1 + src/khoj/routers/api_chat.py | 4 +- src/khoj/routers/helpers.py | 27 +++++---- 6 files changed, 70 insertions(+), 33 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 5152ac6db..f641b8089 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1521,12 +1521,17 @@ async def save_conversation( ): slug = user_message.strip()[:200] if user_message else None if conversation_id: - conversation = await Conversation.objects.filter( - user=user, client=client_application, id=conversation_id - ).afirst() + conversation = ( + await Conversation.objects.filter(user=user, client=client_application, id=conversation_id) + .prefetch_related("agent", "agent__chat_model") + .afirst() + ) else: conversation = ( - await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() + await Conversation.objects.filter(user=user, client=client_application) + .prefetch_related("agent", "agent__chat_model") + .order_by("-updated_at") + .afirst() ) existing_messages = conversation.messages if conversation else [] @@ -2144,51 +2149,64 @@ def delete_automation(user: KhojUser, automation_id: str): class UserMemoryAdapters: @staticmethod @require_valid_user - async def pull_memories(user: KhojUser, window=10, limit=5) -> list[UserMemory]: + async def pull_memories(user: KhojUser, agent: Agent = None, window=10, limit=5) -> list[UserMemory]: """ Pulls memories from the database for a given user. Medium term memory. """ time_frame = datetime.now(timezone.utc) - timedelta(days=window) - memories = UserMemory.objects.filter(user=user, updated_at__gte=time_frame).order_by("-created_at")[:limit] + default_agent = await AgentAdapters.aget_default_agent() + if agent and agent != default_agent: + memories = UserMemory.objects.filter(user=user, agent=agent, updated_at__gte=time_frame).order_by( + "-created_at" + )[:limit] + else: + memories = UserMemory.objects.filter(user=user, updated_at__gte=time_frame).order_by("-created_at")[:limit] return await sync_to_async(list)(memories) @staticmethod @require_valid_user - async def save_memory(user: KhojUser, memory: str) -> UserMemory: + async def save_memory(user: KhojUser, memory: str, agent: Agent = None) -> UserMemory: """ Saves a memory to the database for a given user. """ embeddings_model = state.embeddings_model model = await aget_default_search_model() - embeddings = await sync_to_async(embeddings_model[model.name].embed_query)(memory) - memory_instance = await UserMemory.objects.acreate( - user=user, embeddings=embeddings, raw=memory, search_model=model - ) + default_agent = await AgentAdapters.aget_default_agent() + if agent and agent != default_agent: + memory_instance = await UserMemory.objects.acreate( + user=user, embeddings=embeddings, raw=memory, search_model=model, agent=agent + ) + else: + memory_instance = await UserMemory.objects.acreate( + user=user, embeddings=embeddings, raw=memory, search_model=model + ) return memory_instance @staticmethod @require_valid_user - async def search_memories(user: KhojUser, query: str) -> list[UserMemory]: + async def search_memories(user: KhojUser, query: str, agent: Agent = None) -> list[UserMemory]: """ Searches for memories in the database for a given user. Long term memory. """ embeddings_model = state.embeddings_model model = await aget_default_search_model() - max_distance = model.bi_encoder_confidence_threshold or math.inf - embedded_query = await sync_to_async(embeddings_model[model.name].embed_query)(query) + default_agent = await AgentAdapters.aget_default_agent() + + if agent and agent != default_agent: + relevant_memories = UserMemory.objects.filter(user=user, agent=agent) + else: + relevant_memories = UserMemory.objects.filter(user=user) relevant_memories = ( - UserMemory.objects.filter(user=user) - .annotate(distance=CosineDistance("embeddings", embedded_query)) + relevant_memories.annotate(distance=CosineDistance("embeddings", embedded_query)) .order_by("distance") + .filter(distance__lte=max_distance) ) - relevant_memories = relevant_memories.filter(distance__lte=max_distance) - return await sync_to_async(list)(relevant_memories[:10]) @staticmethod @@ -2205,7 +2223,7 @@ async def delete_memory(user: KhojUser, memory_id: str) -> bool: return False @staticmethod - def convert_memories_to_dict(memories: List[UserMemory]) -> List[dict]: + def to_dict(memories: List[UserMemory]) -> List[dict]: """ Converts a list of Memory objects to a list of dictionaries. """ diff --git a/src/khoj/database/migrations/0095_usermemory_delete_datastore.py b/src/khoj/database/migrations/0095_usermemory_delete_datastore.py index 9af81513a..e7ab4d1a5 100644 --- a/src/khoj/database/migrations/0095_usermemory_delete_datastore.py +++ b/src/khoj/database/migrations/0095_usermemory_delete_datastore.py @@ -1,4 +1,4 @@ -# Generated by Django 5.1.10 on 2025-08-27 22:44 +# Generated by Django 5.1.10 on 2025-08-28 00:21 import django.db.models.deletion import pgvector.django @@ -28,6 +28,16 @@ class Migration(migrations.Migration): ("updated_at", models.DateTimeField(auto_now=True)), ("embeddings", pgvector.django.VectorField()), ("raw", models.TextField()), + ( + "agent", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="database.agent", + ), + ), ( "search_model", models.ForeignKey( diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 56640ee59..ae08531bc 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -801,10 +801,11 @@ def __str__(self): class UserMemory(DbBaseModel): """ - A class to represent a memory storage model for longer term memories + Long term memory store derived from conversation between user and agent. """ user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) + agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True) embeddings = VectorField(dimensions=None) raw = models.TextField() search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 60b7933cf..3e2663b77 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -510,6 +510,7 @@ async def save_to_conversation_log( user=user, conversation_history=new_messages or [], memories=relevant_memories, + agent=db_conversation.agent if db_conversation else None, tracer=tracer, ) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 30426e11b..e1c9cf2c7 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -964,8 +964,8 @@ def collect_telemetry(): chat_history = conversation.messages # Get most recent memories and long term relevant memories - recent_memories = await UserMemoryAdapters.pull_memories(user) - long_term_memories = await UserMemoryAdapters.search_memories(user=user, query=q) + recent_memories = await UserMemoryAdapters.pull_memories(user, agent=agent) + long_term_memories = await UserMemoryAdapters.search_memories(user=user, query=q, agent=agent) # Create a de-duped set of memories relevant_memories = list(set(recent_memories + long_term_memories)) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 2be463fb9..44ba66c6b 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -973,6 +973,7 @@ async def extract_facts_from_query( user: KhojUser, conversation_history: List[ChatMessageModel], existing_facts: List[UserMemory] = None, + agent: Agent = None, tracer: dict = {}, ) -> ExtractedFacts: """ @@ -980,7 +981,7 @@ async def extract_facts_from_query( """ chat_history = construct_chat_history(conversation_history, n=2) - formatted_memories = UserMemoryAdapters.convert_memories_to_dict(existing_facts) if existing_facts else [] + formatted_memories = UserMemoryAdapters.to_dict(existing_facts) if existing_facts else [] extract_facts_prompt = prompts.extract_facts_from_query.format( chat_history=chat_history, @@ -988,7 +989,9 @@ async def extract_facts_from_query( ) with timer("Chat actor: Extract facts from query", logger): - response = await send_message_to_model_wrapper(extract_facts_prompt, user=user, tracer=tracer) + response = await send_message_to_model_wrapper( + extract_facts_prompt, user=user, agent_chat_model=agent.chat_model, tracer=tracer + ) response = response.text.strip() # JSON parse the list of strings try: @@ -1006,13 +1009,17 @@ async def extract_facts_from_query( @require_valid_user async def ai_update_memories( - user: KhojUser, conversation_history: List[ChatMessageModel], memories: List[UserMemory], tracer: dict = {} + user: KhojUser, + conversation_history: List[ChatMessageModel], + memories: List[UserMemory], + agent: Agent, + tracer: dict = {}, ): """ Updates the memories for a given user, based on their latest input query. """ new_data = await extract_facts_from_query( - user=user, conversation_history=conversation_history, existing_facts=memories, tracer=tracer + user=user, conversation_history=conversation_history, existing_facts=memories, agent=agent, tracer=tracer ) if not new_data: @@ -1022,13 +1029,13 @@ async def ai_update_memories( created_memories = new_data.create deleted_memories = new_data.delete - for m in created_memories: - logger.info(f"Creating memory: {m}") - await UserMemoryAdapters.save_memory(user, m) + for memory in created_memories: + logger.info(f"Creating memory: {memory}") + await UserMemoryAdapters.save_memory(user, memory, agent=agent) - for m in deleted_memories: - logger.info(f"Deleting memory: {m}") - await UserMemoryAdapters.delete_memory(user, m) + for memory in deleted_memories: + logger.info(f"Deleting memory: {memory}") + await UserMemoryAdapters.delete_memory(user, memory) async def generate_mermaidjs_diagram( From fd58d95dcb30d26569d377aa0acf9b21e55b627c Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 28 Aug 2025 18:48:26 -0700 Subject: [PATCH 17/21] Improve memory manager prompt, enforce json schema for output - Make example and actual data in prompt more aligned with each other - Use json schema enforcement to have consistent output from model - Tune prompt associated with showing relevant user memories to model --- src/khoj/database/adapters/__init__.py | 4 +- src/khoj/processor/conversation/prompts.py | 69 ++++++++++++---------- src/khoj/processor/conversation/utils.py | 6 +- src/khoj/routers/helpers.py | 38 ++++++------ 4 files changed, 63 insertions(+), 54 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index f641b8089..93a0bf133 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -2229,9 +2229,9 @@ def to_dict(memories: List[UserMemory]) -> List[dict]: """ return [ { - "id": memory.id, + "id": f"{memory.id}", "raw": memory.raw, - "updated_at": memory.updated_at, + "updated_at": memory.updated_at.astimezone(timezone.utc).isoformat(timespec="seconds"), } for memory in memories ] diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index b7ab4f2e6..f84f37c94 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -1306,41 +1306,47 @@ extract_facts_from_query = PromptTemplate.from_template( """ -Given a query, extract the facts *related to the user* from the query. This is in order to construct a robust memory of who the user is, their interests, their life circumstances, events in their life, their personal motivations. +You are Muninn, the user's memory manager. Construct and maintain an accurate, up-to-date set of facts about and on behalf of the user. +This can include who the user is, their interests, their life circumstances, events in their life, their personal motivations and any facts that the user explicitly asks you to remember. -You will be provided a subset of the existing facts that are already stored for the user, and potentially relevant to the query. You have two possible actions: +You are given the latest chat session and some previously stored facts about the user. You can take two kinds of action: 1. Create new facts 2. Delete existing facts -You may use the existing facts to enhance the new facts that you're creating. You may also choose to delete existing facts that are no longer relevant. You cannot update existing facts; you can only create new facts or delete existing ones. +You should delete existing facts that are no longer true. +You can enhance new facts with information from existing facts. +You cannot update existing facts directly, instead create new facts and delete related existing ones to update them. -To create a new fact, add it to the create array. Do not create an ID. If you have nothing to create, leave the create array empty. Use first person perspective when creating new facts. - -To delete a fact, specify the fact's ID in the delete array. If you have nothing to delete, leave the delete array empty. You must delete anything that is no longer relevant or true about the user. +Your output should be a JSON object with two lists: create and delete. +- The create list should contain important, new facts *related to the user* to be added. Each fact should be atomic, self-contained and written in the user's first person perspective. +- The delete list should contain IDs of existing facts to be deleted. You must delete all facts that are no longer relevant or true. +- Leave the create or delete list empty if you have nothing important to add or remove. # Example Existing Facts: -{{ - "facts": [ - {{ - "id": "abc", - "raw": "I am not interested in sports", - "updated_at": "2023-10-01T12:00:00Z" - }}, - {{ - "id": "def", - "raw": "I am a software engineer" - "updated_at": "2023-10-31T14:00:00Z" - }}, - {{ - "id": "ghi", - "raw": "My mother works at the hospital", - "updated_at": "2023-10-02T17:00:00Z" - }} - ] -}} - -Input Query: I had an amazing day today! I was replicating this core AI paper, but ran into some issues with the training pipeline. In between coding, I took my cat Whiskers out for a walk and played a game of football. My mom called me in between her shift at the hospital (she's a doctor), so we had a nice chat. +[ + {{ + "id": "5283", + "raw": "I am not interested in sports", + "updated_at": "2023-10-01T12:00:00+00:00" + }}, + {{ + "id": "22", + "raw": "I am a software engineer", + "updated_at": "2023-10-31T14:00:00+00:00" + }}, + {{ + "id": "651", + "raw": "My mother works at the hospital", + "updated_at": "2023-10-02T17:00:00+00:00" + }} +] + +Latest Chat Session: +- User: I had an amazing day today! I was replicating this core AI paper, but ran into some issues with the training pipeline. +In between coding, I took my cat Whiskers out for a walk and played a game of football. +My mom called me in between her shift at the hospital (she's a doctor), so we had a nice chat. +- AI: That's great to hear! Response: {{ @@ -1351,17 +1357,16 @@ "My mother works at the hospital and is a doctor" ], "delete": [ - "abc", - "ghi" + "5283", + "651" ], }} # Input -These are some potentially related facts: +Existing Facts: {matched_facts} -Conversation History: +Latest Chat Session: {chat_history} - """.strip() ) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 3e2663b77..8419f3609 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -717,11 +717,11 @@ def generate_chatml_messages_with_context( ) if not is_none_or_empty(relevant_memories): - memory_context = "Here are some relevant memories about me stored in the system context. You can ignore them if they are not relevant to the query:\n\n" + memory_context = "Your memory system retrieved the following memories about me based on our previous conversations. Ignore them if they are not relevant to the query.\n\n" for memory in relevant_memories: friendly_dt = memory.created_at.strftime("%Y-%m-%d %H:%M:%S") - memory_context += f"- {memory.raw} ({friendly_dt})\n" - memory_context += "\n" + memory_context += f"- [{friendly_dt}]: {memory.raw}\n" + memory_context += "" messages.append(ChatMessage(content=memory_context, role="user")) if not is_none_or_empty(user_message): diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 44ba66c6b..018913b45 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -964,9 +964,11 @@ async def generate_excalidraw_diagram_from_description( return response -class ExtractedFacts(BaseModel): - create: List[str] = Field(..., min_items=0) - delete: List[str] = Field(..., min_items=0) +class MemoryUpdates(BaseModel): + """Facts to add or remove from memory.""" + + create: List[str] = Field(..., min_items=0, description="List of facts to add to memory.") + delete: List[str] = Field(..., min_items=0, description="List of facts to remove from memory.") async def extract_facts_from_query( @@ -975,13 +977,13 @@ async def extract_facts_from_query( existing_facts: List[UserMemory] = None, agent: Agent = None, tracer: dict = {}, -) -> ExtractedFacts: +) -> MemoryUpdates: """ Extract facts from the given query """ chat_history = construct_chat_history(conversation_history, n=2) - formatted_memories = UserMemoryAdapters.to_dict(existing_facts) if existing_facts else [] + formatted_memories = json.dumps(UserMemoryAdapters.to_dict(existing_facts), indent=2) if existing_facts else [] extract_facts_prompt = prompts.extract_facts_from_query.format( chat_history=chat_history, @@ -990,21 +992,26 @@ async def extract_facts_from_query( with timer("Chat actor: Extract facts from query", logger): response = await send_message_to_model_wrapper( - extract_facts_prompt, user=user, agent_chat_model=agent.chat_model, tracer=tracer + extract_facts_prompt, + response_schema=MemoryUpdates, + user=user, + fast_model=False, + agent_chat_model=agent.chat_model, + tracer=tracer, ) response = response.text.strip() # JSON parse the list of strings try: response = clean_json(response) response = json.loads(response) - parsed_response = ExtractedFacts(**response) - if not isinstance(parsed_response, ExtractedFacts): + parsed_response = MemoryUpdates(**response) + if not isinstance(parsed_response, MemoryUpdates): raise ValueError(f"Invalid response for extracting facts: {response}") return parsed_response except Exception: logger.error(f"Invalid response for extracting facts: {response}") - return ExtractedFacts(create=[], delete=[]) + return MemoryUpdates(create=[], delete=[]) @require_valid_user @@ -1018,22 +1025,19 @@ async def ai_update_memories( """ Updates the memories for a given user, based on their latest input query. """ - new_data = await extract_facts_from_query( + memory_update = await extract_facts_from_query( user=user, conversation_history=conversation_history, existing_facts=memories, agent=agent, tracer=tracer ) - if not new_data: + if not memory_update: return - # Save the new data to the database - created_memories = new_data.create - deleted_memories = new_data.delete - - for memory in created_memories: + # Save the memory updates to the database + for memory in memory_update.create: logger.info(f"Creating memory: {memory}") await UserMemoryAdapters.save_memory(user, memory, agent=agent) - for memory in deleted_memories: + for memory in memory_update.delete: logger.info(f"Deleting memory: {memory}") await UserMemoryAdapters.delete_memory(user, memory) From 4c497e90bb13e94c2fc36ff819a174794ea34700 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 28 Aug 2025 18:57:21 -0700 Subject: [PATCH 18/21] Fix passing retrieved memories to memory manager for updation context The merge or master branch had broken passing retrieved memories to the memory manager. This is required for the manager to correctly create, delete and update memories --- src/khoj/routers/api_chat.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index e1c9cf2c7..5e42a4cac 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -1451,6 +1451,7 @@ def collect_telemetry(): query_images=uploaded_images, train_of_thought=train_of_thought, raw_query_files=raw_query_files, + relevant_memories=relevant_memories, generated_images=generated_images, generated_mermaidjs_diagram=generated_mermaidjs_diagram, tracer=tracer, From ef737fd6acf2b0e327988871a9073abcb55f16dd Mon Sep 17 00:00:00 2001 From: Debanjum Date: Thu, 28 Aug 2025 19:13:23 -0700 Subject: [PATCH 19/21] Increase memories limit. Make it configurable. Improve memory dedupe --- src/khoj/database/adapters/__init__.py | 6 +++--- src/khoj/routers/api_chat.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 93a0bf133..66a521f78 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -2149,7 +2149,7 @@ def delete_automation(user: KhojUser, automation_id: str): class UserMemoryAdapters: @staticmethod @require_valid_user - async def pull_memories(user: KhojUser, agent: Agent = None, window=10, limit=5) -> list[UserMemory]: + async def pull_memories(user: KhojUser, agent: Agent = None, limit=10, window=7) -> list[UserMemory]: """ Pulls memories from the database for a given user. Medium term memory. """ @@ -2186,7 +2186,7 @@ async def save_memory(user: KhojUser, memory: str, agent: Agent = None) -> UserM @staticmethod @require_valid_user - async def search_memories(user: KhojUser, query: str, agent: Agent = None) -> list[UserMemory]: + async def search_memories(query: str, user: KhojUser, agent: Agent = None, limit: int = 10) -> list[UserMemory]: """ Searches for memories in the database for a given user. Long term memory. """ @@ -2207,7 +2207,7 @@ async def search_memories(user: KhojUser, query: str, agent: Agent = None) -> li .filter(distance__lte=max_distance) ) - return await sync_to_async(list)(relevant_memories[:10]) + return await sync_to_async(list)(relevant_memories[:limit]) @staticmethod @require_valid_user diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 5e42a4cac..7e930dfc0 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -964,11 +964,11 @@ def collect_telemetry(): chat_history = conversation.messages # Get most recent memories and long term relevant memories - recent_memories = await UserMemoryAdapters.pull_memories(user, agent=agent) - long_term_memories = await UserMemoryAdapters.search_memories(user=user, query=q, agent=agent) + recent_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent) + long_term_memories = await UserMemoryAdapters.search_memories(query=q, user=user, agent=agent) # Create a de-duped set of memories - relevant_memories = list(set(recent_memories + long_term_memories)) + relevant_memories = list({m.id: m for m in recent_memories + long_term_memories}.values()) # If interrupted message in DB if last_message := await conversation.pop_message(interrupted=True): From f79dfbbc77f8c6ba66b699075a8fa47b2ae77012 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Fri, 29 Aug 2025 15:57:47 -0700 Subject: [PATCH 20/21] Do not delete data store in memory feature as unrelated --- ...5_usermemory_delete_datastore.py => 0095_usermemory.py} | 5 +---- src/khoj/database/models/__init__.py | 7 +++++++ 2 files changed, 8 insertions(+), 4 deletions(-) rename src/khoj/database/migrations/{0095_usermemory_delete_datastore.py => 0095_usermemory.py} (94%) diff --git a/src/khoj/database/migrations/0095_usermemory_delete_datastore.py b/src/khoj/database/migrations/0095_usermemory.py similarity index 94% rename from src/khoj/database/migrations/0095_usermemory_delete_datastore.py rename to src/khoj/database/migrations/0095_usermemory.py index e7ab4d1a5..5fdba2eb4 100644 --- a/src/khoj/database/migrations/0095_usermemory_delete_datastore.py +++ b/src/khoj/database/migrations/0095_usermemory.py @@ -1,4 +1,4 @@ -# Generated by Django 5.1.10 on 2025-08-28 00:21 +# Generated by Django 5.1.10 on 2025-08-29 22:57 import django.db.models.deletion import pgvector.django @@ -60,7 +60,4 @@ class Migration(migrations.Migration): "abstract": False, }, ), - migrations.DeleteModel( - name="DataStore", - ), ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index ae08531bc..7d06c6abd 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -799,6 +799,13 @@ def __str__(self): return f"{self.slug} - {self.identifier} at {self.created_at}" +class DataStore(DbBaseModel): + key = models.CharField(max_length=200, unique=True) + value = models.JSONField(default=dict) + private = models.BooleanField(default=False) + owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True) + + class UserMemory(DbBaseModel): """ Long term memory store derived from conversation between user and agent. From 89f467a9f2edbf7d425e7a60f90f1ebfbc2dd450 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Fri, 29 Aug 2025 17:20:31 -0700 Subject: [PATCH 21/21] Create management job to delete, generate memories from old chats - Control lookback days, users to generate memories for - Make job resumable to handle generating memories from lots of chat - Process newer chats after oldest ones to maintain chat dependency order - Support both generating and deleting memories Usage: --apply to actually apply changes, default is to dry-run --users=user1@email1.com,user2@email2.com to limit to specific users --lookback-days=7 to lookback N days --force to run for previously processed turns, else resumes from last processed Examples - Memory Generation - Generate for last 7 days (default) python src/khoj/manage.py manage_memories --apply - Generate for specific period python src/khoj/manage.py manage_memories --lookback-days=30 --apply - Memory Deletion - Delete ALL memories python src/khoj/manage.py manage_memories --delete --apply - Delete only recent memories (last 30 days) python src/khoj/manage.py manage_memories --delete --lookback-days=30 --apply --- .../management/commands/manage_memories.py | 384 ++++++++++++++++++ 1 file changed, 384 insertions(+) create mode 100644 src/khoj/database/management/commands/manage_memories.py diff --git a/src/khoj/database/management/commands/manage_memories.py b/src/khoj/database/management/commands/manage_memories.py new file mode 100644 index 000000000..62e14c978 --- /dev/null +++ b/src/khoj/database/management/commands/manage_memories.py @@ -0,0 +1,384 @@ +import asyncio +from datetime import datetime, timedelta +from typing import List, Optional + +from django.core.management.base import BaseCommand +from django.db.models import Q +from django.utils import timezone + +from khoj.configure import initialize_server +from khoj.database.adapters import UserMemoryAdapters +from khoj.database.models import ( + Conversation, + DataStore, + KhojUser, +) +from khoj.routers.helpers import extract_facts_from_query + + +class Command(BaseCommand): + help = "Manage user memories - generate from conversations or delete existing memories" + + def add_arguments(self, parser): + parser.add_argument( + "--lookback-days", + type=int, + default=None, + help="Number of days to look back. For generation: defaults to 7 days. For deletion: if not specified, deletes ALL memories", + ) + parser.add_argument( + "--users", + type=str, + help="Process specific users (comma-separated usernames or emails)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=10, + help="Number of conversations to process in each batch (default: 10)", + ) + parser.add_argument( + "--apply", + action="store_true", + help="Actually perform the operation. Without this flag, only shows what would be processed.", + ) + parser.add_argument( + "--delete", + action="store_true", + help="Delete all memories for specified users instead of generating new ones", + ) + parser.add_argument( + "--resume", + action="store_true", + help="Resume from last checkpoint if process was interrupted", + ) + parser.add_argument( + "--force", + action="store_true", + help="Force regenerate memories even if already processed", + ) + + def handle(self, *args, **options): + """Main entry point for the command""" + initialize_server() + asyncio.run(self.async_handle(*args, **options)) + + async def async_handle(self, *args, **options): + """Async handler for memory management""" + lookback_days = options["lookback_days"] + usernames = options["users"] + batch_size = options["batch_size"] + apply = options["apply"] + delete = options["delete"] + resume = options["resume"] + force = options["force"] + + mode = "APPLY" if apply else "DRY RUN" + + # Handle deletion mode + if delete: + # For deletion, only use cutoff_date if lookback_days is explicitly provided + cutoff_date = timezone.now() - timedelta(days=lookback_days) if lookback_days else None + await self.handle_delete_memories(usernames, cutoff_date, apply) + return + + # Handle generation mode + # For generation, default to 7 days if not specified + if lookback_days is None: + lookback_days = 7 + cutoff_date = timezone.now() - timedelta(days=lookback_days) + self.stdout.write(f"[{mode}] Generating memories for conversations from the last {lookback_days} days") + + # Get users to process + users = await self.get_users_to_process(usernames) + if not users: + self.stdout.write("No users found to process") + return + + self.stdout.write(f"Found {len(users)} users to process") + + # Initialize or retrieve checkpoint + checkpoint = await self.get_or_create_checkpoint(resume) + + total_conversations = 0 + total_memories = 0 + + for user in users: + # Check if user already processed in checkpoint + if not force and user.id in checkpoint.get("processed_users", []): + self.stdout.write(f"Skipping already processed user: {user.username}") + continue + + self.stdout.write(f"\nProcessing user: {user.username} (ID: {user.id})") + + # Get conversations for this user + conversations = await self.get_user_conversations(user, cutoff_date, checkpoint, force) + + if not conversations: + self.stdout.write(f" No conversations to process for {user.username}") + # Mark user as processed + if apply: + await self.update_checkpoint(checkpoint, user_id=user.id) + continue + + self.stdout.write(f" Found {len(conversations)} conversations to process") + + # Process conversations in batches + user_memories = 0 + for i in range(0, len(conversations), batch_size): + batch = conversations[i : i + batch_size] + batch_memories = await self.process_conversation_batch(user, batch, apply, checkpoint) + user_memories += batch_memories + total_conversations += len(batch) + + # Update progress + progress = min(i + batch_size, len(conversations)) + self.stdout.write( + f" Processed {progress}/{len(conversations)} conversations, generated {batch_memories} memories" + ) + + total_memories += user_memories + self.stdout.write( + f" Completed user {user.username}: " + f"processed {len(conversations)} conversations, " + f"generated {user_memories} memories" + ) + + # Mark user as processed + if apply: + await self.update_checkpoint(checkpoint, user_id=user.id) + + # Clear checkpoint on successful completion + if apply: + await self.clear_checkpoint() + + action = "Generated" if apply else "Would generate" + self.stdout.write( + self.style.SUCCESS(f"\n{action} {total_memories} memories from {total_conversations} conversations") + ) + + async def get_users_to_process(self, users_str: Optional[str]) -> List[KhojUser]: + """Get list of users to comma separated usernames or emails to process""" + if users_str: + usernames = [u.strip() for u in users_str.split(",") if u.strip()] + # Process specific users + users = [user async for user in KhojUser.objects.filter(Q(username__in=usernames) | Q(email__in=usernames))] + return users + else: + # Process all users with conversations + return [user async for user in KhojUser.objects.filter(conversation__isnull=False).distinct()] + + async def get_user_conversations( + self, user: KhojUser, cutoff_date: Optional[datetime], checkpoint: dict, force: bool + ) -> List[Conversation]: + """Get conversations for a user that need processing""" + if cutoff_date is None: + query = Conversation.objects.filter(user=user).order_by("updated_at") + else: + query = Conversation.objects.filter(user=user, updated_at__gte=cutoff_date).order_by("updated_at") + + # Filter out already processed conversations if resuming + if not force and user.id in checkpoint.get("processed_conversations", {}): + processed_ids = checkpoint["processed_conversations"][user.id] + query = query.exclude(id__in=processed_ids) + + return [conv async for conv in query] + + async def process_conversation_batch( + self, user: KhojUser, conversations: List[Conversation], apply: bool, checkpoint: dict + ) -> int: + """Process a batch of conversations and generate memories""" + total_memories = 0 + + for conversation in conversations: + try: + # Get conversation messages using sync_to_async for property access + from asgiref.sync import sync_to_async + + # Access conversation_log synchronously + @sync_to_async + def get_messages(): + return conversation.messages + + messages = await get_messages() + if not messages: + continue + + # Get agent if conversation has one + @sync_to_async + def get_agent(): + return conversation.agent + + agent = await get_agent() + + # Get existing memories for context + # Process each conversation turn + conversation_memories = 0 + i = 0 + while i + 1 < len(messages): + # Only process user-assistant pairs as a valid turn for memory extraction + if messages[i].by != "you" or messages[i + 1].by != "khoj": + i += 1 + continue + + # Get the conversation history up to this point + history = messages[: i + 2] + + # Extract user query text for memory search + q = "" + if messages[i].message is None: + i += 1 + continue + elif isinstance(messages[i].message, str): + q = messages[i].message + elif isinstance(messages[i].message, list): + q = "\n\n".join( + content.get("text", "") + for content in messages[i].message + if isinstance(content, dict) and content.get("text") + ) + + if not q or not q.strip(): + i += 1 + continue + + # Get unique recent and long term relevant memories + recent_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent) + long_term_memories = await UserMemoryAdapters.search_memories(query=q, user=user, agent=agent) + relevant_memories = list({m.id: m for m in recent_memories + long_term_memories}.values()) + + if apply: + # Ensure agent is fully loaded with its chat_model + if agent: + + @sync_to_async + def load_agent_with_chat_model(): + # Force load the chat_model relationship + _ = agent.chat_model + return agent + + agent = await load_agent_with_chat_model() + + # Update memories based on latest conversation turn + memory_updates = await extract_facts_from_query( + user=user, + conversation_history=history, + existing_facts=relevant_memories, + agent=agent, + tracer={}, + ) + + # Save new memories + for memory in memory_updates.create: + await UserMemoryAdapters.save_memory(user, memory, agent=agent) + conversation_memories += 1 + self.stdout.write(f"Created memory for user {user.id}: {memory[:50]}...") + + # Delete outdated memories + for memory in memory_updates.delete: + await UserMemoryAdapters.delete_memory(user, memory) + self.stdout.write(f"Deleted memory for user {user.id}: {memory[:50]}...") + else: + # Dry run - estimate memories that would be created + conversation_memories += 1 # Rough estimate + + # Move to next conversation turn pair + i += 2 + + total_memories += conversation_memories + + # Update checkpoint after each conversation + if apply: + await self.update_checkpoint(checkpoint, user_id=user.id, conversation_id=str(conversation.id)) + except Exception as e: + import traceback + + self.stderr.write( + f"Error processing conversation {conversation.id} for user {user.id}: {e}\n" + f"Traceback: {traceback.format_exc()}" + ) + continue + + return total_memories + + async def get_or_create_checkpoint(self, resume: bool) -> dict: + """Get or create checkpoint for resumable processing""" + checkpoint_key = "memory_generation_checkpoint" + + if resume: + # Try to retrieve existing checkpoint + checkpoint_store = await DataStore.objects.filter(key=checkpoint_key, private=True).afirst() + if checkpoint_store: + self.stdout.write("Resuming from checkpoint...") + return checkpoint_store.value + + # Create new checkpoint + return {"started_at": timezone.now().isoformat(), "processed_users": [], "processed_conversations": {}} + + async def update_checkpoint( + self, checkpoint: dict, user_id: Optional[int] = None, conversation_id: Optional[str] = None + ): + """Update checkpoint with progress""" + if user_id and user_id not in checkpoint["processed_users"]: + checkpoint["processed_users"].append(user_id) + + if user_id and conversation_id: + if user_id not in checkpoint["processed_conversations"]: + checkpoint["processed_conversations"][user_id] = [] + if conversation_id not in checkpoint["processed_conversations"][user_id]: + checkpoint["processed_conversations"][user_id].append(conversation_id) + + # Save checkpoint to database + await DataStore.objects.aupdate_or_create( + key="memory_generation_checkpoint", defaults={"value": checkpoint, "private": True} + ) + + async def clear_checkpoint(self): + """Clear checkpoint after successful completion""" + await DataStore.objects.filter(key="memory_generation_checkpoint").adelete() + self.stdout.write("Checkpoint cleared") + + async def handle_delete_memories(self, usernames: Optional[str], cutoff_date: Optional[datetime], apply: bool): + """Handle deletion of user memories""" + from khoj.database.models import UserMemory + + # Get users to process + users = await self.get_users_to_process(usernames) + if not users: + self.stdout.write("No users found to process") + return + + mode = "APPLY" if apply else "DRY RUN" + if cutoff_date: + # Calculate days from cutoff date + days_back = (timezone.now() - cutoff_date).days + self.stdout.write(f"[{mode}] Deleting memories created in the last {days_back} days for {len(users)} users") + else: + self.stdout.write(f"[{mode}] Deleting ALL memories for {len(users)} users") + + total_deleted = 0 + for user in users: + # Count memories for this user + if cutoff_date is None: + user_memories = UserMemory.objects.filter(user=user) + else: + user_memories = UserMemory.objects.filter(user=user, created_at__gte=cutoff_date) + + memories_count = await user_memories.acount() + if memories_count == 0: + self.stdout.write(f" User {user.username} has no memories to delete") + continue + + self.stdout.write(f"\n User {user.username} (ID: {user.id}): {memories_count} memories") + + if apply: + # Delete memories for this user (with date filter if specified) + deleted_count, _ = await user_memories.adelete() + self.stdout.write(f" Deleted {deleted_count} memories") + total_deleted += deleted_count + else: + self.stdout.write(f" Would delete {memories_count} memories") + total_deleted += memories_count + + action = "Deleted" if apply else "Would delete" + self.stdout.write(self.style.SUCCESS(f"\n{action} {total_deleted} memories total"))