Skip to content

How to train Sundial? #1

Open
Open
@JasonStraka

Description

@JasonStraka

非常棒的工作!我尝试迁移模型到BLAST数据集进行训练,但是loss和mae指标变化趋势貌似并不匹配,请问是否可以告诉我代码是否有问题,或提供训练sundial的代码?谢谢!

class Sundial(nn.Module):
    def __init__(self, model_id: str, from_pretrained: bool,
                    context_length: int,
                    trust_remote_code: bool):
        
        super().__init__()
        
        self.model_type = 'causal' # TimeMoE is a causal model
        self.context_length = context_length
        
        if from_pretrained:
            self.model = AutoModelForCausalLM.from_pretrained(
                model_id,
                trust_remote_code=trust_remote_code,
            )
        else:
            kwargs = {}
            kwargs['torch_dtype'] = 'float32'
            # kwargs['attn_implementation'] = 'flash_attention_2'
            config, model_kwargs = SundialConfig.from_pretrained(
                                    pretrained_model_name_or_path=model_id,
                                    return_unused_kwargs=True,
                                    **kwargs)
            # print(f'Using attention implementation: {kwargs.get("attn_implementation", "original")}')
            self.model = SundialForPrediction(config)
            from safetensors.torch import load_model, save_model
            load_model(self.model, "baselines/Sundial/ckpt/model.safetensors")
        self.chunk_size = self.model.config.input_token_len
        self.output_token_len = self.model.config.output_token_lens[-1]

    def forward(self, context: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, target_mask: torch.Tensor):
        # 这里label应该是rolling多步?
        # torch.Size([64, 4096]) torch.Size([64, 720]) torch.Size([64, 4096]) torch.Size([64, 720])
        labels = torch.cat([context[:, self.chunk_size:], target], dim=-1).detach() # torch.Size([64, 4800])
        loss_masks = torch.cat([mask[:, self.chunk_size:], target_mask], dim=-1)
        loss_masks = loss_masks.unfold(dimension=-1, size=self.output_token_len, step=self.chunk_size).any(dim=-1).detach() # torch.Size([64, 256, 720])  torch.Size([64, 256]) 
        mask = self.mask_pre(mask).detach() # torch.Size([64, 256])
        # print(mask.sum(), loss_masks.sum())
        context, labels = context.nan_to_num(), labels.nan_to_num()
        # 这里attention_mask和BLAST数据集中的mask是反过来的,可能需要进行取反~;但是loss_masks不需要
        # 数据集中已经归一化,这里不需要再次revin
        output, _mae, _mse = self.model(input_ids=context, labels=labels, attention_mask=~mask, loss_masks=loss_masks, mask_y=None, revin=False)

        loss, _ = output.loss, output.logits # _ is the logits

        return loss, _mae, _mse

    def mask_pre(self, mask):
        B, L = mask.shape
        compressed_mask = mask.reshape(B, L // self.chunk_size, self.chunk_size).any(dim=2)
        # print(mask.size(), compressed_mask.size())
        return compressed_mask

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions