-
-
Notifications
You must be signed in to change notification settings - Fork 8
Convert single column binary predictions to two #375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
probably should add a test case |
N_data <- length(predict_tensor) | ||
vec_dim <- c(N_data, 1) | ||
pos_scores <- predict_tensor$reshape(vec_dim) | ||
neg_scores <- torch::torch_zeros(N_data)$reshape(vec_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for justification of using zero for negative scores before putting them through softmax, see my slides, section "Multi-class classification" https://github.com/tdhock/2023-08-deep-learning/blob/main/slides/torch-part1/06-backprop.pdf
I need to think a bit about this to be sure how to handle this nicely (e.g. via |
Also when we introduce this, I think it should be done consistently across classification learners. The library(mlr3torch)
design = benchmark_grid(
# sonar is binary, iris multi-class
tasks = tsks(c("sonar", "iris")),
learners = lrn("classif.mlp", epochs = 10, batch_size = 16, loss = t_loss("cross_entropy")),
resampling = rsmp("cv")
)
benchmark(design) I think the solution is to make |
agree about "make t_loss("cross_entropy") generate a nn_bce_loss_with_logits when it encounters a binary classification problem and otherwise a nn_cross_entropy_loss." not sure I understand "load the targets should be attached to the loss function, which currently is not the case. Maybe even the prediction encoder should be attached to the loss" |
When The prediction encoder defines how a torch prediction tensor as output by the underlying network is converted to an |
agree with "a runtime check on the dimension as done here is a little hacky so ideally I would like to do this differently (and solve it more generically)." |
Superseded by #385 |
Closes #374