Skip to content

Commit a752c9d

Browse files
committed
refactor: database URL handling
1 parent 96d5e9e commit a752c9d

File tree

5 files changed

+27
-36
lines changed

5 files changed

+27
-36
lines changed

openagent/core/database/engine.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -70,27 +70,29 @@ def _create_postgres_engine(db_url: str) -> Engine:
7070
return sa_create_engine(db_url)
7171

7272

73-
def create_engine(db_type: Literal["sqlite", "postgres"] = "sqlite", db_url: str = None) -> Engine:
73+
def create_engine(db_url: str) -> Engine:
7474
"""
7575
Create a database engine based on the provided configuration.
76+
Database type is automatically detected from the URL.
7677
7778
Args:
78-
db_type: Type of database ('sqlite' or 'postgres')
7979
db_url: Database URL. For postgres: postgresql://user:password@host:port/database,
8080
for sqlite: sqlite:///path/to/file.db
8181
8282
Returns:
8383
SQLAlchemy engine instance
8484
8585
Raises:
86-
ValueError: If an unsupported database type is specified or if db_url is missing
86+
ValueError: If an unsupported database type is detected or if db_url is missing
8787
"""
8888
if not db_url:
8989
raise ValueError("Database URL is required")
90-
91-
if db_type == "sqlite":
90+
91+
# Auto-detect database type from URL
92+
if db_url.startswith('sqlite:'):
9293
return _create_sqlite_engine(db_url)
93-
elif db_type == "postgres":
94+
elif db_url.startswith('postgresql:'):
9495
return _create_postgres_engine(db_url)
9596
else:
96-
raise ValueError(f"Unsupported database type: {db_type}")
97+
raise ValueError(f"Could not detect database type from URL: {db_url}. "
98+
"Supported URL formats: 'sqlite:///path/to/file.db' or 'postgresql://user:password@host:port/database'")

openagent/tools/pendle/market_analysis.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ class PendleMarketConfig(BaseModel):
7171
default=None,
7272
description="Model configuration for LLM. If not provided, will use agent's core model",
7373
)
74-
database: Optional[DatabaseConfig] = Field(
74+
db_url: Optional[str] = Field(
7575
default=None,
76-
description="Database configuration. If not provided, will use SQLite with default path",
76+
description="Database URL. For postgres: postgresql://user:password@host:port/database, for sqlite: sqlite:///path/to/file.db"
7777
)
7878

7979

@@ -87,21 +87,14 @@ def __init__(self, core_model=None):
8787
self.tool_prompt = None
8888
self.DBSession = None
8989

90-
def _init_database(self, config: Optional[DatabaseConfig] = None) -> None:
90+
def _init_database(self, db_url: Optional[str]) -> None:
9191
"""Initialize database connection based on configuration"""
9292
# Set default configuration if not provided
93-
db_type = "sqlite"
94-
db_url = 'sqlite:///' + os.path.join(os.getcwd(), "storage", f"{self.name}.db")
95-
96-
if config:
97-
db_type = config.type
98-
db_url = config.url
93+
if not db_url:
94+
db_url = 'sqlite:///' + os.path.join(os.getcwd(), "storage", f"{self.name}.db")
9995

10096
# Create engine using the database module's create_engine function
101-
engine = create_engine(
102-
db_type=db_type,
103-
db_url=db_url,
104-
)
97+
engine = create_engine(db_url)
10598

10699
# Create tables and initialize session factory
107100
Base.metadata.create_all(engine)
@@ -119,7 +112,7 @@ async def setup(self, config: PendleMarketConfig) -> None:
119112
"""Setup the analysis tool with model and prompt"""
120113

121114
# Initialize database
122-
self._init_database(config.database)
115+
self._init_database(config.db_url)
123116

124117
# Initialize the model
125118
model_config = config.model if config.model else self.core_model

openagent/tools/pendle/voter_apy_analysis.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def __init__(self, core_model=None):
4747
self.core_model = core_model
4848
self.tool_model = None
4949
self.tool_prompt = None
50-
db_path = 'sqlite:///' + os.path.join(os.getcwd(), "storage", f"{self.name}.db")
51-
self.engine = create_engine("sqlite", db_path)
50+
db_url = 'sqlite:///' + os.path.join(os.getcwd(), "storage", f"{self.name}.db")
51+
self.engine = create_engine(db_url)
5252
Base.metadata.create_all(self.engine)
5353
session = sessionmaker(bind=self.engine)
5454
self.session = session()

openagent/tools/twitter/feed.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def __init__(self):
8080
self.retry_delay = 1
8181

8282
# Initialize database
83-
db_path = 'sqlite:///'+os.path.join(os.getcwd(), "storage", f"{self.name}.db")
84-
self.engine = create_engine('sqlite',db_path)
83+
db_url = 'sqlite:///' + os.path.join(os.getcwd(), "storage", f"{self.name}.db")
84+
self.engine = create_engine(db_url)
8585
Base.metadata.create_all(self.engine)
8686
Session = sessionmaker(bind=self.engine)
8787
self.session = Session()
@@ -99,7 +99,7 @@ async def setup(self, config: TwitterFeedConfig) -> None:
9999
self.config = config
100100

101101
async def _fetch_single_handle(
102-
self, client: httpx.AsyncClient, handle: str
102+
self, client: httpx.AsyncClient, handle: str
103103
) -> List[dict]:
104104
"""Fetch tweets for a single handle with retry logic"""
105105
params = {"limit": self.config.limit if self.config else 50}

test/tools/test_pendle_market.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
22

3+
from sqlalchemy.testing.config import db_url
4+
35
from openagent.agent.config import ModelConfig
46
from openagent.tools.pendle.market_analysis import PendleMarketTool, PendleMarketConfig, DatabaseConfig
57

@@ -12,10 +14,7 @@ async def test_pendle_market():
1214
name="gpt-4",
1315
temperature=0.7
1416
),
15-
database=DatabaseConfig(
16-
type="sqlite",
17-
url="sqlite:///storage/test_pendle_market.db"
18-
)
17+
db_url='sqlite:///storage/test_pendle_market.db'
1918
)
2019

2120
# Initialize the tool
@@ -36,10 +35,7 @@ async def test_pendle_market_postgres():
3635
name="gpt-4",
3736
temperature=0.7
3837
),
39-
database=DatabaseConfig(
40-
type="postgres",
41-
url="postgresql://postgres:password@localhost:5434/pendle_market_test"
42-
)
38+
db_url='postgresql://postgres:password@localhost:5434/pendle_market_test'
4339
)
4440

4541
# Initialize the tool
@@ -54,5 +50,5 @@ async def test_pendle_market_postgres():
5450

5551
if __name__ == "__main__":
5652
# Run both tests
57-
asyncio.run(test_pendle_market())
58-
# asyncio.run(test_pendle_market_postgres())
53+
# asyncio.run(test_pendle_market())
54+
asyncio.run(test_pendle_market_postgres())

0 commit comments

Comments
 (0)