diff --git a/bin/deepstate/reducer.py b/bin/deepstate/reducer.py index 2756a61..7bf12ce 100644 --- a/bin/deepstate/reducer.py +++ b/bin/deepstate/reducer.py @@ -17,46 +17,47 @@ from __future__ import print_function import argparse import subprocess import os +import sys import time def main(): - global candidateRuns + global candidateRuns, currentTest, s, passStart parser = argparse.ArgumentParser(description="Intelligently reduce test case") parser.add_argument( "binary", type=str, help="Path to the test binary to run.") - parser.add_argument( "input_test", type=str, help="Path to test to reduce.") - parser.add_argument( "output_test", type=str, help="Path for reduced test.") - parser.add_argument( "--which_test", type=str, help="Which test to run (equivalent to --input_which_test).", default=None) - parser.add_argument( "--criteria", type=str, help="String to search for in valid reduction outputs.", default=None) - parser.add_argument( "--search", action="store_true", help="Allow initial test to not satisfy criteria (search for test).", default=None) - parser.add_argument( "--timeout", type=int, help="After this amount of time (in seconds), give up on reduction.", default=1200) - + parser.add_argument( + "--maxByteRange", type=int, help="Maximum size of byte chunk to try in range removals.", + default=16) parser.add_argument( "--fast", action='store_true', - help="Faster, less complete, reduction (no range or byte pattern attempts).") - + help="Faster, less complete, reduction (no byte range removal pass).") + parser.add_argument( + "--slow", action='store_true', + help="Slower, more complete, reduction (byte pattern pass).") + parser.add_argument( + "--slowest", action='store_true', + help="Slowest, most complete, reduction (byte pattern pass, tries all byte ranges).") parser.add_argument( "--verbose", action='store_true', help="Verbose reduction.") - parser.add_argument( "--fork", action='store_true', help="Fork when running.") @@ -66,6 +67,7 @@ def main(): args = parser.parse_args() + maxByteRange = args.maxByteRange deepstate = args.binary test = args.input_test out = args.output_test @@ -77,6 +79,7 @@ def main(): def runCandidate(candidate): global candidateRuns + candidateRuns += 1 if (time.time() - start) > args.timeout: raise TimeoutException @@ -147,7 +150,7 @@ def main(): conversions.append((currentMulti, int(line.split()[-1]))) return conversions - def fixUp(test, conversions): + def fixRangeConversions(test, conversions): numConversions = 0 for (pos, value) in conversions: if pos[1] >= len(test): @@ -158,262 +161,338 @@ def main(): test[b] = 0 test[pos[1]] = value if numConversions > 0: - print("APPLIED", numConversions, "RANGE CONVERSIONS") + print("Applied", numConversions, "range conversions") initial = runCandidate(test) if (not args.search) and (not checks(initial)): - print("STARTING TEST DOES NOT SATISFY REDUCTION CRITERIA") + print("STARTING TEST DOES NOT SATISFY REDUCTION CRITERIA!") return 1 with open(test, 'rb') as test: currentTest = bytearray(test.read()) original = bytearray(currentTest) - print("ORIGINAL TEST HAS", len(currentTest), "BYTES") + print("Original test has", len(currentTest), "bytes") + if args.slowest: + maxByteRange = len(currentTest) - fixUp(currentTest, rangeConversions(initial)) + fixRangeConversions(currentTest, rangeConversions(initial)) r = writeAndRunCandidate(currentTest) assert(checks(r)) s = structure(initial) if (s[1]+1) < len(currentTest): - print("LAST BYTE READ IS", s[1]) - print("SHRINKING TO IGNORE UNREAD BYTES") + print("Last byte read:", s[1]) + print("Shrinking to ignore unread bytes") currentTest = currentTest[:s[1]+1] if currentTest != original: - print("WRITING REDUCED TEST WITH", len(currentTest), "BYTES TO", out) + print("Writing reduced test with", len(currentTest), "bytes to", out) with open(out, 'wb') as outf: outf.write(currentTest) initialSize = float(len(currentTest)) iteration = 0 - changed = True - rangeRemovePos = 0 - byteReducePos = 0 + def updateCurrent(newTest): + global currentTest, s + currentTest = newTest + fixRangeConversions(currentTest, rangeConversions(r)) + print("Writing reduced test with", len(currentTest), "bytes to", out) + with open(out, 'wb') as outf: + outf.write(currentTest) + s = structure(r) + percent = 100.0 * ((initialSize - len(currentTest)) / initialSize) + print(round(time.time()-start, 2), "secs /", + candidateRuns, "execs /", str(round(percent, 2)) + "% reduction") + print("="*80) + sys.stdout.flush() + def passInfo(): + global passStart + percent = 100.0 * ((initialSize - len(currentTest)) / initialSize) + print("PASS FINISHED IN", round(time.time() - passStart, 2), "SECONDS, RUN:", round(time.time()-start, 2), + "secs /", candidateRuns, "execs /", str(round(percent, 2)) + "% reduction") + passStart = time.time() + + oldTest = [] + lastOneOfRemovalTest = [] + lastChunkRemovalTest = {} + lastChunkRemovalTest[1] = [] + lastChunkRemovalTest[4] = [] + lastChunkRemovalTest[8] = [] + lastReduceAndDeleteTest = {} + lastReduceAndDeleteTest[1] = [] + lastReduceAndDeleteTest[4] = [] + lastReduceAndDeleteTest[8] = [] + lastAllRangeTest = [] + lastOneOfSwapTest = [] + lastByteReduceTest = [] + lastPatternSearchTest = [] + + passStart = time.time() try: - while changed: - changed = False + while oldTest != currentTest: + oldTest = bytearray(currentTest) iteration += 1 percent = 100.0 * ((initialSize - len(currentTest)) / initialSize) print("=" * 80) - print("STARTING ITERATION #" + str(iteration), round(time.time()-start, 2), "SECONDS /", - candidateRuns, "EXECUTIONS /", str(round(percent, 2)) + "% REDUCTION") - if args.verbose: - print("TRYING ONEOF REMOVALS...") - cuts = s[0] - for c in cuts: - newTest = currentTest[:c[0]] + currentTest[c[1]+1:] - if len(newTest) == len(currentTest): - continue # Ignore non-shrinking reductions - r = writeAndRunCandidate(newTest) - if checks(r): - print("ONEOF REMOVAL REDUCED TEST TO", len(newTest), "BYTES") + print("Iteration #" + str(iteration), round(time.time()-start, 2), "secs /", + candidateRuns, "execs /", str(round(percent, 2)) + "% reduction") + + if currentTest != lastOneOfRemovalTest: + if args.verbose: + print("*"*80+"\nPASS: removing OneOfs...") + changed = True + while changed: + changed = False + cuts = s[0] + for c in cuts: + newTest = currentTest[:c[0]] + currentTest[c[1]+1:] + if len(newTest) == len(currentTest): + continue # Ignore non-shrinking reductions + r = writeAndRunCandidate(newTest) + if checks(r): + print("OneOf removal reduced test to", len(newTest), "bytes") + changed = True + updateCurrent(newTest) + break + lastOneOfRemovalTest = bytearray(currentTest) + passInfo() + + for k in [1, 4, 8]: + if currentTest != lastChunkRemovalTest[k]: + if args.verbose: + print("*"*80+"\nPASS: trying", k, "byte chunk removals...") changed = True - rangeRemovePos = 0 - byteReducePos = 0 - break - - if (not args.fast) and (not changed): - for b in range(rangeRemovePos, len(currentTest)): - if args.verbose: - print("TRYING BYTE RANGE REMOVAL FROM BYTE", str(b) + "...") - for v in range(b+1, len(currentTest)): - newTest = currentTest[:b] + currentTest[v:] - r = writeAndRunCandidate(newTest) - if checks(r): - print("BYTE RANGE REMOVAL REDUCED TEST TO", len(newTest), "BYTES") - rangeRemovePos = b - byteReducePos = 0 - changed = True - break - if changed: - break - - if (not args.fast) and (not changed): - for b in range(0, rangeRemovePos): - if args.verbose: - print("TRYING BYTE RANGE REMOVAL FROM BYTE", str(b) + "...") - for v in range(b+1, len(currentTest)): - newTest = currentTest[:b] + currentTest[v:] - r = writeAndRunCandidate(newTest) - if checks(r): - print("BYTE RANGE REMOVAL REDUCED TEST TO", len(newTest), "BYTES") - rangeRemovePos = b - byteReducePos = 0 - changed = True - break - if changed: - break - if not changed: - rangeRemovePos = 0 - - if not changed: - if args.verbose: - print("TRYING ONEOF SWAPPING...") - cuts = s[0] - for i in range(len(cuts)-1): - cuti = cuts[i] - bytesi = currentTest[cuti[0]:cuti[1] + 1] - if args.verbose: - print("TRYING ONEOF SWAPPING FROM BYTE", cuti[0], "[" + " ".join(map(str, bytesi)) + "]") - for j in range(i + 1, len(cuts)): - cutj = cuts[j] - if cutj[0] > cuti[1]: - bytesj = currentTest[cutj[0]:cutj[1] + 1] - if (len(bytesj) > 0) and (bytesi > bytesj): - newTest = currentTest[:cuti[0]] + bytesj + currentTest[cuti[1]+1:cutj[0]] - newTest += bytesi - newTest += currentTest[cutj[1]+1:] - newTest = bytearray(newTest) + startingPos = 0 + while changed: + changed = False + for b in range(startingPos, len(currentTest)): + newTest = currentTest[:b] + currentTest[b+k:] + r = writeAndRunCandidate(newTest) + if checks(r): + print("Removed", k, "byte(s) @", str(b) + ": reduced test to", len(newTest), "bytes") + changed = True + updateCurrent(newTest) + startingPos = b + break + if not changed: + for b in range(0, startingPos): + newTest = currentTest[:b] + currentTest[b+k:] r = writeAndRunCandidate(newTest) if checks(r): - print("ONEOF SWAP @ BYTE", cuti[0], "[" + " ".join(map(str, bytesi)) + "]", "WITH", - cutj[0], "[" + " ".join(map(str, bytesj)) + "]") + print("Removed", k, "byte(s) @", str(b) + ": reduced test to", len(newTest), "bytes") changed = True - byteReducePos = 0 + updateCurrent(newTest) + startingPos = b break - if changed: - break + lastChunkRemovalTest[k] = bytearray(currentTest) + passInfo() - if not changed: - if args.verbose: - print("TRYING BYTE REDUCTIONS...") - for b in range(byteReducePos, len(currentTest)): - for v in range(0, currentTest[b]): - newTest = bytearray(currentTest) - newTest[b] = v - r = writeAndRunCandidate(newTest) - if checks(r): - print("BYTE REDUCTION: BYTE", b, "FROM", currentTest[b], "TO", v) - changed = True - byteReducePos = b+1 - break - if changed: - break - - if not changed: - for b in range(0, byteReducePos): - for v in range(0, currentTest[b]): - newTest = bytearray(currentTest) - newTest[b] = v - r = writeAndRunCandidate(newTest) - if checks(r): - print("BYTE REDUCTION: BYTE", b, "FROM", currentTest[b], "TO", v) - changed = True - byteReducePos = b+1 - break - if changed: - break - if not changed: - byteReducePos = 0 - - if not changed: - if args.verbose: - print("TRYING BYTE REDUCE AND DELETE...") - for b in range(0, len(currentTest)-1): - if currentTest[b] == 0: - continue - newTest = bytearray(currentTest) - newTest[b] = currentTest[b]-1 - newTest = newTest[:b+1] + newTest[b+2:] - r = writeAndRunCandidate(newTest) - if checks(r): - print("BYTE REDUCE AND DELETE AT BYTE", b) - changed = True - break - - if not changed: - if args.verbose: - print("TRYING BYTE REDUCE AND DELETE 4...") - for b in range(0, len(currentTest)-5): - if currentTest[b] == 0: - continue - newTest = bytearray(currentTest) - newTest[b] = currentTest[b]-1 - newTest = newTest[:b+1] + newTest[b+5:] - r = writeAndRunCandidate(newTest) - if checks(r): - print("BYTE REDUCE AND DELETE 4 AT BYTE", b) - changed = True - break - - if not changed: - if args.verbose: - print("TRYING BYTE REDUCE AND DELETE 8...") - for b in range(0, len(currentTest)-9): - if currentTest[b] == 0: - continue - newTest = bytearray(currentTest) - newTest[b] = currentTest[b]-1 - newTest = newTest[:b+1] + newTest[b+9:] - r = writeAndRunCandidate(newTest) - if checks(r): - print("BYTE REDUCE AND DELETE 8 AT BYTE", b) - changed = True - break - - if (not args.fast) and (not changed): - for b1 in range(0, len(currentTest)-4): + for k in [1, 4, 8]: + if currentTest != lastReduceAndDeleteTest[k]: if args.verbose: - print("TRYING BYTE PATTERN SEARCH FROM BYTE", str(b1) + "...") - for b2 in range(b1+2, len(currentTest)-4): - v1 = (currentTest[b1], currentTest[b1+1]) - v2 = (currentTest[b2], currentTest[b2+1]) - if (v1 == v2): - ba = bytearray(v1) - part1 = currentTest[:b1] - part2 = currentTest[b1+2:b2] - part3 = currentTest[b2+2:] - banews = [] - banews.append(ba[0:1]) - banews.append(ba[1:2]) - if ba[0] > 0: - for v in range(0, ba[0]): - banews.append(bytearray([v, ba[1]])) - banews.append(bytearray([ba[0]-1])) - if ba[1] > 0: - for v in range(0, ba[1]): - banews.append(bytearray([ba[0], v])) - for banew in banews: - newTest = part1 + banew + part2 + banew + part3 + print("*"*80+"\nPASS: byte reduce and delete", str(k) + "...") + changed = True + while changed: + changed = False + for b in range(0, len(currentTest)-k): + if currentTest[b] == 0: + continue + newTest = bytearray(currentTest) + newTest[b] = currentTest[b]-1 + newTest = newTest[:b+1] + newTest[b+k+1:] + r = writeAndRunCandidate(newTest) + if checks(r): + print("Reduced byte", b, "by 1 and deleted", k, "bytes, reducing test to", len(newTest), "bytes") + changed = True + updateCurrent(newTest) + break + lastReduceAndDeleteTest[k] = bytearray(currentTest) + passInfo() + + if not args.fast: + if currentTest != lastAllRangeTest: + if args.verbose: + print("*"*80+"\nPASS: trying all byte range removals...") + changed = True + startingPos = 0 + while changed: + changed = False + for b in range(startingPos, len(currentTest)): + if args.verbose: + print("Trying byte range removal from", str(b) + "...") + for v in range(b+2, min(len(currentTest), b+maxByteRange)): + if (v-b) in [4, 8]: + continue + newTest = currentTest[:b] + currentTest[v:] r = writeAndRunCandidate(newTest) if checks(r): - print("BYTE PATTERN", tuple(ba), "AT", b1, "AND", b2, "CHANGED TO", tuple(banew)) + print("Byte range removal of bytes", str(b) + "-" + str(v-1), + "reduced test to", len(newTest), "bytes") changed = True + updateCurrent(newTest) + startingPos = b break if changed: break - if changed: - break + if not changed: + for b in range(0, startingPos): + if args.verbose: + print("Trying byte range removal from", str(b) + "...") + for v in range(b+2, min(len(currentTest), b+maxByteRange)): + if (v-b) in [4, 8]: + continue + newTest = currentTest[:b] + currentTest[v:] + r = writeAndRunCandidate(newTest) + if checks(r): + print("Byte range removal of bytes", str(b) + "-" + str(v-1), + "reduced test to", len(newTest), "bytes") + changed = True + updateCurrent(newTest) + startingPos = b + break + if changed: + break + lastAllRangeTest = bytearray(currentTest) + passInfo() - if changed: - currentTest = newTest - print("WRITING REDUCED TEST WITH", len(currentTest), "BYTES TO", out) - with open(out, 'wb') as outf: - outf.write(currentTest) - s = structure(r) - fixUp(currentTest, rangeConversions(r)) - else: - print("*" * 80) - print("NO (MORE) REDUCTIONS FOUND") + if currentTest != lastOneOfSwapTest: + if args.verbose: + print("*"*80+"\nPASS: swapping OneOfs...") + changed = True + while changed: + changed = False + cuts = s[0] + for i in range(len(cuts)-1): + cuti = cuts[i] + bytesi = currentTest[cuti[0]:cuti[1] + 1] + if args.verbose: + print("Trying OneOf swaps from byte", cuti[0], "[" + " ".join(map(str, bytesi)) + "]") + for j in range(i + 1, len(cuts)): + cutj = cuts[j] + if cutj[0] > cuti[1]: + bytesj = currentTest[cutj[0]:cutj[1] + 1] + if (len(bytesj) > 0) and (bytesi > bytesj): + newTest = currentTest[:cuti[0]] + bytesj + currentTest[cuti[1]+1:cutj[0]] + newTest += bytesi + newTest += currentTest[cutj[1]+1:] + newTest = bytearray(newTest) + r = writeAndRunCandidate(newTest) + if checks(r): + print("OneOf swap @ byte", cuti[0], "[" + " ".join(map(str, bytesi)) + "]", "with", + cutj[0], "[" + " ".join(map(str, bytesj)) + "]") + changed = True + updateCurrent(newTest) + break + if changed: + break + if changed: + break + lastOneOfSwapTest = bytearray(currentTest) + passInfo() + + if currentTest != lastByteReduceTest: + if args.verbose: + print("*"*80+"\nPASS: byte reductions...") + changed = True + startingPos = 0 + while changed: + changed = False + for b in range(startingPos, len(currentTest)): + for v in range(0, currentTest[b]): + newTest = bytearray(currentTest) + newTest[b] = v + r = writeAndRunCandidate(newTest) + if checks(r): + print("Reduced byte", b, "from", currentTest[b], "to", v) + changed = True + updateCurrent(newTest) + startingPos = b+1 + break + if changed: + break + if changed: + continue + for b in range(0, startingPos): + for v in range(0, currentTest[b]): + newTest = bytearray(currentTest) + newTest[b] = v + r = writeAndRunCandidate(newTest) + if checks(r): + print("Reduced byte", b, "from", currentTest[b], "to", v) + changed = True + updateCurrent(newTest) + startingPos = b+1 + break + if changed: + break + lastByteReduceTest = bytearray(currentTest) + passInfo() + + if (args.slow or args.slowest) and (oldTest == currentTest): + if currentTest != lastPatternSearchTest: + if args.verbose: + print("*"*80+"\nPASS: byte pattern search...") + changed = True + while changed: + changed = False + for b1 in range(0, len(currentTest)-4): + if args.verbose: + print("Trying byte pattern search from byte", str(b1) + "...") + for b2 in range(b1+2, len(currentTest)-4): + v1 = (currentTest[b1], currentTest[b1+1]) + v2 = (currentTest[b2], currentTest[b2+1]) + if (v1 == v2): + ba = bytearray(v1) + part1 = currentTest[:b1] + part2 = currentTest[b1+2:b2] + part3 = currentTest[b2+2:] + banews = [] + banews.append(ba[0:1]) + banews.append(ba[1:2]) + if ba[0] > 0: + for v in range(0, ba[0]): + banews.append(bytearray([v, ba[1]])) + banews.append(bytearray([ba[0]-1])) + if ba[1] > 0: + for v in range(0, ba[1]): + banews.append(bytearray([ba[0], v])) + for banew in banews: + newTest = part1 + banew + part2 + banew + part3 + r = writeAndRunCandidate(newTest) + if checks(r): + print("Byte pattern", tuple(ba), "at", b1, "and", b2, "changed to", tuple(banew)) + changed = True + updateCurrent(newTest) + break + if changed: + break + if changed: + break + lastPatternSearchTest = bytearray(currentTest) + passInfo() + + if oldTest == currentTest: + print("*" * 80) + print("DONE: NO (MORE) REDUCTIONS FOUND") except TimeoutException: print("*" * 80) - print("REDUCTION TIMED OUT AFTER", args.timeout, "SECONDS") + print("DONE: REDUCTION TIMED OUT AFTER", args.timeout, "SECONDS") + + print("=" * 80) + percent = 100.0 * ((initialSize - len(currentTest)) / initialSize) + print("Completed", iteration, "iterations:", round(time.time()-start, 2), "secs /", + candidateRuns, "execs /", str(round(percent, 2)) + "% reduction") if (s[1] + 1) > len(currentTest): - print("PADDING TEST WITH", (s[1] + 1) - len(currentTest), "ZEROS") + print("Padding test with", (s[1] + 1) - len(currentTest), "zeroes") padding = bytearray('\x00' * ((s[1] + 1) - len(currentTest)), 'utf-8') currentTest = currentTest + padding - print("=" * 80) - percent = 100.0 * ((initialSize - len(currentTest)) / initialSize) - print("COMPLETED AFTER", iteration, "ITERATIONS:", round(time.time()-start, 2), "SECONDS /", - candidateRuns, "EXECUTIONS /", str(round(percent, 2)) + "% REDUCTION") - print("WRITING REDUCED TEST WITH", len(currentTest), "BYTES TO", out) + print("Writing reduced test with", len(currentTest), "bytes to", out) with open(out, 'wb') as outf: outf.write(currentTest)