diff --git a/src/interface/web/app/components/userMemory/userMemory.tsx b/src/interface/web/app/components/userMemory/userMemory.tsx
new file mode 100644
index 000000000..db360e7bc
--- /dev/null
+++ b/src/interface/web/app/components/userMemory/userMemory.tsx
@@ -0,0 +1,90 @@
+import { useState } from "react";
+import { Input } from "@/components/ui/input";
+import { Button } from "@/components/ui/button";
+import { Pencil, TrashSimple, FloppyDisk, X } from "@phosphor-icons/react";
+import { useToast } from "@/components/ui/use-toast";
+
+export interface UserMemorySchema {
+ id: number;
+ raw: string;
+ created_at: string;
+}
+
+interface UserMemoryProps {
+ memory: UserMemorySchema;
+ onDelete: (id: number) => void;
+ onUpdate: (id: number, raw: string) => void;
+}
+
+export function UserMemory({ memory, onDelete, onUpdate }: UserMemoryProps) {
+ const [isEditing, setIsEditing] = useState(false);
+ const [content, setContent] = useState(memory.raw);
+ const { toast } = useToast();
+
+ const handleUpdate = () => {
+ onUpdate(memory.id, content);
+ setIsEditing(false);
+ toast({
+ title: "Memory Updated",
+ description: "Your memory has been successfully updated.",
+ });
+ };
+
+ const handleDelete = () => {
+ onDelete(memory.id);
+ toast({
+ title: "Memory Deleted",
+ description: "Your memory has been successfully deleted.",
+ });
+ };
+
+ return (
+
+ {isEditing ? (
+ <>
+
setContent(e.target.value)}
+ className="flex-1"
+ />
+
+
+ >
+ ) : (
+ <>
+
+
+
+ >
+ )}
+
+ );
+}
diff --git a/src/interface/web/app/settings/page.tsx b/src/interface/web/app/settings/page.tsx
index 92d13ecf5..fffd726fd 100644
--- a/src/interface/web/app/settings/page.tsx
+++ b/src/interface/web/app/settings/page.tsx
@@ -15,6 +15,7 @@ import { Button } from "@/components/ui/button";
import { InputOTP, InputOTPGroup, InputOTPSlot } from "@/components/ui/input-otp";
import { Input } from "@/components/ui/input";
import { Card, CardContent, CardFooter, CardHeader } from "@/components/ui/card";
+
import {
DropdownMenu,
DropdownMenuContent,
@@ -33,6 +34,15 @@ import {
AlertDialogTitle,
AlertDialogTrigger,
} from "@/components/ui/alert-dialog";
+
+import {
+ Dialog,
+ DialogContent,
+ DialogHeader,
+ DialogTitle,
+ DialogTrigger
+} from "@/components/ui/dialog";
+
import { Table, TableBody, TableCell, TableRow } from "@/components/ui/table";
import {
@@ -74,6 +84,7 @@ import Loading from "../components/loading/loading";
import IntlTelInput from "intl-tel-input/react";
import { SidebarInset, SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
import { AppSidebar } from "../components/appSidebar/appSidebar";
+import { UserMemory, UserMemorySchema } from "../components/userMemory/userMemory";
import { Separator } from "@/components/ui/separator";
import { KhojLogoType } from "../components/logo/khojLogo";
import { Progress } from "@/components/ui/progress";
@@ -323,6 +334,7 @@ export default function SettingsView() {
const [numberValidationState, setNumberValidationState] = useState(
PhoneNumberValidationState.Verified,
);
+ const [memories, setMemories] = useState([]);
const [isExporting, setIsExporting] = useState(false);
const [exportProgress, setExportProgress] = useState(0);
const [exportedConversations, setExportedConversations] = useState(0);
@@ -666,6 +678,65 @@ export default function SettingsView() {
}
};
+ const fetchMemories = async () => {
+ try {
+ console.log("Fetching memories...");
+ const response = await fetch('/api/memories/');
+ if (!response.ok) throw new Error('Failed to fetch memories');
+ const data = await response.json();
+ setMemories(data);
+ } catch (error) {
+ console.error('Error fetching memories:', error);
+ toast({
+ title: "Error",
+ description: "Failed to fetch memories. Please try again.",
+ variant: "destructive"
+ });
+ }
+ };
+
+ const handleDeleteMemory = async (id: number) => {
+ try {
+ const response = await fetch(`/api/memories/${id}`, {
+ method: 'DELETE'
+ });
+ if (!response.ok) throw new Error('Failed to delete memory');
+ setMemories(memories.filter(memory => memory.id !== id));
+ } catch (error) {
+ console.error('Error deleting memory:', error);
+ toast({
+ title: "Error",
+ description: "Failed to delete memory. Please try again.",
+ variant: "destructive"
+ });
+ }
+ };
+
+ const handleUpdateMemory = async (id: number, raw: string) => {
+ try {
+ const response = await fetch(`/api/memories/${id}`, {
+ method: 'PUT',
+ headers: {
+ 'Content-Type': 'application/json',
+ },
+ body: JSON.stringify({ raw, memory_id: id }),
+ });
+ if (!response.ok) throw new Error('Failed to update memory');
+ const updatedMemory: UserMemorySchema = await response.json();
+ setMemories(memories.map(memory =>
+ memory.id === id ? updatedMemory : memory
+ ));
+ } catch (error) {
+ console.error('Error updating memory:', error);
+ toast({
+ title: "Error",
+ description: "Failed to update memory. Please try again.",
+ variant: "destructive"
+ });
+ }
+ };
+
+
const syncContent = async (type: string) => {
try {
const response = await fetch(`/api/content?t=${type}`, {
@@ -1269,7 +1340,45 @@ export default function SettingsView() {
-
+
+
+
+ Memories
+
+
+
+ View and manage your long-term memories
+
+
+
+
+
+
diff --git a/src/khoj/configure.py b/src/khoj/configure.py
index 40d1eeb5a..50b2dd313 100644
--- a/src/khoj/configure.py
+++ b/src/khoj/configure.py
@@ -323,6 +323,7 @@ def configure_routes(app):
from khoj.routers.api_automation import api_automation
from khoj.routers.api_chat import api_chat
from khoj.routers.api_content import api_content
+ from khoj.routers.api_memories import api_memories
from khoj.routers.api_model import api_model
from khoj.routers.notion import notion_router
from khoj.routers.web_client import web_client
@@ -332,6 +333,7 @@ def configure_routes(app):
app.include_router(api_agents, prefix="/api/agents")
app.include_router(api_automation, prefix="/api/automation")
app.include_router(api_model, prefix="/api/model")
+ app.include_router(api_memories, prefix="/api/memories")
app.include_router(api_content, prefix="/api/content")
app.include_router(notion_router, prefix="/api/notion")
app.include_router(web_client)
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index ea9c773f5..66a521f78 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -61,6 +61,7 @@
Subscription,
TextToImageModelConfig,
UserConversationConfig,
+ UserMemory,
UserRequests,
UserTextToImageModelConfig,
UserVoiceModelConfig,
@@ -584,6 +585,16 @@ def get_default_search_model() -> SearchModelConfig:
return SearchModelConfig.objects.first()
+async def aget_default_search_model() -> SearchModelConfig:
+ default_search_model = await SearchModelConfig.objects.filter(name="default").afirst()
+
+ if default_search_model:
+ return default_search_model
+ elif await SearchModelConfig.objects.count() == 0:
+ await SearchModelConfig.objects.acreate()
+ return await SearchModelConfig.objects.afirst()
+
+
def get_or_create_search_models():
search_models = SearchModelConfig.objects.all()
if search_models.count() == 0:
@@ -1510,12 +1521,17 @@ async def save_conversation(
):
slug = user_message.strip()[:200] if user_message else None
if conversation_id:
- conversation = await Conversation.objects.filter(
- user=user, client=client_application, id=conversation_id
- ).afirst()
+ conversation = (
+ await Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
+ .prefetch_related("agent", "agent__chat_model")
+ .afirst()
+ )
else:
conversation = (
- await Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
+ await Conversation.objects.filter(user=user, client=client_application)
+ .prefetch_related("agent", "agent__chat_model")
+ .order_by("-updated_at")
+ .afirst()
)
existing_messages = conversation.messages if conversation else []
@@ -2128,3 +2144,94 @@ def delete_automation(user: KhojUser, automation_id: str):
automation.remove()
return automation_metadata
+
+
+class UserMemoryAdapters:
+ @staticmethod
+ @require_valid_user
+ async def pull_memories(user: KhojUser, agent: Agent = None, limit=10, window=7) -> list[UserMemory]:
+ """
+ Pulls memories from the database for a given user. Medium term memory.
+ """
+ time_frame = datetime.now(timezone.utc) - timedelta(days=window)
+ default_agent = await AgentAdapters.aget_default_agent()
+ if agent and agent != default_agent:
+ memories = UserMemory.objects.filter(user=user, agent=agent, updated_at__gte=time_frame).order_by(
+ "-created_at"
+ )[:limit]
+ else:
+ memories = UserMemory.objects.filter(user=user, updated_at__gte=time_frame).order_by("-created_at")[:limit]
+ return await sync_to_async(list)(memories)
+
+ @staticmethod
+ @require_valid_user
+ async def save_memory(user: KhojUser, memory: str, agent: Agent = None) -> UserMemory:
+ """
+ Saves a memory to the database for a given user.
+ """
+ embeddings_model = state.embeddings_model
+ model = await aget_default_search_model()
+ embeddings = await sync_to_async(embeddings_model[model.name].embed_query)(memory)
+ default_agent = await AgentAdapters.aget_default_agent()
+ if agent and agent != default_agent:
+ memory_instance = await UserMemory.objects.acreate(
+ user=user, embeddings=embeddings, raw=memory, search_model=model, agent=agent
+ )
+ else:
+ memory_instance = await UserMemory.objects.acreate(
+ user=user, embeddings=embeddings, raw=memory, search_model=model
+ )
+
+ return memory_instance
+
+ @staticmethod
+ @require_valid_user
+ async def search_memories(query: str, user: KhojUser, agent: Agent = None, limit: int = 10) -> list[UserMemory]:
+ """
+ Searches for memories in the database for a given user. Long term memory.
+ """
+ embeddings_model = state.embeddings_model
+ model = await aget_default_search_model()
+ max_distance = model.bi_encoder_confidence_threshold or math.inf
+ embedded_query = await sync_to_async(embeddings_model[model.name].embed_query)(query)
+ default_agent = await AgentAdapters.aget_default_agent()
+
+ if agent and agent != default_agent:
+ relevant_memories = UserMemory.objects.filter(user=user, agent=agent)
+ else:
+ relevant_memories = UserMemory.objects.filter(user=user)
+
+ relevant_memories = (
+ relevant_memories.annotate(distance=CosineDistance("embeddings", embedded_query))
+ .order_by("distance")
+ .filter(distance__lte=max_distance)
+ )
+
+ return await sync_to_async(list)(relevant_memories[:limit])
+
+ @staticmethod
+ @require_valid_user
+ async def delete_memory(user: KhojUser, memory_id: str) -> bool:
+ """
+ Deletes a memory from the database for a given user.
+ """
+ try:
+ memory = await UserMemory.objects.aget(user=user, id=memory_id)
+ await memory.adelete()
+ return True
+ except UserMemory.DoesNotExist:
+ return False
+
+ @staticmethod
+ def to_dict(memories: List[UserMemory]) -> List[dict]:
+ """
+ Converts a list of Memory objects to a list of dictionaries.
+ """
+ return [
+ {
+ "id": f"{memory.id}",
+ "raw": memory.raw,
+ "updated_at": memory.updated_at.astimezone(timezone.utc).isoformat(timespec="seconds"),
+ }
+ for memory in memories
+ ]
diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py
index 5400c89e4..1730608a5 100644
--- a/src/khoj/database/admin.py
+++ b/src/khoj/database/admin.py
@@ -33,6 +33,7 @@
Subscription,
TextToImageModelConfig,
UserConversationConfig,
+ UserMemory,
UserRequests,
UserVoiceModelConfig,
VoiceModelOption,
@@ -181,6 +182,7 @@ def get_email_login_url(self, request, queryset):
admin.site.register(VoiceModelOption, unfold_admin.ModelAdmin)
admin.site.register(UserRequests, unfold_admin.ModelAdmin)
admin.site.register(RateLimitRecord, unfold_admin.ModelAdmin)
+admin.site.register(UserMemory, unfold_admin.ModelAdmin)
@admin.register(Agent)
diff --git a/src/khoj/database/management/commands/manage_memories.py b/src/khoj/database/management/commands/manage_memories.py
new file mode 100644
index 000000000..62e14c978
--- /dev/null
+++ b/src/khoj/database/management/commands/manage_memories.py
@@ -0,0 +1,384 @@
+import asyncio
+from datetime import datetime, timedelta
+from typing import List, Optional
+
+from django.core.management.base import BaseCommand
+from django.db.models import Q
+from django.utils import timezone
+
+from khoj.configure import initialize_server
+from khoj.database.adapters import UserMemoryAdapters
+from khoj.database.models import (
+ Conversation,
+ DataStore,
+ KhojUser,
+)
+from khoj.routers.helpers import extract_facts_from_query
+
+
+class Command(BaseCommand):
+ help = "Manage user memories - generate from conversations or delete existing memories"
+
+ def add_arguments(self, parser):
+ parser.add_argument(
+ "--lookback-days",
+ type=int,
+ default=None,
+ help="Number of days to look back. For generation: defaults to 7 days. For deletion: if not specified, deletes ALL memories",
+ )
+ parser.add_argument(
+ "--users",
+ type=str,
+ help="Process specific users (comma-separated usernames or emails)",
+ )
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=10,
+ help="Number of conversations to process in each batch (default: 10)",
+ )
+ parser.add_argument(
+ "--apply",
+ action="store_true",
+ help="Actually perform the operation. Without this flag, only shows what would be processed.",
+ )
+ parser.add_argument(
+ "--delete",
+ action="store_true",
+ help="Delete all memories for specified users instead of generating new ones",
+ )
+ parser.add_argument(
+ "--resume",
+ action="store_true",
+ help="Resume from last checkpoint if process was interrupted",
+ )
+ parser.add_argument(
+ "--force",
+ action="store_true",
+ help="Force regenerate memories even if already processed",
+ )
+
+ def handle(self, *args, **options):
+ """Main entry point for the command"""
+ initialize_server()
+ asyncio.run(self.async_handle(*args, **options))
+
+ async def async_handle(self, *args, **options):
+ """Async handler for memory management"""
+ lookback_days = options["lookback_days"]
+ usernames = options["users"]
+ batch_size = options["batch_size"]
+ apply = options["apply"]
+ delete = options["delete"]
+ resume = options["resume"]
+ force = options["force"]
+
+ mode = "APPLY" if apply else "DRY RUN"
+
+ # Handle deletion mode
+ if delete:
+ # For deletion, only use cutoff_date if lookback_days is explicitly provided
+ cutoff_date = timezone.now() - timedelta(days=lookback_days) if lookback_days else None
+ await self.handle_delete_memories(usernames, cutoff_date, apply)
+ return
+
+ # Handle generation mode
+ # For generation, default to 7 days if not specified
+ if lookback_days is None:
+ lookback_days = 7
+ cutoff_date = timezone.now() - timedelta(days=lookback_days)
+ self.stdout.write(f"[{mode}] Generating memories for conversations from the last {lookback_days} days")
+
+ # Get users to process
+ users = await self.get_users_to_process(usernames)
+ if not users:
+ self.stdout.write("No users found to process")
+ return
+
+ self.stdout.write(f"Found {len(users)} users to process")
+
+ # Initialize or retrieve checkpoint
+ checkpoint = await self.get_or_create_checkpoint(resume)
+
+ total_conversations = 0
+ total_memories = 0
+
+ for user in users:
+ # Check if user already processed in checkpoint
+ if not force and user.id in checkpoint.get("processed_users", []):
+ self.stdout.write(f"Skipping already processed user: {user.username}")
+ continue
+
+ self.stdout.write(f"\nProcessing user: {user.username} (ID: {user.id})")
+
+ # Get conversations for this user
+ conversations = await self.get_user_conversations(user, cutoff_date, checkpoint, force)
+
+ if not conversations:
+ self.stdout.write(f" No conversations to process for {user.username}")
+ # Mark user as processed
+ if apply:
+ await self.update_checkpoint(checkpoint, user_id=user.id)
+ continue
+
+ self.stdout.write(f" Found {len(conversations)} conversations to process")
+
+ # Process conversations in batches
+ user_memories = 0
+ for i in range(0, len(conversations), batch_size):
+ batch = conversations[i : i + batch_size]
+ batch_memories = await self.process_conversation_batch(user, batch, apply, checkpoint)
+ user_memories += batch_memories
+ total_conversations += len(batch)
+
+ # Update progress
+ progress = min(i + batch_size, len(conversations))
+ self.stdout.write(
+ f" Processed {progress}/{len(conversations)} conversations, generated {batch_memories} memories"
+ )
+
+ total_memories += user_memories
+ self.stdout.write(
+ f" Completed user {user.username}: "
+ f"processed {len(conversations)} conversations, "
+ f"generated {user_memories} memories"
+ )
+
+ # Mark user as processed
+ if apply:
+ await self.update_checkpoint(checkpoint, user_id=user.id)
+
+ # Clear checkpoint on successful completion
+ if apply:
+ await self.clear_checkpoint()
+
+ action = "Generated" if apply else "Would generate"
+ self.stdout.write(
+ self.style.SUCCESS(f"\n{action} {total_memories} memories from {total_conversations} conversations")
+ )
+
+ async def get_users_to_process(self, users_str: Optional[str]) -> List[KhojUser]:
+ """Get list of users to comma separated usernames or emails to process"""
+ if users_str:
+ usernames = [u.strip() for u in users_str.split(",") if u.strip()]
+ # Process specific users
+ users = [user async for user in KhojUser.objects.filter(Q(username__in=usernames) | Q(email__in=usernames))]
+ return users
+ else:
+ # Process all users with conversations
+ return [user async for user in KhojUser.objects.filter(conversation__isnull=False).distinct()]
+
+ async def get_user_conversations(
+ self, user: KhojUser, cutoff_date: Optional[datetime], checkpoint: dict, force: bool
+ ) -> List[Conversation]:
+ """Get conversations for a user that need processing"""
+ if cutoff_date is None:
+ query = Conversation.objects.filter(user=user).order_by("updated_at")
+ else:
+ query = Conversation.objects.filter(user=user, updated_at__gte=cutoff_date).order_by("updated_at")
+
+ # Filter out already processed conversations if resuming
+ if not force and user.id in checkpoint.get("processed_conversations", {}):
+ processed_ids = checkpoint["processed_conversations"][user.id]
+ query = query.exclude(id__in=processed_ids)
+
+ return [conv async for conv in query]
+
+ async def process_conversation_batch(
+ self, user: KhojUser, conversations: List[Conversation], apply: bool, checkpoint: dict
+ ) -> int:
+ """Process a batch of conversations and generate memories"""
+ total_memories = 0
+
+ for conversation in conversations:
+ try:
+ # Get conversation messages using sync_to_async for property access
+ from asgiref.sync import sync_to_async
+
+ # Access conversation_log synchronously
+ @sync_to_async
+ def get_messages():
+ return conversation.messages
+
+ messages = await get_messages()
+ if not messages:
+ continue
+
+ # Get agent if conversation has one
+ @sync_to_async
+ def get_agent():
+ return conversation.agent
+
+ agent = await get_agent()
+
+ # Get existing memories for context
+ # Process each conversation turn
+ conversation_memories = 0
+ i = 0
+ while i + 1 < len(messages):
+ # Only process user-assistant pairs as a valid turn for memory extraction
+ if messages[i].by != "you" or messages[i + 1].by != "khoj":
+ i += 1
+ continue
+
+ # Get the conversation history up to this point
+ history = messages[: i + 2]
+
+ # Extract user query text for memory search
+ q = ""
+ if messages[i].message is None:
+ i += 1
+ continue
+ elif isinstance(messages[i].message, str):
+ q = messages[i].message
+ elif isinstance(messages[i].message, list):
+ q = "\n\n".join(
+ content.get("text", "")
+ for content in messages[i].message
+ if isinstance(content, dict) and content.get("text")
+ )
+
+ if not q or not q.strip():
+ i += 1
+ continue
+
+ # Get unique recent and long term relevant memories
+ recent_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent)
+ long_term_memories = await UserMemoryAdapters.search_memories(query=q, user=user, agent=agent)
+ relevant_memories = list({m.id: m for m in recent_memories + long_term_memories}.values())
+
+ if apply:
+ # Ensure agent is fully loaded with its chat_model
+ if agent:
+
+ @sync_to_async
+ def load_agent_with_chat_model():
+ # Force load the chat_model relationship
+ _ = agent.chat_model
+ return agent
+
+ agent = await load_agent_with_chat_model()
+
+ # Update memories based on latest conversation turn
+ memory_updates = await extract_facts_from_query(
+ user=user,
+ conversation_history=history,
+ existing_facts=relevant_memories,
+ agent=agent,
+ tracer={},
+ )
+
+ # Save new memories
+ for memory in memory_updates.create:
+ await UserMemoryAdapters.save_memory(user, memory, agent=agent)
+ conversation_memories += 1
+ self.stdout.write(f"Created memory for user {user.id}: {memory[:50]}...")
+
+ # Delete outdated memories
+ for memory in memory_updates.delete:
+ await UserMemoryAdapters.delete_memory(user, memory)
+ self.stdout.write(f"Deleted memory for user {user.id}: {memory[:50]}...")
+ else:
+ # Dry run - estimate memories that would be created
+ conversation_memories += 1 # Rough estimate
+
+ # Move to next conversation turn pair
+ i += 2
+
+ total_memories += conversation_memories
+
+ # Update checkpoint after each conversation
+ if apply:
+ await self.update_checkpoint(checkpoint, user_id=user.id, conversation_id=str(conversation.id))
+ except Exception as e:
+ import traceback
+
+ self.stderr.write(
+ f"Error processing conversation {conversation.id} for user {user.id}: {e}\n"
+ f"Traceback: {traceback.format_exc()}"
+ )
+ continue
+
+ return total_memories
+
+ async def get_or_create_checkpoint(self, resume: bool) -> dict:
+ """Get or create checkpoint for resumable processing"""
+ checkpoint_key = "memory_generation_checkpoint"
+
+ if resume:
+ # Try to retrieve existing checkpoint
+ checkpoint_store = await DataStore.objects.filter(key=checkpoint_key, private=True).afirst()
+ if checkpoint_store:
+ self.stdout.write("Resuming from checkpoint...")
+ return checkpoint_store.value
+
+ # Create new checkpoint
+ return {"started_at": timezone.now().isoformat(), "processed_users": [], "processed_conversations": {}}
+
+ async def update_checkpoint(
+ self, checkpoint: dict, user_id: Optional[int] = None, conversation_id: Optional[str] = None
+ ):
+ """Update checkpoint with progress"""
+ if user_id and user_id not in checkpoint["processed_users"]:
+ checkpoint["processed_users"].append(user_id)
+
+ if user_id and conversation_id:
+ if user_id not in checkpoint["processed_conversations"]:
+ checkpoint["processed_conversations"][user_id] = []
+ if conversation_id not in checkpoint["processed_conversations"][user_id]:
+ checkpoint["processed_conversations"][user_id].append(conversation_id)
+
+ # Save checkpoint to database
+ await DataStore.objects.aupdate_or_create(
+ key="memory_generation_checkpoint", defaults={"value": checkpoint, "private": True}
+ )
+
+ async def clear_checkpoint(self):
+ """Clear checkpoint after successful completion"""
+ await DataStore.objects.filter(key="memory_generation_checkpoint").adelete()
+ self.stdout.write("Checkpoint cleared")
+
+ async def handle_delete_memories(self, usernames: Optional[str], cutoff_date: Optional[datetime], apply: bool):
+ """Handle deletion of user memories"""
+ from khoj.database.models import UserMemory
+
+ # Get users to process
+ users = await self.get_users_to_process(usernames)
+ if not users:
+ self.stdout.write("No users found to process")
+ return
+
+ mode = "APPLY" if apply else "DRY RUN"
+ if cutoff_date:
+ # Calculate days from cutoff date
+ days_back = (timezone.now() - cutoff_date).days
+ self.stdout.write(f"[{mode}] Deleting memories created in the last {days_back} days for {len(users)} users")
+ else:
+ self.stdout.write(f"[{mode}] Deleting ALL memories for {len(users)} users")
+
+ total_deleted = 0
+ for user in users:
+ # Count memories for this user
+ if cutoff_date is None:
+ user_memories = UserMemory.objects.filter(user=user)
+ else:
+ user_memories = UserMemory.objects.filter(user=user, created_at__gte=cutoff_date)
+
+ memories_count = await user_memories.acount()
+ if memories_count == 0:
+ self.stdout.write(f" User {user.username} has no memories to delete")
+ continue
+
+ self.stdout.write(f"\n User {user.username} (ID: {user.id}): {memories_count} memories")
+
+ if apply:
+ # Delete memories for this user (with date filter if specified)
+ deleted_count, _ = await user_memories.adelete()
+ self.stdout.write(f" Deleted {deleted_count} memories")
+ total_deleted += deleted_count
+ else:
+ self.stdout.write(f" Would delete {memories_count} memories")
+ total_deleted += memories_count
+
+ action = "Deleted" if apply else "Would delete"
+ self.stdout.write(self.style.SUCCESS(f"\n{action} {total_deleted} memories total"))
diff --git a/src/khoj/database/migrations/0095_usermemory.py b/src/khoj/database/migrations/0095_usermemory.py
new file mode 100644
index 000000000..5fdba2eb4
--- /dev/null
+++ b/src/khoj/database/migrations/0095_usermemory.py
@@ -0,0 +1,63 @@
+# Generated by Django 5.1.10 on 2025-08-29 22:57
+
+import django.db.models.deletion
+import pgvector.django
+from django.conf import settings
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0094_serverchatsettings_think_free_deep_and_more"),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name="UserMemory",
+ fields=[
+ (
+ "id",
+ models.BigAutoField(
+ auto_created=True,
+ primary_key=True,
+ serialize=False,
+ verbose_name="ID",
+ ),
+ ),
+ ("created_at", models.DateTimeField(auto_now_add=True)),
+ ("updated_at", models.DateTimeField(auto_now=True)),
+ ("embeddings", pgvector.django.VectorField()),
+ ("raw", models.TextField()),
+ (
+ "agent",
+ models.ForeignKey(
+ blank=True,
+ default=None,
+ null=True,
+ on_delete=django.db.models.deletion.CASCADE,
+ to="database.agent",
+ ),
+ ),
+ (
+ "search_model",
+ models.ForeignKey(
+ blank=True,
+ default=None,
+ null=True,
+ on_delete=django.db.models.deletion.SET_NULL,
+ to="database.searchmodelconfig",
+ ),
+ ),
+ (
+ "user",
+ models.ForeignKey(
+ on_delete=django.db.models.deletion.CASCADE,
+ to=settings.AUTH_USER_MODEL,
+ ),
+ ),
+ ],
+ options={
+ "abstract": False,
+ },
+ ),
+ ]
diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index 90ed67a83..7d06c6abd 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -804,3 +804,15 @@ class DataStore(DbBaseModel):
value = models.JSONField(default=dict)
private = models.BooleanField(default=False)
owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
+
+
+class UserMemory(DbBaseModel):
+ """
+ Long term memory store derived from conversation between user and agent.
+ """
+
+ user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
+ agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
+ embeddings = VectorField(dimensions=None)
+ raw = models.TextField()
+ search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True)
diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py
index 29ea56eca..f84f37c94 100644
--- a/src/khoj/processor/conversation/prompts.py
+++ b/src/khoj/processor/conversation/prompts.py
@@ -1303,3 +1303,70 @@
User's Name: {name}
""".strip()
)
+
+extract_facts_from_query = PromptTemplate.from_template(
+ """
+You are Muninn, the user's memory manager. Construct and maintain an accurate, up-to-date set of facts about and on behalf of the user.
+This can include who the user is, their interests, their life circumstances, events in their life, their personal motivations and any facts that the user explicitly asks you to remember.
+
+You are given the latest chat session and some previously stored facts about the user. You can take two kinds of action:
+1. Create new facts
+2. Delete existing facts
+
+You should delete existing facts that are no longer true.
+You can enhance new facts with information from existing facts.
+You cannot update existing facts directly, instead create new facts and delete related existing ones to update them.
+
+Your output should be a JSON object with two lists: create and delete.
+- The create list should contain important, new facts *related to the user* to be added. Each fact should be atomic, self-contained and written in the user's first person perspective.
+- The delete list should contain IDs of existing facts to be deleted. You must delete all facts that are no longer relevant or true.
+- Leave the create or delete list empty if you have nothing important to add or remove.
+
+# Example
+Existing Facts:
+[
+ {{
+ "id": "5283",
+ "raw": "I am not interested in sports",
+ "updated_at": "2023-10-01T12:00:00+00:00"
+ }},
+ {{
+ "id": "22",
+ "raw": "I am a software engineer",
+ "updated_at": "2023-10-31T14:00:00+00:00"
+ }},
+ {{
+ "id": "651",
+ "raw": "My mother works at the hospital",
+ "updated_at": "2023-10-02T17:00:00+00:00"
+ }}
+]
+
+Latest Chat Session:
+- User: I had an amazing day today! I was replicating this core AI paper, but ran into some issues with the training pipeline.
+In between coding, I took my cat Whiskers out for a walk and played a game of football.
+My mom called me in between her shift at the hospital (she's a doctor), so we had a nice chat.
+- AI: That's great to hear!
+
+Response:
+{{
+ "create": [
+ "I am interested in AI and machine learning",
+ "I have a pet cat named Whiskers",
+ "I enjoy playing football",
+ "My mother works at the hospital and is a doctor"
+ ],
+ "delete": [
+ "5283",
+ "651"
+ ],
+}}
+
+# Input
+Existing Facts:
+{matched_facts}
+
+Latest Chat Session:
+{chat_history}
+""".strip()
+)
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index 46a03b2c8..8419f3609 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -26,6 +26,7 @@
ClientApplication,
Intent,
KhojUser,
+ UserMemory,
)
from khoj.processor.conversation import prompts
from khoj.search_filter.base_filter import BaseFilter
@@ -446,6 +447,7 @@ async def save_to_conversation_log(
client_application: ClientApplication = None,
conversation_id: str = None,
automation_id: str = None,
+ relevant_memories: List[UserMemory] = [],
query_images: List[str] = None,
raw_query_files: List[FileAttachment] = [],
generated_images: List[str] = [],
@@ -454,6 +456,8 @@ async def save_to_conversation_log(
train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {},
):
+ from khoj.routers.helpers import ai_update_memories
+
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
turn_id = tracer.get("mid") or str(uuid.uuid4())
@@ -500,6 +504,16 @@ async def save_to_conversation_log(
user_message=q,
)
+ if not automation_id:
+ # Don't update memories from automations, as this could get noisy.
+ await ai_update_memories(
+ user=user,
+ conversation_history=new_messages or [],
+ memories=relevant_memories,
+ agent=db_conversation.agent if db_conversation else None,
+ tracer=tracer,
+ )
+
if is_promptrace_enabled():
merge_message_into_conversation_trace(q, chat_response, tracer)
@@ -561,6 +575,7 @@ def generate_chatml_messages_with_context(
query_files: str = None,
query_images=None,
context_message="",
+ relevant_memories: List[UserMemory] = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = [],
chat_history: list[ChatMessageModel] = [],
@@ -701,6 +716,14 @@ def generate_chatml_messages_with_context(
),
)
+ if not is_none_or_empty(relevant_memories):
+ memory_context = "Your memory system retrieved the following memories about me based on our previous conversations. Ignore them if they are not relevant to the query.\n\n"
+ for memory in relevant_memories:
+ friendly_dt = memory.created_at.strftime("%Y-%m-%d %H:%M:%S")
+ memory_context += f"- [{friendly_dt}]: {memory.raw}\n"
+ memory_context += ""
+ messages.append(ChatMessage(content=memory_context, role="user"))
+
if not is_none_or_empty(user_message):
messages.append(
ChatMessage(
diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py
index 7ff838299..ab87a7af7 100644
--- a/src/khoj/processor/image/generate.py
+++ b/src/khoj/processor/image/generate.py
@@ -25,6 +25,7 @@
Intent,
KhojUser,
TextToImageModelConfig,
+ UserMemory,
)
from khoj.processor.conversation.google.utils import _is_retryable_error
from khoj.routers.helpers import ChatEvent, ImageShape, generate_better_image_prompt
@@ -47,6 +48,7 @@ async def text_to_image(
query_images: Optional[List[str]] = None,
agent: Agent = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
):
status_code = 200
@@ -91,6 +93,7 @@ async def text_to_image(
user=user,
agent=agent,
query_files=query_files,
+ relevant_memories=relevant_memories,
tracer=tracer,
)
image_prompt = image_prompt_response["description"]
diff --git a/src/khoj/processor/operator/__init__.py b/src/khoj/processor/operator/__init__.py
index 0aa12ca4a..ccd780aa4 100644
--- a/src/khoj/processor/operator/__init__.py
+++ b/src/khoj/processor/operator/__init__.py
@@ -5,7 +5,7 @@
from typing import Callable, List, Optional
from khoj.database.adapters import AgentAdapters, ConversationAdapters
-from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser
+from khoj.database.models import Agent, ChatMessageModel, ChatModel, KhojUser, UserMemory
from khoj.processor.conversation.utils import (
AgentMessage,
OperatorRun,
@@ -42,6 +42,7 @@ async def operate_environment(
query_images: Optional[List[str]] = None, # TODO: Handle query images
agent: Agent = None,
query_files: str = None, # TODO: Handle query files
+ relevant_memories: Optional[List[UserMemory]] = None, # TODO: Handle relevant memories
cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None,
abort_message: Optional[str] = ChatEvent.END_EVENT.value,
diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py
index 130bff302..72c8a7358 100644
--- a/src/khoj/processor/tools/online_search.py
+++ b/src/khoj/processor/tools/online_search.py
@@ -15,6 +15,7 @@
ChatMessageModel,
KhojUser,
ServerChatSettings,
+ UserMemory,
WebScraper,
)
from khoj.processor.conversation import prompts
@@ -77,6 +78,7 @@ async def search_online(
max_webpages_to_read: int = 1,
query_images: List[str] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
previous_subqueries: Set = set(),
agent: Agent = None,
tracer: dict = {},
@@ -95,6 +97,7 @@ async def search_online(
user,
query_images=query_images,
query_files=query_files,
+ relevant_memories=relevant_memories,
max_queries=max_online_searches,
agent=agent,
tracer=tracer,
@@ -172,7 +175,13 @@ async def search_online(
yield {ChatEvent.STATUS: event}
tasks = [
read_webpage_and_extract_content(
- data["queries"], link, data.get("content"), user=user, agent=agent, tracer=tracer
+ data["queries"],
+ link,
+ data.get("content"),
+ user=user,
+ agent=agent,
+ relevant_memories=relevant_memories,
+ tracer=tracer,
)
for link, data in webpages.items()
]
@@ -382,6 +391,7 @@ async def read_webpages(
agent: Agent = None,
max_webpages_to_read: int = 1,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
):
"Infer web pages to read from the query and extract relevant information from them"
@@ -395,6 +405,7 @@ async def read_webpages(
query_images,
agent=agent,
query_files=query_files,
+ relevant_memories=relevant_memories,
tracer=tracer,
)
async for result in read_webpages_content(
@@ -414,6 +425,7 @@ async def read_webpages_content(
user: KhojUser,
send_status_func: Optional[Callable] = None,
agent: Agent = None,
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
):
logger.info(f"Reading web pages at: {urls}")
@@ -421,7 +433,14 @@ async def read_webpages_content(
webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
- tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls]
+
+ tasks = [
+ read_webpage_and_extract_content(
+ {query}, url, user=user, agent=agent, relevant_memories=relevant_memories, tracer=tracer
+ )
+ for url in urls
+ ]
+
results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict)
@@ -452,6 +471,7 @@ async def read_webpage_and_extract_content(
content: str = None,
user: KhojUser = None,
agent: Agent = None,
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
) -> Tuple[set[str], str, Union[None, str]]:
# Select the web scrapers to use for reading the web page
@@ -475,7 +495,7 @@ async def read_webpage_and_extract_content(
if is_none_or_empty(extracted_info):
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(
- subqueries, content, user=user, agent=agent, tracer=tracer
+ subqueries, content, user=user, agent=agent, relevant_memories=relevant_memories, tracer=tracer
)
# If we successfully extracted information, break the loop
diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py
index 26edad27d..e98d5871b 100644
--- a/src/khoj/processor/tools/run_code.py
+++ b/src/khoj/processor/tools/run_code.py
@@ -20,7 +20,7 @@
)
from khoj.database.adapters import AgentAdapters, FileObjectAdapters
-from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser
+from khoj.database.models import Agent, ChatMessageModel, FileObject, KhojUser, UserMemory
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
ChatEvent,
@@ -59,6 +59,7 @@ async def run_code(
agent: Agent = None,
sandbox_url: str = SANDBOX_URL,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
):
# Generate Code
@@ -77,6 +78,7 @@ async def run_code(
agent,
tracer,
query_files,
+ relevant_memories,
)
except Exception as e:
raise ValueError(f"Failed to generate code for {instructions} with error: {e}")
@@ -124,6 +126,7 @@ async def generate_python_code(
agent: Agent = None,
tracer: dict = {},
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
) -> GeneratedCode:
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
@@ -158,6 +161,7 @@ async def generate_python_code(
code_generation_prompt,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py
index 68e7ea433..7e930dfc0 100644
--- a/src/khoj/routers/api_chat.py
+++ b/src/khoj/routers/api_chat.py
@@ -29,6 +29,7 @@
ConversationAdapters,
EntryAdapters,
PublicConversationAdapters,
+ UserMemoryAdapters,
aget_user_name,
)
from khoj.database.models import Agent, KhojUser
@@ -101,7 +102,6 @@
trial_rate_limit=20, subscribed_rate_limit=75, slug="command"
)
-
api_chat = APIRouter()
@@ -963,6 +963,13 @@ def collect_telemetry():
location = LocationData(city=city, region=region, country=country, country_code=country_code)
chat_history = conversation.messages
+ # Get most recent memories and long term relevant memories
+ recent_memories = await UserMemoryAdapters.pull_memories(user=user, agent=agent)
+ long_term_memories = await UserMemoryAdapters.search_memories(query=q, user=user, agent=agent)
+
+ # Create a de-duped set of memories
+ relevant_memories = list({m.id: m for m in recent_memories + long_term_memories}.values())
+
# If interrupted message in DB
if last_message := await conversation.pop_message(interrupted=True):
# Populate context from interrupted message
@@ -987,6 +994,7 @@ def collect_telemetry():
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
tracer=tracer,
)
except ValueError as e:
@@ -1027,6 +1035,7 @@ def collect_telemetry():
send_status_func=partial(send_event, ChatEvent.STATUS),
user_name=user_name,
location=location,
+ relevant_memories=relevant_memories,
query_files=attached_file_context,
cancellation_event=cancellation_event,
interrupt_queue=child_interrupt_queue,
@@ -1082,6 +1091,7 @@ def collect_telemetry():
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@@ -1130,6 +1140,7 @@ def collect_telemetry():
max_online_searches=3,
query_images=uploaded_images,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
agent=agent,
tracer=tracer,
):
@@ -1158,6 +1169,7 @@ def collect_telemetry():
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@@ -1199,6 +1211,7 @@ def collect_telemetry():
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@@ -1223,6 +1236,7 @@ def collect_telemetry():
list(operator_results)[-1] if operator_results else None,
query_images=uploaded_images,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
send_status_func=partial(send_event, ChatEvent.STATUS),
agent=agent,
cancellation_event=cancellation_event,
@@ -1277,6 +1291,7 @@ def collect_telemetry():
query_images=uploaded_images,
agent=agent,
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@@ -1320,6 +1335,7 @@ def collect_telemetry():
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
query_files=attached_file_context,
+ relevant_memories=relevant_memories,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@@ -1375,6 +1391,7 @@ def collect_telemetry():
user_name,
uploaded_images,
attached_file_context,
+ relevant_memories,
program_execution_context,
generated_asset_results,
is_subscribed,
@@ -1434,6 +1451,7 @@ def collect_telemetry():
query_images=uploaded_images,
train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
+ relevant_memories=relevant_memories,
generated_images=generated_images,
generated_mermaidjs_diagram=generated_mermaidjs_diagram,
tracer=tracer,
diff --git a/src/khoj/routers/api_memories.py b/src/khoj/routers/api_memories.py
new file mode 100644
index 000000000..d75278279
--- /dev/null
+++ b/src/khoj/routers/api_memories.py
@@ -0,0 +1,114 @@
+import json
+import logging
+from typing import Optional
+
+from asgiref.sync import sync_to_async
+from fastapi import APIRouter, Request
+from fastapi.responses import Response
+from pydantic import BaseModel
+from starlette.authentication import requires
+
+from khoj.database.adapters import UserMemoryAdapters
+from khoj.database.models import UserMemory
+
+api_memories = APIRouter()
+logger = logging.getLogger(__name__)
+
+
+@api_memories.get("")
+@requires(["authenticated"])
+async def get_memories(
+ request: Request,
+ client: Optional[str] = None,
+):
+ """Get all memories for the authenticated user"""
+ user = request.user.object
+
+ memories = UserMemory.objects.filter(user=user)
+ all_memories = await sync_to_async(list)(memories)
+
+ # Convert memories to a list of dictionaries
+ formatted_memories = [
+ {
+ "id": memory.id,
+ "raw": memory.raw,
+ "created_at": memory.created_at.isoformat(),
+ }
+ for memory in all_memories
+ ]
+
+ return Response(content=json.dumps(formatted_memories), media_type="application/json", status_code=200)
+
+
+@api_memories.delete("/{memory_id}")
+@requires(["authenticated"])
+async def delete_memory(
+ request: Request,
+ memory_id: int,
+ client: Optional[str] = None,
+):
+ """Delete a specific memory by ID"""
+ user = request.user.object
+
+ # Verify memory belongs to user before deleting
+ memory = await UserMemory.objects.filter(id=memory_id, user=user).afirst()
+ if not memory:
+ return Response(
+ content=json.dumps({"error": "Memory not found"}), media_type="application/json", status_code=404
+ )
+
+ await memory.adelete()
+
+ return Response(status_code=204)
+
+
+class UpdateMemoryBody(BaseModel):
+ """Request model for updating a memory"""
+
+ raw: str
+
+
+@api_memories.put("/{memory_id}")
+@requires(["authenticated"])
+async def update_memory(
+ request: Request,
+ body: UpdateMemoryBody,
+ memory_id: int,
+ client: Optional[str] = None,
+):
+ """Update a specific memory's content"""
+ user = request.user.object
+
+ # Get the memory and verify it belongs to the user
+ memory = await UserMemory.objects.filter(id=memory_id, user=user).afirst()
+ if not memory:
+ return Response(
+ content=json.dumps({"error": "Memory not found"}), media_type="application/json", status_code=404
+ )
+
+ new_content = body.raw
+ if not new_content:
+ return Response(
+ content=json.dumps({"error": "Missing required field 'raw'"}),
+ media_type="application/json",
+ status_code=400,
+ )
+
+ await memory.adelete()
+
+ # Create a new memory with the updated content
+ new_memory = await UserMemoryAdapters.save_memory(
+ user=user,
+ memory=new_content,
+ )
+
+ return Response(
+ content=json.dumps(
+ {
+ "id": new_memory.id,
+ "raw": new_memory.raw,
+ }
+ ),
+ media_type="application/json",
+ status_code=200,
+ )
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index 6c3eac0fb..018913b45 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -46,6 +46,7 @@
ConversationAdapters,
EntryAdapters,
FileObjectAdapters,
+ UserMemoryAdapters,
aget_user_by_email,
create_khoj_token,
get_default_search_model,
@@ -53,6 +54,7 @@
get_user_name,
get_user_notion_config,
get_user_subscription_state,
+ require_valid_user,
run_with_process_lock,
)
from khoj.database.models import (
@@ -68,6 +70,7 @@
RateLimitRecord,
Subscription,
TextToImageModelConfig,
+ UserMemory,
UserRequests,
)
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
@@ -349,6 +352,7 @@ async def aget_data_sources_and_output_format(
query_images: List[str] = None,
agent: Agent = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
) -> Dict[str, Any]:
"""
@@ -406,6 +410,7 @@ class PickTools(BaseModel):
relevant_tools_prompt,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
response_type="json_object",
response_schema=PickTools,
fast_model=False,
@@ -462,6 +467,7 @@ async def infer_webpage_urls(
query_images: List[str] = None,
agent: Agent = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
) -> List[str]:
"""
@@ -496,6 +502,7 @@ class WebpageUrls(BaseModel):
online_queries_prompt,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
response_type="json_object",
response_schema=WebpageUrls,
fast_model=False,
@@ -526,6 +533,7 @@ async def generate_online_subqueries(
user: KhojUser,
query_images: List[str] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
max_queries: int = 3,
agent: Agent = None,
tracer: dict = {},
@@ -562,6 +570,7 @@ class OnlineQueries(BaseModel):
online_queries_prompt,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
response_type="json_object",
response_schema=OnlineQueries,
fast_model=False,
@@ -648,7 +657,12 @@ async def aschedule_query(
async def extract_relevant_info(
- qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
+ qs: set[str],
+ corpus: str,
+ user: KhojUser = None,
+ agent: Agent = None,
+ relevant_memories: List[UserMemory] = None,
+ tracer: dict = {},
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
@@ -672,6 +686,7 @@ async def extract_relevant_info(
response = await send_message_to_model_wrapper(
extract_relevant_information,
system_message=prompts.system_prompt_extract_relevant_information,
+ relevant_memories=relevant_memories,
fast_model=True,
agent_chat_model=agent_chat_model,
user=user,
@@ -794,6 +809,7 @@ async def generate_excalidraw_diagram(
agent: Agent = None,
send_status_func: Optional[Callable] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
):
if send_status_func:
@@ -810,6 +826,7 @@ async def generate_excalidraw_diagram(
user=user,
agent=agent,
query_files=query_files,
+ relevant_memories=relevant_memories,
tracer=tracer,
)
@@ -845,6 +862,7 @@ async def generate_better_diagram_description(
user: KhojUser = None,
agent: Agent = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
) -> str:
"""
@@ -888,6 +906,7 @@ async def generate_better_diagram_description(
improve_diagram_description_prompt,
query_images=query_images,
query_files=query_files,
+ relevant_memories=relevant_memories,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
@@ -945,6 +964,84 @@ async def generate_excalidraw_diagram_from_description(
return response
+class MemoryUpdates(BaseModel):
+ """Facts to add or remove from memory."""
+
+ create: List[str] = Field(..., min_items=0, description="List of facts to add to memory.")
+ delete: List[str] = Field(..., min_items=0, description="List of facts to remove from memory.")
+
+
+async def extract_facts_from_query(
+ user: KhojUser,
+ conversation_history: List[ChatMessageModel],
+ existing_facts: List[UserMemory] = None,
+ agent: Agent = None,
+ tracer: dict = {},
+) -> MemoryUpdates:
+ """
+ Extract facts from the given query
+ """
+ chat_history = construct_chat_history(conversation_history, n=2)
+
+ formatted_memories = json.dumps(UserMemoryAdapters.to_dict(existing_facts), indent=2) if existing_facts else []
+
+ extract_facts_prompt = prompts.extract_facts_from_query.format(
+ chat_history=chat_history,
+ matched_facts=formatted_memories,
+ )
+
+ with timer("Chat actor: Extract facts from query", logger):
+ response = await send_message_to_model_wrapper(
+ extract_facts_prompt,
+ response_schema=MemoryUpdates,
+ user=user,
+ fast_model=False,
+ agent_chat_model=agent.chat_model,
+ tracer=tracer,
+ )
+ response = response.text.strip()
+ # JSON parse the list of strings
+ try:
+ response = clean_json(response)
+ response = json.loads(response)
+ parsed_response = MemoryUpdates(**response)
+ if not isinstance(parsed_response, MemoryUpdates):
+ raise ValueError(f"Invalid response for extracting facts: {response}")
+ return parsed_response
+
+ except Exception:
+ logger.error(f"Invalid response for extracting facts: {response}")
+ return MemoryUpdates(create=[], delete=[])
+
+
+@require_valid_user
+async def ai_update_memories(
+ user: KhojUser,
+ conversation_history: List[ChatMessageModel],
+ memories: List[UserMemory],
+ agent: Agent,
+ tracer: dict = {},
+):
+ """
+ Updates the memories for a given user, based on their latest input query.
+ """
+ memory_update = await extract_facts_from_query(
+ user=user, conversation_history=conversation_history, existing_facts=memories, agent=agent, tracer=tracer
+ )
+
+ if not memory_update:
+ return
+
+ # Save the memory updates to the database
+ for memory in memory_update.create:
+ logger.info(f"Creating memory: {memory}")
+ await UserMemoryAdapters.save_memory(user, memory, agent=agent)
+
+ for memory in memory_update.delete:
+ logger.info(f"Deleting memory: {memory}")
+ await UserMemoryAdapters.delete_memory(user, memory)
+
+
async def generate_mermaidjs_diagram(
q: str,
chat_history: List[ChatMessageModel],
@@ -956,6 +1053,7 @@ async def generate_mermaidjs_diagram(
agent: Agent = None,
send_status_func: Optional[Callable] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
):
if send_status_func:
@@ -972,6 +1070,7 @@ async def generate_mermaidjs_diagram(
user=user,
agent=agent,
query_files=query_files,
+ relevant_memories=relevant_memories,
tracer=tracer,
)
@@ -1001,6 +1100,7 @@ async def generate_better_mermaidjs_diagram_description(
user: KhojUser = None,
agent: Agent = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
) -> str:
"""
@@ -1044,6 +1144,7 @@ async def generate_better_mermaidjs_diagram_description(
improve_diagram_description_prompt,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
fast_model=False,
agent_chat_model=agent_chat_model,
user=user,
@@ -1095,6 +1196,7 @@ async def generate_better_image_prompt(
user: KhojUser = None,
agent: Agent = None,
query_files: str = "",
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
) -> dict:
"""
@@ -1137,6 +1239,7 @@ class ImagePromptResponse(BaseModel):
q,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
chat_history=conversation_history,
system_message=enhance_image_system_message,
response_type="json_object",
@@ -1171,6 +1274,7 @@ async def search_documents(
previous_inferred_queries: Set = set(),
agent: Agent = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = None,
tracer: dict = {},
):
# Initialize Variables
@@ -1218,6 +1322,7 @@ async def search_documents(
user=user,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
personality_context=personality_context,
location_data=location_data,
chat_history=chat_history,
@@ -1270,6 +1375,7 @@ async def extract_questions(
query_files: str = None,
query_images: Optional[List[str]] = None,
personality_context: str = "",
+ relevant_memories: List[UserMemory] = None,
location_data: LocationData = None,
chat_history: List[ChatMessageModel] = [],
max_queries: int = 5,
@@ -1324,6 +1430,7 @@ class DocumentQueries(BaseModel):
query=prompt,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
system_message=system_prompt,
response_type="json_object",
response_schema=DocumentQueries,
@@ -1446,6 +1553,7 @@ async def send_message_to_model_wrapper(
query_files: str = None,
query_images: List[str] = None,
context: str = "",
+ relevant_memories: List[UserMemory] = None,
chat_history: list[ChatMessageModel] = [],
system_message: str = "",
# Model Config
@@ -1484,6 +1592,7 @@ async def send_message_to_model_wrapper(
query_files=query_files,
query_images=query_images,
context_message=context,
+ relevant_memories=relevant_memories,
chat_history=chat_history,
system_message=system_message,
model_name=chat_model_name,
@@ -1612,6 +1721,7 @@ def build_conversation_context(
operator_results: List[OperatorRun],
query_files: str = None,
query_images: Optional[List[str]] = None,
+ relevant_memories: List[UserMemory] = None,
generated_asset_results: Dict[str, Dict] = {},
program_execution_context: List[str] = None,
chat_history: List[ChatMessageModel] = [],
@@ -1688,6 +1798,7 @@ def build_conversation_context(
query_files=query_files,
query_images=query_images,
context_message=context_message,
+ relevant_memories=relevant_memories,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
chat_history=chat_history,
@@ -1716,6 +1827,7 @@ async def agenerate_chat_response(
user_name: Optional[str] = None,
query_images: Optional[List[str]] = None,
query_files: str = None,
+ relevant_memories: List[UserMemory] = [],
program_execution_context: List[str] = [],
generated_asset_results: Dict[str, Dict] = {},
is_subscribed: bool = False,
@@ -1758,6 +1870,7 @@ async def agenerate_chat_response(
operator_results=operator_results,
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
generated_asset_results=generated_asset_results,
program_execution_context=program_execution_context,
chat_history=chat_history,
diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py
index 43abaaf0c..03302e32c 100644
--- a/src/khoj/routers/research.py
+++ b/src/khoj/routers/research.py
@@ -8,7 +8,7 @@
import yaml
from khoj.database.adapters import AgentAdapters, EntryAdapters
-from khoj.database.models import Agent, ChatMessageModel, KhojUser
+from khoj.database.models import Agent, ChatMessageModel, KhojUser, UserMemory
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
OperatorRun,
@@ -56,6 +56,7 @@ async def apick_next_tool(
max_iterations: int = 5,
query_images: List[str] = [],
query_files: str = None,
+ relevant_memories: List[UserMemory] = [],
max_document_searches: int = 7,
max_online_searches: int = 3,
max_webpages_to_read: int = 3,
@@ -161,6 +162,7 @@ async def apick_next_tool(
query="",
query_files=query_files,
query_images=query_images,
+ relevant_memories=relevant_memories,
system_message=function_planning_prompt,
chat_history=chat_and_research_history,
tools=tools,
@@ -218,6 +220,7 @@ async def research(
send_status_func: Optional[Callable] = None,
user_name: str = None,
location: LocationData = None,
+ relevant_memories: List[UserMemory] = [],
query_files: str = None,
cancellation_event: Optional[asyncio.Event] = None,
interrupt_queue: Optional[asyncio.Queue] = None,
@@ -280,6 +283,7 @@ async def research(
MAX_ITERATIONS,
query_images=query_images,
query_files=query_files,
+ relevant_memories=relevant_memories,
max_document_searches=max_document_searches,
max_online_searches=max_online_searches,
max_webpages_to_read=max_webpages_to_read,
@@ -323,8 +327,9 @@ async def research(
query_images=query_images,
previous_inferred_queries=previous_inferred_queries,
agent=agent,
- tracer=tracer,
query_files=query_files,
+ relevant_memories=relevant_memories,
+ tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]