Skip to content

Commit

Permalink
add mistral support
Browse files Browse the repository at this point in the history
  • Loading branch information
higgsfield committed Nov 14, 2023
1 parent 5d4bba7 commit 61bc432
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ You can use whatever you want, whenever you want. We just introduce a simple int

**Clouds we have tested on:**

- Azure
- LambdaLabs
- FluidStack

Expand Down
Empty file added higgsfield/mistral/__init__.py
Empty file.
219 changes: 219 additions & 0 deletions higgsfield/mistral/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import os
import functools
from pathlib import Path

import torch
import torch.distributed as dist

from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
CPUOffload,
)

from torch.distributed.fsdp import (
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)

from transformers import (
MistralForCausalLM,
MistralConfig,
default_data_collator,
)
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
from optimum.bettertransformer import BetterTransformer

from higgsfield.checkpoint.fsdp_checkpoint import (
save_distributed_model_rank0,
fsdp_model_state_dict_rank0,
)

from higgsfield.mistral.mistral_utils import (
load_mistral_from_checkpoint,
load_mistral_from_config,
)

class Mistral(FSDP):
def __init__(
self,
model_name,
checkpoint_path=None,
zero_stage=3,
fast_attn=False,
precision="bf16",
cpu_init_rank0=False,
cpu_offload=False,
num_embeddings=None,
cache_dir=None,
):

rank = dist.get_rank()


model = MistralForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)

if num_embeddings:
model.resize_token_embeddings(num_embeddings)


if fast_attn:
#raise NotImplementedError("Fast attention is not supported yet")
model = BetterTransformer.transform(model)

fpSixteen = MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16,
)

bfSixteen_mixed = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)

pure_bf16 = False
if precision == "fp16":
mixed_precision_policy = fpSixteen

elif precision == "bf16":
mixed_precision_policy = None
pure_bf16 = True

elif precision == "bf16_mixed":
mixed_precision_policy = bfSixteen_mixed

else:
mixed_precision_policy = None

if pure_bf16:
model.to(torch.bfloat16)

wrapping_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
MistralDecoderLayer,
}
)

if zero_stage == 0:
sharding_strategy = ShardingStrategy.NO_SHARD

elif zero_stage == 1:
raise NotImplementedError("stage 1 is not supported. Only 0 2 3")

elif zero_stage == 2:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP

elif zero_stage == 3:
sharding_strategy = ShardingStrategy.FULL_SHARD
else:
raise NotImplementedError("stage can be only 0 2 3")

if cpu_init_rank0 and rank != 0:
param_init_fn = lambda module: module.to_empty(
device=torch.device('cuda'),
recurse=False,
)
else:
param_init_fn = None

if cpu_offload:
cpu_offload = CPUOffload(offload_params=True)
else:
cpu_offload = None

super().__init__(
model,
auto_wrap_policy=wrapping_policy,
cpu_offload=cpu_offload,
mixed_precision=mixed_precision_policy,
sharding_strategy=sharding_strategy,
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
sync_module_states=cpu_init_rank0,
param_init_fn=param_init_fn,
)

non_reentrant_wrapper = functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

check_fn = lambda submodule: isinstance(submodule, MistralDecoderLayer)

apply_activation_checkpointing(
self,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=check_fn,
)

fsdp = True
self.precision = precision
self.fsdp = fsdp
self.model_name = model_name
self.num_embeddings = num_embeddings

def __call__(self, batch):
local_rank = int(os.environ["LOCAL_RANK"])

for key in batch.keys():
batch[key] = batch[key].to(local_rank)

if self.precision == "fp16":
with torch.cuda.amp.autocast():
loss = super().__call__(**batch).loss
else:
loss = super().__call__(**batch).loss

return loss

def save_model(self, save_path):
'''
Save model's weight to master node
~/.cache/higgsfield/{save_path}
'''
if "/" == save_path[0]:
save_path = save_path[1:]

head, tail = os.path.split(save_path)

path = Path.home() / ".cache/higgsfield" / head
path.mkdir(exist_ok=True, parents=True)

