33#
44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
6-
6+ import argparse
77import os .path
88import runpy
99import subprocess
10- from typing import List
10+ from typing import List , Tuple
1111
1212# required env vars:
1313# CU_VERSION: E.g. cu112
2323source_root_dir = os .environ ["PWD" ]
2424
2525
26- def version_constraint (version ):
26+ def version_constraint (version ) -> str :
2727 """
2828 Given version "11.3" returns " >=11.3,<11.4"
2929 """
@@ -32,7 +32,7 @@ def version_constraint(version):
3232 return f" >={ version } ,<{ upper } "
3333
3434
35- def get_cuda_major_minor ():
35+ def get_cuda_major_minor () -> Tuple [ str , str ] :
3636 if CU_VERSION == "cpu" :
3737 raise ValueError ("fn only for cuda builds" )
3838 if len (CU_VERSION ) != 5 or CU_VERSION [:2 ] != "cu" :
@@ -42,11 +42,17 @@ def get_cuda_major_minor():
4242 return major , minor
4343
4444
45- def setup_cuda () :
45+ def setup_cuda (use_conda_cuda : bool ) -> None :
4646 if CU_VERSION == "cpu" :
4747 return
4848 major , minor = get_cuda_major_minor ()
49- os .environ ["CUDA_HOME" ] = f"/usr/local/cuda-{ major } .{ minor } /"
49+ if use_conda_cuda :
50+ os .environ ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1" ] = "- cudatoolkit"
51+ os .environ ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2" ] = (
52+ f"- cuda-version={ major } .{ minor } "
53+ )
54+ else :
55+ os .environ ["CUDA_HOME" ] = f"/usr/local/cuda-{ major } .{ minor } /"
5056 os .environ ["FORCE_CUDA" ] = "1"
5157
5258 basic_nvcc_flags = (
@@ -95,7 +101,7 @@ def setup_conda_pytorch_constraint() -> List[str]:
95101 return ["-c" , "pytorch" , "-c" , "nvidia" ]
96102
97103
98- def setup_conda_cudatoolkit_constraint ():
104+ def setup_conda_cudatoolkit_constraint () -> None :
99105 if CU_VERSION == "cpu" :
100106 os .environ ["CONDA_CPUONLY_FEATURE" ] = "- cpuonly"
101107 os .environ ["CONDA_CUDATOOLKIT_CONSTRAINT" ] = ""
@@ -116,14 +122,25 @@ def setup_conda_cudatoolkit_constraint():
116122 os .environ ["CONDA_CUDATOOLKIT_CONSTRAINT" ] = toolkit
117123
118124
119- def do_build (start_args : List [str ]):
125+ def do_build (start_args : List [str ]) -> None :
120126 args = start_args .copy ()
121127
122128 test_flag = os .environ .get ("TEST_FLAG" )
123129 if test_flag is not None :
124130 args .append (test_flag )
125131
126- args .extend (["-c" , "bottler" , "-c" , "iopath" , "-c" , "conda-forge" ])
132+ args .extend (
133+ [
134+ "-c" ,
135+ "bottler" ,
136+ "-c" ,
137+ "iopath" ,
138+ "-c" ,
139+ "conda-forge" ,
140+ "-c" ,
141+ "nvidia/label/cuda-12.1.0" ,
142+ ]
143+ )
127144 args .append ("--no-anaconda-upload" )
128145 args .extend (["--python" , os .environ ["PYTHON_VERSION" ]])
129146 args .append ("packaging/pytorch3d" )
@@ -132,8 +149,16 @@ def do_build(start_args: List[str]):
132149
133150
134151if __name__ == "__main__" :
152+ parser = argparse .ArgumentParser (description = "Build the conda package." )
153+ parser .add_argument (
154+ "--use-conda-cuda" ,
155+ action = "store_true" ,
156+ help = "get cuda from conda ignoring local cuda" ,
157+ )
158+ our_args = parser .parse_args ()
159+
135160 args = ["conda" , "build" ]
136- setup_cuda ()
161+ setup_cuda (use_conda_cuda = our_args . use_conda_cuda )
137162
138163 init_path = source_root_dir + "/pytorch3d/__init__.py"
139164 build_version = runpy .run_path (init_path )["__version__" ]
0 commit comments