Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 21, 2023
1 parent 8c08162 commit 80e7db7
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 167 deletions.
116 changes: 27 additions & 89 deletions misc/exports_to_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class TieAPIError(Exception):
pass
tie_last_exception = None
Expand Down Expand Up @@ -219,31 +219,21 @@ def is_bool(self):

def is_handle(self):
typename = self.base_type_name
return (
typename.startswith("Tie")
and typename.endswith("Handle")
and self.ptr_levels == 0
)
return typename.startswith("Tie") and typename.endswith("Handle") and self.ptr_levels == 0

def is_callback(self):
return self.base_type_name == "TieCallback" and self.ptr_levels == 0

def is_basic_type(self):
return (
self.base_type_name in C_BUILTIN_TYPE_TO_CTYPES_TYPE
and self.ptr_levels == 0
)
return self.base_type_name in C_BUILTIN_TYPE_TO_CTYPES_TYPE and self.ptr_levels == 0

def __str__(self):
return f"{self.base_type_name}{'*' * self.ptr_levels}"

def __eq__(self, other: object) -> bool:
if not isinstance(other, CType):
return False
return (
self.base_type_name == other.base_type_name
and self.ptr_levels == other.ptr_levels
)
return self.base_type_name == other.base_type_name and self.ptr_levels == other.ptr_levels


class EnumDecl:
Expand All @@ -268,7 +258,7 @@ def is_arr_param(self):

def is_cstr_param(self):
return self.type.is_cstr()

def is_basic_type_param(self):
return self.type.is_basic_type()

Expand Down Expand Up @@ -421,9 +411,7 @@ def printer(*args, verbose_level=1, **kwargs):
return printer


def parse_exports_header(
filename: str, cpp_path: str, cpp_args: List[str], printer
) -> ExportsHeader:
def parse_exports_header(filename: str, cpp_path: str, cpp_args: List[str], printer) -> ExportsHeader:
assert filename.endswith(".h")

printer(f"Parsing {filename} ...")
Expand Down Expand Up @@ -460,9 +448,7 @@ def translate_c_type_to_ctypes_type(type: CType, exclude_ptr_levels: int = 0) ->
elif type.is_cstr():
return "ctypes.c_char_p"
else:
return (
f"ctypes.POINTER({translate_c_type_to_ctypes_type(type.exclude_ptr(1), 0)})"
)
return f"ctypes.POINTER({translate_c_type_to_ctypes_type(type.exclude_ptr(1), 0)})"


