Skip to content

Commit c701a23

Browse files
authored
Merge pull request #1522 from rstudio/use_backend-updates
Updates to `use_backend()`
2 parents 15923ca + c080399 commit c701a23

File tree

11 files changed

+119
-53
lines changed

11 files changed

+119
-53
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ export(image_smart_resize)
259259
export(image_to_array)
260260
export(imagenet_decode_predictions)
261261
export(imagenet_preprocess_input)
262+
export(import)
262263
export(initializer_constant)
263264
export(initializer_glorot_normal)
264265
export(initializer_glorot_uniform)
@@ -805,6 +806,7 @@ export(optimizer_sgd)
805806
export(pad_sequences)
806807
export(pop_layer)
807808
export(predict_on_batch)
809+
export(py_require)
808810
export(quantize_weights)
809811
export(random_beta)
810812
export(random_binomial)
@@ -886,6 +888,7 @@ importFrom(reticulate,py_has_attr)
886888
importFrom(reticulate,py_install)
887889
importFrom(reticulate,py_is_null_xptr)
888890
importFrom(reticulate,py_iterator)
891+
importFrom(reticulate,py_require)
889892
importFrom(reticulate,py_str)
890893
importFrom(reticulate,py_to_r)
891894
importFrom(reticulate,py_to_r_wrapper)

R/install.R

Lines changed: 96 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ use_backend <- function(backend, gpu = NA) {
170170
reticulate::import("os")$environ$update(list(KERAS_BACKEND = backend))
171171
}
172172

173+
# tensorflow requirements are by default registered from .onLoad (unless KERAS_BACKEND envvar is set). Undo that action first.
174+
# in case user has multiple conflicting `use_backend()` calls, last one wins
175+
py_require_remove_all_tensorflow()
176+
py_require_remove_all_jax()
177+
py_require_remove_all_torch()
178+
173179
set_envvar("UV_CONSTRAINT", pkg_file("keras-constraints.txt"),
174180
action = "append", sep = " ", unique = TRUE)
175181

@@ -184,15 +190,12 @@ use_backend <- function(backend, gpu = NA) {
184190
if (gpu) {
185191
py_require(c("tensorflow", "tensorflow-metal"))
186192
} else {
187-
py_require(action = "remove", c("tensorflow-macos", "tensorflow-metal"))
188193
py_require("tensorflow")
189194
}
190195

191196
},
192197

