Skip to content

用代码生成的字体用于训练以后,再去预测手写体,结果非常不好,请问是哪方面的原因呢? #2

@mnzn2530

Description

@mnzn2530

我是用pytorch的
class DetectNet(nn.Module):
def init(self):
super(DetectNet, self).init()
self.model = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),

        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        
        nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(128 * 6 * 6, 128),
        nn.ReLU(),
        nn.Linear(128, 15))
    self.loss_fn = nn.CrossEntropyLoss()
    
    
def forward(self, x):
    return self.model(x)

def loss(self, x, t):
    labels_tensor = torch.tensor(t, dtype=torch.long)

    input_tensor = torch.from_numpy(x).float()  # 转换为 PyTorch 张量并确保为 float 类型
    
    xret = self.forward(input_tensor)
    
    ok = torch.argmax(xret, dim=1)
    
    equal_elements = np.equal(ok.numpy(), t)
    
    count_equal = np.sum(equal_elements)
    
    # 计算损失
    loss = self.loss_fn(xret, labels_tensor)
    return loss, count_equal

没落模型如上,预测的结果如下:

20=5=4
3+4=
9=5=5
5=244
4+24
23=545
8=5=
2+2=
4+2=
4=24
8=849
3=247
8=3=4
2+4=5
8=8=
2+846
8=34
2=249
22=84
8=247
8=24
20=2=
8=5=
20=7=

可以看到效果很不好,很多地方都预测错误了

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions