Skip to content

Commit 2c79f17

Browse files
committed
add some format
1 parent 9d61847 commit 2c79f17

File tree

4 files changed

+22
-5
lines changed

4 files changed

+22
-5
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,5 @@ dmypy.json
129129
.pyre/
130130

131131
chinese-roberta-wwm-ext/
132+
133+
output/*.pth

data.py

+4
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@ def __getitem__(self, item):
1717

1818
encoding = self.tokenizer.encode_plus(
1919
review,
20+
add_special_tokens=True,
2021
max_length=self.max_len,
22+
return_token_type_ids=True,
2123
padding='max_length',
24+
truncation=True,
25+
return_attention_mask=True,
2226
return_tensors='pt')
2327

2428
return {

main.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from model import TextBackbone
88
from data import ReviewsDataset, load_data
99

10+
from transformers import AdamW
11+
1012
def parse_arguments():
1113
parser = argparse.ArgumentParser(
1214
description='Simple Sentiment Analysis with PyTorch and Transformers'
@@ -15,7 +17,7 @@ def parse_arguments():
1517
parser.add_argument('--n_classes', default=2, type=int, help='number of classes')
1618

1719
parser.add_argument('--data_path', type=str, default='data/data.txt', help='the path of dataset')
18-
parser.add_argument('--batch_size', default=20, type=int, help='batch size')
20+
parser.add_argument('--batch_size', default=8, type=int, help='batch size')
1921

2022
parser.add_argument('--epochs', default=50, type=int, help='number of epochs tp train for')
2123
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
@@ -28,9 +30,12 @@ def parse_arguments():
2830
def train(model, dataset, optimizer, device, batch_size, epochs):
2931
model.train()
3032
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
31-
pbar = tqdm(train_loader)
33+
34+
min_loss = float('inf')
3235
for epoch in range(epochs):
36+
pbar = tqdm(train_loader)
3337
pbar.set_description("Epoch {}:".format(epoch))
38+
total_loss = 0
3439
for batch in pbar:
3540
batch = {key: value.to(device) for key, value in batch.items()}
3641
optimizer.zero_grad()
@@ -39,9 +44,13 @@ def train(model, dataset, optimizer, device, batch_size, epochs):
3944
loss.backward()
4045
optimizer.step()
4146
pbar.set_postfix(loss=loss.item())
47+
total_loss += loss.item()
4248

43-
if epoch % 10 == 0:
44-
model.save('model_{}.pth'.format(epoch))
49+
if total_loss < min_loss:
50+
min_loss = total_loss
51+
torch.save(model.state_dict(), 'output/model_best.pth')
52+
53+
print("Epoch {}: Average loss: {}".format(epoch, total_loss / len(train_loader)))
4554

4655
return model
4756

@@ -51,7 +60,8 @@ def main():
5160
reviews, targets = load_data(args.data_path)
5261
dataset = ReviewsDataset(reviews, targets)
5362
model = TextBackbone(num_classes=args.n_classes).to(args.device)
54-
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
63+
optimizer = AdamW(model.parameters(),lr=2e-5, eps=1e-8)
64+
5565
model = train(model, dataset, optimizer, args.device, args.batch_size, args.epochs)
5666
torch.save(model.state_dict(), 'model.pth')
5767

output/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Output

0 commit comments

Comments
 (0)