Skip to content

Commit 5ea1e7d

Browse files
committed
changes for #68
1 parent d94ccde commit 5ea1e7d

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
* Various changes and improvements to error and warning messages.
66

7+
* Fixed a bug that occurred when linear activation was used for neural networks (#68).
8+
79
# brulee 0.3.0
810

911
* Fixed bug where `coef()` didn't would error if used on a `brulee_logistic_reg()` that was trained with a recipe. (#66)

R/activation.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ brulee_activations <- function() {
1515
get_activation_fn <- function(arg, ...) {
1616

1717
if (arg == "linear") {
18-
res <- identity
19-
} else {
20-
cl <- rlang::call2(paste0("nn_", arg), .ns = "torch")
21-
res <- rlang::eval_bare(cl)
18+
arg <- "identity"
2219
}
2320

21+
cl <- rlang::call2(paste0("nn_", arg), .ns = "torch")
22+
res <- rlang::eval_bare(cl)
23+
2424
res
2525
}

tests/testthat/test-mlp-binary.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,19 @@ test_that("binomial mlp case weights", {
226226
sum(unweighted_pred$.pred_class == "class_2")
227227
)
228228
})
229+
230+
test_that('linear activations', {
231+
# See https://github.com/tidymodels/brulee/issues/68
232+
skip_if(!torch::torch_is_installed())
233+
skip_if_not_installed("modeldata")
234+
235+
data(bivariate, package = "modeldata")
236+
set.seed(20)
237+
nn_log_biv <-
238+
try(
239+
brulee_mlp(Class ~ log(A) + log(B), data = bivariate_train,
240+
epochs = 150, hidden_units = 3, activation = "linear"),
241+
silent = TRUE)
242+
expect_s3_class(nn_log_biv, "brulee_mlp")
243+
244+
})

0 commit comments

Comments
 (0)