Skip to content

Commit 549219e

Browse files
committed
code cleanup
1 parent 7ca0870 commit 549219e

File tree

1 file changed

+60
-73
lines changed

1 file changed

+60
-73
lines changed

app/app.py

Lines changed: 60 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,17 @@
3333

3434
def create_inference_client(
3535
model: Optional[str] = None, base_url: Optional[str] = None
36-
) -> InferenceClient:
36+
) -> InferenceClient | dict:
3737
"""Create an InferenceClient instance with the given model or environment settings.
3838
This function will run the model locally if ZERO_GPU is set to True.
3939
This function will run the model locally if ZERO_GPU is set to True.
4040
4141
Args:
4242
model: Optional model identifier to use. If not provided, will use environment settings.
43+
base_url: Optional base URL for the inference API.
4344
4445
Returns:
45-
InferenceClient: Configured client instance
46+
Either an InferenceClient instance or a dictionary with pipeline and tokenizer
4647
"""
4748
if ZERO_GPU:
4849
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
@@ -67,11 +68,17 @@ def create_inference_client(
6768
CLIENT = create_inference_client()
6869

6970

70-
def load_languages() -> dict[str, str]:
71-
"""Load languages from JSON file or persistent storage"""
72-
# First check if we have persistent storage available
73-
persistent_path = Path("/data/languages.json")
74-
local_path = Path(__file__).parent / "languages.json"
71+
def get_persistent_storage_path(filename: str) -> tuple[Path, bool]:
72+
"""Check if persistent storage is available and return the appropriate path.
73+
74+
Args:
75+
filename: The name of the file to check/create
76+
77+
Returns:
78+
A tuple containing (file_path, is_persistent)
79+
"""
80+
persistent_path = Path("/data") / filename
81+
local_path = Path(__file__).parent / filename
7582

7683
# Check if persistent storage is available and writable
7784
use_persistent = False
@@ -86,35 +93,44 @@ def load_languages() -> dict[str, str]:
8693
print("Persistent storage exists but is not writable, falling back to local storage")
8794
use_persistent = False
8895

89-
# Use persistent storage if available and writable, otherwise fall back to local file
90-
if use_persistent and persistent_path.exists():
91-
languages_path = persistent_path
92-
else:
96+
return (persistent_path if use_persistent else local_path, use_persistent)
97+
98+
99+
def load_languages() -> dict[str, str]:
100+
"""Load languages from JSON file or persistent storage"""
101+
languages_path, use_persistent = get_persistent_storage_path("languages.json")
102+
local_path = Path(__file__).parent / "languages.json"
103+
104+
# If persistent storage is available but file doesn't exist yet,
105+
# copy the local file to persistent storage
106+
if use_persistent and not languages_path.exists():
107+
try:
108+
if local_path.exists():
109+
import shutil
110+
# Copy the file to persistent storage
111+
shutil.copy(local_path, languages_path)
112+
print(f"Copied languages to persistent storage at {languages_path}")
113+
else:
114+
# Create an empty languages file in persistent storage
115+
with open(languages_path, "w", encoding="utf-8") as f:
116+
json.dump({"English": "You are a helpful assistant."}, f, ensure_ascii=False, indent=2)
117+
print(f"Created new languages file in persistent storage at {languages_path}")
118+
except Exception as e:
119+
print(f"Error setting up persistent storage: {e}")
120+
languages_path = local_path # Fall back to local path if any error occurs
121+
122+
# If the file doesn't exist at the chosen path but exists at the local path, use local
123+
if not languages_path.exists() and local_path.exists():
93124
languages_path = local_path
94-
95-
# If persistent storage is available and writable but file doesn't exist yet,
96-
# copy the local file to persistent storage
97-
if use_persistent:
98-
try:
99-
# Ensure local file exists first
100-
if local_path.exists():
101-
import shutil
102-
# Copy the file to persistent storage
103-
shutil.copy(local_path, persistent_path)
104-
languages_path = persistent_path
105-
print(f"Copied languages to persistent storage at {persistent_path}")
106-
else:
107-
# Create an empty languages file in persistent storage
108-
with open(persistent_path, "w", encoding="utf-8") as f:
109-
json.dump({"English": "You are a helpful assistant."}, f, ensure_ascii=False, indent=2)
110-
languages_path = persistent_path
111-
print(f"Created new languages file in persistent storage at {persistent_path}")
112-
except Exception as e:
113-
print(f"Error setting up persistent storage: {e}")
114-
languages_path = local_path # Fall back to local path if any error occurs
115125

116-
with open(languages_path, "r", encoding="utf-8") as f:
117-
return json.load(f)
126+
# If the file exists, load it
127+
if languages_path.exists():
128+
with open(languages_path, "r", encoding="utf-8") as f:
129+
return json.load(f)
130+
else:
131+
# Return a default if no file exists
132+
default_languages = {"English": "You are a helpful assistant."}
133+
return default_languages
118134

119135

120136
# Initial load
@@ -257,6 +273,7 @@ def add_fake_like_data(
257273

258274
@spaces.GPU
259275
def call_pipeline(messages: list, language: str):
276+
"""Call the appropriate model pipeline based on configuration"""
260277
if ZERO_GPU:
261278
# Format the messages using the tokenizer's chat template
262279
tokenizer = CLIENT["tokenizer"]
@@ -274,28 +291,27 @@ def call_pipeline(messages: list, language: str):
274291
)
275292

276293
# Extract the generated content
277-
content = response[0]["generated_text"]
278-
return content
294+
return response[0]["generated_text"]
279295
else:
280296
response = CLIENT(
281297
messages,
282298
clean_up_tokenization_spaces=False,
283299
max_length=2000,
284300
)
285-
content = response[0]["generated_text"][-1]["content"]
286-
return content
301+
return response[0]["generated_text"][-1]["content"]
287302

288303

289304
def respond(
290305
history: list,
291306
language: str,
292307
temperature: Optional[float] = None,
293308
seed: Optional[int] = None,
294-
) -> list: # -> list:
309+
) -> list:
295310
"""Respond to the user message with a system message
296311
297312
Return the history with the new message"""
298313
messages = format_history_as_messages(history)
314+
299315
if ZERO_GPU:
300316
content = call_pipeline(messages, language)
301317
else:
@@ -307,17 +323,7 @@ def respond(
307323
temperature=temperature,
308324
)
309325
content = response.choices[0].message.content
310-
if ZERO_GPU:
311-
content = call_pipeline(messages, language)
312-
else:
313-
response = CLIENT.chat.completions.create(
314-
messages=messages,
315-
max_tokens=2000,
316-
stream=False,
317-
seed=seed,
318-
temperature=temperature,
319-
)
320-
content = response.choices[0].message.content
326+
321327
message = gr.ChatMessage(role="assistant", content=content)
322328
history.append(message)
323329
return history
@@ -510,26 +516,10 @@ def save_new_language(lang_name, system_prompt):
510516
"""Save the new language and system prompt to persistent storage if available, otherwise to local file."""
511517
global LANGUAGES # Access the global variable
512518