193198
macOS_jax = {
194-
py_require(c("tensorflow-metal", "tensorflow-macos"), action = "remove")
195-
196199
if (is.na(gpu))
197200
gpu <- TRUE
198201

@@ -207,72 +210,61 @@ use_backend <- function(backend, gpu = NA) {
207210
if(isTRUE(gpu))
208211
warning("GPU usage not supported on macOS. Please use a different backend to use the GPU (jax)")
209212

210-
py_require(c("tensorflow-metal", "tensorflow-macos"), action = "remove")
211-
212213
py_require(c("tensorflow", "torch", "torchvision", "torchaudio"))
213214
},
214215

215216
macOS_numpy = {
216-
py_require(c("tensorflow-metal", "tensorflow-macos"), action = "remove")
217217
py_require(c("tensorflow", "numpy", "jax[cpu]")) # numpy backend requires jax for some image ops
218218
},
219219

220+
220221
Linux_tensorflow = {
221-
py_require(c("jax[cuda12]", "jax[cpu]"), action = "remove")
222222

223223
if (is.na(gpu))
224224
gpu <- has_gpu()
225225

226226
if (gpu) {
227-
uv_unset_override_tf_cpu()
228-
py_require(action = "remove", c("tensorflow", "tensorflow-cpu"))
229227
py_require("tensorflow[and-cuda]")
230228
} else {
231-
uv_set_override_tf_cpu()
229+
py_require_tensorflow_cpu()
232230
}
233231
},
234232

235233
Linux_jax = {
236-
py_require(action = "remove",
237-
c("tensorflow", "tensorflow[and-cuda]",
238-
"jax[cuda12]", "jax[cpu]"))
239-
uv_set_override_tf_cpu()
234+
py_require_tensorflow_cpu()
240235

241236
if (is.na(gpu))
242237
gpu <- has_gpu()
243238

244239
if (gpu) {
245240
Sys.setenv("XLA_PYTHON_CLIENT_MEM_FRACTION" = "1.00")
246-
py_require(c("tensorflow-cpu", "jax[cuda12]!=0.6.1"))
241+
py_require(c("jax[cuda12]!=0.6.1"))
247242
} else {
248-
py_require(c("tensorflow-cpu", "jax[cpu]"))
243+
py_require(c("jax[cpu]"))
249244
}
250245
},
251246

252247
Linux_torch = {
253-
py_require(c("tensorflow", "tensorflow[and-cuda]"), action = "remove")
254-
uv_set_override_tf_cpu()
248+
py_require_tensorflow_cpu()
255249

256250
if (is.na(gpu))
257251
gpu <- has_gpu()
258252

259253
if (gpu) {
260-
py_require(c("tensorflow-cpu", "torch", "torchvision", "torchaudio"))
254+
py_require(c("torch", "torchvision", "torchaudio"))
261255
} else {
262-
Sys.setenv("UV_INDEX" = trimws(paste(sep = " ",
263-
"https://download.pytorch.org/whl/cpu",
264-
Sys.getenv("UV_INDEX")
265-
)))
266-
py_require(c("tensorflow-cpu", "torch", "torchvision", "torchaudio"))
256+
set_envvar("UV_INDEX", "https://download.pytorch.org/whl/cpu",
257+
action = "append", sep = " ", unique = TRUE)
258+
py_require(c("torch", "torchvision", "torchaudio"))
267259
}
268260
},
269261

270262
Linux_numpy = {
271-
uv_set_override_tf_cpu()
272-
py_require(c("tensorflow", "tensorflow[and-cuda]"), action = "remove")
273-
py_require(c("tensorflow-cpu", "numpy", "jax[cpu]"))
263+
py_require_tensorflow_cpu()
264+
py_require(c("numpy", "jax[cpu]"))
274265
},
275266

267+
276268
Windows_tensorflow = {
277269
if(isTRUE(gpu)) warning("GPU usage not supported on Windows. Please use WSL.")
278270
py_require(c("tensorflow", "numpy<2"))
@@ -288,10 +280,8 @@ use_backend <- function(backend, gpu = NA) {
288280
gpu <- FALSE
289281

290282
if (gpu) {
291-
Sys.setenv("UV_INDEX" = trimws(paste(sep = " ",
292-
"https://download.pytorch.org/whl/cu126",
293-
Sys.getenv("UV_INDEX")
294-
)))
283+
set_envvar("UV_INDEX", "https://download.pytorch.org/whl/cu129",
284+
action = "append", sep = " ", unique = TRUE)
295285
py_require(c("tensorflow", "torch", "torchvision", "torchaudio"))
296286
} else {
297287
py_require(c("tensorflow", "torch", "torchvision", "torchaudio"))
@@ -329,6 +319,7 @@ set_envvar <- function(
329319
)
330320
if (unique) {
331321
value <- unique(unlist(strsplit(value, sep, fixed = TRUE)))
322+
value <- value[nzchar(value)]
332323
value <- paste0(value, collapse = sep)
333324
}
334325
}
@@ -339,20 +330,63 @@ set_envvar <- function(
339330
invisible(old)
340331
}
341332

342-
uv_set_override_tf_cpu <- function() {
343-
py_require(action = "remove", c(
344-
"tensorflow", "tensorflow[and-cuda]", "tensorflow-cpu",
345-
"tensorflow-metal", "tensorflow-macos"
346-
))
347-
py_require(if (is_linux()) "tensorflow-cpu" else "tensorflow")
348-
set_envvar("UV_OVERRIDE", pkg_file("tf-cpu-override.txt"),
333+
334+
py_require_remove_all_tensorflow <- function() {
335+
pkgs <- py_require()$packages
336+
tf_pkgs <- grep(
337+
"^tensorflow(-cpu|-metal|-macos|\\[and-cuda\\])?[=~*!<>0-9.]*$",
338+
pkgs, value = TRUE
339+
)
340+
py_require(tf_pkgs, action = "remove")
341+
uv_unset_override_never_tensorflow()
342+
}
343+
344+
py_require_remove_all_jax <- function() {
345+
pkgs <- py_require()$packages
346+
jax_pkgs <- grep(
347+
"^(jax(-metal)?|jax\\[[^]]*\\]|jaxlib)[=~*!<>0-9A-Za-z.+-]*$",
348+
pkgs, value = TRUE
349+
)
350+
py_require(jax_pkgs, action = "remove")
351+
}
352+
353+
py_require_remove_all_torch <- function() {
354+
pkgs <- py_require()$packages
355+
torch_pkgs <- grep(
356+
"^(torch|torchvision|torchaudio)(\\[[^]]+\\])?[=~*!<>0-9A-Za-z.+-]*$",
357+
pkgs, value = TRUE, perl = TRUE
358+
)
359+
py_require(torch_pkgs, action = "remove")
360+
uv_unset_index_download_pytorch()
361+
}
362+
363+
py_require_tensorflow_cpu <- function() {
364+
if (is_linux()) {
365+
366+
# pin 2.18.* because later versions of 'tensorflow-cpu' are not
367+
# compatible with 'tensorflow-text', used by 'keras-hub'
368+
py_require("tensorflow-cpu==2.18.*")
369+
370+
# set override so tensorflow-text is prevented from pulling in 'tensorflow'
371+
uv_set_override_never_tensorflow()
372+
373+
} else {
374+
# macOS and Windows only support CPU
375+
py_require("tensorflow")
376+
}
377+
}
378+
379+
uv_set_override_never_tensorflow <- function() {
380+
# packages like tensorflow-text pull in tensorflow, even if we specify
381+
# tensorflow-cpu. This override it to allow forcing `tensorflow-cpu`
382+
set_envvar("UV_OVERRIDE", pkg_file("never-tensorflow-override.txt"),
349383
action = "append", sep = " ", unique = TRUE)
350384
}
351385

352-
uv_unset_override_tf_cpu <- function() {
386+
uv_unset_override_never_tensorflow <- function() {
353387
override <- Sys.getenv("UV_OVERRIDE", NA)
354388
if (is.na(override)) return()
355-
cpu_override <- pkg_file("tf-cpu-override.txt")
389+
cpu_override <- pkg_file("never-tensorflow-override.txt")
356390
if (override == cpu_override) {
357391
Sys.unsetenv(override)
358392
} else {
@@ -363,6 +397,27 @@ uv_unset_override_tf_cpu <- function() {
363397
invisible(override)
364398
}
365399

400+
uv_unset_index_download_pytorch <- function() {
401+
index <- Sys.getenv("UV_INDEX", NA)
402+
if (is.na(index) || !nzchar(index))
403+
return(invisible(index))
404+
405+
entries <- strsplit(trimws(index), "[[:space:]]+")[[1L]]
406+
entries <- entries[nzchar(entries)]
407+
if (!length(entries))
408+
return(invisible(index))
409+
410+
keep <- entries[!startsWith(entries, "https://download.pytorch.org/whl/")]
411+
412+
if (length(keep)) {
413+
Sys.setenv("UV_INDEX" = paste(keep, collapse = " "))
414+
} else {
415+
Sys.unsetenv("UV_INDEX")
416+
}
417+
418+
invisible(index)
419+
}
420+
366421
get_os <- function() {
367422
if (is_windows()) "Windows" else if (is_mac_arm64()) "macOS" else "Linux"
368423
}
@@ -381,7 +436,7 @@ is_keras_loaded <- function() {
381436
}
382437

383438
pkg_file <- function(..., package = "keras3") {
384-
path <- system.file(..., package = "keras3", mustWork = TRUE)
439+
path <- system.file(..., package = package, mustWork = TRUE)
385440
if(is_windows())
386441
path <- utils::shortPathName(path)
387442
path

R/package.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ keras <- NULL
7070
Sys.setenv(RETICULATE_PYTHON = keras_python)
7171

7272
py_require(c(
73-
"keras", "pydot", "scipy", "pandas", "Pillow",
74-
"ipython" #, "tensorflow_datasets"
73+
"keras", "pydot", "scipy", "pandas", "Pillow", "ipython"
74+
#, "tensorflow_datasets"
7575
))
7676

7777
if (is.na(Sys.getenv("KERAS_HOME", NA))) {

R/reexports.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ reticulate::iterate
6565
#' @export
6666
reticulate::as_iterator
6767

68+
#' @export
69+
reticulate::py_require
70+
71+
#' @export
72+
reticulate::import
73+
6874
#' @importFrom tensorflow tensorboard
6975
#' @export
7076
tensorflow::tensorboard

inst/keras-constraints.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,3 @@
44
# This is a workaround to nudge uv to resolve the latest keras-hub.
55
keras-hub>0.19.0
66

7-
8-
# tensorflow-text 2.19.* fails to load with tensorflow-cpu>=2.19.0
9-
tensorflow-cpu==2.18.*

inst/never-tensorflow-override.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# packages like tensorflow-text pull in tensorflow, even if we specify
2+
# tensorflow-cpu. This override it to allow forcing `tensorflow-cpu`
3+
tensorflow; sys_platform == "never"

man/layer_tfsm.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/metric_mean_absolute_percentage_error.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/op_erf.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/op_gelu.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)