-
Notifications
You must be signed in to change notification settings - Fork 60
Open
Description
我是用pytorch的
class DetectNet(nn.Module):
def init(self):
super(DetectNet, self).init()
self.model = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(128 * 6 * 6, 128),
nn.ReLU(),
nn.Linear(128, 15))
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, x):
return self.model(x)
def loss(self, x, t):
labels_tensor = torch.tensor(t, dtype=torch.long)
input_tensor = torch.from_numpy(x).float() # 转换为 PyTorch 张量并确保为 float 类型
xret = self.forward(input_tensor)
ok = torch.argmax(xret, dim=1)
equal_elements = np.equal(ok.numpy(), t)
count_equal = np.sum(equal_elements)
# 计算损失
loss = self.loss_fn(xret, labels_tensor)
return loss, count_equal
没落模型如上,预测的结果如下:
20=5=4
3+4=
9=5=5
5=244
4+24
23=545
8=5=
2+2=
4+2=
4=24
8=849
3=247
8=3=4
2+4=5
8=8=
2+846
8=34
2=249
22=84
8=247
8=24
20=2=
8=5=
20=7=
可以看到效果很不好,很多地方都预测错误了
Metadata
Metadata
Assignees
Labels
No labels