513-
# First determine where to save the file
514-
persistent_path = Path("/data/languages.json")
519+
# Get the appropriate path
520+
languages_path, use_persistent = get_persistent_storage_path("languages.json")
515521
local_path = Path(__file__).parent / "languages.json"
516522

517-
# Check if persistent storage is available and writable
518-
use_persistent = False
519-
if Path("/data").exists() and Path("/data").is_dir():
520-
try:
521-
# Test if we can write to the directory
522-
test_file = Path("/data/write_test.tmp")
523-
test_file.touch()
524-
test_file.unlink() # Remove the test file
525-
use_persistent = True
526-
except (PermissionError, OSError):
527-
print("Persistent storage exists but is not writable, falling back to local storage")
528-
use_persistent = False
529-
530-
# Use persistent storage if available and writable, otherwise fall back to local file
531-
languages_path = persistent_path if use_persistent else local_path
532-
533523
# Load existing languages
534524
if languages_path.exists():
535525
with open(languages_path, "r", encoding="utf-8") as f:
@@ -545,7 +535,7 @@ def save_new_language(lang_name, system_prompt):
545535
json.dump(data, f, ensure_ascii=False, indent=2)
546536

547537
# If we're using persistent storage, also update the local file as backup
548-
if use_persistent and local_path != persistent_path:
538+
if use_persistent and local_path != languages_path:
549539
try:
550540
with open(local_path, "w", encoding="utf-8") as f:
551541
json.dump(data, f, ensure_ascii=False, indent=2)
@@ -555,11 +545,8 @@ def save_new_language(lang_name, system_prompt):
555545
# Update the global LANGUAGES variable with the new data
556546
LANGUAGES.update({lang_name: system_prompt})
557547

558-
# Update the dropdown choices
559-
new_choices = list(LANGUAGES.keys())
560-
561548
# Return a message that will trigger a JavaScript refresh
562-
return gr.Group(visible=False), gr.HTML("<script>window.location.reload();</script>"), gr.Dropdown(choices=new_choices)
549+
return gr.Group(visible=False), gr.HTML("<script>window.location.reload();</script>"), gr.Dropdown(choices=list(LANGUAGES.keys()))
563550

564551

565552
css = """

0 commit comments

Comments
 (0)