@@ -170,6 +170,12 @@ use_backend <- function(backend, gpu = NA) {
170170 reticulate :: import(" os" )$ environ $ update(list (KERAS_BACKEND = backend ))
171171 }
172172
173+ # tensorflow requirements are by default registered from .onLoad (unless KERAS_BACKEND envvar is set). Undo that action first.
174+ # in case user has multiple conflicting `use_backend()` calls, last one wins
175+ py_require_remove_all_tensorflow()
176+ py_require_remove_all_jax()
177+ py_require_remove_all_torch()
178+
173179 set_envvar(" UV_CONSTRAINT" , pkg_file(" keras-constraints.txt" ),
174180 action = " append" , sep = " " , unique = TRUE )
175181
@@ -184,15 +190,12 @@ use_backend <- function(backend, gpu = NA) {
184190 if (gpu ) {
185191 py_require(c(" tensorflow" , " tensorflow-metal" ))
186192 } else {
187- py_require(action = " remove" , c(" tensorflow-macos" , " tensorflow-metal" ))
188193 py_require(" tensorflow" )
189194 }
190195
191196 },
192197
193198 macOS_jax = {
194- py_require(c(" tensorflow-metal" , " tensorflow-macos" ), action = " remove" )
195-
196199 if (is.na(gpu ))
197200 gpu <- TRUE
198201
@@ -207,72 +210,61 @@ use_backend <- function(backend, gpu = NA) {
207210 if (isTRUE(gpu ))
208211 warning(" GPU usage not supported on macOS. Please use a different backend to use the GPU (jax)" )
209212
210- py_require(c(" tensorflow-metal" , " tensorflow-macos" ), action = " remove" )
211-
212213 py_require(c(" tensorflow" , " torch" , " torchvision" , " torchaudio" ))
213214 },
214215
215216 macOS_numpy = {
216- py_require(c(" tensorflow-metal" , " tensorflow-macos" ), action = " remove" )
217217 py_require(c(" tensorflow" , " numpy" , " jax[cpu]" )) # numpy backend requires jax for some image ops
218218 },
219219
220+
220221 Linux_tensorflow = {
221- py_require(c(" jax[cuda12]" , " jax[cpu]" ), action = " remove" )
222222
223223 if (is.na(gpu ))
224224 gpu <- has_gpu()
225225
226226 if (gpu ) {
227- uv_unset_override_tf_cpu()
228- py_require(action = " remove" , c(" tensorflow" , " tensorflow-cpu" ))
229227 py_require(" tensorflow[and-cuda]" )
230228 } else {
231- uv_set_override_tf_cpu ()
229+ py_require_tensorflow_cpu ()
232230 }
233231 },
234232
235233 Linux_jax = {
236- py_require(action = " remove" ,
237- c(" tensorflow" , " tensorflow[and-cuda]" ,
238- " jax[cuda12]" , " jax[cpu]" ))
239- uv_set_override_tf_cpu()
234+ py_require_tensorflow_cpu()
240235
241236 if (is.na(gpu ))
242237 gpu <- has_gpu()
243238
244239 if (gpu ) {
245240 Sys.setenv(" XLA_PYTHON_CLIENT_MEM_FRACTION" = " 1.00" )
246- py_require(c(" tensorflow-cpu " , " jax[cuda12]!=0.6.1" ))
241+ py_require(c(" jax[cuda12]!=0.6.1" ))
247242 } else {
248- py_require(c(" tensorflow-cpu " , " jax[cpu]" ))
243+ py_require(c(" jax[cpu]" ))
249244 }
250245 },
251246
252247 Linux_torch = {
253- py_require(c(" tensorflow" , " tensorflow[and-cuda]" ), action = " remove" )
254- uv_set_override_tf_cpu()
248+ py_require_tensorflow_cpu()
255249
256250 if (is.na(gpu ))
257251 gpu <- has_gpu()
258252
259253 if (gpu ) {
260- py_require(c(" tensorflow-cpu " , " torch" , " torchvision" , " torchaudio" ))
254+ py_require(c(" torch" , " torchvision" , " torchaudio" ))
261255 } else {
262- Sys.setenv(" UV_INDEX" = trimws(paste(sep = " " ,
263- " https://download.pytorch.org/whl/cpu" ,
264- Sys.getenv(" UV_INDEX" )
265- )))
266- py_require(c(" tensorflow-cpu" , " torch" , " torchvision" , " torchaudio" ))
256+ set_envvar(" UV_INDEX" , " https://download.pytorch.org/whl/cpu" ,
257+ action = " append" , sep = " " , unique = TRUE )
258+ py_require(c(" torch" , " torchvision" , " torchaudio" ))
267259 }
268260 },
269261
270262 Linux_numpy = {
271- uv_set_override_tf_cpu()
272- py_require(c(" tensorflow" , " tensorflow[and-cuda]" ), action = " remove" )
273- py_require(c(" tensorflow-cpu" , " numpy" , " jax[cpu]" ))
263+ py_require_tensorflow_cpu()
264+ py_require(c(" numpy" , " jax[cpu]" ))
274265 },
275266
267+
276268 Windows_tensorflow = {
277269 if (isTRUE(gpu )) warning(" GPU usage not supported on Windows. Please use WSL." )
278270 py_require(c(" tensorflow" , " numpy<2" ))
@@ -288,10 +280,8 @@ use_backend <- function(backend, gpu = NA) {
288280 gpu <- FALSE
289281
290282 if (gpu ) {
291- Sys.setenv(" UV_INDEX" = trimws(paste(sep = " " ,
292- " https://download.pytorch.org/whl/cu126" ,
293- Sys.getenv(" UV_INDEX" )
294- )))
283+ set_envvar(" UV_INDEX" , " https://download.pytorch.org/whl/cu129" ,
284+ action = " append" , sep = " " , unique = TRUE )
295285 py_require(c(" tensorflow" , " torch" , " torchvision" , " torchaudio" ))
296286 } else {
297287 py_require(c(" tensorflow" , " torch" , " torchvision" , " torchaudio" ))
@@ -329,6 +319,7 @@ set_envvar <- function(
329319 )
330320 if (unique ) {
331321 value <- unique(unlist(strsplit(value , sep , fixed = TRUE )))
322+ value <- value [nzchar(value )]
332323 value <- paste0(value , collapse = sep )
333324 }
334325 }
@@ -339,20 +330,63 @@ set_envvar <- function(
339330 invisible (old )
340331}
341332
342- uv_set_override_tf_cpu <- function () {
343- py_require(action = " remove" , c(
344- " tensorflow" , " tensorflow[and-cuda]" , " tensorflow-cpu" ,
345- " tensorflow-metal" , " tensorflow-macos"
346- ))
347- py_require(if (is_linux()) " tensorflow-cpu" else " tensorflow" )
348- set_envvar(" UV_OVERRIDE" , pkg_file(" tf-cpu-override.txt" ),
333+
334+ py_require_remove_all_tensorflow <- function () {
335+ pkgs <- py_require()$ packages
336+ tf_pkgs <- grep(
337+ " ^tensorflow(-cpu|-metal|-macos|\\ [and-cuda\\ ])?[=~*!<>0-9.]*$" ,
338+ pkgs , value = TRUE
339+ )
340+ py_require(tf_pkgs , action = " remove" )
341+ uv_unset_override_never_tensorflow()
342+ }
343+
344+ py_require_remove_all_jax <- function () {
345+ pkgs <- py_require()$ packages
346+ jax_pkgs <- grep(
347+ " ^(jax(-metal)?|jax\\ [[^]]*\\ ]|jaxlib)[=~*!<>0-9A-Za-z.+-]*$" ,
348+ pkgs , value = TRUE
349+ )
350+ py_require(jax_pkgs , action = " remove" )
351+ }
352+
353+ py_require_remove_all_torch <- function () {
354+ pkgs <- py_require()$ packages
355+ torch_pkgs <- grep(
356+ " ^(torch|torchvision|torchaudio)(\\ [[^]]+\\ ])?[=~*!<>0-9A-Za-z.+-]*$" ,
357+ pkgs , value = TRUE , perl = TRUE
358+ )
359+ py_require(torch_pkgs , action = " remove" )
360+ uv_unset_index_download_pytorch()
361+ }
362+
363+ py_require_tensorflow_cpu <- function () {
364+ if (is_linux()) {
365+
366+ # pin 2.18.* because later versions of 'tensorflow-cpu' are not
367+ # compatible with 'tensorflow-text', used by 'keras-hub'
368+ py_require(" tensorflow-cpu==2.18.*" )
369+
370+ # set override so tensorflow-text is prevented from pulling in 'tensorflow'
371+ uv_set_override_never_tensorflow()
372+
373+ } else {
374+ # macOS and Windows only support CPU
375+ py_require(" tensorflow" )
376+ }
377+ }
378+
379+ uv_set_override_never_tensorflow <- function () {
380+ # packages like tensorflow-text pull in tensorflow, even if we specify
381+ # tensorflow-cpu. This override it to allow forcing `tensorflow-cpu`
382+ set_envvar(" UV_OVERRIDE" , pkg_file(" never-tensorflow-override.txt" ),
349383 action = " append" , sep = " " , unique = TRUE )
350384}
351385
352- uv_unset_override_tf_cpu <- function () {
386+ uv_unset_override_never_tensorflow <- function () {
353387 override <- Sys.getenv(" UV_OVERRIDE" , NA )
354388 if (is.na(override )) return ()
355- cpu_override <- pkg_file(" tf-cpu -override.txt" )
389+ cpu_override <- pkg_file(" never-tensorflow -override.txt" )
356390 if (override == cpu_override ) {
357391 Sys.unsetenv(override )
358392 } else {
@@ -363,6 +397,27 @@ uv_unset_override_tf_cpu <- function() {
363397 invisible (override )
364398}
365399
400+ uv_unset_index_download_pytorch <- function () {
401+ index <- Sys.getenv(" UV_INDEX" , NA )
402+ if (is.na(index ) || ! nzchar(index ))
403+ return (invisible (index ))
404+
405+ entries <- strsplit(trimws(index ), " [[:space:]]+" )[[1L ]]
406+ entries <- entries [nzchar(entries )]
407+ if (! length(entries ))
408+ return (invisible (index ))
409+
410+ keep <- entries [! startsWith(entries , " https://download.pytorch.org/whl/" )]
411+
412+ if (length(keep )) {
413+ Sys.setenv(" UV_INDEX" = paste(keep , collapse = " " ))
414+ } else {
415+ Sys.unsetenv(" UV_INDEX" )
416+ }
417+
418+ invisible (index )
419+ }
420+
366421get_os <- function () {
367422 if (is_windows()) " Windows" else if (is_mac_arm64()) " macOS" else " Linux"
368423}
@@ -381,7 +436,7 @@ is_keras_loaded <- function() {
381436}
382437
383438pkg_file <- function (... , package = " keras3" ) {
384- path <- system.file(... , package = " keras3 " , mustWork = TRUE )
439+ path <- system.file(... , package = package , mustWork = TRUE )
385440 if (is_windows())
386441 path <- utils :: shortPathName(path )
387442 path
0 commit comments