Replies: 1 comment
-
|
now could deal with any model, like that import torch
import torch.nn as nn
class MyFeatureBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.flatten = nn.Flatten()
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.block = MyFeatureBlock(3, 16)
self.fc = nn.Linear(16 * 224 * 224, 10)
def init_weights(m):
# 只要是有权重的层,全部初始化为 1.0f
if hasattr(m, 'weight') and m.weight is not None:
nn.init.constant_(m.weight, 1.0)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 1.0)
model = MyModel()
model.apply(init_weights) # 递归应用初始化
torch.save(model, "custom_model_ones.pt")java log |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I wrote some code to tackle a difficult task: reading a PyTorch model saved as a Python pickle file and translating it into a JavaCPP-PyTorch model. This is quite a challenge. In ten years, no one has managed to implement it because it's just too difficult and challenging.
This is currently just a proof of concept, and I'm showing you the logs. For full support, a lot more layer type mappings would need to be implemented. This time, we primarily mapped Sequential and linear.
This feature is worth improving; in the future, Java should be able to load the original pth models directly, which would be very appealing to users.
log
Beta Was this translation helpful? Give feedback.
All reactions