From 2baf9d683328aa3f91c53139cc61d977a0c2454a Mon Sep 17 00:00:00 2001 From: ObeliskGate <96614155+ObeliskGate@users.noreply.github.com> Date: Thu, 22 Aug 2024 03:33:06 +0800 Subject: [PATCH] fix: `` support for `py::tuple` and `py::list` (#5314) * feat: add `` support for `py::tuple` and `py::list` * fix: format the code * fix: disable `ranges` in clang < 16 * refactor: move `` test macro to `test_pytypes.h` * refactor: seperate `ranges` test into 3 funcs * style: compress the if statement * style: pre-commit fixes * style: better formatting --------- Co-authored-by: Henry Schreiner Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- include/pybind11/pytypes.h | 2 ++ tests/test_pytypes.cpp | 63 ++++++++++++++++++++++++++++++++++++++ tests/test_pytypes.py | 42 +++++++++++++++++++++++++ 3 files changed, 107 insertions(+) diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index f26c307a87..1e76d7bc13 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -1259,6 +1259,7 @@ class sequence_fast_readonly { using pointer = arrow_proxy; sequence_fast_readonly(handle obj, ssize_t n) : ptr(PySequence_Fast_ITEMS(obj.ptr()) + n) {} + sequence_fast_readonly() = default; // NOLINTNEXTLINE(readability-const-return-type) // PR #3263 reference dereference() const { return *ptr; } @@ -1281,6 +1282,7 @@ class sequence_slow_readwrite { using pointer = arrow_proxy; sequence_slow_readwrite(handle obj, ssize_t index) : obj(obj), index(index) {} + sequence_slow_readwrite() = default; reference dereference() const { return {obj, static_cast(index)}; } void increment() { ++index; } diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index ecb44939aa..19f65ce7eb 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -13,6 +13,14 @@ #include +//__has_include has been part of C++17, no need to check it +#if defined(PYBIND11_CPP20) && __has_include() +# if !defined(PYBIND11_COMPILER_CLANG) || __clang_major__ >= 16 // llvm/llvm-project#52696 +# define PYBIND11_TEST_PYTYPES_HAS_RANGES +# include +# endif +#endif + namespace external { namespace detail { bool check(PyObject *o) { return PyFloat_Check(o) != 0; } @@ -923,4 +931,59 @@ TEST_SUBMODULE(pytypes, m) { #else m.attr("defined_PYBIND11_TYPING_H_HAS_STRING_LITERAL") = false; #endif + +#if defined(PYBIND11_TEST_PYTYPES_HAS_RANGES) + + // test_tuple_ranges + m.def("tuple_iterator_default_initialization", []() { + using TupleIterator = decltype(std::declval().begin()); + static_assert(std::random_access_iterator); + return TupleIterator{} == TupleIterator{}; + }); + + m.def("transform_tuple_plus_one", [](py::tuple &tpl) { + py::list ret{}; + for (auto it : tpl | std::views::transform([](auto &o) { return py::cast(o) + 1; })) { + ret.append(py::int_(it)); + } + return ret; + }); + + // test_list_ranges + m.def("list_iterator_default_initialization", []() { + using ListIterator = decltype(std::declval().begin()); + static_assert(std::random_access_iterator); + return ListIterator{} == ListIterator{}; + }); + + m.def("transform_list_plus_one", [](py::list &lst) { + py::list ret{}; + for (auto it : lst | std::views::transform([](auto &o) { return py::cast(o) + 1; })) { + ret.append(py::int_(it)); + } + return ret; + }); + + // test_dict_ranges + m.def("dict_iterator_default_initialization", []() { + using DictIterator = decltype(std::declval().begin()); + static_assert(std::forward_iterator); + return DictIterator{} == DictIterator{}; + }); + + m.def("transform_dict_plus_one", [](py::dict &dct) { + py::list ret{}; + for (auto it : dct | std::views::transform([](auto &o) { + return std::pair{py::cast(o.first) + 1, + py::cast(o.second) + 1}; + })) { + ret.append(py::make_tuple(py::int_(it.first), py::int_(it.second))); + } + return ret; + }); + + m.attr("defined_PYBIND11_TEST_PYTYPES_HAS_RANGES") = true; +#else + m.attr("defined_PYBIND11_TEST_PYTYPES_HAS_RANGES") = false; +#endif } diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 218092b434..1c6335f757 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -1048,3 +1048,45 @@ def test_typevar(doc): assert doc(m.annotate_listT_to_T) == "annotate_listT_to_T(arg0: list[T]) -> T" assert doc(m.annotate_object_to_T) == "annotate_object_to_T(arg0: object) -> T" + + +@pytest.mark.skipif( + not m.defined_PYBIND11_TEST_PYTYPES_HAS_RANGES, + reason=" not available.", +) +@pytest.mark.parametrize( + ("tested_tuple", "expected"), + [((1,), [2]), ((3, 4), [4, 5]), ((7, 8, 9), [8, 9, 10])], +) +def test_tuple_ranges(tested_tuple, expected): + assert m.tuple_iterator_default_initialization() + assert m.transform_tuple_plus_one(tested_tuple) == expected + + +@pytest.mark.skipif( + not m.defined_PYBIND11_TEST_PYTYPES_HAS_RANGES, + reason=" not available.", +) +@pytest.mark.parametrize( + ("tested_list", "expected"), [([1], [2]), ([3, 4], [4, 5]), ([7, 8, 9], [8, 9, 10])] +) +def test_list_ranges(tested_list, expected): + assert m.list_iterator_default_initialization() + assert m.transform_list_plus_one(tested_list) == expected + + +@pytest.mark.skipif( + not m.defined_PYBIND11_TEST_PYTYPES_HAS_RANGES, + reason=" not available.", +) +@pytest.mark.parametrize( + ("tested_dict", "expected"), + [ + ({1: 2}, [(2, 3)]), + ({3: 4, 5: 6}, [(4, 5), (6, 7)]), + ({7: 8, 9: 10, 11: 12}, [(8, 9), (10, 11), (12, 13)]), + ], +) +def test_dict_ranges(tested_dict, expected): + assert m.dict_iterator_default_initialization() + assert m.transform_dict_plus_one(tested_dict) == expected