From 5b7a600364a8e88f424a3ae33a7c2a043257603d Mon Sep 17 00:00:00 2001
From: Joel Mason <jobba1@hotmail.com>
Date: Thu, 29 Mar 2018 04:38:36 +1100
Subject: [PATCH 1/2] Add PyFuncWrap

---
 benchmarks/callperf.jl  | 45 +++++++++++++++++++++
 src/PyCall.jl           |  4 +-
 src/conversions.jl      |  5 +++
 src/pyfuncwrap.jl       | 90 +++++++++++++++++++++++++++++++++++++++++
 test/runtests.jl        |  2 +
 test/test_pyfuncwrap.jl | 29 +++++++++++++
 6 files changed, 174 insertions(+), 1 deletion(-)
 create mode 100644 benchmarks/callperf.jl
 create mode 100644 src/pyfuncwrap.jl
 create mode 100644 test/test_pyfuncwrap.jl

diff --git a/benchmarks/callperf.jl b/benchmarks/callperf.jl
new file mode 100644
index 00000000..4e924438
--- /dev/null
+++ b/benchmarks/callperf.jl
@@ -0,0 +1,45 @@
+using PyCall, BenchmarkTools, DataStructures
+
+results = OrderedDict{String,Any}()
+
+let
+    np = pyimport("numpy")
+    nprand = np["random"]["rand"]
+    nprand_pyo(sz...) = pycall(nprand, PyObject, sz...)
+    nprand2d_wrap = PyFuncWrap(nprand, (Int, Int))
+
+    arr_size = (2,2)
+
+    results["nprand_pyo"] = @benchmark $nprand_pyo($arr_size...)
+    println("nprand_pyo:\n"); display(results["nprand_pyo"])
+    println("--------------------------------------------------")
+
+    results["nprand2d_wrap"] = @benchmark $nprand2d_wrap($arr_size...)
+    println("nprand2d_wrap:\n"); display(results["nprand2d_wrap"])
+    println("--------------------------------------------------")
+
+    # args already set by nprand2d_wrap calls above
+    results["nprand2d_wrap_noargs"] = @benchmark $nprand2d_wrap()
+    println("nprand2d_wrap_noargs:\n"); display(results["nprand2d_wrap_noargs"])
+    println("--------------------------------------------------")
+
+    arr_size = ntuple(i->2, 15)
+
+    results["nprand_pyo2"] = @benchmark $nprand_pyo($arr_size...)
+    println("nprand_pyo2:\n"); display(results["nprand_pyo2"])
+    println("--------------------------------------------------")
+
+    results["nprand2d_wrap2"] = @benchmark $nprand2d_wrap($arr_size...)
+    println("nprand2d_wrap2:\n"); display(results["nprand2d_wrap2"])
+    println("--------------------------------------------------")
+
+    # args already set by nprand2d_wrap calls above
+    results["nprand2d_wrap_noargs2"] = @benchmark $nprand2d_wrap()
+    println("nprand2d_wrap_noargs2:\n"); display(results["nprand2d_wrap_noargs2"])
+    println("--------------------------------------------------")
+end
+
+println("")
+println("Mean times")
+println("----------")
+foreach((r)->println(rpad(r[1],23), ": ", mean(r[2])), results)
diff --git a/src/PyCall.jl b/src/PyCall.jl
index 52a163bd..687bd227 100644
--- a/src/PyCall.jl
+++ b/src/PyCall.jl
@@ -10,7 +10,8 @@ export pycall, pyimport, pybuiltin, PyObject, PyReverseDims,
        pyisinstance, pywrap, pytypeof, pyeval, PyVector, pystring, pystr, pyrepr,
        pyraise, pytype_mapping, pygui, pygui_start, pygui_stop,
        pygui_stop_all, @pylab, set!, PyTextIO, @pysym, PyNULL, @pydef,
-       pyimport_conda, @py_str, @pywith, @pycall, pybytes, pyfunction, pyfunctionret
+       pyimport_conda, @py_str, @pywith, @pycall, pybytes, pyfunction, pyfunctionret,
+       PyFuncWrap, setarg!, setargs!
 
 import Base: size, ndims, similar, copy, getindex, setindex!, stride,
        convert, pointer, summary, convert, show, haskey, keys, values,
