Skip to content

Commit

Permalink
refactor: refactored prompts and get_codeblock
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Nov 5, 2023
1 parent 32df015 commit d0a2245
Show file tree
Hide file tree
Showing 10 changed files with 238 additions and 142 deletions.
27 changes: 12 additions & 15 deletions gptme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
to do so, it needs to be able to store and query past conversations in a database.
"""
# The above may be used as a prompt for the agent.

import atexit
import errno
import importlib.metadata
Expand All @@ -44,7 +43,7 @@
from .llm import init_llm, reply
from .logmanager import LogManager, _conversations
from .message import Message
from .prompts import initial_prompt_single_message
from .prompts import get_prompt
from .tabcomplete import register_tabcomplete
from .tools import execute_msg, init_tools
from .util import epoch_to_age, generate_unique_name
Expand Down Expand Up @@ -166,29 +165,27 @@ def main(
if no_confirm:
logger.warning("Skipping all confirmation prompts.")

if prompt_system in ["full", "short"]:
promptmsgs = [initial_prompt_single_message(short=prompt_system == "short")]
else:
promptmsgs = [Message("system", prompt_system)]

# we need to run this before checking stdin, since the interactive doesn't work with the switch back to interactive mode
logfile = get_logfile(
name, interactive=(not prompts and interactive) and sys.stdin.isatty()
)
print(f"Using logdir {logfile.parent}")
log = LogManager.load(logfile, initial_msgs=promptmsgs, show_hidden=show_hidden)
# get initial system prompt
prompt_msgs = [get_prompt(prompt_system)]

# if stdin is not a tty, we're getting piped input
# if stdin is not a tty, we're getting piped input, which we should include in the prompt
if not sys.stdin.isatty():
# fetch prompt from stdin
prompt_stdin = _read_stdin()
if prompt_stdin:
promptmsgs += [Message("system", prompt_stdin)]
prompt_msgs += [Message("system", f"```stdin\n{prompt_stdin}\n```")]

# Attempt to switch to interactive mode
sys.stdin.close()
sys.stdin = open("/dev/tty")

# we need to run this before checking stdin, since the interactive doesn't work with the switch back to interactive mode
logfile = get_logfile(
name, interactive=(not prompts and interactive) and sys.stdin.isatty()
)
print(f"Using logdir {logfile.parent}")
log = LogManager.load(logfile, initial_msgs=prompt_msgs, show_hidden=show_hidden)

# print log
log.print()
print("--- ^^^ past messages ^^^ ---")
Expand Down
14 changes: 5 additions & 9 deletions gptme/logmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .constants import CMDFIX, LOGSDIR
from .message import Message, print_msg
from .prompts import initial_prompt
from .prompts import get_prompt
from .tools.reduce import limit_log, reduce_log
from .util import len_tokens

Expand Down Expand Up @@ -111,7 +111,7 @@ def prepare_messages(self) -> list[Message]:
def load(
cls,
logfile: PathLike,
initial_msgs: list[Message] = list(initial_prompt()),
initial_msgs: list[Message] = [get_prompt()],
**kwargs,
) -> "LogManager":
"""Loads a conversation log."""
Expand Down Expand Up @@ -146,13 +146,9 @@ def get_last_code_block(
msgs = msgs[-history:]

for msg in msgs[::-1]:
# check if message contains a code block
backtick_count = msg.content.count("```")
if backtick_count >= 2:
if content:
return msg.content.split("```")[-2].split("\n", 1)[-1]
else:
return msg.content
codeblocks = msg.get_codeblocks(content=content)
if codeblocks:
return codeblocks[-1]
return None

def rename(self, name: str, keep_date=False) -> None:
Expand Down
29 changes: 29 additions & 0 deletions gptme/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,35 @@ def __repr__(self):
content = textwrap.shorten(self.content, 20, placeholder="...")
return f"<Message role={self.role} content={content}>"

def get_codeblocks(self, content=False) -> list[str]:
"""
Get all codeblocks.
If `content` set, return the content of the code block, else return the whole message.
"""
codeblocks = []
content_str = self.content
# prepend newline to make sure we get the first codeblock
if not content_str.startswith("\n"):
content_str = "\n" + content_str

# check if message contains a code block
backtick_count = content_str.count("\n```")
if backtick_count < 2:
return []
for i in range(1, backtick_count, 2):
codeblock_str = content_str.split("\n```")[i]
# get codeblock language or filename from first line
lang_or_fn = codeblock_str.split("\n")[0]
codeblock_str = "\n".join(codeblock_str.split("\n")[1:])

if content:
codeblocks.append(codeblock_str)
else:
full_codeblock = f"```{lang_or_fn}\n{codeblock_str}\n```"
codeblocks.append(full_codeblock)

return codeblocks


def format_msgs(
msgs: list[Message],
Expand Down
Loading

0 comments on commit d0a2245

Please sign in to comment.