|
1 |
| -# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved. |
| 1 | +# Copyright (c) 2019-2025, NVIDIA CORPORATION. All rights reserved. |
2 | 2 | # See file LICENSE for terms.
|
3 | 3 |
|
4 | 4 | """UCX-Py: Python bindings for UCX <www.openucx.org>"""
|
|
39 | 39 | except ImportError:
|
40 | 40 | pynvml = None
|
41 | 41 |
|
| 42 | +_ucx_version = get_ucx_version() |
| 43 | + |
| 44 | +__ucx_min_version__ = "1.15.0" |
| 45 | +__ucx_version__ = "%d.%d.%d" % _ucx_version |
| 46 | + |
| 47 | +if _ucx_version < tuple(int(i) for i in __ucx_min_version__.split(".")): |
| 48 | + raise ImportError( |
| 49 | + f"Support for UCX {__ucx_version__} has ended. Please upgrade to " |
| 50 | + f"{__ucx_min_version__} or newer. If you believe the wrong version " |
| 51 | + "is being loaded, please check the path from where UCX is loaded " |
| 52 | + "by rerunning with the environment variable `UCX_LOG_LEVEL=debug`." |
| 53 | + ) |
| 54 | + |
42 | 55 | # Setup UCX-Py logger
|
43 | 56 | logger = get_ucxpy_logger()
|
44 | 57 |
|
|
53 | 66 | if (
|
54 | 67 | pynvml is not None
|
55 | 68 | and "UCX_CUDA_COPY_MAX_REG_RATIO" not in os.environ
|
56 |
| - and get_ucx_version() >= (1, 12, 0) |
| 69 | + and _ucx_version >= (1, 12, 0) |
57 | 70 | ):
|
58 | 71 | try:
|
59 | 72 | pynvml.nvmlInit()
|
@@ -98,23 +111,11 @@ def _is_mig_device(handle):
|
98 | 111 | ):
|
99 | 112 | pass
|
100 | 113 |
|
101 |
| -if "UCX_MAX_RNDV_RAILS" not in os.environ and get_ucx_version() >= (1, 12, 0): |
| 114 | +if "UCX_MAX_RNDV_RAILS" not in os.environ and _ucx_version >= (1, 12, 0): |
102 | 115 | logger.info("Setting UCX_MAX_RNDV_RAILS=1")
|
103 | 116 | os.environ["UCX_MAX_RNDV_RAILS"] = "1"
|
104 | 117 |
|
105 |
| -if "UCX_PROTO_ENABLE" not in os.environ and get_ucx_version() >= (1, 12, 0): |
| 118 | +if "UCX_PROTO_ENABLE" not in os.environ and (1, 12, 0) <= _ucx_version < (1, 18, 0): |
106 | 119 | # UCX protov2 still doesn't support CUDA async/managed memory
|
107 | 120 | logger.info("Setting UCX_PROTO_ENABLE=n")
|
108 | 121 | os.environ["UCX_PROTO_ENABLE"] = "n"
|
109 |
| - |
110 |
| - |
111 |
| -__ucx_min_version__ = "1.15.0" |
112 |
| -__ucx_version__ = "%d.%d.%d" % get_ucx_version() |
113 |
| - |
114 |
| -if get_ucx_version() < tuple(int(i) for i in __ucx_min_version__.split(".")): |
115 |
| - raise ImportError( |
116 |
| - f"Support for UCX {__ucx_version__} has ended. Please upgrade to " |
117 |
| - f"{__ucx_min_version__} or newer. If you believe the wrong version " |
118 |
| - "is being loaded, please check the path from where UCX is loaded " |
119 |
| - "by rerunning with the environment variable `UCX_LOG_LEVEL=debug`." |
120 |
| - ) |
|
0 commit comments