@@ -170,6 +171,7 @@ include("pytype.jl")
 include("pyiterator.jl")
 include("pyclass.jl")
 include("callback.jl")
+include("pyfuncwrap.jl")
 include("io.jl")
 
 #########################################################################
diff --git a/src/conversions.jl b/src/conversions.jl
index 014b8f26..5d5b8bbf 100644
--- a/src/conversions.jl
+++ b/src/conversions.jl
@@ -174,6 +174,11 @@ end
 
 # somewhat annoying to get the length and types in a tuple type
 # ... would be better not to have to use undocumented internals!
+function tuplen(T::DataType)
+    isvatuple(T) && ArgumentError("can't determine length of vararg tuple: $T")
+    return length(T.parameters)
+end
+tuplen(T::UnionAll) = tuplen(T.body)
 istuplen(T,isva,n) = isva ? n ≥ length(T.parameters)-1 : n == length(T.parameters)
 function tuptype(T::DataType,isva,i)
     if isva && i ≥ length(T.parameters)
diff --git a/src/pyfuncwrap.jl b/src/pyfuncwrap.jl
new file mode 100644
index 00000000..7279a269
--- /dev/null
+++ b/src/pyfuncwrap.jl
@@ -0,0 +1,90 @@
+struct PyFuncWrap{P<:Union{PyObject,PyPtr}, AT<:Tuple, N, RT}
+    o::P
+    oargs::Vector{PyObject}
+    pyargsptr::PyPtr
+    ret::PyObject
+end
+
+"""
+```
+PyFuncWrap(o::P, argtypes::Tuple #= of Types =#, returntype::Type)
+```
+
+Wrap a callable PyObject/PyPtr to reduce the number of allocations made for
+passing its arguments, and its return value, sometimes providing a speedup.
+Mainly useful for functions called in a tight loop, particularly if most or
+all of the arguments to the function don't change.
+```
+@pyimport numpy as np
+rand22fn = PyFuncWrap(np.random["rand"], (Int, Int), PyArray)
+setargs!(rand22fn, 2, 2)
+for i in 1:10^9
+    arr = rand22fn()
+    ...
+end
+```
+"""
+function PyFuncWrap(o::P, argtypes::Tuple{Vararg{<:Union{Tuple, Type}}},
+            returntype::Type{RT}=PyObject) where {P<:Union{PyObject,PyPtr}, RT}
+    AT = typeof(argtypes)
+    isvatuple(AT) && throw(ArgumentError("Vararg functions not supported, arg signature provided: $AT"))
+    N = tuplen(AT)
+    oargs = Array{PyObject}(N)
+    pyargsptr = ccall((@pysym :PyTuple_New), PyPtr, (Int,), N)
+    return PyFuncWrap{P, AT, N, RT}(o, oargs, pyargsptr, PyNULL())
+end
+
+"""
+```
+setargs!(pf::PyFuncWrap, args...)
+```
+Set the arguments to a python function wrapped in a PyFuncWrap, and convert them
+to `PyObject`s that can be passed directly to python when the function is
+called. After the arguments have been set, the function can be efficiently
+called with `pf()`
+"""
+function setargs!(pf::PyFuncWrap{P, AT, N, RT}, args...) where {P, AT, RT, N}
+    for i = 1:N
+        setarg!(pf, args[i], i)
+    end
+    nothing
+end
+
+"""
+```
+setarg!(pf::PyFuncWrap, arg, i::Integer=1)
+```
+Set the `i`th argument to a python function wrapped in a PyFuncWrap, and convert
+it to a `PyObject` that can be passed directly to python when the function is
+called. Useful if a function takes multiple arguments, but only one or two of
+them change, when calling the function in a tight loop
+"""
+function setarg!(pf::PyFuncWrap{P, AT, N, RT}, arg, i::Integer=1) where {P, AT, N, RT}
+    pf.oargs[i] = PyObject(arg)
+    @pycheckz ccall((@pysym :PyTuple_SetItem), Cint,
+                     (PyPtr,Int,PyPtr), pf.pyargsptr, i-1, pf.oargs[i])
+    pyincref(pf.oargs[i]) # PyTuple_SetItem steals the reference
+    nothing
+end
+
+function (pf::PyFuncWrap{P, AT, N, RT})(args...) where {P, AT, N, RT}
+    setargs!(pf, args...)
+    return pf()
+end
+
+"""
+Warning: if pf(args) or setargs(pf, ...) hasn't been called yet, this will likely segfault
+"""
+function (pf::PyFuncWrap{P, AT, N, RT})() where {P, AT, N, RT}
+    sigatomic_begin()
+    try
+        kw = C_NULL
+        retptr = ccall((@pysym :PyObject_Call), PyPtr, (PyPtr,PyPtr,PyPtr), pf.o,
+                        pf.pyargsptr, kw)
+        pyincref_(retptr)
+        pf.ret.o = retptr
+    finally
+        sigatomic_end()
+    end
+    convert(RT, pf.ret)
+end
diff --git a/test/runtests.jl b/test/runtests.jl
index b7dd2e8f..8ecbe0c9 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -538,3 +538,5 @@ end
     @test pyfunctionret(factorial, nothing, Int)(3) === nothing
     @test PyCall.is_pyjlwrap(pycall(pyfunctionret(factorial, Any, Int), PyObject, 3))
 end
