From 5b7a600364a8e88f424a3ae33a7c2a043257603d Mon Sep 17 00:00:00 2001 From: Joel Mason <jobba1@hotmail.com> Date: Thu, 29 Mar 2018 04:38:36 +1100 Subject: [PATCH 1/2] Add PyFuncWrap --- benchmarks/callperf.jl | 45 +++++++++++++++++++++ src/PyCall.jl | 4 +- src/conversions.jl | 5 +++ src/pyfuncwrap.jl | 90 +++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 + test/test_pyfuncwrap.jl | 29 +++++++++++++ 6 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 benchmarks/callperf.jl create mode 100644 src/pyfuncwrap.jl create mode 100644 test/test_pyfuncwrap.jl diff --git a/benchmarks/callperf.jl b/benchmarks/callperf.jl new file mode 100644 index 00000000..4e924438 --- /dev/null +++ b/benchmarks/callperf.jl @@ -0,0 +1,45 @@ +using PyCall, BenchmarkTools, DataStructures + +results = OrderedDict{String,Any}() + +let + np = pyimport("numpy") + nprand = np["random"]["rand"] + nprand_pyo(sz...) = pycall(nprand, PyObject, sz...) + nprand2d_wrap = PyFuncWrap(nprand, (Int, Int)) + + arr_size = (2,2) + + results["nprand_pyo"] = @benchmark $nprand_pyo($arr_size...) + println("nprand_pyo:\n"); display(results["nprand_pyo"]) + println("--------------------------------------------------") + + results["nprand2d_wrap"] = @benchmark $nprand2d_wrap($arr_size...) + println("nprand2d_wrap:\n"); display(results["nprand2d_wrap"]) + println("--------------------------------------------------") + + # args already set by nprand2d_wrap calls above + results["nprand2d_wrap_noargs"] = @benchmark $nprand2d_wrap() + println("nprand2d_wrap_noargs:\n"); display(results["nprand2d_wrap_noargs"]) + println("--------------------------------------------------") + + arr_size = ntuple(i->2, 15) + + results["nprand_pyo2"] = @benchmark $nprand_pyo($arr_size...) + println("nprand_pyo2:\n"); display(results["nprand_pyo2"]) + println("--------------------------------------------------") + + results["nprand2d_wrap2"] = @benchmark $nprand2d_wrap($arr_size...) + println("nprand2d_wrap2:\n"); display(results["nprand2d_wrap2"]) + println("--------------------------------------------------") + + # args already set by nprand2d_wrap calls above + results["nprand2d_wrap_noargs2"] = @benchmark $nprand2d_wrap() + println("nprand2d_wrap_noargs2:\n"); display(results["nprand2d_wrap_noargs2"]) + println("--------------------------------------------------") +end + +println("") +println("Mean times") +println("----------") +foreach((r)->println(rpad(r[1],23), ": ", mean(r[2])), results) diff --git a/src/PyCall.jl b/src/PyCall.jl index 52a163bd..687bd227 100644 --- a/src/PyCall.jl +++ b/src/PyCall.jl @@ -10,7 +10,8 @@ export pycall, pyimport, pybuiltin, PyObject, PyReverseDims, pyisinstance, pywrap, pytypeof, pyeval, PyVector, pystring, pystr, pyrepr, pyraise, pytype_mapping, pygui, pygui_start, pygui_stop, pygui_stop_all, @pylab, set!, PyTextIO, @pysym, PyNULL, @pydef, - pyimport_conda, @py_str, @pywith, @pycall, pybytes, pyfunction, pyfunctionret + pyimport_conda, @py_str, @pywith, @pycall, pybytes, pyfunction, pyfunctionret, + PyFuncWrap, setarg!, setargs! import Base: size, ndims, similar, copy, getindex, setindex!, stride, convert, pointer, summary, convert, show, haskey, keys, values, @@ -170,6 +171,7 @@ include("pytype.jl") include("pyiterator.jl") include("pyclass.jl") include("callback.jl") +include("pyfuncwrap.jl") include("io.jl") ######################################################################### diff --git a/src/conversions.jl b/src/conversions.jl index 014b8f26..5d5b8bbf 100644 --- a/src/conversions.jl +++ b/src/conversions.jl @@ -174,6 +174,11 @@ end # somewhat annoying to get the length and types in a tuple type # ... would be better not to have to use undocumented internals! +function tuplen(T::DataType) + isvatuple(T) && ArgumentError("can't determine length of vararg tuple: $T") + return length(T.parameters) +end +tuplen(T::UnionAll) = tuplen(T.body) istuplen(T,isva,n) = isva ? n ≥ length(T.parameters)-1 : n == length(T.parameters) function tuptype(T::DataType,isva,i) if isva && i ≥ length(T.parameters) diff --git a/src/pyfuncwrap.jl b/src/pyfuncwrap.jl new file mode 100644 index 00000000..7279a269 --- /dev/null +++ b/src/pyfuncwrap.jl @@ -0,0 +1,90 @@ +struct PyFuncWrap{P<:Union{PyObject,PyPtr}, AT<:Tuple, N, RT} + o::P + oargs::Vector{PyObject} + pyargsptr::PyPtr + ret::PyObject +end + +""" +``` +PyFuncWrap(o::P, argtypes::Tuple #= of Types =#, returntype::Type) +``` + +Wrap a callable PyObject/PyPtr to reduce the number of allocations made for +passing its arguments, and its return value, sometimes providing a speedup. +Mainly useful for functions called in a tight loop, particularly if most or +all of the arguments to the function don't change. +``` +@pyimport numpy as np +rand22fn = PyFuncWrap(np.random["rand"], (Int, Int), PyArray) +setargs!(rand22fn, 2, 2) +for i in 1:10^9 + arr = rand22fn() + ... +end +``` +""" +function PyFuncWrap(o::P, argtypes::Tuple{Vararg{<:Union{Tuple, Type}}}, + returntype::Type{RT}=PyObject) where {P<:Union{PyObject,PyPtr}, RT} + AT = typeof(argtypes) + isvatuple(AT) && throw(ArgumentError("Vararg functions not supported, arg signature provided: $AT")) + N = tuplen(AT) + oargs = Array{PyObject}(N) + pyargsptr = ccall((@pysym :PyTuple_New), PyPtr, (Int,), N) + return PyFuncWrap{P, AT, N, RT}(o, oargs, pyargsptr, PyNULL()) +end + +""" +``` +setargs!(pf::PyFuncWrap, args...) +``` +Set the arguments to a python function wrapped in a PyFuncWrap, and convert them +to `PyObject`s that can be passed directly to python when the function is +called. After the arguments have been set, the function can be efficiently +called with `pf()` +""" +function setargs!(pf::PyFuncWrap{P, AT, N, RT}, args...) where {P, AT, RT, N} + for i = 1:N + setarg!(pf, args[i], i) + end + nothing +end + +""" +``` +setarg!(pf::PyFuncWrap, arg, i::Integer=1) +``` +Set the `i`th argument to a python function wrapped in a PyFuncWrap, and convert +it to a `PyObject` that can be passed directly to python when the function is +called. Useful if a function takes multiple arguments, but only one or two of +them change, when calling the function in a tight loop +""" +function setarg!(pf::PyFuncWrap{P, AT, N, RT}, arg, i::Integer=1) where {P, AT, N, RT} + pf.oargs[i] = PyObject(arg) + @pycheckz ccall((@pysym :PyTuple_SetItem), Cint, + (PyPtr,Int,PyPtr), pf.pyargsptr, i-1, pf.oargs[i]) + pyincref(pf.oargs[i]) # PyTuple_SetItem steals the reference + nothing +end + +function (pf::PyFuncWrap{P, AT, N, RT})(args...) where {P, AT, N, RT} + setargs!(pf, args...) + return pf() +end + +""" +Warning: if pf(args) or setargs(pf, ...) hasn't been called yet, this will likely segfault +""" +function (pf::PyFuncWrap{P, AT, N, RT})() where {P, AT, N, RT} + sigatomic_begin() + try + kw = C_NULL + retptr = ccall((@pysym :PyObject_Call), PyPtr, (PyPtr,PyPtr,PyPtr), pf.o, + pf.pyargsptr, kw) + pyincref_(retptr) + pf.ret.o = retptr + finally + sigatomic_end() + end + convert(RT, pf.ret) +end diff --git a/test/runtests.jl b/test/runtests.jl index b7dd2e8f..8ecbe0c9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -538,3 +538,5 @@ end @test pyfunctionret(factorial, nothing, Int)(3) === nothing @test PyCall.is_pyjlwrap(pycall(pyfunctionret(factorial, Any, Int), PyObject, 3)) end + +include("test_pyfuncwrap.jl") \ No newline at end of file diff --git a/test/test_pyfuncwrap.jl b/test/test_pyfuncwrap.jl new file mode 100644 index 00000000..ed213d54 --- /dev/null +++ b/test/test_pyfuncwrap.jl @@ -0,0 +1,29 @@ +using Compat.Test, PyCall + +@testset "PyFuncWrap" begin + np = pyimport("numpy") + ops = pyimport("operator") + eq = ops["eq"] + npzeros = np["zeros"] + npzeros_pyo(sz, dtype="d", order="F") = pycall(npzeros, PyObject, sz, dtype, order) + npzeros_pyany(sz, dtype="d", order="F") = pycall(npzeros, PyAny, sz, dtype, order) + npzeros_pyarray(sz, dtype="d", order="F") = pycall(npzeros, PyArray, sz, dtype, order) + + # PyObject is default returntype + npzeros2dwrap_pyo = PyFuncWrap(npzeros, ((Int, Int), String, String)) + npzeros2dwrap_pyany = PyFuncWrap(npzeros, ((Int, Int), String, String), PyAny) + npzeros2dwrap_pyarray = PyFuncWrap(npzeros, ((Int, Int), String, String), PyArray) + + arr_size = (2,2) + + # all args + @test np["array_equal"](npzeros2dwrap_pyo(arr_size, "d", "F"), npzeros_pyo(arr_size)) + # args already set + @test np["array_equal"](npzeros2dwrap_pyo(), npzeros_pyo(arr_size)) + + @test all(npzeros2dwrap_pyany(arr_size, "d", "F") .== npzeros_pyany(arr_size)) + @test all(npzeros2dwrap_pyany() .== npzeros_pyany(arr_size)) + + @test all(npzeros2dwrap_pyarray(arr_size, "d", "F") .== npzeros_pyarray(arr_size)) + @test all(npzeros2dwrap_pyarray() .== npzeros_pyarray(arr_size)) +end From 1826f4ad7025a0a7b7249e3969f3b16fd0d6940c Mon Sep 17 00:00:00 2001 From: Joel Mason <jobba1@hotmail.com> Date: Wed, 11 Apr 2018 22:26:05 +1000 Subject: [PATCH 2/2] Add pydecref to old ret value, and fix bug in benchmark --- benchmarks/callperf.jl | 54 +++++++++++++++++------------------------- src/pyfuncwrap.jl | 1 + 2 files changed, 23 insertions(+), 32 deletions(-) diff --git a/benchmarks/callperf.jl b/benchmarks/callperf.jl index 4e924438..1f0f4b00 100644 --- a/benchmarks/callperf.jl +++ b/benchmarks/callperf.jl @@ -6,40 +6,30 @@ let np = pyimport("numpy") nprand = np["random"]["rand"] nprand_pyo(sz...) = pycall(nprand, PyObject, sz...) - nprand2d_wrap = PyFuncWrap(nprand, (Int, Int)) - - arr_size = (2,2) - - results["nprand_pyo"] = @benchmark $nprand_pyo($arr_size...) - println("nprand_pyo:\n"); display(results["nprand_pyo"]) - println("--------------------------------------------------") - - results["nprand2d_wrap"] = @benchmark $nprand2d_wrap($arr_size...) - println("nprand2d_wrap:\n"); display(results["nprand2d_wrap"]) - println("--------------------------------------------------") - - # args already set by nprand2d_wrap calls above - results["nprand2d_wrap_noargs"] = @benchmark $nprand2d_wrap() - println("nprand2d_wrap_noargs:\n"); display(results["nprand2d_wrap_noargs"]) - println("--------------------------------------------------") - - arr_size = ntuple(i->2, 15) - - results["nprand_pyo2"] = @benchmark $nprand_pyo($arr_size...) - println("nprand_pyo2:\n"); display(results["nprand_pyo2"]) - println("--------------------------------------------------") - - results["nprand2d_wrap2"] = @benchmark $nprand2d_wrap($arr_size...) - println("nprand2d_wrap2:\n"); display(results["nprand2d_wrap2"]) - println("--------------------------------------------------") - - # args already set by nprand2d_wrap calls above - results["nprand2d_wrap_noargs2"] = @benchmark $nprand2d_wrap() - println("nprand2d_wrap_noargs2:\n"); display(results["nprand2d_wrap_noargs2"]) - println("--------------------------------------------------") + ret = PyNULL() + args_lens = (0,3,7,12,17) + arr_sizes = (ntuple(i->1, len) for len in args_lens) + nprand_wraps = [PyFuncWrap(nprand, map(typeof, arr_size)) for arr_size in arr_sizes] + @show typeof(nprand_wraps) + for (i, arr_size) in enumerate(arr_sizes) + nprand_wrap = nprand_wraps[i] + arr_size_str = args_lens[i] < 5 ? "$arr_size" : "$(args_lens[i])*(1,1,...)" + results["nprand_pyo $arr_size_str"] = @benchmark $nprand_pyo($arr_size...) + println("nprand_pyo $arr_size_str:\n"); display(results["nprand_pyo $arr_size_str"]) + println("--------------------------------------------------") + + results["nprand_wrap $arr_size_str"] = @benchmark $nprand_wrap($arr_size...) + println("nprand_wrap $arr_size_str:\n"); display(results["nprand_wrap $arr_size_str"]) + println("--------------------------------------------------") + + # args already set by nprand_wrap calls above + results["nprand_wrap_noargs $arr_size_str"] = @benchmark $nprand_wrap() + println("nprand_wrap_noargs $arr_size_str:\n"); display(results["nprand_wrap_noargs $arr_size_str"]) + println("--------------------------------------------------") + end end println("") println("Mean times") println("----------") -foreach((r)->println(rpad(r[1],23), ": ", mean(r[2])), results) +foreach((r)->println(rpad(r[1],33), ": ", mean(r[2])), results) diff --git a/src/pyfuncwrap.jl b/src/pyfuncwrap.jl index 7279a269..6d118cbc 100644 --- a/src/pyfuncwrap.jl +++ b/src/pyfuncwrap.jl @@ -82,6 +82,7 @@ function (pf::PyFuncWrap{P, AT, N, RT})() where {P, AT, N, RT} retptr = ccall((@pysym :PyObject_Call), PyPtr, (PyPtr,PyPtr,PyPtr), pf.o, pf.pyargsptr, kw) pyincref_(retptr) + pydecref(pf.ret) pf.ret.o = retptr finally sigatomic_end()