Skip to content

Commit 598f62a

Browse files
committed
resolve some flak8 complaints about new code
Signed-off-by: Jared Van Bortel <[email protected]>
1 parent b3cc860 commit 598f62a

File tree

2 files changed

+32
-24
lines changed

2 files changed

+32
-24
lines changed

gpt4all-bindings/python/gpt4all/_pyllmodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import threading
1010
from enum import Enum
1111
from queue import Queue
12-
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Iterator, Literal, NoReturn, TypeVar, overload
12+
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, Literal, NoReturn, TypeVar, overload
1313

1414
if sys.version_info >= (3, 9):
1515
import importlib.resources as importlib_resources

gpt4all-bindings/python/gpt4all/gpt4all.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737

3838
ConfigType: TypeAlias = "dict[str, Any]"
3939

40-
# Environment setup adapted from HF transformers
4140
@_operator_call
4241
def _jinja_env() -> ImmutableSandboxedEnvironment:
42+
# Environment setup adapted from HF transformers
4343
def raise_exception(message: str) -> NoReturn:
4444
raise jinja2.exceptions.TemplateError(message)
4545

@@ -56,15 +56,17 @@ def strftime_now(fmt: str) -> str:
5656
return env
5757

5858

59-
class MessageType(TypedDict):
59+
class Message(TypedDict):
60+
"""A message in a chat with a GPT4All model."""
61+
6062
role: str
6163
content: str
6264

6365

64-
class ChatSession(NamedTuple):
66+
class _ChatSession(NamedTuple):
6567
template: jinja2.Template
6668
template_source: str
67-
history: list[MessageType]
69+
history: list[Message]
6870

6971

7072
class Embed4All:
@@ -195,7 +197,8 @@ class GPT4All:
195197
"""
196198

197199
RE_LEGACY_SYSPROMPT = re.compile(
198-
r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>",
200+
r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|"
201+
r"<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>",
199202
re.MULTILINE,
200203
)
201204

@@ -244,7 +247,7 @@ def __init__(
244247
"""
245248

246249
self.model_type = model_type
247-
self._chat_session: ChatSession | None = None
250+
self._chat_session: _ChatSession | None = None
248251

249252
device_init = None
250253
if sys.platform == "darwin":
@@ -303,11 +306,12 @@ def device(self) -> str | None:
303306
return self.model.device
304307

305308
@property
306-
def current_chat_session(self) -> list[MessageType] | None:
309+
def current_chat_session(self) -> list[Message] | None:
310+
"""The message history of the current chat session."""
307311
return None if self._chat_session is None else self._chat_session.history
308312

309313
@current_chat_session.setter
310-
def current_chat_session(self, history: list[MessageType]) -> None:
314+
def current_chat_session(self, history: list[Message]) -> None:
311315
if self._chat_session is None:
312316
raise ValueError("current_chat_session may only be set when there is an active chat session")
313317
self._chat_session.history[:] = history
@@ -585,13 +589,13 @@ def _callback_wrapper(token_id: int, response: str) -> bool:
585589
last_msg_rendered = prompt
586590
if self._chat_session is not None:
587591
session = self._chat_session
588-
def render(messages: list[MessageType]) -> str:
592+
def render(messages: list[Message]) -> str:
589593
return session.template.render(
590594
messages=messages,
591595
add_generation_prompt=True,
592596
**self.model.special_tokens_map,
593597
)
594-
session.history.append(MessageType(role="user", content=prompt))
598+
session.history.append(Message(role="user", content=prompt))
595599
prompt = render(session.history)
596600
if len(session.history) > 1:
597601
last_msg_rendered = render(session.history[-1:])
@@ -606,20 +610,14 @@ def render(messages: list[MessageType]) -> str:
606610
def stream() -> Iterator[str]:
607611
yield from self.model.prompt_model_streaming(prompt, _callback_wrapper, **generate_kwargs)
608612
if self._chat_session is not None:
609-
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
613+
self._chat_session.history.append(Message(role="assistant", content=full_response))
610614
return stream()
611615

612616
self.model.prompt_model(prompt, _callback_wrapper, **generate_kwargs)
613617
if self._chat_session is not None:
614-
self._chat_session.history.append(MessageType(role="assistant", content=full_response))
618+
self._chat_session.history.append(Message(role="assistant", content=full_response))
615619
return full_response
616620

617-
@classmethod
618-
def is_legacy_chat_template(cls, tmpl: str) -> bool:
619-
"""A fairly reliable heuristic for detecting templates that don't look like Jinja templates."""
620-
return bool(re.search(r"%[12]\b", tmpl) or not cls.RE_JINJA_LIKE.search(tmpl)
621-
or not re.search(r"\bcontent\b", tmpl))
622-
623621
@contextmanager
624622
def chat_session(
625623
self,
@@ -632,10 +630,14 @@ def chat_session(
632630
Context manager to hold an inference optimized chat session with a GPT4All model.
633631
634632
Args:
635-
system_message: An initial instruction for the model, None to use the model default, or False to disable. Defaults to None.
633+
system_message: An initial instruction for the model, None to use the model default, or False to disable.
634+
Defaults to None.
636635
chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None.
637-
"""
636+
warn_legacy: Whether to warn about legacy system prompts or prompt templates. Defaults to True.
638637
638+
Raises:
639+
ValueError: If no valid chat template was found.
640+
"""
639641
if system_message is None:
640642
system_message = self.config.get("systemMessage", False)
641643
elif system_message is not False and warn_legacy and (m := self.RE_LEGACY_SYSPROMPT.search(system_message)):
@@ -662,7 +664,7 @@ def chat_session(
662664
msg += " If this is a built-in model, consider setting allow_download to True."
663665
raise ValueError(msg) from None
664666
raise
665-
elif warn_legacy and self.is_legacy_chat_template(chat_template):
667+
elif warn_legacy and self._is_legacy_chat_template(chat_template):
666668
print(
667669
"Warning: chat_session() was passed a chat template that is not in Jinja format. Old-style prompt "
668670
"templates are no longer supported.\nTo disable this warning, pass warn_legacy=False.",
@@ -671,8 +673,8 @@ def chat_session(
671673

672674
history = []
673675
if system_message is not False:
674-
history.append(MessageType(role="system", content=system_message))
675-
self._chat_session = ChatSession(
676+
history.append(Message(role="system", content=system_message))
677+
self._chat_session = _ChatSession(
676678
template=_jinja_env.from_string(chat_template),
677679
template_source=chat_template,
678680
history=history,
@@ -692,6 +694,12 @@ def list_gpus() -> list[str]:
692694
"""
693695
return LLModel.list_gpus()
694696

697+
@classmethod
698+
def _is_legacy_chat_template(cls, tmpl: str) -> bool:
699+
# check if tmpl does not look like a Jinja template
700+
return bool(re.search(r"%[12]\b", tmpl) or not cls.RE_JINJA_LIKE.search(tmpl)
701+
or not re.search(r"\bcontent\b", tmpl))
702+
695703

696704
def append_extension_if_missing(model_name):
697705
if not model_name.endswith((".bin", ".gguf")):

0 commit comments

Comments
 (0)