Skip to content

Commit ea6e77d

Browse files
authored
Make the code more like PEP8 for readability (oobabooga#862)
1 parent 848c4ed commit ea6e77d

28 files changed

+302
-165
lines changed

Diff for: api-example-stream.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def random_hash():
1717
letters = string.ascii_lowercase + string.digits
1818
return ''.join(random.choice(letters) for i in range(9))
1919

20+
2021
async def run(context):
2122
server = "127.0.0.1"
2223
params = {
@@ -41,7 +42,7 @@ async def run(context):
4142

4243
async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
4344
while content := json.loads(await websocket.recv()):
44-
#Python3.10 syntax, replace with if elif on older
45+
# Python3.10 syntax, replace with if elif on older
4546
match content["msg"]:
4647
case "send_hash":
4748
await websocket.send(json.dumps({
@@ -62,13 +63,14 @@ async def run(context):
6263
pass
6364
case "process_generating" | "process_completed":
6465
yield content["output"]["data"][0]
65-
# You can search for your desired end indicator and
66+
# You can search for your desired end indicator and
6667
# stop generation by closing the websocket here
6768
if (content["msg"] == "process_completed"):
6869
break
6970

7071
prompt = "What I would like to say is the following: "
7172

73+
7274
async def get_result():
7375
async for response in run(prompt):
7476
# Print intermediate steps

Diff for: convert-to-flexgen.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
from tqdm import tqdm
1414
from transformers import AutoModelForCausalLM, AutoTokenizer
1515

16-
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
16+
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
1717
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
1818
args = parser.parse_args()
1919

20+
2021
def disable_torch_init():
2122
"""
2223
Disable the redundant torch default initialization to accelerate model creation.
@@ -31,20 +32,22 @@ def disable_torch_init():
3132
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
3233
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
3334

35+
3436
def restore_torch_init():
3537
"""Rollback the change made by disable_torch_init."""
3638
import torch
3739
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
3840
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
3941

42+
4043
if __name__ == '__main__':
4144
path = Path(args.MODEL)
4245
model_name = path.name
4346

4447
print(f"Loading {model_name}...")
45-
#disable_torch_init()
48+
# disable_torch_init()
4649
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
47-
#restore_torch_init()
50+
# restore_torch_init()
4851

4952
tokenizer = AutoTokenizer.from_pretrained(path)
5053

Diff for: convert-to-safetensors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from transformers import AutoModelForCausalLM, AutoTokenizer
1919

20-
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
20+
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
2121
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
2222
parser.add_argument('--output', type=str, default=None, help='Path to the output folder (default: models/{model_name}_safetensors).')
2323
parser.add_argument("--max-shard-size", type=str, default="2GB", help="Maximum size of a shard in GB or MB (default: %(default)s).")

Diff for: download-model.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
3030
args = parser.parse_args()
3131

32+
3233
def get_file(url, output_folder):
3334
filename = Path(url.rsplit('/', 1)[1])
3435
output_path = output_folder / filename
@@ -54,13 +55,15 @@ def get_file(url, output_folder):
5455
t.update(len(data))
5556
f.write(data)
5657

58+
5759
def sanitize_branch_name(branch_name):
5860
pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
5961
if pattern.match(branch_name):
6062
return branch_name
6163
else:
6264
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
6365

66+
6467
def select_model_from_default_options():
6568
models = {
6669
"OPT 6.7B": ("facebook", "opt-6.7b", "main"),
@@ -78,11 +81,11 @@ def select_model_from_default_options():
7881
choices = {}
7982

8083
print("Select the model that you want to download:\n")
81-
for i,name in enumerate(models):
82-
char = chr(ord('A')+i)
84+
for i, name in enumerate(models):
85+
char = chr(ord('A') + i)
8386
choices[char] = name
8487
print(f"{char}) {name}")
85-
char = chr(ord('A')+len(models))
88+
char = chr(ord('A') + len(models))
8689
print(f"{char}) None of the above")
8790

8891
print()
@@ -106,6 +109,7 @@ def select_model_from_default_options():
106109

107110
return model, branch
108111

112+
109113
def get_download_links_from_huggingface(model, branch):
110114
base = "https://huggingface.co"
111115
page = f"/api/models/{model}/tree/{branch}?cursor="
@@ -166,15 +170,17 @@ def get_download_links_from_huggingface(model, branch):
166170

167171
# If both pytorch and safetensors are available, download safetensors only
168172
if (has_pytorch or has_pt) and has_safetensors:
169-
for i in range(len(classifications)-1, -1, -1):
173+
for i in range(len(classifications) - 1, -1, -1):
170174
if classifications[i] in ['pytorch', 'pt']:
171175
links.pop(i)
172176

173177
return links, sha256, is_lora
174178

179+
175180
def download_files(file_list, output_folder, num_threads=8):
176181
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
177182

183+
178184
if __name__ == '__main__':
179185
model = args.MODEL
180186
branch = args.branch
@@ -224,7 +230,7 @@ def download_files(file_list, output_folder, num_threads=8):
224230
validated = False
225231
else:
226232
print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}')
227-
233+
228234
if validated:
229235
print('[+] Validated checksums of all model files!')
230236
else:

Diff for: extensions/api/script.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
'port': 5000,
1010
}
1111

12+
1213
class Handler(BaseHTTPRequestHandler):
1314
def do_GET(self):
1415
if self.path == '/api/v1/model':
@@ -32,34 +33,34 @@ def do_POST(self):
3233
self.end_headers()
3334

3435
prompt = body['prompt']
35-
prompt_lines = [l.strip() for l in prompt.split('\n')]
36+
prompt_lines = [k.strip() for k in prompt.split('\n')]
3637

3738
max_context = body.get('max_context_length', 2048)
3839

3940
while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
4041
prompt_lines.pop(0)
4142

4243
prompt = '\n'.join(prompt_lines)
43-
generate_params = {
44-
'max_new_tokens': int(body.get('max_length', 200)),
44+
generate_params = {
45+
'max_new_tokens': int(body.get('max_length', 200)),
4546
'do_sample': bool(body.get('do_sample', True)),
46-
'temperature': float(body.get('temperature', 0.5)),
47-
'top_p': float(body.get('top_p', 1)),
48-
'typical_p': float(body.get('typical', 1)),
49-
'repetition_penalty': float(body.get('rep_pen', 1.1)),
47+
'temperature': float(body.get('temperature', 0.5)),
48+
'top_p': float(body.get('top_p', 1)),
49+
'typical_p': float(body.get('typical', 1)),
50+
'repetition_penalty': float(body.get('rep_pen', 1.1)),
5051
'encoder_repetition_penalty': 1,
51-
'top_k': int(body.get('top_k', 0)),
52+
'top_k': int(body.get('top_k', 0)),
5253
'min_length': int(body.get('min_length', 0)),
53-
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
54-
'num_beams': int(body.get('num_beams',1)),
54+
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
55+
'num_beams': int(body.get('num_beams', 1)),
5556
'penalty_alpha': float(body.get('penalty_alpha', 0)),
5657
'length_penalty': float(body.get('length_penalty', 1)),
5758
'early_stopping': bool(body.get('early_stopping', False)),
5859
'seed': int(body.get('seed', -1)),
5960
}
6061

6162
generator = generate_reply(
62-
prompt,
63+
prompt,
6364
generate_params,
6465
stopping_strings=body.get('stopping_strings', []),
6566
)
@@ -84,9 +85,9 @@ def do_POST(self):
8485
def run_server():
8586
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
8687
server = ThreadingHTTPServer(server_addr, Handler)
87-
if shared.args.share:
88+
if shared.args.share:
8889
try:
89-
from flask_cloudflared import _run_cloudflared
90+
from flask_cloudflared import _run_cloudflared
9091
public_url = _run_cloudflared(params['port'], params['port'] + 1)
9192
print(f'Starting KoboldAI compatible api at {public_url}/api')
9293
except ImportError:
@@ -95,5 +96,6 @@ def run_server():
9596
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
9697
server.serve_forever()
9798

99+
98100
def setup():
99101
Thread(target=run_server, daemon=True).start()

Diff for: extensions/character_bias/script.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,37 @@
55
"bias string": " *I am so happy*",
66
}
77

8+
89
def input_modifier(string):
910
"""
1011
This function is applied to your text inputs before
1112
they are fed into the model.
12-
"""
13+
"""
1314

1415
return string
1516

17+
1618
def output_modifier(string):
1719
"""
1820
This function is applied to the model outputs.
1921
"""
2022

2123
return string
2224

25+
2326
def bot_prefix_modifier(string):
2427
"""
2528
This function is only applied in chat mode. It modifies
2629
the prefix text for the Bot and can be used to bias its
2730
behavior.
2831
"""
2932

30-
if params['activate'] == True:
33+
if params['activate']:
3134
return f'{string} {params["bias string"].strip()} '
3235
else:
3336
return string
3437

38+
3539
def ui():
3640
# Gradio elements
3741
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')

Diff for: extensions/elevenlabs_tts/script.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -20,41 +20,47 @@
2020
if not shared.args.no_stream:
2121
print("Please add --no-stream. This extension is not meant to be used with streaming.")
2222
raise ValueError
23-
23+
2424
# Check if the API is valid and refresh the UI accordingly.
25+
26+
2527
def check_valid_api():
26-
28+
2729
global user, user_info, params
2830

2931
user = ElevenLabsUser(params['api_key'])
3032
user_info = user._get_subscription_data()
3133
print('checking api')
32-
if params['activate'] == False:
34+
if not params['activate']:
3335
return gr.update(value='Disconnected')
3436
elif user_info is None:
3537
print('Incorrect API Key')
3638
return gr.update(value='Disconnected')
3739
else:
3840
print('Got an API Key!')
3941
return gr.update(value='Connected')
40-
42+
4143
# Once the API is verified, get the available voices and update the dropdown list
44+
45+
4246
def refresh_voices():
43-
47+
4448
global user, user_info
45-
49+
4650
your_voices = [None]
4751
if user_info is not None:
4852
for voice in user.get_available_voices():
4953
your_voices.append(voice.initialName)
50-
return gr.Dropdown.update(choices=your_voices)
54+
return gr.Dropdown.update(choices=your_voices)
5155
else:
5256
return
5357

58+
5459
def remove_surrounded_chars(string):
5560
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
5661
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
57-
return re.sub('\*[^\*]*?(\*|$)','',string)
62+
return re.sub('\*[^\*]*?(\*|$)', '', string)
63+
5864

5965
def input_modifier(string):
6066
"""
@@ -64,16 +70,17 @@ def input_modifier(string):
6470

6571
return string
6672

73+
6774
def output_modifier(string):
6875
"""
6976
This function is applied to the model outputs.
7077
"""
7178

7279
global params, wav_idx, user, user_info
73-
74-
if params['activate'] == False:
80+
81+
if not params['activate']:
7582
return string
76-
elif user_info == None:
83+
elif user_info is None:
7784
return string
7885

7986
string = remove_surrounded_chars(string)
@@ -84,7 +91,7 @@ def output_modifier(string):
8491

8592
if string == '':
8693
string = 'empty reply, try regenerating'
87-
94+
8895
output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.wav'.format(wav_idx))
8996
voice = user.get_voices_by_name(params['selected_voice'])[0]
9097
audio_data = voice.generate_audio_bytes(string)
@@ -94,6 +101,7 @@ def output_modifier(string):
94101
wav_idx += 1
95102
return string
96103

104+
97105
def ui():
98106

99107
# Gradio elements
@@ -110,4 +118,4 @@ def ui():
110118
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
111119
api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
112120
connect.click(check_valid_api, [], connection_status)
113-
connect.click(refresh_voices, [], voice)
121+
connect.click(refresh_voices, [], voice)

Diff for: extensions/gallery/script.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,12 @@ def select_character(evt: gr.SelectData):
8585
def ui():
8686
with gr.Accordion("Character gallery", open=False):
8787
update = gr.Button("Refresh")
88-
gr.HTML(value="<style>"+generate_css()+"</style>")
88+
gr.HTML(value="<style>" + generate_css() + "</style>")
8989
gallery = gr.Dataset(components=[gr.HTML(visible=False)],
9090
label="",
9191
samples=generate_html(),
9292
elem_classes=["character-gallery"],
9393
samples_per_page=50
9494
)
9595
update.click(generate_html, [], gallery)
96-
gallery.select(select_character, None, gradio['character_menu'])
96+
gallery.select(select_character, None, gradio['character_menu'])

0 commit comments

Comments
 (0)