|
7 | 7 |
|
8 | 8 | from bloom_inference.tpu_manager import TPUManager |
9 | 9 |
|
10 | | -num_mp_partitions = 8 |
11 | | - |
12 | | -#tpu_name = "suraj-tpu-v3-32" |
13 | | -# tpu_name = "patrick-tpu-v3-32" |
14 | | -# region = "europe-west4-a" |
15 | 10 | tpu_name="bloom-tpu-v4-64" |
16 | 11 | region="us-central2-b" |
17 | 12 |
|
18 | | -ckpt = "bigscience/bloom-6b3", |
19 | | -t5x_path = "gs://bloom-jax-us-central2-b/bloom-176B-scan-t5x/checkpoint_0", |
20 | | -max_len = 256, |
21 | | -max_input_len = 64, |
22 | | -model_parallel_submesh = (1, 2, 4, 1), # for v4-64 |
23 | | - |
24 | | - |
25 | | -# get Python list of TPU hosts |
26 | | -conns = get_connection(tpu_name, region) |
27 | | - |
28 | | -head_info = ray.init(include_dashboard=False, object_store_memory=10**9) |
29 | | -address = head_info.address_info['address'] |
30 | | - |
31 | | -# start ray CPU<->TPU on all hosts |
32 | | -with pool.ThreadPool(processes=len(conns)) as p: |
33 | | - p.map(functools.partial(start_ray, address=address), conns) |
34 | | - |
35 | | -# initialise TPU manager |
36 | | -t = TPUManager( |
37 | | - len(conns), |
38 | | - ckpt=ckpt, |
39 | | - t5x_path=t5x_path, |
40 | | - max_len=max_len, |
41 | | - max_input_len=max_input_len, |
42 | | - model_parallel_submesh=model_parallel_submesh, |
43 | | -) |
44 | | - |
45 | | -# benchmark compile step |
46 | | -start = time.time() |
47 | | -print(t.generate(4*['Recipe for coconut pasta:'])) |
48 | | -print(f"Generations completed in ***REMOVED***time.time() - start:.06***REMOVED***s") |
49 | | - |
50 | | -# benchmark generate |
51 | | -start = time.time() |
52 | | -print(t.generate(4*['Recipe for coconut pasta:'])) |
53 | | -print(f"Generations completed in ***REMOVED***time.time() - start:.06***REMOVED***s") |
54 | | - |
55 | | -# shutdown ray rpc |
56 | | -ray.shutdown() |
| 13 | +ckpt = "bigscience/bloom" |
| 14 | +t5x_path = "gs://bloom-jax-us-central2-b/bloom-176B-scan-t5x/checkpoint_0" |
| 15 | +max_len = 128 |
| 16 | +max_input_len = 64 |
| 17 | +model_parallel_submesh = (1, 2, 4, 1) # for v4-64 |
| 18 | + |
| 19 | + |
| 20 | +def setup(): |
| 21 | + # get Python list of TPU hosts |
| 22 | + conns = get_connection(tpu_name, region) |
| 23 | + print(len(conns)) |
| 24 | + address='10.130.0.10:8080' |
| 25 | + head_info = ray.init(include_dashboard=False, address="auto") |
| 26 | + # object_store_memory=10**9, |
| 27 | + |
| 28 | + # start ray CPU<->TPU on all hosts |
| 29 | + with pool.ThreadPool(processes=len(conns)) as p: |
| 30 | + p.map(functools.partial(start_ray, address=address), conns) |
| 31 | + |
| 32 | +def init_manager(): |
| 33 | + # initialise TPU manager |
| 34 | + t = TPUManager( |
| 35 | + 8, |
| 36 | + ckpt=ckpt, |
| 37 | + t5x_path=t5x_path, |
| 38 | + max_len=max_len, |
| 39 | + max_input_len=max_input_len, |
| 40 | + model_parallel_submesh=model_parallel_submesh, |
| 41 | + ) |
| 42 | + return t |
| 43 | + |
| 44 | +# # benchmark compile step |
| 45 | +# start = time.time() |
| 46 | +# print(t.generate(4*['Recipe for coconut pasta:'])) |
| 47 | +# print(f"Generations completed in ***REMOVED***time.time() - start:.06***REMOVED***s") |
| 48 | + |
| 49 | +# # benchmark generate |
| 50 | +# start = time.time() |
| 51 | +# print(t.generate(4*['Recipe for coconut pasta:'])) |
| 52 | +# print(f"Generations completed in ***REMOVED***time.time() - start:.06***REMOVED***s") |
| 53 | + |
| 54 | +# # shutdown ray rpc |
| 55 | +# ray.shutdown() |
0 commit comments