-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdb_utils.py
159 lines (121 loc) · 4.39 KB
/
db_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""
Functions to work with the database
"""
import os
from datetime import datetime
from typing import Any, Dict, Optional, Union
from cryptography.fernet import Fernet
from discord import Interaction
from sqlmodel import JSON, Column, Field, Session, SQLModel, create_engine, select
SQLITE_FILE_NAME = "database.db"
SQLITE_URL = f"sqlite:///{SQLITE_FILE_NAME}"
engine = create_engine(SQLITE_URL)
class CommandContext(SQLModel, table=True):
"""
Table for storing information about the command interaction
"""
id: Optional[int] = Field(default=None, primary_key=True)
guild_id: int = Field(index=True)
user_id: int = Field(index=True)
user: str
command_name: str
params: Dict[str, Any] = Field(default_factory=dict, sa_column=Column(JSON))
timestamp: datetime = Field(default_factory=datetime.now)
async def save(self) -> bool:
with get_session() as session:
session.add(self)
session.commit()
return True
class Key(SQLModel, table=True):
"""
Table for storing OpenAI API keys.
"""
guild_id: int = Field(default=None, primary_key=True)
guild_name: str
api_key: str
class Chat(SQLModel, table=True):
"""
Table for storing OpenAI Response IDs
"""
response_id: str = Field(default=None, primary_key=True)
topic: str
guild_id: int
updated: datetime
async def create_command_context(interaction: Interaction, **kwargs) -> CommandContext:
"""
Helper function to create CommandContext entry.
"""
context = CommandContext(
guild_id=interaction.guild_id,
user_id=interaction.user.id,
user=interaction.user.name,
command_name=interaction.command.name,
**kwargs,
)
return context
def get_session() -> Session:
"""
Returns a database session for queries 'n' things.
"""
return Session(engine)
async def get_response_id(context: CommandContext) -> Union[str, None]:
"""
Looks for a previous reponse id if one exists for a given "command" in the Chat table
"""
with get_session() as session:
statement = (
select(Chat).where(Chat.guild_id == context.guild_id).where(Chat.topic == context.params.get("topic"))
)
results = session.exec(statement=statement)
response_record = results.one_or_none()
return response_record.response_id if response_record else None
async def update_chat(response_id: str, context: CommandContext) -> None:
"""
Update the command's record in the Chat table.
"""
with get_session() as session:
statement = (
select(Chat).where(Chat.guild_id == context.guild_id).where(Chat.topic == context.params.get("topic"))
)
results = session.exec(statement=statement)
response = results.one_or_none()
if response:
response.response_id = response_id
response.updated = datetime.now()
session.add(response)
session.commit()
else:
entry = Chat(
response_id=response_id,
topic=context.params.get("topic"),
guild_id=context.guild_id,
updated=datetime.now(),
)
session.add(entry)
session.commit()
return
async def get_api_key(guild_id: int) -> str:
"""
Retrieve the top-secret API key from the incredibly secure database.
"""
fernet_key = os.getenv("FERNET_KEY")
if not fernet_key:
raise ValueError("FERNET_KEY environment variable not set!")
cipher = Fernet(fernet_key.encode())
with get_session() as session:
statement = select(Key).where(Key.guild_id == guild_id)
results = session.exec(statement=statement)
key_record = results.first()
if not key_record:
raise ValueError(f"No API token found for guild_id: {guild_id}")
return cipher.decrypt(key_record.api_key.encode()).decode()
if __name__ == "__main__":
SQLModel.metadata.create_all(engine)
with get_session() as db_session:
with open("encrypted_api_keys.txt", mode="r", encoding="UTF-8") as f:
rows = f.readlines()
for row in rows:
data_list = row.split(",")
db_entry = Key(guild_id=int(data_list[0]), guild_name=data_list[1], api_key=data_list[2])
db_session.add(db_entry)
db_session.commit()