diff --git a/vd b/vd index f97c871..1ffeb66 100755 --- a/vd +++ b/vd @@ -48,6 +48,8 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description='') parser.add_argument("-i", help="", type=str, default=None, required=True, dest="seeds") parser.add_argument("-o", help="", type=str, default=None, required=True, dest="out") + parser.add_argument("-m", help="", type=str, nargs='+', dest="mods") + #parser.add_argument("-v", help="", type=str, default=None, required=False, dest="vectorizer") #parser.add_argument("-m", help="", type=str, default="afl", dest="fuzzer") parser.add_argument("cmd", help="", type=str, default=None) @@ -55,6 +57,7 @@ if __name__ == "__main__": options = parser.parse_args() seeds = options.seeds outfile = options.out + mods = options.mods #fuzzer = options.fuzzer cmd = options.cmd #vectorizer = options.vectorizer @@ -70,8 +73,10 @@ if __name__ == "__main__": traces = traces_path else: - #app = Process(program, envs, timeout, ["libpixman-1.so","libcairo.so.2","libpango"], [], True) modules_to_trace = [main_module] + if mods is not None: + modules_to_trace = modules_to_trace + mods + if "LD_LIBRARY_PATH" in os.environ: libs = os.environ["LD_LIBRARY_PATH"] for _,_,files in os.walk(libs): diff --git a/vdiscover/Process.py b/vdiscover/Process.py index a7e9bbc..e3dfdc8 100644 --- a/vdiscover/Process.py +++ b/vdiscover/Process.py @@ -77,7 +77,7 @@ def __init__(self, program, envs, timeout, included_mods = [], ignored_mods = [] self.last_signal = {} self.last_call = None self.crashed = False - self.nevents = 0 + self.nevents = dict() self.events = [] self.binfo = dict() @@ -125,6 +125,7 @@ def createEvents(self, signal): for (range, mod, atts) in self.mm.items(): if '/' in mod and 'x' in atts and not ("libc-" in mod): + # FIXME: self.elf.path should be absolute if mod == self.elf.path: base = 0 else: @@ -155,7 +156,15 @@ def createEvents(self, signal): call_ip = ip self.process.singleStep() self.debugger.waitProcessEvent() - self.breakpoint(call_ip) + + n = self.nevents.get(name, 0) + self.nevents[name] = n + 1 + + if n < self.max_events: + self.breakpoint(call_ip) + #else: + #print "disabled!" + #print "call detected!" return [call] @@ -264,7 +273,7 @@ def cont(self, signum=None): #vulns = self.DetectVulnerabilities(self.events, events) #print "vulns detected" self.events = self.events + events #+ vulns - self.nevents = self.nevents + len(events) + #self.nevents = self.nevents + len(events) def readInstrSize(self, address, default_size=None): @@ -337,12 +346,12 @@ def runProcess(self, cmd): while True: #self.cont() - if self.nevents > self.max_events: - - self.events.append(Timeout(timeout)) - alarm(0) - return - elif not self.debugger or self.crashed: + #if self.nevents > self.max_events: + # + # self.events.append(Timeout(timeout)) + # alarm(0) + # return + if not self.debugger or self.crashed: # There is no more process: quit alarm(0) return @@ -381,7 +390,7 @@ def runProcess(self, cmd): def getData(self, inputs): self.events = [] - self.nevents = 0 + self.nevents = dict() self.debugger = PtraceDebugger() self.runProcess([self.program]+inputs) @@ -396,7 +405,7 @@ def getData(self, inputs): self.process.terminate() self.process.detach() - #print "terminated!" + #print self.nevents self.process = None return self.events