Skip to content

Commit

Permalink
resolve some flak8 complaints about new code
Browse files Browse the repository at this point in the history
Signed-off-by: Jared Van Bortel <[email protected]>
  • Loading branch information
cebtenzzre committed Dec 5, 2024
1 parent b3cc860 commit 598f62a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 24 deletions.
2 changes: 1 addition & 1 deletion gpt4all-bindings/python/gpt4all/_pyllmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import threading
from enum import Enum
from queue import Queue
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Iterator, Literal, NoReturn, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Literal, NoReturn, TypeVar, overload

if sys.version_info >= (3, 9):
import importlib.resources as importlib_resources
Expand Down
54 changes: 31 additions & 23 deletions gpt4all-bindings/python/gpt4all/gpt4all.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@

ConfigType: TypeAlias = "dict[str, Any]"

# Environment setup adapted from HF transformers
@_operator_call
def _jinja_env() -> ImmutableSandboxedEnvironment:
# Environment setup adapted from HF transformers
def raise_exception(message: str) -> NoReturn:
raise jinja2.exceptions.TemplateError(message)

Expand All @@ -56,15 +56,17 @@ def strftime_now(fmt: str) -> str:
return env


class MessageType(TypedDict):
class Message(TypedDict):
"""A message in a chat with a GPT4All model."""

role: str
content: str


class ChatSession(NamedTuple):
class _ChatSession(NamedTuple):
template: jinja2.Template
template_source: str
history: list[MessageType]
history: list[Message]


class Embed4All:
Expand Down Expand Up @@ -195,7 +197,8 @@ class GPT4All:
"""

RE_LEGACY_SYSPROMPT = re.compile(
r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>",
r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|"
r"<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>",
re.MULTILINE,
)

Expand Down Expand Up @@ -244,7 +247,7 @@ def __init__(
"""

self.model_type = model_type
self._chat_session: ChatSession | None = None
self._chat_session: _ChatSession | None = None

device_init = None
if sys.platform == "darwin":
Expand Down Expand Up @@ -303,11 +306,12 @@ def device(self) -> str | None:
return self.model.device

@property
def current_chat_session(self) -> list[MessageType] | None:
def current_chat_session(self) -> list[Message] | None:
"""The message history of the current chat session."""
return None if self._chat_session is None else self._chat_session.history

@current_chat_session.setter
def current_chat_session(self, history: list[MessageType]) -> None:
def current_chat_session(self, history: list[Message]) -> None:
if self._chat_session is None:
raise ValueError("current_chat_session may only be set when there is an active chat session")
self._chat_session.history[:] = history
Expand Down Expand Up @@ -585,13 +589,13 @@ def _callback_wrapper(token_id: int, response: str) -> bool:
last_msg_rendered = prompt
if self._chat_session is not None:
session = self._chat_session
def render(messages: list[MessageType]) -> str:
def render(messages: list[Message]) -> str:
return session.template.render(
messages=messages,
add_generation_prompt=True,
**self.model.special_tokens_map,
)
session.history.append(MessageType(role="user", content=prompt))
session.history.append(Message(role="user", content=prompt))
prompt = render(session.history)
if len(session.history) > 1:
last_msg_rendered = render(session.history[-1:])
Expand All @@ -606,20 +610,14 @@ def render(messages: list[MessageType]) -> str:
def stream() -> Iterator[str]:
yield from self.model.prompt_model_streaming(prompt, _callback_wrapper, **generate_kwargs)
if self._chat_session is not None:
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
self._chat_session.history.append(Message(role="assistant", content=full_response))
return stream()

self.model.prompt_model(prompt, _callback_wrapper, **generate_kwargs)
if self._chat_session is not None:
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
self._chat_session.history.append(Message(role="assistant", content=full_response))
return full_response

@classmethod
def is_legacy_chat_template(cls, tmpl: str) -> bool:
"""A fairly reliable heuristic for detecting templates that don't look like Jinja templates."""
return bool(re.search(r"%[12]\b", tmpl) or not cls.RE_JINJA_LIKE.search(tmpl)
or not re.search(r"\bcontent\b", tmpl))

@contextmanager
def chat_session(
self,
Expand All @@ -632,10 +630,14 @@ def chat_session(
Context manager to hold an inference optimized chat session with a GPT4All model.
Args:
system_message: An initial instruction for the model, None to use the model default, or False to disable. Defaults to None.
system_message: An initial instruction for the model, None to use the model default, or False to disable.
Defaults to None.
chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None.
"""
warn_legacy: Whether to warn about legacy system prompts or prompt templates. Defaults to True.
Raises:
ValueError: If no valid chat template was found.
"""
if system_message is None:
system_message = self.config.get("systemMessage", False)
elif system_message is not False and warn_legacy and (m := self.RE_LEGACY_SYSPROMPT.search(system_message)):
Expand All @@ -662,7 +664,7 @@ def chat_session(
msg += " If this is a built-in model, consider setting allow_download to True."
raise ValueError(msg) from None
raise
elif warn_legacy and self.is_legacy_chat_template(chat_template):
elif warn_legacy and self._is_legacy_chat_template(chat_template):
print(
"Warning: chat_session() was passed a chat template that is not in Jinja format. Old-style prompt "
"templates are no longer supported.\nTo disable this warning, pass warn_legacy=False.",
Expand All @@ -671,8 +673,8 @@ def chat_session(

history = []
if system_message is not False:
history.append(MessageType(role="system", content=system_message))
self._chat_session = ChatSession(
history.append(Message(role="system", content=system_message))
self._chat_session = _ChatSession(
template=_jinja_env.from_string(chat_template),
template_source=chat_template,
history=history,
Expand All @@ -692,6 +694,12 @@ def list_gpus() -> list[str]:
"""
return LLModel.list_gpus()

@classmethod
def _is_legacy_chat_template(cls, tmpl: str) -> bool:
# check if tmpl does not look like a Jinja template
return bool(re.search(r"%[12]\b", tmpl) or not cls.RE_JINJA_LIKE.search(tmpl)
or not re.search(r"\bcontent\b", tmpl))


def append_extension_if_missing(model_name):
if not model_name.endswith((".bin", ".gguf")):
Expand Down

0 comments on commit 598f62a

Please sign in to comment.