diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index bb60b2d180..e1804c7b18 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -3,6 +3,7 @@ import math import os import time +import random from pathlib import Path from pprint import pprint from typing import Dict, List, Literal, Optional, Tuple, Union @@ -235,10 +236,13 @@ def fit( data: DataModule, ) -> None: tokenizer = Tokenizer(checkpoint_dir) - longest_seq_length, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) + + longest_seq_length_train, longest_seq_ix = get_longest_seq_length(train_dataloader.dataset) + longest_seq_length_val, longest_seq_ix = get_longest_seq_length(val_dataloader.dataset) + longest_seq_length = longest_seq_length_train if longest_seq_length_train > longest_seq_length_val else longest_seq_length_val model.max_seq_length = min(longest_seq_length, train.max_seq_length or float("inf")) fabric.print( - f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" +"The longest sequence lengths are {longest_seq_length_train} in training data and {longest_seq_length_val} in validation data. The model's maximum sequence length is" f" {model.max_seq_length} and context length is {model.config.block_size}" ) @@ -344,7 +348,12 @@ def validate( val_loss = losses.mean() # produce an example: - instruction = "Recommend a movie for me to watch during the weekend and explain the reason." + rand = random.randint(0, 50) + try: + instruction = val_dataloader.dataset.data[rand]["instruction"] + except Exception as e: + print(f"Import of validation data failed: {e}") + instruction = "Recommend a movie for me to watch during the weekend and explain the reason." fabric.print(instruction) prompt = data.prompt_style.apply(instruction) encoded = tokenizer.encode(prompt, device=fabric.device)