diff --git a/patch.go b/patch.go index 60c1094..454c5ed 100644 --- a/patch.go +++ b/patch.go @@ -28,6 +28,10 @@ func ApplyMethod(target reflect.Type, methodName string, double interface{}) *Pa return create().ApplyMethod(target, methodName, double) } +func ApplyMethodFunc(target reflect.Type, methodName string, doubleFunc interface{}) *Patches { + return create().ApplyMethodFunc(target, methodName, doubleFunc) +} + func ApplyPrivateMethod(target reflect.Type, methodName string, double interface{}) *Patches { return create().ApplyPrivateMethod(target, methodName, double) } @@ -52,6 +56,18 @@ func ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches { return create().ApplyFuncVarSeq(target, outputs) } +func ApplyFuncReturn(target interface{}, output ...interface{}) *Patches { + return create().ApplyFuncReturn(target, output...) +} + +func ApplyMethodReturn(target interface{}, methodName string, output ...interface{}) *Patches { + return create().ApplyMethodReturn(target, methodName, output...) +} + +func ApplyFuncVarReturn(target interface{}, output ...interface{}) *Patches { + return create().ApplyFuncVarReturn(target, output...) +} + func create() *Patches { return &Patches{originals: make(map[uintptr][]byte), values: make(map[reflect.Value]reflect.Value), valueHolders: make(map[reflect.Value]reflect.Value)} } @@ -75,6 +91,15 @@ func (this *Patches) ApplyMethod(target reflect.Type, methodName string, double return this.ApplyCore(m.Func, d) } +func (this *Patches) ApplyMethodFunc(target reflect.Type, methodName string, doubleFunc interface{}) *Patches { + m, ok := target.MethodByName(methodName) + if !ok { + panic("retrieve method by name failed") + } + d := funcToMethod(m.Type, doubleFunc) + return this.ApplyCore(m.Func, d) +} + func (this *Patches) ApplyPrivateMethod(target reflect.Type, methodName string, double interface{}) *Patches { m, ok := creflect.MethodByName(target, methodName) if !ok { @@ -136,6 +161,40 @@ func (this *Patches) ApplyFuncVarSeq(target interface{}, outputs []OutputCell) * return this.ApplyGlobalVar(target, double) } +func (this *Patches) ApplyFuncReturn(target interface{}, returns ...interface{}) *Patches { + funcType := reflect.TypeOf(target) + t := reflect.ValueOf(target) + outputs := []OutputCell{{Values: returns, Times: -1}} + d := getDoubleFunc(funcType, outputs) + return this.ApplyCore(t, d) +} + +func (this *Patches) ApplyMethodReturn(target interface{}, methodName string, returns ...interface{}) *Patches { + m, ok := reflect.TypeOf(target).MethodByName(methodName) + if !ok { + panic("retrieve method by name failed") + } + + outputs := []OutputCell{{Values: returns, Times: -1}} + d := getDoubleFunc(m.Type, outputs) + return this.ApplyCore(m.Func, d) +} + +func (this *Patches) ApplyFuncVarReturn(target interface{}, returns ...interface{}) *Patches { + t := reflect.ValueOf(target) + if t.Type().Kind() != reflect.Ptr { + panic("target is not a pointer") + } + if t.Elem().Kind() != reflect.Func { + panic("target is not a func") + } + + funcType := reflect.TypeOf(target).Elem() + outputs := []OutputCell{{Values: returns, Times: -1}} + double := getDoubleFunc(funcType, outputs).Interface() + return this.ApplyGlobalVar(target, double) +} + func (this *Patches) Reset() { for target, bytes := range this.originals { modifyBinary(target, bytes) @@ -203,8 +262,14 @@ func getDoubleFunc(funcType reflect.Type, outputs []OutputCell) reflect.Value { funcType.NumOut(), len(outputs[0].Values))) } + needReturn := false slice := make([]Params, 0) for _, output := range outputs { + if output.Times == -1 { + needReturn = true + slice = []Params{output.Values} + break + } t := 0 if output.Times <= 1 { t = 1 @@ -217,9 +282,12 @@ func getDoubleFunc(funcType reflect.Type, outputs []OutputCell) reflect.Value { } i := 0 - len := len(slice) + lenOutputs := len(slice) return reflect.MakeFunc(funcType, func(_ []reflect.Value) []reflect.Value { - if i < len { + if needReturn { + return GetResultValues(funcType, slice[0]...) + } + if i < lenOutputs { i++ return GetResultValues(funcType, slice[i-1]...) } @@ -259,3 +327,14 @@ func entryAddress(p uintptr, l int) []byte { func pageStart(ptr uintptr) uintptr { return ptr & ^(uintptr(syscall.Getpagesize() - 1)) } + +func funcToMethod(funcType reflect.Type, doubleFunc interface{}) reflect.Value { + rf := reflect.TypeOf(doubleFunc) + if rf.Kind() != reflect.Func { + panic("doubleFunc is not a func") + } + vf := reflect.ValueOf(doubleFunc) + return reflect.MakeFunc(funcType, func(in []reflect.Value) []reflect.Value { + return vf.Call(in[1:]) + }) +} diff --git a/test/apply_func_return_test.go b/test/apply_func_return_test.go new file mode 100755 index 0000000..8e0900f --- /dev/null +++ b/test/apply_func_return_test.go @@ -0,0 +1,39 @@ +package test + +import ( + "testing" + + . "github.com/agiledragon/gomonkey/v2" + "github.com/agiledragon/gomonkey/v2/test/fake" + . "github.com/smartystreets/goconvey/convey" +) + +/* + compare with apply_func_seq_test.go +*/ +func TestApplyFuncReturn(t *testing.T) { + Convey("TestApplyFuncReturn", t, func() { + + Convey("declares the values to be returned", func() { + info1 := "hello cpp" + + patches := ApplyFuncReturn(fake.ReadLeaf, info1, nil) + defer patches.Reset() + + for i := 0; i < 10; i++ { + output, err := fake.ReadLeaf("") + So(err, ShouldEqual, nil) + So(output, ShouldEqual, info1) + } + + patches.Reset() // if not reset will occur:patch has been existed + info2 := "hello golang" + patches.ApplyFuncReturn(fake.ReadLeaf, info2, nil) + for i := 0; i < 10; i++ { + output, err := fake.ReadLeaf("") + So(err, ShouldEqual, nil) + So(output, ShouldEqual, info2) + } + }) + }) +} diff --git a/test/apply_func_var_return_test.go b/test/apply_func_var_return_test.go new file mode 100755 index 0000000..1aed9f7 --- /dev/null +++ b/test/apply_func_var_return_test.go @@ -0,0 +1,38 @@ +package test + +import ( + "testing" + + . "github.com/agiledragon/gomonkey/v2" + "github.com/agiledragon/gomonkey/v2/test/fake" + . "github.com/smartystreets/goconvey/convey" +) + +/* + compare with apply_func_var_seq_test.go +*/ +func TestApplyFuncVarReturn(t *testing.T) { + Convey("TestApplyFuncVarReturn", t, func() { + + Convey("declares the values to be returned", func() { + info1 := "hello cpp" + + patches := ApplyFuncVarReturn(&fake.Marshal, []byte(info1), nil) + defer patches.Reset() + for i := 0; i < 10; i++ { + bytes, err := fake.Marshal("") + So(err, ShouldEqual, nil) + So(string(bytes), ShouldEqual, info1) + } + + info2 := "hello golang" + patches.ApplyFuncVarReturn(&fake.Marshal, []byte(info2), nil) + for i := 0; i < 10; i++ { + bytes, err := fake.Marshal("") + So(err, ShouldEqual, nil) + So(string(bytes), ShouldEqual, info2) + } + }) + + }) +} diff --git a/test/apply_method_func_test.go b/test/apply_method_func_test.go new file mode 100755 index 0000000..05ebaba --- /dev/null +++ b/test/apply_method_func_test.go @@ -0,0 +1,87 @@ +package test + +import ( + "reflect" + "testing" + + . "github.com/agiledragon/gomonkey/v2" + "github.com/agiledragon/gomonkey/v2/test/fake" + . "github.com/smartystreets/goconvey/convey" +) + +/* + compare with apply_method_test.go, no need pass receiver +*/ + +func TestApplyMethodFunc(t *testing.T) { + slice := fake.NewSlice() + var s *fake.Slice + Convey("TestApplyMethodFunc", t, func() { + Convey("for succ", func() { + err := slice.Add(1) + So(err, ShouldEqual, nil) + patches := ApplyMethodFunc(reflect.TypeOf(s), "Add", func(_ int) error { + return nil + }) + defer patches.Reset() + err = slice.Add(1) + So(err, ShouldEqual, nil) + err = slice.Remove(1) + So(err, ShouldEqual, nil) + So(len(slice), ShouldEqual, 0) + }) + + Convey("for already exist", func() { + err := slice.Add(2) + So(err, ShouldEqual, nil) + patches := ApplyMethodFunc(reflect.TypeOf(s), "Add", func(_ int) error { + return fake.ErrElemExsit + }) + defer patches.Reset() + err = slice.Add(1) + So(err, ShouldEqual, fake.ErrElemExsit) + err = slice.Remove(2) + So(err, ShouldEqual, nil) + So(len(slice), ShouldEqual, 0) + }) + + Convey("two methods", func() { + err := slice.Add(3) + So(err, ShouldEqual, nil) + defer slice.Remove(3) + patches := ApplyMethodFunc(reflect.TypeOf(s), "Add", func(_ int) error { + return fake.ErrElemExsit + }) + defer patches.Reset() + patches.ApplyMethodFunc(reflect.TypeOf(s), "Remove", func(_ int) error { + return fake.ErrElemNotExsit + }) + err = slice.Add(2) + So(err, ShouldEqual, fake.ErrElemExsit) + err = slice.Remove(1) + So(err, ShouldEqual, fake.ErrElemNotExsit) + So(len(slice), ShouldEqual, 1) + So(slice[0], ShouldEqual, 3) + }) + + Convey("one func and one method", func() { + err := slice.Add(4) + So(err, ShouldEqual, nil) + defer slice.Remove(4) + patches := ApplyFunc(fake.Exec, func(_ string, _ ...string) (string, error) { + return outputExpect, nil + }) + defer patches.Reset() + patches.ApplyMethodFunc(reflect.TypeOf(s), "Remove", func(_ int) error { + return fake.ErrElemNotExsit + }) + output, err := fake.Exec("", "") + So(err, ShouldEqual, nil) + So(output, ShouldEqual, outputExpect) + err = slice.Remove(1) + So(err, ShouldEqual, fake.ErrElemNotExsit) + So(len(slice), ShouldEqual, 1) + So(slice[0], ShouldEqual, 4) + }) + }) +} diff --git a/test/apply_method_return_test.go b/test/apply_method_return_test.go new file mode 100755 index 0000000..c4741e6 --- /dev/null +++ b/test/apply_method_return_test.go @@ -0,0 +1,38 @@ +package test + +import ( + "testing" + + . "github.com/agiledragon/gomonkey/v2" + "github.com/agiledragon/gomonkey/v2/test/fake" + . "github.com/smartystreets/goconvey/convey" +) + +/* + compare with apply_method_seq_test.go +*/ + +func TestApplyMethodReturn(t *testing.T) { + e := &fake.Etcd{} + Convey("TestApplyMethodReturn", t, func() { + Convey("declares the values to be returned", func() { + info1 := "hello cpp" + patches := ApplyMethodReturn(e, "Retrieve", info1, nil) + defer patches.Reset() + for i := 0; i < 10; i++ { + output1, err1 := e.Retrieve("") + So(err1, ShouldEqual, nil) + So(output1, ShouldEqual, info1) + } + + patches.Reset() // if not reset will occur:patch has been existed + info2 := "hello golang" + patches.ApplyMethodReturn(e, "Retrieve", info2, nil) + for i := 0; i < 10; i++ { + output2, err2 := e.Retrieve("") + So(err2, ShouldEqual, nil) + So(output2, ShouldEqual, info2) + } + }) + }) +}