3333
3434def create_inference_client (
3535 model : Optional [str ] = None , base_url : Optional [str ] = None
36- ) -> InferenceClient :
36+ ) -> InferenceClient | dict :
3737 """Create an InferenceClient instance with the given model or environment settings.
3838 This function will run the model locally if ZERO_GPU is set to True.
3939 This function will run the model locally if ZERO_GPU is set to True.
4040
4141 Args:
4242 model: Optional model identifier to use. If not provided, will use environment settings.
43+ base_url: Optional base URL for the inference API.
4344
4445 Returns:
45- InferenceClient: Configured client instance
46+ Either an InferenceClient instance or a dictionary with pipeline and tokenizer
4647 """
4748 if ZERO_GPU :
4849 tokenizer = AutoTokenizer .from_pretrained (BASE_MODEL )
@@ -67,11 +68,17 @@ def create_inference_client(
6768CLIENT = create_inference_client ()
6869
6970
70- def load_languages () -> dict [str , str ]:
71- """Load languages from JSON file or persistent storage"""
72- # First check if we have persistent storage available
73- persistent_path = Path ("/data/languages.json" )
74- local_path = Path (__file__ ).parent / "languages.json"
71+ def get_persistent_storage_path (filename : str ) -> tuple [Path , bool ]:
72+ """Check if persistent storage is available and return the appropriate path.
73+
74+ Args:
75+ filename: The name of the file to check/create
76+
77+ Returns:
78+ A tuple containing (file_path, is_persistent)
79+ """
80+ persistent_path = Path ("/data" ) / filename
81+ local_path = Path (__file__ ).parent / filename
7582
7683 # Check if persistent storage is available and writable
7784 use_persistent = False
@@ -86,35 +93,44 @@ def load_languages() -> dict[str, str]:
8693 print ("Persistent storage exists but is not writable, falling back to local storage" )
8794 use_persistent = False
8895
89- # Use persistent storage if available and writable, otherwise fall back to local file
90- if use_persistent and persistent_path .exists ():
91- languages_path = persistent_path
92- else :
96+ return (persistent_path if use_persistent else local_path , use_persistent )
97+
98+
99+ def load_languages () -> dict [str , str ]:
100+ """Load languages from JSON file or persistent storage"""
101+ languages_path , use_persistent = get_persistent_storage_path ("languages.json" )
102+ local_path = Path (__file__ ).parent / "languages.json"
103+
104+ # If persistent storage is available but file doesn't exist yet,
105+ # copy the local file to persistent storage
106+ if use_persistent and not languages_path .exists ():
107+ try :
108+ if local_path .exists ():
109+ import shutil
110+ # Copy the file to persistent storage
111+ shutil .copy (local_path , languages_path )
112+ print (f"Copied languages to persistent storage at { languages_path } " )
113+ else :
114+ # Create an empty languages file in persistent storage
115+ with open (languages_path , "w" , encoding = "utf-8" ) as f :
116+ json .dump ({"English" : "You are a helpful assistant." }, f , ensure_ascii = False , indent = 2 )
117+ print (f"Created new languages file in persistent storage at { languages_path } " )
118+ except Exception as e :
119+ print (f"Error setting up persistent storage: { e } " )
120+ languages_path = local_path # Fall back to local path if any error occurs
121+
122+ # If the file doesn't exist at the chosen path but exists at the local path, use local
123+ if not languages_path .exists () and local_path .exists ():
93124 languages_path = local_path
94-
95- # If persistent storage is available and writable but file doesn't exist yet,
96- # copy the local file to persistent storage
97- if use_persistent :
98- try :
99- # Ensure local file exists first
100- if local_path .exists ():
101- import shutil
102- # Copy the file to persistent storage
103- shutil .copy (local_path , persistent_path )
104- languages_path = persistent_path
105- print (f"Copied languages to persistent storage at { persistent_path } " )
106- else :
107- # Create an empty languages file in persistent storage
108- with open (persistent_path , "w" , encoding = "utf-8" ) as f :
109- json .dump ({"English" : "You are a helpful assistant." }, f , ensure_ascii = False , indent = 2 )
110- languages_path = persistent_path
111- print (f"Created new languages file in persistent storage at { persistent_path } " )
112- except Exception as e :
113- print (f"Error setting up persistent storage: { e } " )
114- languages_path = local_path # Fall back to local path if any error occurs
115125
116- with open (languages_path , "r" , encoding = "utf-8" ) as f :
117- return json .load (f )
126+ # If the file exists, load it
127+ if languages_path .exists ():
128+ with open (languages_path , "r" , encoding = "utf-8" ) as f :
129+ return json .load (f )
130+ else :
131+ # Return a default if no file exists
132+ default_languages = {"English" : "You are a helpful assistant." }
133+ return default_languages
118134
119135
120136# Initial load
@@ -257,6 +273,7 @@ def add_fake_like_data(
257273
258274@spaces .GPU
259275def call_pipeline (messages : list , language : str ):
276+ """Call the appropriate model pipeline based on configuration"""
260277 if ZERO_GPU :
261278 # Format the messages using the tokenizer's chat template
262279 tokenizer = CLIENT ["tokenizer" ]
@@ -274,28 +291,27 @@ def call_pipeline(messages: list, language: str):
274291 )
275292
276293 # Extract the generated content
277- content = response [0 ]["generated_text" ]
278- return content
294+ return response [0 ]["generated_text" ]
279295 else :
280296 response = CLIENT (
281297 messages ,
282298 clean_up_tokenization_spaces = False ,
283299 max_length = 2000 ,
284300 )
285- content = response [0 ]["generated_text" ][- 1 ]["content" ]
286- return content
301+ return response [0 ]["generated_text" ][- 1 ]["content" ]
287302
288303
289304def respond (
290305 history : list ,
291306 language : str ,
292307 temperature : Optional [float ] = None ,
293308 seed : Optional [int ] = None ,
294- ) -> list : # -> list:
309+ ) -> list :
295310 """Respond to the user message with a system message
296311
297312 Return the history with the new message"""
298313 messages = format_history_as_messages (history )
314+
299315 if ZERO_GPU :
300316 content = call_pipeline (messages , language )
301317 else :
@@ -307,17 +323,7 @@ def respond(
307323 temperature = temperature ,
308324 )
309325 content = response .choices [0 ].message .content
310- if ZERO_GPU :
311- content = call_pipeline (messages , language )
312- else :
313- response = CLIENT .chat .completions .create (
314- messages = messages ,
315- max_tokens = 2000 ,
316- stream = False ,
317- seed = seed ,
318- temperature = temperature ,
319- )
320- content = response .choices [0 ].message .content
326+
321327 message = gr .ChatMessage (role = "assistant" , content = content )
322328 history .append (message )
323329 return history
@@ -510,26 +516,10 @@ def save_new_language(lang_name, system_prompt):
510516 """Save the new language and system prompt to persistent storage if available, otherwise to local file."""
511517 global LANGUAGES # Access the global variable
512518
513- # First determine where to save the file
514- persistent_path = Path ( "/data/ languages.json" )
519+ # Get the appropriate path
520+ languages_path , use_persistent = get_persistent_storage_path ( " languages.json" )
515521 local_path = Path (__file__ ).parent / "languages.json"
516522
517- # Check if persistent storage is available and writable
518- use_persistent = False
519- if Path ("/data" ).exists () and Path ("/data" ).is_dir ():
520- try :
521- # Test if we can write to the directory
522- test_file = Path ("/data/write_test.tmp" )
523- test_file .touch ()
524- test_file .unlink () # Remove the test file
525- use_persistent = True
526- except (PermissionError , OSError ):
527- print ("Persistent storage exists but is not writable, falling back to local storage" )
528- use_persistent = False
529-
530- # Use persistent storage if available and writable, otherwise fall back to local file
531- languages_path = persistent_path if use_persistent else local_path
532-
533523 # Load existing languages
534524 if languages_path .exists ():
535525 with open (languages_path , "r" , encoding = "utf-8" ) as f :
@@ -545,7 +535,7 @@ def save_new_language(lang_name, system_prompt):
545535 json .dump (data , f , ensure_ascii = False , indent = 2 )
546536
547537 # If we're using persistent storage, also update the local file as backup
548- if use_persistent and local_path != persistent_path :
538+ if use_persistent and local_path != languages_path :
549539 try :
550540 with open (local_path , "w" , encoding = "utf-8" ) as f :
551541 json .dump (data , f , ensure_ascii = False , indent = 2 )
@@ -555,11 +545,8 @@ def save_new_language(lang_name, system_prompt):
555545 # Update the global LANGUAGES variable with the new data
556546 LANGUAGES .update ({lang_name : system_prompt })
557547
558- # Update the dropdown choices
559- new_choices = list (LANGUAGES .keys ())
560-
561548 # Return a message that will trigger a JavaScript refresh
562- return gr .Group (visible = False ), gr .HTML ("<script>window.location.reload();</script>" ), gr .Dropdown (choices = new_choices )
549+ return gr .Group (visible = False ), gr .HTML ("<script>window.location.reload();</script>" ), gr .Dropdown (choices = list ( LANGUAGES . keys ()) )
563550
564551
565552css = """
0 commit comments