diff --git a/src/interface/web/app/components/userMemory/userMemory.tsx b/src/interface/web/app/components/userMemory/userMemory.tsx new file mode 100644 index 000000000..db360e7bc --- /dev/null +++ b/src/interface/web/app/components/userMemory/userMemory.tsx @@ -0,0 +1,90 @@ +import { useState } from "react"; +import { Input } from "@/components/ui/input"; +import { Button } from "@/components/ui/button"; +import { Pencil, TrashSimple, FloppyDisk, X } from "@phosphor-icons/react"; +import { useToast } from "@/components/ui/use-toast"; + +export interface UserMemorySchema { + id: number; + raw: string; + created_at: string; +} + +interface UserMemoryProps { + memory: UserMemorySchema; + onDelete: (id: number) => void; + onUpdate: (id: number, raw: string) => void; +} + +export function UserMemory({ memory, onDelete, onUpdate }: UserMemoryProps) { + const [isEditing, setIsEditing] = useState(false); + const [content, setContent] = useState(memory.raw); + const { toast } = useToast(); + + const handleUpdate = () => { + onUpdate(memory.id, content); + setIsEditing(false); + toast({ + title: "Memory Updated", + description: "Your memory has been successfully updated.", + }); + }; + + const handleDelete = () => { + onDelete(memory.id); + toast({ + title: "Memory Deleted", + description: "Your memory has been successfully deleted.", + }); + }; + + return ( +
+ {isEditing ? ( + <> + setContent(e.target.value)} + className="flex-1" + /> + + + + ) : ( + <> + + + + + )} +
+ ); +} diff --git a/src/interface/web/app/settings/page.tsx b/src/interface/web/app/settings/page.tsx index 92d13ecf5..fffd726fd 100644 --- a/src/interface/web/app/settings/page.tsx +++ b/src/interface/web/app/settings/page.tsx @@ -15,6 +15,7 @@ import { Button } from "@/components/ui/button"; import { InputOTP, InputOTPGroup, InputOTPSlot } from "@/components/ui/input-otp"; import { Input } from "@/components/ui/input"; import { Card, CardContent, CardFooter, CardHeader } from "@/components/ui/card"; + import { DropdownMenu, DropdownMenuContent, @@ -33,6 +34,15 @@ import { AlertDialogTitle, AlertDialogTrigger, } from "@/components/ui/alert-dialog"; + +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogTrigger +} from "@/components/ui/dialog"; + import { Table, TableBody, TableCell, TableRow } from "@/components/ui/table"; import { @@ -74,6 +84,7 @@ import Loading from "../components/loading/loading"; import IntlTelInput from "intl-tel-input/react"; import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar"; import { AppSidebar } from "../components/appSidebar/appSidebar"; +import { UserMemory, UserMemorySchema } from "../components/userMemory/userMemory"; import { Separator } from "@/components/ui/separator"; import { KhojLogoType } from "../components/logo/khojLogo"; import { Progress } from "@/components/ui/progress"; @@ -323,6 +334,7 @@ export default function SettingsView() { const [numberValidationState, setNumberValidationState] = useState( PhoneNumberValidationState.Verified, ); + const [memories, setMemories] = useState([]); const [isExporting, setIsExporting] = useState(false); const [exportProgress, setExportProgress] = useState(0); const [exportedConversations, setExportedConversations] = useState(0); @@ -666,6 +678,65 @@ export default function SettingsView() { } }; + const fetchMemories = async () => { + try { + console.log("Fetching memories..."); + const response = await fetch('/api/memories/'); + if (!response.ok) throw new Error('Failed to fetch memories'); + const data = await response.json(); + setMemories(data); + } catch (error) { + console.error('Error fetching memories:', error); + toast({ + title: "Error", + description: "Failed to fetch memories. Please try again.", + variant: "destructive" + }); + } + }; + + const handleDeleteMemory = async (id: number) => { + try { + const response = await fetch(`/api/memories/${id}`, { + method: 'DELETE' + }); + if (!response.ok) throw new Error('Failed to delete memory'); + setMemories(memories.filter(memory => memory.id !== id)); + } catch (error) { + console.error('Error deleting memory:', error); + toast({ + title: "Error", + description: "Failed to delete memory. Please try again.", + variant: "destructive" + }); + } + }; + + const handleUpdateMemory = async (id: number, raw: string) => { + try { + const response = await fetch(`/api/memories/${id}`, { + method: 'PUT', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ raw, memory_id: id }), + }); + if (!response.ok) throw new Error('Failed to update memory'); + const updatedMemory: UserMemorySchema = await response.json(); + setMemories(memories.map(memory => + memory.id === id ? updatedMemory : memory + )); + } catch (error) { + console.error('Error updating memory:', error); + toast({ + title: "Error", + description: "Failed to update memory. Please try again.", + variant: "destructive" + }); + } + }; + + const syncContent = async (type: string) => { try { const response = await fetch(`/api/content?t=${type}`, { @@ -1269,7 +1340,45 @@ export default function SettingsView() { - + + + + Memories + + +

+ View and manage your long-term memories +

+
+ + open && fetchMemories()}> + + + + + + Your Memories + +
+ {memories.map((memory) => ( + + ))} + {memories.length === 0 && ( +

No memories found

+ )} +
+
+
+
+
diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 40d1eeb5a..50b2dd313 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -323,6 +323,7 @@ def configure_routes(app): from khoj.routers.api_automation import api_automation from khoj.routers.api_chat import api_chat from khoj.routers.api_content import api_content + from khoj.routers.api_memories import api_memories from khoj.routers.api_model import api_model from khoj.routers.notion import notion_router from khoj.routers.web_client import web_client @@ -332,6 +333,7 @@ def configure_routes(app): app.include_router(api_agents, prefix="/api/agents") app.include_router(api_automation, prefix="/api/automation") app.include_router(api_model, prefix="/api/model") + app.include_router(api_memories, prefix="/api/memories") app.include_router(api_content, prefix="/api/content") app.include_router(notion_router, prefix="/api/notion") app.include_router(web_client) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index ea9c773f5..66a521f78 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -61,6 +61,7 @@ Subscription, TextToImageModelConfig, UserConversationConfig, + UserMemory, UserRequests, UserTextToImageModelConfig, UserVoiceModelConfig, @@ -584,6 +585,16 @@ def get_default_search_model() -> SearchModelConfig: return SearchModelConfig.objects.first() +async def aget_default_search_model() -> SearchModelConfig: + default_search_model = await SearchModelConfig.objects.filter(name="default").afirst() + + if default_search_model: + return default_search_model + elif await SearchModelConfig.objects.count() == 0: + await SearchModelConfig.objects.acreate() + return await SearchModelConfig.objects.afirst() + + def get_or_create_search_models(): search_models = SearchModelConfig.objects.all() if search_models.count() == 0: @@ -1510,12 +1521,17 @@ async def save_conversation( ): slug = user_message.strip()[:200] if user_message else None if conversation_id: - conversation = await Conversation.objects.filter( - user=user, client=client_application, id=conversation_id - ).afirst() + conversation = ( + await Conversation.objects.filter(user=user, client=client_application, id=conversation_id) + .prefetch_related("agent", "agent__chat_model") + .afirst() + ) else: conversation = ( - await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() + await Conversation.objects.filter(user=user, client=client_application) + .prefetch_related("agent", "agent__chat_model") + .order_by("-updated_at") + .afirst() ) existing_messages = conversation.messages if conversation else [] @@ -2128,3 +2144,94 @@ def delete_automation(user: KhojUser, automation_id: str): automation.remove() return automation_metadata + + +class UserMemoryAdapters: + @staticmethod + @require_valid_user + async def pull_memories(user: KhojUser, agent: Agent = None, limit=10, window=7) -> list[UserMemory]: + """ + Pulls memories from the database for a given user. Medium term memory. + """ + time_frame = datetime.now(timezone.utc) - timedelta(days=window) + default_agent = await AgentAdapters.aget_default_agent() + if agent and agent != default_agent: + memories = UserMemory.objects.filter(user=user, agent=agent, updated_at__gte=time_frame).order_by( + "-created_at" + )[:limit] + else: + memories = UserMemory.objects.filter(user=user, updated_at__gte=time_frame).order_by("-created_at")[:limit] + return await sync_to_async(list)(memories) + + @staticmethod + @require_valid_user + async def save_memory(user: KhojUser, memory: str, agent: Agent = None) -> UserMemory: + """ + Saves a memory to the database for a given user. + """ + embeddings_model = state.embeddings_model + model = await aget_default_search_model() + embeddings = await sync_to_async(embeddings_model[model.name].embed_query)(memory) + default_agent = await AgentAdapters.aget_default_agent() + if agent and agent != default_agent: + memory_instance = await UserMemory.objects.acreate( + user=user, embeddings=embeddings, raw=memory, search_model=model, agent=agent + ) + else: + memory_instance = await UserMemory.objects.acreate( + user=user, embeddings=embeddings, raw=memory, search_model=model + ) + + return memory_instance + + @staticmethod + @require_valid_user + async def search_memories(query: str, user: KhojUser, agent: Agent = None, limit: int = 10) -> list[UserMemory]: + """ + Searches for memories in the database for a given user. Long term memory. + """ + embeddings_model = state.embeddings_model + model = await aget_default_search_model() + max_distance = model.bi_encoder_confidence_threshold or math.inf + embedded_query = await sync_to_async(embeddings_model[model.name].embed_query)(query) + default_agent = await AgentAdapters.aget_default_agent() + + if agent and agent != default_agent: + relevant_memories = UserMemory.objects.filter(user=user, agent=agent) + else: + relevant_memories = UserMemory.objects.filter(user=user) + + relevant_memories = ( + relevant_memories.annotate(distance=CosineDistance("embeddings", embedded_query)) + .order_by("distance") + .filter(distance__lte=max_distance) + ) + + return await sync_to_async(list)(relevant_memories[:limit]) + + @staticmethod + @require_valid_user + async def delete_memory(user: KhojUser, memory_id: str) -> bool: + """ + Deletes a memory from the database for a given user. + """ + try: + memory = await UserMemory.objects.aget(user=user, id=memory_id) + await memory.adelete() + return True + except UserMemory.DoesNotExist: + return False + + @staticmethod + def to_dict(memories: List[UserMemory]) -> List[dict]: + """ + Converts a list of Memory objects to a list of dictionaries. + """ + return [ + { + "id": f"{memory.id}", + "raw": memory.raw, + "updated_at": memory.updated_at.astimezone(timezone.utc).isoformat(timespec="seconds"), + } + for memory in memories + ] diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 5400c89e4..1730608a5 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -33,6 +33,7 @@ Subscription, TextToImageModelConfig, UserConversationConfig, + UserMemory, UserRequests, UserVoiceModelConfig, VoiceModelOption, @@ -181,6 +182,7 @@ def get_email_login_url(self, request, queryset): admin.site.register(VoiceModelOption, unfold_admin.ModelAdmin) admin.site.register(UserRequests, unfold_admin.ModelAdmin) admin.site.register(RateLimitRecord, unfold_admin.ModelAdmin) +admin.site.register(UserMemory, unfold_admin.ModelAdmin) @admin.register(Agent) diff --git a/src/khoj/database/management/commands/manage_memories.py b/src/khoj/database/management/commands/manage_memories.py new file mode 100644 index 000000000..62e14c978 --- /dev/null +++ b/src/khoj/database/management/commands/manage_memories.py @@ -0,0 +1,384 @@ +import asyncio +from datetime import datetime, timedelta +from typing import List, Optional + +from django.core.management.base import BaseCommand +from django.db.models import Q +from django.utils import timezone + +from khoj.configure import initialize_server +from khoj.database.adapters import UserMemoryAdapters +from khoj.database.models import ( + Conversation, + DataStore, + KhojUser, +) +from khoj.routers.helpers import extract_facts_from_query + + +class Command(BaseCommand): + help = "Manage user memories - generate from conversations or delete existing memories" + + def add_arguments(self, parser): + parser.add_argument( + "--lookback-days", + type=int, + default=None, + help="Number of days to look back. For generation: defaults to 7 days. For deletion: if not specified, deletes ALL memories", + ) + parser.add_argument( + "--users", + type=str, + help="Process specific users (comma-separated usernames or emails)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=10, + help="Number of conversations to process in each batch (default: 10)", + ) + parser.add_argument( + "--apply", + action="store_true", + help="Actually perform the operation. Without this flag, only shows what would be processed.", + ) + parser.add_argument( + "--delete", + action="store_true", + help="Delete all memories for specified users instead of generating new ones", + ) + parser.add_argument( + "--resume", + action="store_true", + help="Resume from last checkpoint if process was interrupted", + ) + parser.add_argument( + "--force", + action="store_true", + help="Force regenerate memories even if already processed", + ) + + def handle(self, *args, **options): + """Main entry point for the command""" + initialize_server() + asyncio.run(self.async_handle(*args, **options)) + + async def async_handle(self, *args, **options): + """Async handler for memory management""" + lookback_days = options["lookback_days"] + usernames = options["users"] + batch_size = options["batch_size"] + apply = options["apply"] + delete = options["delete"] + resume = options["resume"] + force = options["force"] + + mode = "APPLY" if apply else "DRY RUN" + + # Handle deletion mode + if delete: + # For deletion, only use cutoff_date if lookback_days is explicitly provided + cutoff_date = timezone.now() - timedelta(days=lookback_days) if lookback_days else None + await self.handle_delete_memories(usernames, cutoff_date, apply) + return + + # Handle generation mode + # For generation, default to 7 days if not specified + if lookback_days is None: + lookback_days = 7 + cutoff_date = timezone.now() - timedelta(days=lookback_days) + self.stdout.write(f"[{mode}] Generating memories for conversations from the last {lookback_days} days") + + # Get users to process + users = await self.get_users_to_process(usernames) + if not users: + self.stdout.write("No users found to process") + return + + self.stdout.write(f"Found {len(users)} users to process") + + # Initialize or retrieve checkpoint + checkpoint = await self.get_or_create_checkpoint(resume) + + total_conversations = 0 + total_memories = 0 + + for user in users: + # Check if user already processed in checkpoint + if not force and user.id in checkpoint.get("processed_users", []): + self.stdout.write(f"Skipping already processed user: {user.username}") + continue + + self.stdout.write(f"\nProcessing user: {user.username} (ID: {user.id})") + + # Get conversations for this user + conversations = await self.get_user_conversations(user, cutoff_date, checkpoint, force) + + if not conversations: + self.stdout.write(f" No conversations to process for {user.username}") + # Mark user as processed + if apply: + await self.update_checkpoint(checkpoint, user_id=user.id) + continue + + self.stdout.write(f" Found {len(conversations)} conversations to process") + + # Process conversations in batches + user_memories = 0 + for i in range(0, len(conversations), batch_size): + batch = conversations[i : i + batch_size] + batch_memories = await self.process_conversation_batch(user, batch, apply, checkpoint) + user_memories += batch_memories + total_conversations += len(batch) + + # Update progress + progress = min(i + batch_size, len(conversations)) + self.stdout.write( + f" Processed {progress}/{len(conversations)} conversations, generated {batch_memories} memories" + ) + + total_memories += user_memories + self.stdout.write( + f" Completed user {user.username}: " + f"processed {len(conversations)} conversations, " + f"generated {user_memories} memories" + ) + + # Mark user as processed + if apply: + await self.update_checkpoint(checkpoint, user_id=user.id) + + # Clear checkpoint on successful completion + if apply: + await self.clear_checkpoint() + + action = "Generated" if apply else "Would generate" + self.stdout.write( + self.style.SUCCESS(f"\n{action} {total_memories} memories from {total_conversations} conversations") + ) + + async def get_users_to_process(self, users_str: Optional[str]) -> List[KhojUser]: + """Get list of users to comma separated usernames or emails to process""" + if users_str: + usernames = [u.strip() for u in users_str.split(",") if u.strip()] + # Process specific users + users = [user async for user in KhojUser.objects.filter(Q(username__in=usernames) | Q(email__in=usernames))] + return users + else: + # Process all users with conversations + return [user async for user in KhojUser.objects.filter(conversation__isnull=False).distinct()] + + async def get_user_conversations( + self, user: KhojUser, cutoff_date: Optional[datetime], checkpoint: dict, force: bool + ) -> List[Conversation]: + """Get conversations for a user that need processing""" + if cutoff_date is None: + query = Conversation.objects.filter(user=user).order_by("updated_at") + else: + query = Conversation.objects.filter(user=user, updated_at__gte=cutoff_date).order_by("updated_at") + + # Filter out already processed conversations if resuming + if not force and user.id in checkpoint.get("processed_conversations", {}): + processed_ids = checkpoint["processed_conversations"][user.id] + query = query.exclude(id__in=processed_ids) + + return [conv async for conv in query] + + async def process_conversation_batch( + self, user: KhojUser, conversations: List[Conversation], apply: bool, checkpoint: dict + ) -> int: + """Process a batch of conversations and generate memories""" + total_memories = 0 + + for conversation in conversations: + try: + # Get conversation messages using sync_to_async for property access + from asgiref.sync import sync_to_async + + # Access conversation_log synchronously + @sync_to_async + def get_messages(): + return conversation.messages + + messages = await get_messages() + if not messages: + continue + + # Get agent if conversation has one + @sync_to_async + def get_agent(): + return conversation.agent + + agent = await get_agent() + + # Get existing memories for context + # Process each conversation turn + conversation_memories = 0 + i = 0 + while i + 1 < len(messages): + # Only process user-assistant pairs as a valid turn for memory extraction + if messages[i].by != "you" or messages[i + 1].by != "khoj": + i += 1 + continue + + # Get the conversation history up to this point + history = messages[: i + 2] + + # Extract user query text for memory search + q = "" + if messages[i].message is None: + i += 1 + continue + elif isinstance(messages[i].message, str): + q = messages[i].message + elif isinstance(messages[i].message, list): + q = "\n\n".join( + content.get("text", "") + for content in messages[i].message + if isinstance(content, dict) and content.get("text") + ) + + if not q or not q.strip(): + i += 1 + continue + + # Get unique recent and long term relevant memories + recent_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent) + long_term_memories = await UserMemoryAdapters.search_memories(query=q, user=user, agent=agent) + relevant_memories = list({m.id: m for m in recent_memories + long_term_memories}.values()) + + if apply: + # Ensure agent is fully loaded with its chat_model + if agent: + + @sync_to_async + def load_agent_with_chat_model(): + # Force load the chat_model relationship + _ = agent.chat_model + return agent + + agent = await load_agent_with_chat_model() + + # Update memories based on latest conversation turn + memory_updates = await extract_facts_from_query( + user=user, + conversation_history=history, + existing_facts=relevant_memories, + agent=agent, + tracer={}, + ) + + # Save new memories + for memory in memory_updates.create: + await UserMemoryAdapters.save_memory(user, memory, agent=agent) + conversation_memories += 1 + self.stdout.write(f"Created memory for user {user.id}: {memory[:50]}...") + + # Delete outdated memories + for memory in memory_updates.delete: + await UserMemoryAdapters.delete_memory(user, memory) + self.stdout.write(f"Deleted memory for user {user.id}: {memory[:50]}...") + else: + # Dry run - estimate memories that would be created + conversation_memories += 1 # Rough estimate + + # Move to next conversation turn pair + i += 2 + + total_memories += conversation_memories + + # Update checkpoint after each conversation + if apply: + await self.update_checkpoint(checkpoint, user_id=user.id, conversation_id=str(conversation.id)) + except Exception as e: + import traceback + + self.stderr.write( + f"Error processing conversation {conversation.id} for user {user.id}: {e}\n" + f"Traceback: {traceback.format_exc()}" + ) + continue + + return total_memories + + async def get_or_create_checkpoint(self, resume: bool) -> dict: + """Get or create checkpoint for resumable processing""" + checkpoint_key = "memory_generation_checkpoint" + + if resume: + # Try to retrieve existing checkpoint + checkpoint_store = await DataStore.objects.filter(key=checkpoint_key, private=True).afirst() + if checkpoint_store: + self.stdout.write("Resuming from checkpoint...") + return checkpoint_store.value + + # Create new checkpoint + return {"started_at": timezone.now().isoformat(), "processed_users": [], "processed_conversations": {}} + + async def update_checkpoint( + self, checkpoint: dict, user_id: Optional[int] = None, conversation_id: Optional[str] = None + ): + """Update checkpoint with progress""" + if user_id and user_id not in checkpoint["processed_users"]: + checkpoint["processed_users"].append(user_id) + + if user_id and conversation_id: + if user_id not in checkpoint["processed_conversations"]: + checkpoint["processed_conversations"][user_id] = [] + if conversation_id not in checkpoint["processed_conversations"][user_id]: + checkpoint["processed_conversations"][user_id].append(conversation_id) + + # Save checkpoint to database + await DataStore.objects.aupdate_or_create( + key="memory_generation_checkpoint", defaults={"value": checkpoint, "private": True} + ) + + async def clear_checkpoint(self): + """Clear checkpoint after successful completion""" + await DataStore.objects.filter(key="memory_generation_checkpoint").adelete() + self.stdout.write("Checkpoint cleared") + + async def handle_delete_memories(self, usernames: Optional[str], cutoff_date: Optional[datetime], apply: bool): + """Handle deletion of user memories""" + from khoj.database.models import UserMemory + + # Get users to process + users = await self.get_users_to_process(usernames) + if not users: + self.stdout.write("No users found to process") + return + + mode = "APPLY" if apply else "DRY RUN" + if cutoff_date: + # Calculate days from cutoff date + days_back = (timezone.now() - cutoff_date).days + self.stdout.write(f"[{mode}] Deleting memories created in the last {days_back} days for {len(users)} users") + else: + self.stdout.write(f"[{mode}] Deleting ALL memories for {len(users)} users") + + total_deleted = 0 + for user in users: + # Count memories for this user + if cutoff_date is None: + user_memories = UserMemory.objects.filter(user=user) + else: + user_memories = UserMemory.objects.filter(user=user, created_at__gte=cutoff_date) + + memories_count = await user_memories.acount() + if memories_count == 0: + self.stdout.write(f" User {user.username} has no memories to delete") + continue + + self.stdout.write(f"\n User {user.username} (ID: {user.id}): {memories_count} memories") + + if apply: + # Delete memories for this user (with date filter if specified) + deleted_count, _ = await user_memories.adelete() + self.stdout.write(f" Deleted {deleted_count} memories") + total_deleted += deleted_count + else: + self.stdout.write(f" Would delete {memories_count} memories") + total_deleted += memories_count + + action = "Deleted" if apply else "Would delete" + self.stdout.write(self.style.SUCCESS(f"\n{action} {total_deleted} memories total")) diff --git a/src/khoj/database/migrations/0095_usermemory.py b/src/khoj/database/migrations/0095_usermemory.py new file mode 100644 index 000000000..5fdba2eb4 --- /dev/null +++ b/src/khoj/database/migrations/0095_usermemory.py @@ -0,0 +1,63 @@ +# Generated by Django 5.1.10 on 2025-08-29 22:57 + +import django.db.models.deletion +import pgvector.django +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0094_serverchatsettings_think_free_deep_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="UserMemory", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("embeddings", pgvector.django.VectorField()), + ("raw", models.TextField()), + ( + "agent", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="database.agent", + ), + ), + ( + "search_model", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to="database.searchmodelconfig", + ), + ), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 90ed67a83..7d06c6abd 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -804,3 +804,15 @@ class DataStore(DbBaseModel): value = models.JSONField(default=dict) private = models.BooleanField(default=False) owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True) + + +class UserMemory(DbBaseModel): + """ + Long term memory store derived from conversation between user and agent. + """ + + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) + agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True) + embeddings = VectorField(dimensions=None) + raw = models.TextField() + search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 29ea56eca..f84f37c94 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -1303,3 +1303,70 @@ User's Name: {name} """.strip() ) + +extract_facts_from_query = PromptTemplate.from_template( + """ +You are Muninn, the user's memory manager. Construct and maintain an accurate, up-to-date set of facts about and on behalf of the user. +This can include who the user is, their interests, their life circumstances, events in their life, their personal motivations and any facts that the user explicitly asks you to remember. + +You are given the latest chat session and some previously stored facts about the user. You can take two kinds of action: +1. Create new facts +2. Delete existing facts + +You should delete existing facts that are no longer true. +You can enhance new facts with information from existing facts. +You cannot update existing facts directly, instead create new facts and delete related existing ones to update them. + +Your output should be a JSON object with two lists: create and delete. +- The create list should contain important, new facts *related to the user* to be added. Each fact should be atomic, self-contained and written in the user's first person perspective. +- The delete list should contain IDs of existing facts to be deleted. You must delete all facts that are no longer relevant or true. +- Leave the create or delete list empty if you have nothing important to add or remove. + +# Example +Existing Facts: +[ + {{ + "id": "5283", + "raw": "I am not interested in sports", + "updated_at": "2023-10-01T12:00:00+00:00" + }}, + {{ + "id": "22", + "raw": "I am a software engineer", + "updated_at": "2023-10-31T14:00:00+00:00" + }}, + {{ + "id": "651", + "raw": "My mother works at the hospital", + "updated_at": "2023-10-02T17:00:00+00:00" + }} +] + +Latest Chat Session: +- User: I had an amazing day today! I was replicating this core AI paper, but ran into some issues with the training pipeline. +In between coding, I took my cat Whiskers out for a walk and played a game of football. +My mom called me in between her shift at the hospital (she's a doctor), so we had a nice chat. +- AI: That's great to hear! + +Response: +{{ + "create": [ + "I am interested in AI and machine learning", + "I have a pet cat named Whiskers", + "I enjoy playing football", + "My mother works at the hospital and is a doctor" + ], + "delete": [ + "5283", + "651" + ], +}} + +# Input +Existing Facts: +{matched_facts} + +Latest Chat Session: +{chat_history} +""".strip() +) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 46a03b2c8..8419f3609 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -26,6 +26,7 @@ ClientApplication, Intent, KhojUser, + UserMemory, ) from khoj.processor.conversation import prompts from khoj.search_filter.base_filter import BaseFilter @@ -446,6 +447,7 @@ async def save_to_conversation_log( client_application: ClientApplication = None, conversation_id: str = None, automation_id: str = None, + relevant_memories: List[UserMemory] = [], query_images: List[str] = None, raw_query_files: List[FileAttachment] = [], generated_images: List[str] = [], @@ -454,6 +456,8 @@ async def save_to_conversation_log( train_of_thought: List[Any] = [], tracer: Dict[str, Any] = {}, ): + from khoj.routers.helpers import ai_update_memories + user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") turn_id = tracer.get("mid") or str(uuid.uuid4()) @@ -500,6 +504,16 @@ async def save_to_conversation_log( user_message=q, ) + if not automation_id: + # Don't update memories from automations, as this could get noisy. + await ai_update_memories( + user=user, + conversation_history=new_messages or [], + memories=relevant_memories, + agent=db_conversation.agent if db_conversation else None, + tracer=tracer, + ) + if is_promptrace_enabled(): merge_message_into_conversation_trace(q, chat_response, tracer) @@ -561,6 +575,7 @@ def generate_chatml_messages_with_context( query_files: str = None, query_images=None, context_message="", + relevant_memories: List[UserMemory] = None, generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = [], chat_history: list[ChatMessageModel] = [], @@ -701,6 +716,14 @@ def generate_chatml_messages_with_context( ), ) + if not is_none_or_empty(relevant_memories): + memory_context = "Your memory system retrieved the following memories about me based on our previous conversations. Ignore them if they are not relevant to the query.\n\n" + for memory in relevant_memories: + friendly_dt = memory.created_at.strftime("%Y-%m-%d %H:%M:%S") + memory_context += f"- [{friendly_dt}]: {memory.raw}\n" + memory_context += "" + messages.append(ChatMessage(content=memory_context, role="user")) + if not is_none_or_empty(user_message): messages.append( ChatMessage( diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 7ff838299..ab87a7af7 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -25,6 +25,7 @@ Intent, KhojUser, TextToImageModelConfig, + UserMemory, ) from khoj.processor.conversation.google.utils import _is_retryable_error from khoj.routers.helpers import ChatEvent, ImageShape, generate_better_image_prompt @@ -47,6 +48,7 @@ async def text_to_image( query_images: Optional[List[str]] = None, agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): status_code = 200 @@ -91,6 +93,7 @@ async def text_to_image( user=user, agent=agent, query_files=query_files, + relevant_memories=relevant_memories, tracer=tracer, ) image_prompt = image_prompt_response["description"] diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py index 0aa12ca4a..ccd780aa4 100644 --- a/src/khoj/processor/operator/__init__.py +++ b/src/khoj/processor/operator/__init__.py @@ -5,7 +5,7 @@ from typing import Callable, List, Optional from khoj.database.adapters import AgentAdapters, ConversationAdapters -from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser, UserMemory from khoj.processor.conversation.utils import ( AgentMessage, OperatorRun, @@ -42,6 +42,7 @@ async def operate_environment( query_images: Optional[List[str]] = None, # TODO: Handle query images agent: Agent = None, query_files: str = None, # TODO: Handle query files + relevant_memories: Optional[List[UserMemory]] = None, # TODO: Handle relevant memories cancellation_event: Optional[asyncio.Event] = None, interrupt_queue: Optional[asyncio.Queue] = None, abort_message: Optional[str] = ChatEvent.END_EVENT.value, diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 130bff302..72c8a7358 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -15,6 +15,7 @@ ChatMessageModel, KhojUser, ServerChatSettings, + UserMemory, WebScraper, ) from khoj.processor.conversation import prompts @@ -77,6 +78,7 @@ async def search_online( max_webpages_to_read: int = 1, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, previous_subqueries: Set = set(), agent: Agent = None, tracer: dict = {}, @@ -95,6 +97,7 @@ async def search_online( user, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, max_queries=max_online_searches, agent=agent, tracer=tracer, @@ -172,7 +175,13 @@ async def search_online( yield {ChatEvent.STATUS: event} tasks = [ read_webpage_and_extract_content( - data["queries"], link, data.get("content"), user=user, agent=agent, tracer=tracer + data["queries"], + link, + data.get("content"), + user=user, + agent=agent, + relevant_memories=relevant_memories, + tracer=tracer, ) for link, data in webpages.items() ] @@ -382,6 +391,7 @@ async def read_webpages( agent: Agent = None, max_webpages_to_read: int = 1, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): "Infer web pages to read from the query and extract relevant information from them" @@ -395,6 +405,7 @@ async def read_webpages( query_images, agent=agent, query_files=query_files, + relevant_memories=relevant_memories, tracer=tracer, ) async for result in read_webpages_content( @@ -414,6 +425,7 @@ async def read_webpages_content( user: KhojUser, send_status_func: Optional[Callable] = None, agent: Agent = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): logger.info(f"Reading web pages at: {urls}") @@ -421,7 +433,14 @@ async def read_webpages_content( webpage_links_str = "\n- " + "\n- ".join(list(urls)) async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} - tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls] + + tasks = [ + read_webpage_and_extract_content( + {query}, url, user=user, agent=agent, relevant_memories=relevant_memories, tracer=tracer + ) + for url in urls + ] + results = await asyncio.gather(*tasks) response: Dict[str, Dict] = defaultdict(dict) @@ -452,6 +471,7 @@ async def read_webpage_and_extract_content( content: str = None, user: KhojUser = None, agent: Agent = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> Tuple[set[str], str, Union[None, str]]: # Select the web scrapers to use for reading the web page @@ -475,7 +495,7 @@ async def read_webpage_and_extract_content( if is_none_or_empty(extracted_info): with timer(f"Extracting relevant information from web page at '{url}' took", logger): extracted_info = await extract_relevant_info( - subqueries, content, user=user, agent=agent, tracer=tracer + subqueries, content, user=user, agent=agent, relevant_memories=relevant_memories, tracer=tracer ) # If we successfully extracted information, break the loop diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 26edad27d..e98d5871b 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -20,7 +20,7 @@ ) from khoj.database.adapters import AgentAdapters, FileObjectAdapters -from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser +from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser, UserMemory from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( ChatEvent, @@ -59,6 +59,7 @@ async def run_code( agent: Agent = None, sandbox_url: str = SANDBOX_URL, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): # Generate Code @@ -77,6 +78,7 @@ async def run_code( agent, tracer, query_files, + relevant_memories, ) except Exception as e: raise ValueError(f"Failed to generate code for {instructions} with error: {e}") @@ -124,6 +126,7 @@ async def generate_python_code( agent: Agent = None, tracer: dict = {}, query_files: str = None, + relevant_memories: List[UserMemory] = None, ) -> GeneratedCode: location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" @@ -158,6 +161,7 @@ async def generate_python_code( code_generation_prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, fast_model=False, agent_chat_model=agent_chat_model, user=user, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 68e7ea433..7e930dfc0 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -29,6 +29,7 @@ ConversationAdapters, EntryAdapters, PublicConversationAdapters, + UserMemoryAdapters, aget_user_name, ) from khoj.database.models import Agent, KhojUser @@ -101,7 +102,6 @@ trial_rate_limit=20, subscribed_rate_limit=75, slug="command" ) - api_chat = APIRouter() @@ -963,6 +963,13 @@ def collect_telemetry(): location = LocationData(city=city, region=region, country=country, country_code=country_code) chat_history = conversation.messages + # Get most recent memories and long term relevant memories + recent_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent) + long_term_memories = await UserMemoryAdapters.search_memories(query=q, user=user, agent=agent) + + # Create a de-duped set of memories + relevant_memories = list({m.id: m for m in recent_memories + long_term_memories}.values()) + # If interrupted message in DB if last_message := await conversation.pop_message(interrupted=True): # Populate context from interrupted message @@ -987,6 +994,7 @@ def collect_telemetry(): query_images=uploaded_images, agent=agent, query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ) except ValueError as e: @@ -1027,6 +1035,7 @@ def collect_telemetry(): send_status_func=partial(send_event, ChatEvent.STATUS), user_name=user_name, location=location, + relevant_memories=relevant_memories, query_files=attached_file_context, cancellation_event=cancellation_event, interrupt_queue=child_interrupt_queue, @@ -1082,6 +1091,7 @@ def collect_telemetry(): query_images=uploaded_images, agent=agent, query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1130,6 +1140,7 @@ def collect_telemetry(): max_online_searches=3, query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, agent=agent, tracer=tracer, ): @@ -1158,6 +1169,7 @@ def collect_telemetry(): query_images=uploaded_images, agent=agent, query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1199,6 +1211,7 @@ def collect_telemetry(): query_images=uploaded_images, agent=agent, query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1223,6 +1236,7 @@ def collect_telemetry(): list(operator_results)[-1] if operator_results else None, query_images=uploaded_images, query_files=attached_file_context, + relevant_memories=relevant_memories, send_status_func=partial(send_event, ChatEvent.STATUS), agent=agent, cancellation_event=cancellation_event, @@ -1277,6 +1291,7 @@ def collect_telemetry(): query_images=uploaded_images, agent=agent, query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1320,6 +1335,7 @@ def collect_telemetry(): agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), query_files=attached_file_context, + relevant_memories=relevant_memories, tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -1375,6 +1391,7 @@ def collect_telemetry(): user_name, uploaded_images, attached_file_context, + relevant_memories, program_execution_context, generated_asset_results, is_subscribed, @@ -1434,6 +1451,7 @@ def collect_telemetry(): query_images=uploaded_images, train_of_thought=train_of_thought, raw_query_files=raw_query_files, + relevant_memories=relevant_memories, generated_images=generated_images, generated_mermaidjs_diagram=generated_mermaidjs_diagram, tracer=tracer, diff --git a/src/khoj/routers/api_memories.py b/src/khoj/routers/api_memories.py new file mode 100644 index 000000000..d75278279 --- /dev/null +++ b/src/khoj/routers/api_memories.py @@ -0,0 +1,114 @@ +import json +import logging +from typing import Optional + +from asgiref.sync import sync_to_async +from fastapi import APIRouter, Request +from fastapi.responses import Response +from pydantic import BaseModel +from starlette.authentication import requires + +from khoj.database.adapters import UserMemoryAdapters +from khoj.database.models import UserMemory + +api_memories = APIRouter() +logger = logging.getLogger(__name__) + + +@api_memories.get("") +@requires(["authenticated"]) +async def get_memories( + request: Request, + client: Optional[str] = None, +): + """Get all memories for the authenticated user""" + user = request.user.object + + memories = UserMemory.objects.filter(user=user) + all_memories = await sync_to_async(list)(memories) + + # Convert memories to a list of dictionaries + formatted_memories = [ + { + "id": memory.id, + "raw": memory.raw, + "created_at": memory.created_at.isoformat(), + } + for memory in all_memories + ] + + return Response(content=json.dumps(formatted_memories), media_type="application/json", status_code=200) + + +@api_memories.delete("/{memory_id}") +@requires(["authenticated"]) +async def delete_memory( + request: Request, + memory_id: int, + client: Optional[str] = None, +): + """Delete a specific memory by ID""" + user = request.user.object + + # Verify memory belongs to user before deleting + memory = await UserMemory.objects.filter(id=memory_id, user=user).afirst() + if not memory: + return Response( + content=json.dumps({"error": "Memory not found"}), media_type="application/json", status_code=404 + ) + + await memory.adelete() + + return Response(status_code=204) + + +class UpdateMemoryBody(BaseModel): + """Request model for updating a memory""" + + raw: str + + +@api_memories.put("/{memory_id}") +@requires(["authenticated"]) +async def update_memory( + request: Request, + body: UpdateMemoryBody, + memory_id: int, + client: Optional[str] = None, +): + """Update a specific memory's content""" + user = request.user.object + + # Get the memory and verify it belongs to the user + memory = await UserMemory.objects.filter(id=memory_id, user=user).afirst() + if not memory: + return Response( + content=json.dumps({"error": "Memory not found"}), media_type="application/json", status_code=404 + ) + + new_content = body.raw + if not new_content: + return Response( + content=json.dumps({"error": "Missing required field 'raw'"}), + media_type="application/json", + status_code=400, + ) + + await memory.adelete() + + # Create a new memory with the updated content + new_memory = await UserMemoryAdapters.save_memory( + user=user, + memory=new_content, + ) + + return Response( + content=json.dumps( + { + "id": new_memory.id, + "raw": new_memory.raw, + } + ), + media_type="application/json", + status_code=200, + ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 6c3eac0fb..018913b45 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -46,6 +46,7 @@ ConversationAdapters, EntryAdapters, FileObjectAdapters, + UserMemoryAdapters, aget_user_by_email, create_khoj_token, get_default_search_model, @@ -53,6 +54,7 @@ get_user_name, get_user_notion_config, get_user_subscription_state, + require_valid_user, run_with_process_lock, ) from khoj.database.models import ( @@ -68,6 +70,7 @@ RateLimitRecord, Subscription, TextToImageModelConfig, + UserMemory, UserRequests, ) from khoj.processor.content.docx.docx_to_entries import DocxToEntries @@ -349,6 +352,7 @@ async def aget_data_sources_and_output_format( query_images: List[str] = None, agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> Dict[str, Any]: """ @@ -406,6 +410,7 @@ class PickTools(BaseModel): relevant_tools_prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, response_type="json_object", response_schema=PickTools, fast_model=False, @@ -462,6 +467,7 @@ async def infer_webpage_urls( query_images: List[str] = None, agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> List[str]: """ @@ -496,6 +502,7 @@ class WebpageUrls(BaseModel): online_queries_prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, response_type="json_object", response_schema=WebpageUrls, fast_model=False, @@ -526,6 +533,7 @@ async def generate_online_subqueries( user: KhojUser, query_images: List[str] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, max_queries: int = 3, agent: Agent = None, tracer: dict = {}, @@ -562,6 +570,7 @@ class OnlineQueries(BaseModel): online_queries_prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, response_type="json_object", response_schema=OnlineQueries, fast_model=False, @@ -648,7 +657,12 @@ async def aschedule_query( async def extract_relevant_info( - qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {} + qs: set[str], + corpus: str, + user: KhojUser = None, + agent: Agent = None, + relevant_memories: List[UserMemory] = None, + tracer: dict = {}, ) -> Union[str, None]: """ Extract relevant information for a given query from the target corpus @@ -672,6 +686,7 @@ async def extract_relevant_info( response = await send_message_to_model_wrapper( extract_relevant_information, system_message=prompts.system_prompt_extract_relevant_information, + relevant_memories=relevant_memories, fast_model=True, agent_chat_model=agent_chat_model, user=user, @@ -794,6 +809,7 @@ async def generate_excalidraw_diagram( agent: Agent = None, send_status_func: Optional[Callable] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): if send_status_func: @@ -810,6 +826,7 @@ async def generate_excalidraw_diagram( user=user, agent=agent, query_files=query_files, + relevant_memories=relevant_memories, tracer=tracer, ) @@ -845,6 +862,7 @@ async def generate_better_diagram_description( user: KhojUser = None, agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> str: """ @@ -888,6 +906,7 @@ async def generate_better_diagram_description( improve_diagram_description_prompt, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, fast_model=False, agent_chat_model=agent_chat_model, user=user, @@ -945,6 +964,84 @@ async def generate_excalidraw_diagram_from_description( return response +class MemoryUpdates(BaseModel): + """Facts to add or remove from memory.""" + + create: List[str] = Field(..., min_items=0, description="List of facts to add to memory.") + delete: List[str] = Field(..., min_items=0, description="List of facts to remove from memory.") + + +async def extract_facts_from_query( + user: KhojUser, + conversation_history: List[ChatMessageModel], + existing_facts: List[UserMemory] = None, + agent: Agent = None, + tracer: dict = {}, +) -> MemoryUpdates: + """ + Extract facts from the given query + """ + chat_history = construct_chat_history(conversation_history, n=2) + + formatted_memories = json.dumps(UserMemoryAdapters.to_dict(existing_facts), indent=2) if existing_facts else [] + + extract_facts_prompt = prompts.extract_facts_from_query.format( + chat_history=chat_history, + matched_facts=formatted_memories, + ) + + with timer("Chat actor: Extract facts from query", logger): + response = await send_message_to_model_wrapper( + extract_facts_prompt, + response_schema=MemoryUpdates, + user=user, + fast_model=False, + agent_chat_model=agent.chat_model, + tracer=tracer, + ) + response = response.text.strip() + # JSON parse the list of strings + try: + response = clean_json(response) + response = json.loads(response) + parsed_response = MemoryUpdates(**response) + if not isinstance(parsed_response, MemoryUpdates): + raise ValueError(f"Invalid response for extracting facts: {response}") + return parsed_response + + except Exception: + logger.error(f"Invalid response for extracting facts: {response}") + return MemoryUpdates(create=[], delete=[]) + + +@require_valid_user +async def ai_update_memories( + user: KhojUser, + conversation_history: List[ChatMessageModel], + memories: List[UserMemory], + agent: Agent, + tracer: dict = {}, +): + """ + Updates the memories for a given user, based on their latest input query. + """ + memory_update = await extract_facts_from_query( + user=user, conversation_history=conversation_history, existing_facts=memories, agent=agent, tracer=tracer + ) + + if not memory_update: + return + + # Save the memory updates to the database + for memory in memory_update.create: + logger.info(f"Creating memory: {memory}") + await UserMemoryAdapters.save_memory(user, memory, agent=agent) + + for memory in memory_update.delete: + logger.info(f"Deleting memory: {memory}") + await UserMemoryAdapters.delete_memory(user, memory) + + async def generate_mermaidjs_diagram( q: str, chat_history: List[ChatMessageModel], @@ -956,6 +1053,7 @@ async def generate_mermaidjs_diagram( agent: Agent = None, send_status_func: Optional[Callable] = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): if send_status_func: @@ -972,6 +1070,7 @@ async def generate_mermaidjs_diagram( user=user, agent=agent, query_files=query_files, + relevant_memories=relevant_memories, tracer=tracer, ) @@ -1001,6 +1100,7 @@ async def generate_better_mermaidjs_diagram_description( user: KhojUser = None, agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> str: """ @@ -1044,6 +1144,7 @@ async def generate_better_mermaidjs_diagram_description( improve_diagram_description_prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, fast_model=False, agent_chat_model=agent_chat_model, user=user, @@ -1095,6 +1196,7 @@ async def generate_better_image_prompt( user: KhojUser = None, agent: Agent = None, query_files: str = "", + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ) -> dict: """ @@ -1137,6 +1239,7 @@ class ImagePromptResponse(BaseModel): q, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, chat_history=conversation_history, system_message=enhance_image_system_message, response_type="json_object", @@ -1171,6 +1274,7 @@ async def search_documents( previous_inferred_queries: Set = set(), agent: Agent = None, query_files: str = None, + relevant_memories: List[UserMemory] = None, tracer: dict = {}, ): # Initialize Variables @@ -1218,6 +1322,7 @@ async def search_documents( user=user, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, personality_context=personality_context, location_data=location_data, chat_history=chat_history, @@ -1270,6 +1375,7 @@ async def extract_questions( query_files: str = None, query_images: Optional[List[str]] = None, personality_context: str = "", + relevant_memories: List[UserMemory] = None, location_data: LocationData = None, chat_history: List[ChatMessageModel] = [], max_queries: int = 5, @@ -1324,6 +1430,7 @@ class DocumentQueries(BaseModel): query=prompt, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, system_message=system_prompt, response_type="json_object", response_schema=DocumentQueries, @@ -1446,6 +1553,7 @@ async def send_message_to_model_wrapper( query_files: str = None, query_images: List[str] = None, context: str = "", + relevant_memories: List[UserMemory] = None, chat_history: list[ChatMessageModel] = [], system_message: str = "", # Model Config @@ -1484,6 +1592,7 @@ async def send_message_to_model_wrapper( query_files=query_files, query_images=query_images, context_message=context, + relevant_memories=relevant_memories, chat_history=chat_history, system_message=system_message, model_name=chat_model_name, @@ -1612,6 +1721,7 @@ def build_conversation_context( operator_results: List[OperatorRun], query_files: str = None, query_images: Optional[List[str]] = None, + relevant_memories: List[UserMemory] = None, generated_asset_results: Dict[str, Dict] = {}, program_execution_context: List[str] = None, chat_history: List[ChatMessageModel] = [], @@ -1688,6 +1798,7 @@ def build_conversation_context( query_files=query_files, query_images=query_images, context_message=context_message, + relevant_memories=relevant_memories, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, chat_history=chat_history, @@ -1716,6 +1827,7 @@ async def agenerate_chat_response( user_name: Optional[str] = None, query_images: Optional[List[str]] = None, query_files: str = None, + relevant_memories: List[UserMemory] = [], program_execution_context: List[str] = [], generated_asset_results: Dict[str, Dict] = {}, is_subscribed: bool = False, @@ -1758,6 +1870,7 @@ async def agenerate_chat_response( operator_results=operator_results, query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, generated_asset_results=generated_asset_results, program_execution_context=program_execution_context, chat_history=chat_history, diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 43abaaf0c..03302e32c 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -8,7 +8,7 @@ import yaml from khoj.database.adapters import AgentAdapters, EntryAdapters -from khoj.database.models import Agent, ChatMessageModel, KhojUser +from khoj.database.models import Agent, ChatMessageModel, KhojUser, UserMemory from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( OperatorRun, @@ -56,6 +56,7 @@ async def apick_next_tool( max_iterations: int = 5, query_images: List[str] = [], query_files: str = None, + relevant_memories: List[UserMemory] = [], max_document_searches: int = 7, max_online_searches: int = 3, max_webpages_to_read: int = 3, @@ -161,6 +162,7 @@ async def apick_next_tool( query="", query_files=query_files, query_images=query_images, + relevant_memories=relevant_memories, system_message=function_planning_prompt, chat_history=chat_and_research_history, tools=tools, @@ -218,6 +220,7 @@ async def research( send_status_func: Optional[Callable] = None, user_name: str = None, location: LocationData = None, + relevant_memories: List[UserMemory] = [], query_files: str = None, cancellation_event: Optional[asyncio.Event] = None, interrupt_queue: Optional[asyncio.Queue] = None, @@ -280,6 +283,7 @@ async def research( MAX_ITERATIONS, query_images=query_images, query_files=query_files, + relevant_memories=relevant_memories, max_document_searches=max_document_searches, max_online_searches=max_online_searches, max_webpages_to_read=max_webpages_to_read, @@ -323,8 +327,9 @@ async def research( query_images=query_images, previous_inferred_queries=previous_inferred_queries, agent=agent, - tracer=tracer, query_files=query_files, + relevant_memories=relevant_memories, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS]