diff --git a/patch.go b/patch.go index b2c5328..2cd881d 100644 --- a/patch.go +++ b/patch.go @@ -11,6 +11,7 @@ import ( type Patches struct { originals map[uintptr][]byte + targets map[uintptr]uintptr values map[reflect.Value]reflect.Value valueHolders map[reflect.Value]reflect.Value } @@ -70,13 +71,25 @@ func ApplyFuncVarReturn(target interface{}, output ...interface{}) *Patches { } func create() *Patches { - return &Patches{originals: make(map[uintptr][]byte), values: make(map[reflect.Value]reflect.Value), valueHolders: make(map[reflect.Value]reflect.Value)} + return &Patches{originals: make(map[uintptr][]byte), targets: map[uintptr]uintptr{}, + values: make(map[reflect.Value]reflect.Value), valueHolders: make(map[reflect.Value]reflect.Value)} } func NewPatches() *Patches { return create() } +func (this *Patches) Origin(fn func()) { + for target, bytes := range this.originals { + modifyBinary(target, bytes) + } + fn() + for target, targetPtr := range this.targets { + code := buildJmpDirective(targetPtr) + modifyBinary(target, code) + } +} + func (this *Patches) ApplyFunc(target, double interface{}) *Patches { t := reflect.ValueOf(target) d := reflect.ValueOf(double) @@ -214,6 +227,7 @@ func (this *Patches) ApplyCore(target, double reflect.Value) *Patches { if _, ok := this.originals[assTarget]; !ok { this.originals[assTarget] = original } + this.targets[assTarget] = uintptr(getPointer(double)) this.valueHolders[double] = double return this } @@ -227,6 +241,7 @@ func (this *Patches) ApplyCoreOnlyForPrivateMethod(target unsafe.Pointer, double if _, ok := this.originals[assTarget]; !ok { this.originals[assTarget] = original } + this.targets[assTarget] = uintptr(getPointer(double)) this.valueHolders[double] = double return this } diff --git a/test/apply_func_test.go b/test/apply_func_test.go index 0edd8c6..f567209 100644 --- a/test/apply_func_test.go +++ b/test/apply_func_test.go @@ -26,6 +26,32 @@ func TestApplyFunc(t *testing.T) { So(output, ShouldEqual, outputExpect) }) + Convey("one func for succ with origin", func() { + patches := ApplyFunc(fake.Belong, func(_ string, _ []string) bool { + return false + }) + defer patches.Reset() + output := fake.Belong("a", []string{"a", "b"}) + So(output, ShouldEqual, false) + patches.Origin(func() { + output = fake.Belong("a", []string{"a", "b"}) + }) + So(output, ShouldEqual, true) + }) + + Convey("one func for succ with origin inside", func() { + var output bool + var patches *Patches + patches = ApplyFunc(fake.Belong, func(_ string, _ []string) bool { + patches.Origin(func() { + output = fake.Belong("a", []string{"a", "b"}) + So(output, ShouldEqual, true) + }) + return false + }) + defer patches.Reset() + }) + Convey("one func for fail", func() { patches := ApplyFunc(fake.Exec, func(_ string, _ ...string) (string, error) { return "", fake.ErrActual @@ -51,6 +77,27 @@ func TestApplyFunc(t *testing.T) { So(flag, ShouldBeTrue) }) + Convey("two funcs with origin", func() { + patches := ApplyFunc(fake.Exec, func(_ string, _ ...string) (string, error) { + return outputExpect, nil + }) + defer patches.Reset() + patches.ApplyFunc(fake.Belong, func(_ string, _ []string) bool { + return true + }) + output, err := fake.Exec("", "") + So(err, ShouldEqual, nil) + So(output, ShouldEqual, outputExpect) + flag := fake.Belong("", nil) + So(flag, ShouldBeTrue) + + var outputBool bool + patches.Origin(func() { + outputBool = fake.Belong("c", []string{"a", "b"}) + }) + So(outputBool, ShouldEqual, false) + }) + Convey("input and output param", func() { patches := ApplyFunc(json.Unmarshal, func(data []byte, v interface{}) error { if data == nil { diff --git a/test/apply_method_func_test.go b/test/apply_method_func_test.go index 0dee818..a349daf 100755 --- a/test/apply_method_func_test.go +++ b/test/apply_method_func_test.go @@ -30,6 +30,24 @@ func TestApplyMethodFunc(t *testing.T) { So(len(slice), ShouldEqual, 0) }) + Convey("for origin", func() { + patches := ApplyMethodFunc(s, "Add", func(_ int) error { + return nil + }) + defer patches.Reset() + + var err error + patches.Origin(func() { + err = slice.Add(1) + So(err, ShouldEqual, nil) + err = slice.Add(1) + So(err, ShouldEqual, fake.ErrElemExsit) + err = slice.Remove(1) + So(err, ShouldEqual, nil) + }) + So(len(slice), ShouldEqual, 0) + }) + Convey("for already exist", func() { err := slice.Add(2) So(err, ShouldEqual, nil)