-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfinetune.py
82 lines (74 loc) · 3.08 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# The code here supports finetuning BERT on downstream tasks that can
# be framed as sentence or sentence pair classification (or regression).
import torch
import torch.nn as nn
import yaml
from dataclasses import dataclass
@dataclass
class FineTuneConfig:
tasks: list
num_epochs: int
batch_size: int
lr: float
weight_decay: float
dropout: float
metadata_file: str
checkpoint_path: str
tokenizer_path: str
@classmethod
def from_yaml(cls, path):
import yaml
with open(path, 'r') as f:
config = yaml.safe_load(f)
return cls(**config)
class FineTuneDataset(torch.utils.data.Dataset):
def __init__(self, sentence1s, sentence2s, labels, num_classes, tokenizer, max_len=128):
self.sentence1s = sentence1s
self.sentence2s = sentence2s
self.labels = labels
self.tokenizer = tokenizer
self.max_len = max_len
self.num_classes = num_classes
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
input_ids = self.tokenizer(self.sentence1s[idx], max_length=self.max_len, truncation=True, padding=False)['input_ids']
if self.sentence2s is not None:
input_ids.append(self.tokenizer.sep_token_id)
input_ids.extend(self.tokenizer(self.sentence2s[idx], max_length=self.max_len, truncation=True, padding=False)['input_ids'])
input_ids.append(self.tokenizer.sep_token_id)
# truncate if too long
input_ids = input_ids[:self.max_len]
# pad if too short
input_ids.extend([self.tokenizer.pad_token_id] * (self.max_len - len(input_ids)))
# attention mask - 1 for tokens that are not padding, 0 for padding tokens
attention_mask = [1 if token != self.tokenizer.pad_token_id else 0 for token in input_ids]
label = self.labels[idx]
if self.num_classes == 1:
label = torch.tensor(label, dtype=torch.float)
else:
label = torch.tensor(label, dtype=torch.long)
return (
torch.LongTensor(input_ids),
label,
torch.BoolTensor(attention_mask)
)
# Supports binary, multiclass, and regression tasks
class BERTForFineTuning(nn.Module):
def __init__(self, bert, num_classes, dropout=0.1):
super().__init__()
self.bert = bert
self.dropout = nn.Dropout(dropout)
self.output_head = nn.Linear(bert.d_model, num_classes)
nn.init.normal_(self.output_head.weight, std=bert.initializer_range)
def forward(self, input_ids, targets=None, attention_mask=None):
outputs = self.bert(input_ids, mask=attention_mask) # (bsz, seq_len, hidden_size)
pooled = torch.mean(outputs, dim=1) # (bsz, hidden_size)
logits = self.output_head(self.dropout(pooled)) # (bsz, num_classes)
if targets is None:
return logits
if self.output_head.out_features == 1:
loss = torch.nn.functional.mse_loss(logits.squeeze(), targets)
else:
loss = torch.nn.functional.cross_entropy(logits, targets)
return loss