Skip to content

Commit 9d50f6a

Browse files
committed
push
1 parent d7cd619 commit 9d50f6a

File tree

11 files changed

+128
-72
lines changed

11 files changed

+128
-72
lines changed

bloom_inference/generator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class Generator:
5050
def __init__(
5151
self,
5252
model_parallel_submesh=(1, 2, 4, 1), # for v4-64
53-
ckpt="bigscience/bloom-6b3",
53+
ckpt="bigscience/bloom",
5454
t5x_path="gs://bloom-jax-us-central2-b/bloom-176B-scan-t5x/checkpoint_0",
5555
max_len=256,
5656
max_input_len=64,
@@ -62,14 +62,14 @@ def __init__(
6262
self.max_input_len = max_input_len
6363

6464
config = BloomConfig.from_pretrained(ckpt, max_length=max_len, do_sample=True, num_beams=1, top_p=0.9)
65-
model = FlaxBloomForCausalLM(config, _do_init=False, dtype=jnp.bfloat16, use_scan=True)
65+
self.model = FlaxBloomForCausalLM(config, _do_init=False, dtype=jnp.bfloat16, use_scan=True)
6666

6767
def init_state():
6868
input_shape = (1,1)
6969
input_ids = jnp.zeros(input_shape, dtype="i4")
7070
attention_mask = jnp.ones_like(input_ids)
7171
rng = jax.random.PRNGKey(0)
72-
initial_vars = model.module.init(rng, input_ids, attention_mask, return_dict=False)
72+
initial_vars = self.model.module.init(rng, input_ids, attention_mask, return_dict=False)
7373
return InferenceState.create(initial_vars)
7474

7575
state_shapes = jax.eval_shape(init_state)

bloom_inference/host_worker.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
import os
12
import ray
23
import time
34
from queue import Queue
45

56

67
@ray.remote(resources=***REMOVED***"tpu": 1***REMOVED***)
8+
# @ray.remote
79
class TPUHostWorker(object):
810
def __init__(
911
self,
10-
ckpt="bigscience/bloom-6b3",
12+
ckpt="bigscience/bloom",
1113
t5x_path="gs://bloom-jax-us-central2-b/bloom-176B-scan-t5x/checkpoint_0",
1214
max_len=256,
1315
max_input_len=64,
@@ -22,14 +24,24 @@ def __init__(
2224
self.input_q = Queue(maxsize=1)
2325
self.output_q = Queue(maxsize=1)
2426

27+
self._is_cpu = os.path.exists("/home/suraj_huggingface_co/bloom-jax-inference/is_cpu.txt")
28+
29+
def is_cpu(self):
30+
return self._is_cpu
31+
2532
def run(self):
2633
# we import packages here to import JAX and Generator only on the Host worker and not the CPU manager
2734
import jax
2835
from bloom_inference.generator import Generator, head_print
2936

3037
print(f"jax runtime initialization starting")
3138
start = time.time()
32-
head_print(f"jax devices: ***REMOVED***jax.device_count()***REMOVED***")
39+
device_count = jax.device_count()
40+
if device_count == 1:
41+
head_print("TPU not found. Returning")
42+
ray.shutdown()
43+
return
44+
head_print(f"jax devices: ***REMOVED***device_count***REMOVED***")
3345
head_print(f"jax runtime initialized in ***REMOVED***time.time() - start:.06***REMOVED***s")
3446

3547
# load model and params

bloom_inference/tpu_manager.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ class TPUManager:
77
def __init__(
88
self,
99
node_count=8,
10-
ckpt="bigscience/bloom-6b3",
10+
ckpt="bigscience/bloom",
1111
t5x_path="gs://bloom-jax-us-central2-b/bloom-176B-scan-t5x/checkpoint_0",
1212
max_len=256,
1313
max_input_len=64,
1414
model_parallel_submesh=(1, 2, 4, 1), # for v4-64
1515
):
1616
# needs a valid ray cluster to start
17-
assert ray.is_initialized(), "ray not initialised"
17+
# assert ray.is_initialized(), "ray not initialised"
1818

1919
from bloom_inference.host_worker import TPUHostWorker
2020

@@ -29,16 +29,33 @@ def __init__(
2929

3030
start = time.time()
3131

32-
for i in range(node_count):
32+
# for i in range(node_count):
33+
# worker = TPUHostWorker.options(max_concurrency=2).remote(
34+
# ckpt,
35+
# t5x_path,
36+
# max_len,
37+
# max_input_len,
38+
# model_parallel_submesh,
39+
# )
40+
# is_cpu = ray.get(worker.is_cpu.remote())
41+
# print(is_cpu)
42+
# if not is_cpu:
43+
# self.nodes.append(worker)
44+
45+
while (len(self.nodes) < node_count):
3346
worker = TPUHostWorker.options(max_concurrency=2).remote(
3447
ckpt,
3548
t5x_path,
3649
max_len,
3750
max_input_len,
3851
model_parallel_submesh,
3952
)
40-
self.nodes.append(worker)
53+
is_cpu = ray.get(worker.is_cpu.remote())
54+
print(is_cpu)
55+
if not is_cpu:
56+
self.nodes.append(worker)
4157

58+
assert len(self.nodes) == node_count
4259
for node in self.nodes:
4360
node.run.remote()
4461

dump.rdb

88 Bytes
Binary file not shown.

is_cpu.txt

Whitespace-only changes.

launch_generate.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
INSTANCE=bloom-tpu-v4-64
2+
ZONE=us-central2-b
3+
PROJECT=huggingface-ml
4+
5+
# run script.bash through run_script.bash
6+
gcloud alpha compute tpus tpu-vm ssh $INSTANCE --project=$PROJECT --zone=$ZONE \
7+
--force-key-file-overwrite --strict-host-key-checking=no \
8+
--worker=all \
9+
--command="bash ~/bloom-jax-inference/run_generate.sh"

ray_tpu.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,28 @@ def get_connection(
4646

4747
def start_ray(conn, address):
4848
# start afresh each launch (temporarily)
49-
conn.run("sudo rm -rf *.py bloom_inference")
49+
conn.run("sudo rm -rf *.py bloom-jax-inference")
5050
# make directory of structure: bloom_inference/bloom_inference/modeling_bloom
51-
conn.run("mkdir bloom_inference bloom_inference/bloom_inference bloom_inference/bloom_inference/modeling_bloom -p")
52-
51+
conn.run("mkdir bloom-jax-inference bloom-jax-inference/bloom_inference bloom-jax-inference/bloom_inference/modeling_bloom -p")
52+
5353
# copy run files into bloom_inference
5454
for i in glob.glob("*.py"):
55-
conn.put(i, "bloom_inference/")
55+
conn.put(i, "bloom-jax-inference/")
5656

5757
# copy CPU/TPU manager files into bloom_inference/bloom_inference
5858
for i in glob.glob("bloom_inference/*.py"):
59-
conn.put(i, "bloom_inference/bloom_inference/")
59+
conn.put(i, "bloom-jax-inference/bloom_inference/")
6060

6161
# copy modeling files into bloom_inference/bloom_inference/modeling_bloom
6262
for i in glob.glob("bloom_inference/modeling_bloom/*.py"):
63-
conn.put(i, "bloom_inference/bloom_inference/modeling_bloom/")
63+
conn.put(i, "bloom-jax-inference/bloom_inference/modeling_bloom/")
64+
65+
# copy modeling files into bloom_inference/bloom_inference/modeling_bloom
66+
for i in glob.glob("*.sh"):
67+
conn.put(i, "bloom-jax-inference/")
68+
69+
# copy key files into bloom_inference
70+
conn.put("key.json", "bloom-jax-inference/")
6471

6572
# transfer start-up script from CPU -> hosts and give permissions
6673
conn.put("scripts/ray_tpu.sh", "/tmp/ray-tpu.sh")
@@ -74,6 +81,6 @@ def start_ray(conn, address):
7481
time.sleep(1)
7582

7683
# run start-up script
77-
out = conn.run(f"bash /tmp/ray-tpu.sh ***REMOVED***address***REMOVED***", hide=True)
84+
out = conn.run(f"bash /tmp/ray-tpu.sh ***REMOVED***address***REMOVED***", hide=False)
7885
# display result
7986
print(out)

run.py

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,50 +7,49 @@
77

88
from bloom_inference.tpu_manager import TPUManager
99

10-
num_mp_partitions = 8
11-
12-
#tpu_name = "suraj-tpu-v3-32"
13-
# tpu_name = "patrick-tpu-v3-32"
14-
# region = "europe-west4-a"
1510
tpu_name="bloom-tpu-v4-64"
1611
region="us-central2-b"
1712

18-
ckpt = "bigscience/bloom-6b3",
19-
t5x_path = "gs://bloom-jax-us-central2-b/bloom-176B-scan-t5x/checkpoint_0",
20-
max_len = 256,
21-
max_input_len = 64,
22-
model_parallel_submesh = (1, 2, 4, 1), # for v4-64
23-
24-
25-
# get Python list of TPU hosts
26-
conns = get_connection(tpu_name, region)
27-
28-
head_info = ray.init(include_dashboard=False, object_store_memory=10**9)
29-
address = head_info.address_info['address']
30-
31-
# start ray CPU<->TPU on all hosts
32-
with pool.ThreadPool(processes=len(conns)) as p:
33-
p.map(functools.partial(start_ray, address=address), conns)
34-
35-
# initialise TPU manager
36-
t = TPUManager(
37-
len(conns),
38-
ckpt=ckpt,
39-
t5x_path=t5x_path,
40-
max_len=max_len,
41-
max_input_len=max_input_len,
42-
model_parallel_submesh=model_parallel_submesh,
43-
)
44-
45-
# benchmark compile step
46-
start = time.time()
47-
print(t.generate(4*['Recipe for coconut pasta:']))
48-
print(f"Generations completed in ***REMOVED***time.time() - start:.06***REMOVED***s")
49-
50-
# benchmark generate
51-
start = time.time()
52-
print(t.generate(4*['Recipe for coconut pasta:']))
53-
print(f"Generations completed in ***REMOVED***time.time() - start:.06***REMOVED***s")
54-
55-
# shutdown ray rpc
56-
ray.shutdown()
13+
ckpt = "bigscience/bloom"
14+
t5x_path = "gs://bloom-jax-us-central2-b/bloom-176B-scan-t5x/checkpoint_0"
15+
max_len = 128
16+
max_input_len = 64
17+
model_parallel_submesh = (1, 2, 4, 1) # for v4-64
18+
19+
20+
def setup():
21+
# get Python list of TPU hosts
22+
conns = get_connection(tpu_name, region)
23+
print(len(conns))
24+
address='10.130.0.10:8080'
25+
head_info = ray.init(include_dashboard=False, address="auto")
26+
# object_store_memory=10**9,
27+
28+
# start ray CPU<->TPU on all hosts
29+
with pool.ThreadPool(processes=len(conns)) as p:
30+
p.map(functools.partial(start_ray, address=address), conns)
31+
32+
def init_manager():
33+
# initialise TPU manager
34+
t = TPUManager(
35+
8,
36+
ckpt=ckpt,
37+
t5x_path=t5x_path,
38+
max_len=max_len,
39+
max_input_len=max_input_len,
40+
model_parallel_submesh=model_parallel_submesh,
41+
)
42+
return t
43+
44+
# # benchmark compile step
45+
# start = time.time()
46+
# print(t.generate(4*['Recipe for coconut pasta:']))
47+
# print(f"Generations completed in ***REMOVED***time.time() - start:.06***REMOVED***s")
48+
49+
# # benchmark generate
50+
# start = time.time()
51+
# print(t.generate(4*['Recipe for coconut pasta:']))
52+
# print(f"Generations completed in ***REMOVED***time.time() - start:.06***REMOVED***s")
53+
54+
# # shutdown ray rpc
55+
# ray.shutdown()

run_generate.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
source ~/venv/bin/activate
2+
export GOOGLE_APPLICATION_CREDENTIALS=~/bloom-jax-inference/key.json
3+
python ~/bloom-jax-inference/run_speed.py

run_speed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import argparse
2-
from time import time
2+
import time
33

44
import numpy as np
55
import jax
@@ -28,7 +28,7 @@
2828
input_len = args.input_len
2929

3030
config = BloomConfig.from_pretrained(ckpt)
31-
model, params = FlaxBloomForCausalLM(config, _do_init=False, dtype=jnp.bfloat16, use_scan=True)
31+
model = FlaxBloomForCausalLM(config, _do_init=False, dtype=jnp.bfloat16, use_scan=True)
3232
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-350m", use_fast=False)
3333

3434

@@ -102,7 +102,7 @@ def generate(params, input_ids, attention_mask):
102102
# This will auto-magically run in mesh context
103103
start = time.time()
104104
gen_ids = p_generate(loaded_state.params, inputs["input_ids"], inputs["attention_mask"])
105-
generated_text = tokenizer.batch_decode(gen_ids.local_shards[0].data, skip_special_tokens=False)
105+
generated_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=False)
106106
if jax.process_index() == 0:
107107
print("Compilation time:", time.time() - start)
108108

0 commit comments

Comments
 (0)