def translate_func_decl(
Expand All @@ -476,9 +462,7 @@ def translate_func_decl(
)
tie = splited_by_[0]
if tie != "tie":
raise ValueError(
f"Invliad function name: {funcname}, which is not started with 'tie'"
)
raise ValueError(f"Invliad function name: {funcname}, which is not started with 'tie'")
class_name = splited_by_[1]
method_name = "_".join(splited_by_[2:])

Expand Down Expand Up @@ -524,7 +508,7 @@ def translate_arg_from_python_to_c(arg_name: str, param: FuncParameter) -> str:
elif param.is_cstr_param():
return f'{arg_name}.encode("utf-8")'
elif param.is_basic_type_param():
return f'{C_BUILTIN_TYPE_TO_PYTHON_TYPE[arg_type.base_type_name]}({arg_name})'
return f"{C_BUILTIN_TYPE_TO_PYTHON_TYPE[arg_type.base_type_name]}({arg_name})"
elif param.is_callback_param():
return f"ctypes.CFUNCTYPE(ctypes.c_int)(wrap_callback_to_c({arg_name}))" # NOTE: Maybe crash
elif param.is_in_param():
Expand Down Expand Up @@ -568,46 +552,28 @@ def fp_write(content: str):
in_params = original_func.in_params
out_params = original_func.out_params
# Func def
in_params = [
in_params[i]
for i in range(len(in_params))
if i == 0 or not in_params[i - 1].is_arr_param()
]
fp_write(
f"def {func_def_name}("
+ ", ".join([param.name for param in in_params])
+ "):\n"
)
in_params = [in_params[i] for i in range(len(in_params)) if i == 0 or not in_params[i - 1].is_arr_param()]
fp_write(f"def {func_def_name}(" + ", ".join([param.name for param in in_params]) + "):\n")
# Func body
args = []
args.extend(
[translate_arg_from_python_to_c(param.name, param) for param in in_params]
)
args.extend(
[translate_arg_from_python_to_c(param.name, param) for param in out_params]
)
args.extend([translate_arg_from_python_to_c(param.name, param) for param in in_params])
args.extend([translate_arg_from_python_to_c(param.name, param) for param in out_params])
for param in in_params:
if param.is_arr_param():
fp_write(
f"{tab}{param.name} = ({translate_c_type_to_ctypes_type(param.type, exclude_ptr_levels=1)} * len({param.name}))(*{param.name})\n"
)
for param in out_params:
fp_write(
f"{tab}{param.name} = {translate_c_type_to_ctypes_type(param.type, exclude_ptr_levels=1)}()\n"
)
fp_write(f"{tab}{param.name} = {translate_c_type_to_ctypes_type(param.type, exclude_ptr_levels=1)}()\n")
fp_write(f"{tab}ret = taichi_ccore.{original_func.name}(" + ", ".join(args) + ")\n")

# Process ret (error code)
assert original_func.ret_type == CType("int")
if func_def_name == "get_last_error": # NOTE: Avoid infinite recursion
fp_write(f"{tab}if ret != 0:\n")
fp_write(
f'{tab*2}raise RuntimeError(f"Failed to call get_last_error, err={{ret}}")\n'
)
fp_write(f'{tab*2}raise RuntimeError(f"Failed to call get_last_error, err={{ret}}")\n')
else:
fp_write(
f"{tab}ex = get_exception_to_throw_if_not_success(ret, *get_last_error())\n"
)
fp_write(f"{tab}ex = get_exception_to_throw_if_not_success(ret, *get_last_error())\n")
fp_write(f"{tab}if ex is not None:\n")
fp_write(f"{tab*2}raise ex\n")

Expand All @@ -626,9 +592,7 @@ def fp_write(content: str):
fp_write(f"{tab})\n")


def generate_py_module_from_exports_header(
dirname: str, header: ExportsHeader, printer
):
def generate_py_module_from_exports_header(dirname: str, header: ExportsHeader, printer):
COMMENT_HEADER = """# This file is auto-generated by misc/exports_to_py.py
# DO NOT edit this file manually!
# To regenerate this file, run:
Expand Down Expand Up @@ -675,9 +639,7 @@ def generate_py_module_from_exports_header(
)
f.write(COMMENT_HEADER)
f.write("\n")
f.write(
CCORE_PYTHON_FILE_FORMAT.format(exported_functions=exported_functions_def)
)
f.write(CCORE_PYTHON_FILE_FORMAT.format(exported_functions=exported_functions_def))

# Generate Python class from exported functions (class0.py, class1.py, ...)
classes: Mapping[str, ClassDecl] = {}
Expand All @@ -693,9 +655,7 @@ def generate_py_module_from_exports_header(
classes[class_name].methods[fn.method_name] = fn

for class_name, class_decl in classes.items():
printer(
f"Generating class {class_name} ({os.path.join(dirname, f'{class_name}.py')}) ..."
)
printer(f"Generating class {class_name} ({os.path.join(dirname, f'{class_name}.py')}) ...")
with open(os.path.join(dirname, f"{class_name}.py"), "w") as f:
methods = class_decl.methods
attrs = {} # {attr_name: (getter_name, setter_name), ...}
Expand All @@ -707,9 +667,7 @@ def generate_py_module_from_exports_header(
f.write("\n")
f.write(f"# Class {class_name}\n")
f.write(f"class {class_name}:\n")
f.write(
" def __init__(self, *args, handle=None, manage_handle=False):\n"
)
f.write(" def __init__(self, *args, handle=None, manage_handle=False):\n")
f.write(" if handle is not None:\n")
f.write(" self._manage_handle = manage_handle\n")
f.write(" self._handle = handle\n")
Expand All @@ -733,21 +691,14 @@ def generate_py_module_from_exports_header(
ClassMethodDecl.STATIC_METHOD_TYPE,
):
f.write(f" @staticmethod\n")
elif (
method_decl.type == ClassMethodDecl.METHOD_TYPE
and method_name.startswith(("get_", "set_"))
):
elif method_decl.type == ClassMethodDecl.METHOD_TYPE and method_name.startswith(("get_", "set_")):
attr_name = method_name[4:]
if attr_name not in attrs:
getter_name = f"get_{attr_name}"
setter_name = f"set_{attr_name}"
if not ClassMethodDecl.is_getter_method(
(methods.get(getter_name, None))
):
if not ClassMethodDecl.is_getter_method((methods.get(getter_name, None))):
getter_name = None
if not ClassMethodDecl.is_setter_method(
(methods.get(setter_name, None))
):
if not ClassMethodDecl.is_setter_method((methods.get(setter_name, None))):
setter_name = None
if getter_name is not None or setter_name is not None:
attrs[attr_name] = (getter_name, setter_name)
Expand All @@ -766,9 +717,7 @@ def generate_py_module_from_exports_header(
f.write(f"__all__ = ['{class_name}']\n")

# Generate global functions (global_functions.py)
printer(
f"Generating global functions ({os.path.join(dirname, 'global_functions.py')}) ..."
)
printer(f"Generating global functions ({os.path.join(dirname, 'global_functions.py')}) ...")
with open(os.path.join(dirname, "global_functions.py"), "w") as f:
f.write(COMMENT_HEADER)
f.write("\n")
Expand All @@ -785,14 +734,7 @@ def generate_py_module_from_exports_header(
)
f.write("\n")
f.write(f"__all__ = [\n")
f.write(
",\n".join(
[
f' "{func.name.replace("tie_G_", "")}"'
for func in global_functions
]
)
)
f.write(",\n".join([f' "{func.name.replace("tie_G_", "")}"' for func in global_functions]))
f.write("\n]\n")

# Generate __init__.py
Expand Down Expand Up @@ -826,9 +768,7 @@ def generate_py_module_from_exports_header(
# --output-dir python/taichi/_lib/exports \
# --verbose 2

parser = argparse.ArgumentParser(
description=f"Generate Python module from exports.h"
)
parser = argparse.ArgumentParser(description=f"Generate Python module from exports.h")

parser.add_argument(
"--exports-header",
Expand Down Expand Up @@ -879,6 +819,4 @@ def generate_py_module_from_exports_header(
cpp_args=cpp_args,
printer=printer,
)
generate_py_module_from_exports_header(
dirname=args.output_dir, header=exports_header, printer=printer
)
generate_py_module_from_exports_header(dirname=args.output_dir, header=exports_header, printer=printer)
18 changes: 10 additions & 8 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,9 @@ def materialize(self, key=None, args=None, arg_features=None):
KernelSimplicityASTChecker(self.func).visit(tree)

# FIXME: Remove the lambda argument
taichi_kernel = _ti_ccore.Kernel(handle=impl.get_runtime().prog.c_create_kernel(lambda k: None, kernel_name, self.autodiff_mode))
taichi_kernel = _ti_ccore.Kernel(
handle=impl.get_runtime().prog.c_create_kernel(lambda k: None, kernel_name, self.autodiff_mode)
)

if self.runtime.inside_kernel:
raise TaichiSyntaxError(
Expand Down Expand Up @@ -745,7 +747,9 @@ def flatten_argpack(argpack, argpack_type):
array_shape = v.shape[element_dim:] if is_soa else v.shape[:-element_dim]
if isinstance(v, np.ndarray):
if v.flags.c_contiguous:
launch_ctx.set_arg_external_array_with_shape(actual_argument_slot, int(v.ctypes.data), v.nbytes, array_shape, 0)
launch_ctx.set_arg_external_array_with_shape(
actual_argument_slot, int(v.ctypes.data), v.nbytes, array_shape, 0
)
elif v.flags.f_contiguous:
# TODO: A better way that avoids copying is saving strides info.
tmp = np.ascontiguousarray(v)
Expand All @@ -756,7 +760,9 @@ def callback(original, updated):
np.copyto(original, np.asfortranarray(updated))

callbacks.append(functools.partial(callback, v, tmp))
launch_ctx.set_arg_external_array_with_shape(actual_argument_slot, int(tmp.ctypes.data), tmp.nbytes, array_shape, 0)
launch_ctx.set_arg_external_array_with_shape(
actual_argument_slot, int(tmp.ctypes.data), tmp.nbytes, array_shape, 0
)
else:
raise ValueError(
"Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) "
Expand Down Expand Up @@ -832,11 +838,7 @@ def call_back():
f"Taichi do not support backend {v.place} that Paddle support"
)
launch_ctx.set_arg_external_array_with_shape(
actual_argument_slot,
int(tmp._ptr()),
v.element_size() * v.size,
array_shape,
0
actual_argument_slot, int(tmp._ptr()), v.element_size() * v.size, array_shape, 0
)
else:
raise TaichiRuntimeTypeError.get(i, needed.to_string(), v)
Expand Down
7 changes: 6 additions & 1 deletion python/taichi/lang/simt/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@


def arch_uses_spv(arch):
return arch == _ti_ccore.TIE_ARCH_VULKAN or arch == _ti_ccore.TIE_ARCH_METAL or arch == _ti_ccore.TIE_ARCH_OPENGL or arch == misc.dx11
return (
arch == _ti_ccore.TIE_ARCH_VULKAN
or arch == _ti_ccore.TIE_ARCH_METAL
or arch == _ti_ccore.TIE_ARCH_OPENGL
or arch == misc.dx11
)


def sync():
Expand Down
4 changes: 1 addition & 3 deletions python/taichi/linalg/sparse_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ def __init__(self, dtype=f32, solver_type="LLT", ordering="AMD"):
if solver_type in solver_type_list and ordering in solver_ordering:
taichi_arch = taichi.lang.impl.current_cfg().arch
assert (
taichi_arch == misc.x64
or taichi_arch == misc.arm64
or taichi_arch == misc.cuda
taichi_arch == misc.x64 or taichi_arch == misc.arm64 or taichi_arch == misc.cuda
), "SparseSolver only supports CPU and CUDA for now."
if taichi_arch == misc.cuda:
self.solver = _ti_core.make_cusparse_solver(dtype, solver_type, ordering)
Expand Down
Loading

0 comments on commit 80e7db7

Please sign in to comment.