+
+include("test_pyfuncwrap.jl")
\ No newline at end of file
diff --git a/test/test_pyfuncwrap.jl b/test/test_pyfuncwrap.jl
new file mode 100644
index 00000000..ed213d54
--- /dev/null
+++ b/test/test_pyfuncwrap.jl
@@ -0,0 +1,29 @@
+using Compat.Test, PyCall
+
+@testset "PyFuncWrap" begin
+    np = pyimport("numpy")
+    ops = pyimport("operator")
+    eq = ops["eq"]
+    npzeros = np["zeros"]
+    npzeros_pyo(sz, dtype="d", order="F")     = pycall(npzeros, PyObject, sz, dtype, order)
+    npzeros_pyany(sz, dtype="d", order="F")   = pycall(npzeros, PyAny, sz, dtype, order)
+    npzeros_pyarray(sz, dtype="d", order="F") = pycall(npzeros, PyArray, sz, dtype, order)
+
+    # PyObject is default returntype
+    npzeros2dwrap_pyo     = PyFuncWrap(npzeros, ((Int, Int), String, String))
+    npzeros2dwrap_pyany   = PyFuncWrap(npzeros, ((Int, Int), String, String), PyAny)
+    npzeros2dwrap_pyarray = PyFuncWrap(npzeros, ((Int, Int), String, String), PyArray)
+
+    arr_size = (2,2)
+
+    # all args
+    @test np["array_equal"](npzeros2dwrap_pyo(arr_size, "d", "F"), npzeros_pyo(arr_size))
+    # args already set
+    @test np["array_equal"](npzeros2dwrap_pyo(), npzeros_pyo(arr_size))
+
+    @test all(npzeros2dwrap_pyany(arr_size, "d", "F") .== npzeros_pyany(arr_size))
+    @test all(npzeros2dwrap_pyany() .== npzeros_pyany(arr_size))
+
+    @test all(npzeros2dwrap_pyarray(arr_size, "d", "F") .== npzeros_pyarray(arr_size))
+    @test all(npzeros2dwrap_pyarray() .== npzeros_pyarray(arr_size))
+end

From 1826f4ad7025a0a7b7249e3969f3b16fd0d6940c Mon Sep 17 00:00:00 2001
From: Joel Mason <jobba1@hotmail.com>
Date: Wed, 11 Apr 2018 22:26:05 +1000
Subject: [PATCH 2/2] Add pydecref to old ret value, and fix bug in benchmark

---
 benchmarks/callperf.jl | 54 +++++++++++++++++-------------------------
 src/pyfuncwrap.jl      |  1 +
 2 files changed, 23 insertions(+), 32 deletions(-)

diff --git a/benchmarks/callperf.jl b/benchmarks/callperf.jl
index 4e924438..1f0f4b00 100644
--- a/benchmarks/callperf.jl
+++ b/benchmarks/callperf.jl
@@ -6,40 +6,30 @@ let
     np = pyimport("numpy")
     nprand = np["random"]["rand"]
     nprand_pyo(sz...) = pycall(nprand, PyObject, sz...)
