37
37
38
38
ConfigType : TypeAlias = "dict[str, Any]"
39
39
40
- # Environment setup adapted from HF transformers
41
40
@_operator_call
42
41
def _jinja_env () -> ImmutableSandboxedEnvironment :
42
+ # Environment setup adapted from HF transformers
43
43
def raise_exception (message : str ) -> NoReturn :
44
44
raise jinja2 .exceptions .TemplateError (message )
45
45
@@ -56,15 +56,17 @@ def strftime_now(fmt: str) -> str:
56
56
return env
57
57
58
58
59
- class MessageType (TypedDict ):
59
+ class Message (TypedDict ):
60
+ """A message in a chat with a GPT4All model."""
61
+
60
62
role : str
61
63
content : str
62
64
63
65
64
- class ChatSession (NamedTuple ):
66
+ class _ChatSession (NamedTuple ):
65
67
template : jinja2 .Template
66
68
template_source : str
67
- history : list [MessageType ]
69
+ history : list [Message ]
68
70
69
71
70
72
class Embed4All :
@@ -195,7 +197,8 @@ class GPT4All:
195
197
"""
196
198
197
199
RE_LEGACY_SYSPROMPT = re .compile (
198
- r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>" ,
200
+ r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|"
201
+ r"<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>" ,
199
202
re .MULTILINE ,
200
203
)
201
204
@@ -244,7 +247,7 @@ def __init__(
244
247
"""
245
248
246
249
self .model_type = model_type
247
- self ._chat_session : ChatSession | None = None
250
+ self ._chat_session : _ChatSession | None = None
248
251
249
252
device_init = None
250
253
if sys .platform == "darwin" :
@@ -303,11 +306,12 @@ def device(self) -> str | None:
303
306
return self .model .device
304
307
305
308
@property
306
- def current_chat_session (self ) -> list [MessageType ] | None :
309
+ def current_chat_session (self ) -> list [Message ] | None :
310
+ """The message history of the current chat session."""
307
311
return None if self ._chat_session is None else self ._chat_session .history
308
312
309
313
@current_chat_session .setter
310
- def current_chat_session (self , history : list [MessageType ]) -> None :
314
+ def current_chat_session (self , history : list [Message ]) -> None :
311
315
if self ._chat_session is None :
312
316
raise ValueError ("current_chat_session may only be set when there is an active chat session" )
313
317
self ._chat_session .history [:] = history
@@ -585,13 +589,13 @@ def _callback_wrapper(token_id: int, response: str) -> bool:
585
589
last_msg_rendered = prompt
586
590
if self ._chat_session is not None :
587
591
session = self ._chat_session
588
- def render (messages : list [MessageType ]) -> str :
592
+ def render (messages : list [Message ]) -> str :
589
593
return session .template .render (
590
594
messages = messages ,
591
595
add_generation_prompt = True ,
592
596
** self .model .special_tokens_map ,
593
597
)
594
- session .history .append (MessageType (role = "user" , content = prompt ))
598
+ session .history .append (Message (role = "user" , content = prompt ))
595
599
prompt = render (session .history )
596
600
if len (session .history ) > 1 :
597
601
last_msg_rendered = render (session .history [- 1 :])
@@ -606,20 +610,14 @@ def render(messages: list[MessageType]) -> str:
606
610
def stream () -> Iterator [str ]:
607
611
yield from self .model .prompt_model_streaming (prompt , _callback_wrapper , ** generate_kwargs )
608
612
if self ._chat_session is not None :
609
- self ._chat_session .history .append (MessageType (role = "assistant" , content = full_response ))
613
+ self ._chat_session .history .append (Message (role = "assistant" , content = full_response ))
610
614
return stream ()
611
615
612
616
self .model .prompt_model (prompt , _callback_wrapper , ** generate_kwargs )
613
617
if self ._chat_session is not None :
614
- self ._chat_session .history .append (MessageType (role = "assistant" , content = full_response ))
618
+ self ._chat_session .history .append (Message (role = "assistant" , content = full_response ))
615
619
return full_response
616
620
617
- @classmethod
618
- def is_legacy_chat_template (cls , tmpl : str ) -> bool :
619
- """A fairly reliable heuristic for detecting templates that don't look like Jinja templates."""
620
- return bool (re .search (r"%[12]\b" , tmpl ) or not cls .RE_JINJA_LIKE .search (tmpl )
621
- or not re .search (r"\bcontent\b" , tmpl ))
622
-
623
621
@contextmanager
624
622
def chat_session (
625
623
self ,
@@ -632,10 +630,14 @@ def chat_session(
632
630
Context manager to hold an inference optimized chat session with a GPT4All model.
633
631
634
632
Args:
635
- system_message: An initial instruction for the model, None to use the model default, or False to disable. Defaults to None.
633
+ system_message: An initial instruction for the model, None to use the model default, or False to disable.
634
+ Defaults to None.
636
635
chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None.
637
- """
636
+ warn_legacy: Whether to warn about legacy system prompts or prompt templates. Defaults to True.
638
637
638
+ Raises:
639
+ ValueError: If no valid chat template was found.
640
+ """
639
641
if system_message is None :
640
642
system_message = self .config .get ("systemMessage" , False )
641
643
elif system_message is not False and warn_legacy and (m := self .RE_LEGACY_SYSPROMPT .search (system_message )):
@@ -662,7 +664,7 @@ def chat_session(
662
664
msg += " If this is a built-in model, consider setting allow_download to True."
663
665
raise ValueError (msg ) from None
664
666
raise
665
- elif warn_legacy and self .is_legacy_chat_template (chat_template ):
667
+ elif warn_legacy and self ._is_legacy_chat_template (chat_template ):
666
668
print (
667
669
"Warning: chat_session() was passed a chat template that is not in Jinja format. Old-style prompt "
668
670
"templates are no longer supported.\n To disable this warning, pass warn_legacy=False." ,
@@ -671,8 +673,8 @@ def chat_session(
671
673
672
674
history = []
673
675
if system_message is not False :
674
- history .append (MessageType (role = "system" , content = system_message ))
675
- self ._chat_session = ChatSession (
676
+ history .append (Message (role = "system" , content = system_message ))
677
+ self ._chat_session = _ChatSession (
676
678
template = _jinja_env .from_string (chat_template ),
677
679
template_source = chat_template ,
678
680
history = history ,
@@ -692,6 +694,12 @@ def list_gpus() -> list[str]:
692
694
"""
693
695
return LLModel .list_gpus ()
694
696
697
+ @classmethod
698
+ def _is_legacy_chat_template (cls , tmpl : str ) -> bool :
699
+ # check if tmpl does not look like a Jinja template
700
+ return bool (re .search (r"%[12]\b" , tmpl ) or not cls .RE_JINJA_LIKE .search (tmpl )
701
+ or not re .search (r"\bcontent\b" , tmpl ))
702
+
695
703
696
704
def append_extension_if_missing (model_name ):
697
705
if not model_name .endswith ((".bin" , ".gguf" )):
0 commit comments