save_distributed_model_rank0(path / tail, self)

def save_huggingface_model(self, save_path):
'''
Save model's weight in huggingface format to master node
~/.cache/higgsfield/{save_path}
'''
if "/" == save_path[0]:
save_path = save_path[1:]

head, tail = os.path.split(save_path)

path = Path.home() / ".cache/higgsfield" / head
path.mkdir(exist_ok=True, parents=True)
cpu_state = fsdp_model_state_dict_rank0(self)

if dist.get_rank() == 0:
model = load_mistral_from_config(self.model_name, num_embeddings=self.num_embeddings)
model.load_state_dict(cpu_state)
model.save_pretrained(path / tail)

def push_to_hub(self, repo_id, token):
cpu_state = fsdp_model_state_dict_rank0(self)

if dist.get_rank() == 0:
model = load_mistral_from_config(self.model_name, num_embeddings=self.num_embeddings)
model.load_state_dict(cpu_state)
model.push_to_hub(repo_id, token=token)
110 changes: 110 additions & 0 deletions higgsfield/mistral/mistral_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch.distributed as dist

from torch.utils.data import (
DistributedSampler,
DataLoader
)

from transformers import (
AutoTokenizer,
default_data_collator
)

from higgsfield.dataset import TorchCompletionDataset

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "<|pad|>"
DEFAULT_EOS_TOKEN = "<|endoftext|>"
DEFAULT_UNK_TOKEN = "<|unk|>"

def get_tokenizer(model_name, max_length, cache_dir=None):

tokenizer = AutoTokenizer.from_pretrained(
model_name,
model_max_length=max_length,
padding_side="right",
use_fast=False,
pad_token=DEFAULT_PAD_TOKEN,
trust_remote_code=True,
cache_dir=cache_dir,
)

special_tokens_dict = dict()
if tokenizer.pad_token is None:
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
if tokenizer.eos_token is None:
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
if tokenizer.unk_token is None:
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

tokenizer.add_special_tokens(special_tokens_dict)

return tokenizer

class HiggsfieldSampler(DistributedSampler):
def __init__(
self,
dataset,
shuffle=True,
seed=0,
drop_last=False
):
rank=dist.get_rank()
num_replicas=dist.get_world_size()

super(HiggsfieldSampler, self).__init__(
dataset=dataset,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
seed=seed,
drop_last=drop_last,
)

class MistralLoader(DataLoader):
def __init__(
self,
dataset,
tokenizer=None,
max_sequence_length=2048,
batch_size_per_gpu=1,
shuffle=True,
seed=0,
num_workers=0,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
*,
prefetch_factor=None,
persistent_workers=False,
pin_memory_device=""
):

if not tokenizer:
tokenizer = get_tokenizer("mistralai/Mistral-7B-v0.1", max_sequence_length)

dataset = TorchCompletionDataset(
dataset,
tokenizer,
max_sequence_length,
)

sampler = HiggsfieldSampler(dataset, shuffle=shuffle, seed=seed,)

super(MistralLoader, self).__init__(
dataset,
batch_size=batch_size_per_gpu,
sampler=sampler,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=drop_last,
timeout=timeout,
worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory_device=pin_memory_device
)

21 changes: 21 additions & 0 deletions higgsfield/mistral/mistral_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from transformers import (
MistralConfig,
MistralForCausalLM,
)
from higgsfield.checkpoint import fsdp_model_state_dict_rank0

def load_mistral_from_config(model_name, num_embeddings=None):
config = MistralConfig.from_pretrained(model_name)
model = MistralForCausalLM(config)

if num_embeddings:
model.resize_token_embeddings(num_embeddings)

return model

def load_mistral_from_checkpoint(model_name, checkpoint_path, num_embeddings=None):
model = load_mistral_from_config(model_name, num_embeddings=num_embeddings)
state_dict = torch.load(checkpoint_path)
model.load_state_dict(state_dict)
return model

0 comments on commit 61bc432

Please sign in to comment.