-    nprand2d_wrap = PyFuncWrap(nprand, (Int, Int))
-
-    arr_size = (2,2)
-
-    results["nprand_pyo"] = @benchmark $nprand_pyo($arr_size...)
-    println("nprand_pyo:\n"); display(results["nprand_pyo"])
-    println("--------------------------------------------------")
-
-    results["nprand2d_wrap"] = @benchmark $nprand2d_wrap($arr_size...)
-    println("nprand2d_wrap:\n"); display(results["nprand2d_wrap"])
-    println("--------------------------------------------------")
-
-    # args already set by nprand2d_wrap calls above
-    results["nprand2d_wrap_noargs"] = @benchmark $nprand2d_wrap()
-    println("nprand2d_wrap_noargs:\n"); display(results["nprand2d_wrap_noargs"])
-    println("--------------------------------------------------")
-
-    arr_size = ntuple(i->2, 15)
-
-    results["nprand_pyo2"] = @benchmark $nprand_pyo($arr_size...)
-    println("nprand_pyo2:\n"); display(results["nprand_pyo2"])
-    println("--------------------------------------------------")
-
-    results["nprand2d_wrap2"] = @benchmark $nprand2d_wrap($arr_size...)
-    println("nprand2d_wrap2:\n"); display(results["nprand2d_wrap2"])
-    println("--------------------------------------------------")
-
-    # args already set by nprand2d_wrap calls above
-    results["nprand2d_wrap_noargs2"] = @benchmark $nprand2d_wrap()
-    println("nprand2d_wrap_noargs2:\n"); display(results["nprand2d_wrap_noargs2"])
-    println("--------------------------------------------------")
+    ret = PyNULL()
+    args_lens = (0,3,7,12,17)
+    arr_sizes = (ntuple(i->1, len) for len in args_lens)
+    nprand_wraps = [PyFuncWrap(nprand, map(typeof, arr_size)) for arr_size in arr_sizes]
+    @show typeof(nprand_wraps)
+    for (i, arr_size) in enumerate(arr_sizes)
+        nprand_wrap = nprand_wraps[i]
+        arr_size_str = args_lens[i] < 5 ? "$arr_size" : "$(args_lens[i])*(1,1,...)"
+        results["nprand_pyo $arr_size_str"] = @benchmark $nprand_pyo($arr_size...)
+        println("nprand_pyo $arr_size_str:\n"); display(results["nprand_pyo $arr_size_str"])
+        println("--------------------------------------------------")
+
+        results["nprand_wrap $arr_size_str"] = @benchmark $nprand_wrap($arr_size...)
+        println("nprand_wrap $arr_size_str:\n"); display(results["nprand_wrap $arr_size_str"])
+        println("--------------------------------------------------")
+
+        # args already set by nprand_wrap calls above
+        results["nprand_wrap_noargs $arr_size_str"] = @benchmark $nprand_wrap()
+        println("nprand_wrap_noargs $arr_size_str:\n"); display(results["nprand_wrap_noargs $arr_size_str"])
+        println("--------------------------------------------------")
+    end
 end
 
 println("")
 println("Mean times")
 println("----------")
-foreach((r)->println(rpad(r[1],23), ": ", mean(r[2])), results)
+foreach((r)->println(rpad(r[1],33), ": ", mean(r[2])), results)
diff --git a/src/pyfuncwrap.jl b/src/pyfuncwrap.jl
index 7279a269..6d118cbc 100644
--- a/src/pyfuncwrap.jl
+++ b/src/pyfuncwrap.jl
@@ -82,6 +82,7 @@ function (pf::PyFuncWrap{P, AT, N, RT})() where {P, AT, N, RT}
         retptr = ccall((@pysym :PyObject_Call), PyPtr, (PyPtr,PyPtr,PyPtr), pf.o,
                         pf.pyargsptr, kw)
         pyincref_(retptr)
+        pydecref(pf.ret)
         pf.ret.o = retptr
     finally
         sigatomic_end()