Skip to content

Commit ccb2a5e

Browse files
committed
chore: Add database filtering
1 parent d005747 commit ccb2a5e

File tree

1 file changed

+55
-3
lines changed

1 file changed

+55
-3
lines changed

openagent/tools/twitter/feed.py

+55-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,29 @@
11
from typing import Optional, List
2-
2+
import os
33
from loguru import logger
44
from pydantic import BaseModel
55
import httpx
66
from datetime import datetime, timedelta, UTC
77
import re
88
import asyncio
99

10+
from sqlalchemy import Column, Integer, String, DateTime
11+
from sqlalchemy.ext.declarative import declarative_base
12+
from sqlalchemy.orm import sessionmaker
13+
from openagent.core.database import sqlite
1014
from openagent.core.tool import Tool
1115

16+
Base = declarative_base()
17+
18+
19+
class ProcessedTweet(Base):
20+
__tablename__ = "processed_tweets"
21+
22+
id = Column(Integer, primary_key=True)
23+
tweet_id = Column(String, unique=True)
24+
handle = Column(String)
25+
created_at = Column(DateTime, default=datetime.now(UTC))
26+
1227

1328
class Tweet(BaseModel):
1429
"""Model representing a tweet"""
@@ -64,6 +79,13 @@ def __init__(self):
6479
self.max_retries = 5
6580
self.retry_delay = 1
6681

82+
# Initialize database
83+
db_path = os.path.join(os.getcwd(), "storage", f"{self.name}.db")
84+
self.engine = sqlite.create_engine(db_path)
85+
Base.metadata.create_all(self.engine)
86+
Session = sessionmaker(bind=self.engine)
87+
self.session = Session()
88+
6789
@property
6890
def name(self) -> str:
6991
return "get_twitter_feed"
@@ -134,6 +156,34 @@ def _apply_time_filter(self, tweets: List[dict]) -> List[dict]:
134156

135157
return filtered
136158

159+
def _filter_processed_tweets(self, tweets: List[dict]) -> List[dict]:
160+
"""Filter out tweets that have already been processed"""
161+
filtered = []
162+
for tweet in tweets:
163+
# Check if tweet exists in database
164+
exists = (
165+
self.session.query(ProcessedTweet)
166+
.filter_by(tweet_id=tweet["tweet_id"])
167+
.first()
168+
)
169+
170+
if not exists:
171+
# Add to database and filtered list
172+
processed_tweet = ProcessedTweet(
173+
tweet_id=tweet["tweet_id"],
174+
handle=tweet["handle"],
175+
)
176+
self.session.add(processed_tweet)
177+
filtered.append(tweet)
178+
179+
# Commit new tweets to database
180+
self.session.commit()
181+
182+
if not filtered:
183+
logger.info("All tweets have been processed before")
184+
185+
return filtered
186+
137187
def _format_output(self, tweets: List[Tweet]) -> str:
138188
"""Format tweets into readable output"""
139189
if not tweets:
@@ -160,10 +210,12 @@ async def __call__(self) -> str:
160210
for handle in self.config.handles:
161211
# Fetch and filter tweets
162212
raw_tweets = await self._fetch_single_handle(client, handle)
163-
filtered_tweets = self._apply_time_filter(raw_tweets)
213+
time_filtered_tweets = self._apply_time_filter(raw_tweets)
214+
# Add database filtering
215+
unprocessed_tweets = self._filter_processed_tweets(time_filtered_tweets)
164216

165217
# Convert to Tweet objects
166-
for tweet in filtered_tweets:
218+
for tweet in unprocessed_tweets:
167219
all_tweets.append(Tweet(**tweet))
168220

169221
result = self._format_output(all_tweets)

0 commit comments

Comments
 (0)