Open
Description
非常棒的工作!我尝试迁移模型到BLAST数据集进行训练,但是loss和mae指标变化趋势貌似并不匹配,请问是否可以告诉我代码是否有问题,或提供训练sundial的代码?谢谢!
class Sundial(nn.Module):
def __init__(self, model_id: str, from_pretrained: bool,
context_length: int,
trust_remote_code: bool):
super().__init__()
self.model_type = 'causal' # TimeMoE is a causal model
self.context_length = context_length
if from_pretrained:
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=trust_remote_code,
)
else:
kwargs = {}
kwargs['torch_dtype'] = 'float32'
# kwargs['attn_implementation'] = 'flash_attention_2'
config, model_kwargs = SundialConfig.from_pretrained(
pretrained_model_name_or_path=model_id,
return_unused_kwargs=True,
**kwargs)
# print(f'Using attention implementation: {kwargs.get("attn_implementation", "original")}')
self.model = SundialForPrediction(config)
from safetensors.torch import load_model, save_model
load_model(self.model, "baselines/Sundial/ckpt/model.safetensors")
self.chunk_size = self.model.config.input_token_len
self.output_token_len = self.model.config.output_token_lens[-1]
def forward(self, context: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, target_mask: torch.Tensor):
# 这里label应该是rolling多步?
# torch.Size([64, 4096]) torch.Size([64, 720]) torch.Size([64, 4096]) torch.Size([64, 720])
labels = torch.cat([context[:, self.chunk_size:], target], dim=-1).detach() # torch.Size([64, 4800])
loss_masks = torch.cat([mask[:, self.chunk_size:], target_mask], dim=-1)
loss_masks = loss_masks.unfold(dimension=-1, size=self.output_token_len, step=self.chunk_size).any(dim=-1).detach() # torch.Size([64, 256, 720]) torch.Size([64, 256])
mask = self.mask_pre(mask).detach() # torch.Size([64, 256])
# print(mask.sum(), loss_masks.sum())
context, labels = context.nan_to_num(), labels.nan_to_num()
# 这里attention_mask和BLAST数据集中的mask是反过来的,可能需要进行取反~;但是loss_masks不需要
# 数据集中已经归一化,这里不需要再次revin
output, _mae, _mse = self.model(input_ids=context, labels=labels, attention_mask=~mask, loss_masks=loss_masks, mask_y=None, revin=False)
loss, _ = output.loss, output.logits # _ is the logits
return loss, _mae, _mse
def mask_pre(self, mask):
B, L = mask.shape
compressed_mask = mask.reshape(B, L // self.chunk_size, self.chunk_size).any(dim=2)
# print(mask.size(), compressed_mask.size())
return compressed_mask