Skip to content

Refine logging for Gaudi warmup #3222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions backends/gaudi/server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,9 +1367,16 @@ def warmup(
prefill_seqlen_list.append(max_input_tokens)
prefill_batch_size_list.sort(reverse=True)
prefill_seqlen_list.sort(reverse=True)
logger.info(
f"Prefill batch size list:{prefill_batch_size_list}\n"
f"Prefill sequence length list:{prefill_seqlen_list}\n"
)
try:
for batch_size in prefill_batch_size_list:
for seq_len in prefill_seqlen_list:
logger.info(
f"Prefill warmup for `batch_size={batch_size}` and `sequence_length={seq_len}`, this may take a while..."
)
batch = self.generate_warmup_batch(request, seq_len - 1, batch_size)
_, prefill_batch, _ = self.generate_token([batch])
except Exception:
Expand All @@ -1384,22 +1391,21 @@ def warmup(
prefill_seqlen_list.sort()
prefill_batch_size_list.sort()
mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing prefill warmup successfully.\n"
f"Prefill batch size list:{prefill_batch_size_list}\n"
f"Prefill sequence length list:{prefill_seqlen_list}\n"
f"Memory stats: {mem_stats} "
)
logger.info(f"Prefill warmup successful.\n" f"Memory stats: {mem_stats} ")

max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
max_exp = math.ceil(math.log(max_decode_batch_size, BATCH_SIZE_EXPONENT_BASE))
decode_batch_size_list = [
BATCH_SIZE_EXPONENT_BASE**exp for exp in range(0, max_exp + 1)
]
decode_batch_size_list.sort(reverse=True)
logger.info(f"Decode batch size list:{decode_batch_size_list}\n")

try:
for batch_size in decode_batch_size_list:
logger.info(
f"Decode warmup for `batch_size={batch_size}`, this may take a while..."
)
batches = []
iters = math.floor(batch_size / max_prefill_batch_size)
for i in range(iters):
Expand Down Expand Up @@ -1432,11 +1438,7 @@ def warmup(
decode_batch_size_list.sort()
max_supported_total_tokens = MAX_TOTAL_TOKENS * decode_batch_size_list[-1]
mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing decode warmup successfully.\n"
f"Decode batch size list:{decode_batch_size_list}\n"
f"Memory stats: {mem_stats} "
)
logger.info(f"Decode warmup successful.\n" f"Memory stats: {mem_stats} ")

max_input_tokens = max_input_tokens
max_total_tokens = MAX_TOTAL_TOKENS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def __init__(
self.prefilling = prefilling

@property
def token_idx(self):
def token_idx(self): # noqa: F811
if self.prefilling:
# no right padding for prefill
token_idx_scalar = self.attention_mask.shape[-1] - 1
Expand Down Expand Up @@ -1511,9 +1511,16 @@ def warmup(
DECODE_WARMUP_BATCH_SIZE_LIST = []
prefill_batch = None
decode_batch = None
logger.info(
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
)
try:
for batch_size in PREFILL_WARMUP_BATCH_SIZE_LIST:
for seq_len in PREFILL_WARMUP_SEQLEN_LIST:
logger.info(
f"Prefill warmup for `batch_size={batch_size}` and `sequence_length={seq_len}`, this may take a while..."
)
batch = self.generate_warmup_batch(
request, seq_len, batch_size, is_warmup=True
)
Expand All @@ -1527,24 +1534,16 @@ def warmup(

except Exception:
raise RuntimeError(
f"Not enough memory to handle following prefill and decode warmup."
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}"
f"You need to decrease `--max-batch-prefill-tokens`"
"Not enough memory to handle following prefill and decode warmup."
"You need to decrease `--max-batch-prefill-tokens`"
)

mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing prefill and decode warmup successfully.\n"
f"Prefill batch size list:{PREFILL_WARMUP_BATCH_SIZE_LIST}\n"
f"Prefill sequence length list:{PREFILL_WARMUP_SEQLEN_LIST}\n"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
f"Memory stats: {mem_stats} "
)
logger.info(f"Prefill warmup successful.\n" f"Memory stats: {mem_stats} ")

max_decode_batch_size = MAX_BATCH_SIZE
batch_size = max_prefill_batch_size * 2
logger.info(f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n")
# Decode warmup with bigger batch_size
try:
if (
Expand All @@ -1554,6 +1553,9 @@ def warmup(
batches = []
while batch_size <= max_decode_batch_size:
for i in range(int(batch_size / max_prefill_batch_size)):
logger.info(
f"Decode warmup for `batch_size={batch_size}`, this may take a while..."
)
batch = self.generate_warmup_batch(
request,
PREFILL_WARMUP_SEQLEN_LIST[0] - 1,
Expand Down Expand Up @@ -1596,11 +1598,7 @@ def warmup(
)

mem_stats = get_hpu_memory_stats(self.device)
logger.info(
f"\nFollowing decode warmup successfully.\n"
f"Decode batch size list:{DECODE_WARMUP_BATCH_SIZE_LIST}\n"
f"Memory stats: {mem_stats}"
)
logger.info(f"Decode warmup successful.\n" f"Memory stats: {mem_stats}")

max_supported_total_tokens = MAX_BATCH_SIZE * MAX_TOTAL_TOKENS
max_input_tokens = max_input_tokens
Expand Down
Loading