Skip to content

Commit

Permalink
Fix collection of context/continuation, etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
chimezie committed Nov 18, 2024
1 parent 4b6c971 commit d11aa55
Showing 1 changed file with 7 additions and 52 deletions.
59 changes: 7 additions & 52 deletions lm_eval/models/mlx_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def loglikelihood(

for j in batch_idx[i]:
prompt, completion = requests[j].args
context_batch.append(prompt)
continuation_batch.append(completion)
prompt_lengths.append(input_length(prompt, completion, self.tokenizer))
full_sequence = self.tokenizer.apply_chat_template(
[
Expand Down Expand Up @@ -174,6 +176,7 @@ def loglikelihood(
)
batch = mx.array(batch_arr)
lengths = mx.array(lengths)
prompt_lengths = mx.array(prompt_lengths)
# yield mx.array(batch_arr), mx.array(prompt_lengths), mx.array(lengths)

shifted_padded_full_sequence = batch[
Expand Down Expand Up @@ -211,18 +214,18 @@ def loglikelihood(
for idx, (is_greedy, log_prob) in enumerate(
zip(batch_target_is_greedy_values, log_probs)
):
input_length = prompt_lengths[idx].item()
prompt_length = prompt_lengths[idx].item()
target_length = lengths[idx]
context = context_batch[idx]
continuation = continuation_batch[idx]

del info[(context, continuation)]

# Extract log prob scores at token sequence positions in the logits
target_end_idx = input_length + target_length
target_sequence = self.tok_encode(continuation)
target_end_idx = (prompt_length + target_length).item()
target_sequence = batch[idx][prompt_length:target_end_idx + 1]
target_log_prob_scores = log_prob[
input_length - 1 : target_end_idx - 1, :
prompt_length - 1 : target_end_idx - 1, :
]

# Use the target sequence for extracting log prob values from logits vocabulary distribution
Expand All @@ -246,54 +249,6 @@ def loglikelihood(
# Return the answers in the original order (lost by the batch creation process, which )
return list(map(lambda i: i[1:], sorted(res, key=lambda i: i[0])))

def delineated_batches(self, batch_size, context_text, continuation_text):
try:
import mlx.core as mx
except ModuleNotFoundError:
raise Exception(
"attempted to use 'mlx' LM type, but package `mlx` is not installed. Please install mlx via "
"`pip install 'lm-eval[mlx]'` or `pip install -e '.[mlx]'`",
)

batch_size = min(batch_size, len(context_text))
encoded_context_batch = [self.tok_encode(record) for record in context_text]
encoded_continuation_batch = [
self.tok_encode(record) for record in continuation_text
]

input_lengths = [len(x) for x in encoded_context_batch]
target_lengths = [len(x) for x in encoded_continuation_batch]

full_labels = [
encoded_context_batch[idx] + encoded_continuation_batch[idx]
for idx in range(batch_size)
]
lengths = [len(x) for x in full_labels]

if max(lengths) > self.max_tokens:
print(
f"[WARNING] Some sequences are longer than {self.max_tokens} tokens. "
f"The longest sentence {max(lengths)} would normally be truncated to {self.max_tokens}. "
"Consider pre-splitting your data to save memory."
)
pad_to = 8
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)

batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
adjusted_lengths = []
for j in range(batch_size):
input_length = input_lengths[j]
full_ids_end_idx = input_length + min(
target_lengths[j], max_length_in_batch - input_length
)
adjusted_lengths.append(full_ids_end_idx)
batch_arr[j, :full_ids_end_idx] = full_labels[j][:full_ids_end_idx]

batch = mx.array(batch_arr)
input_lengths = mx.array(input_lengths)
non_padding_lengths = mx.array(adjusted_lengths)

return batch, input_lengths, target_lengths, non_padding_lengths

def loglikelihood_rolling(
self, requests: list[Instance]
Expand Down

0 comments on commit d11aa55

Please sign in to comment.