almost new reducer, with nested fixpoint passes

This commit is contained in:
agroce 2019-05-20 11:51:54 -07:00
parent d39000393d
commit f916a93405

View File

@ -17,46 +17,47 @@ from __future__ import print_function
import argparse import argparse
import subprocess import subprocess
import os import os
import sys
import time import time
def main(): def main():
global candidateRuns global candidateRuns, currentTest, s, passStart
parser = argparse.ArgumentParser(description="Intelligently reduce test case") parser = argparse.ArgumentParser(description="Intelligently reduce test case")
parser.add_argument( parser.add_argument(
"binary", type=str, help="Path to the test binary to run.") "binary", type=str, help="Path to the test binary to run.")
parser.add_argument( parser.add_argument(
"input_test", type=str, help="Path to test to reduce.") "input_test", type=str, help="Path to test to reduce.")
parser.add_argument( parser.add_argument(
"output_test", type=str, help="Path for reduced test.") "output_test", type=str, help="Path for reduced test.")
parser.add_argument( parser.add_argument(
"--which_test", type=str, help="Which test to run (equivalent to --input_which_test).", default=None) "--which_test", type=str, help="Which test to run (equivalent to --input_which_test).", default=None)
parser.add_argument( parser.add_argument(
"--criteria", type=str, help="String to search for in valid reduction outputs.", "--criteria", type=str, help="String to search for in valid reduction outputs.",
default=None) default=None)
parser.add_argument( parser.add_argument(
"--search", action="store_true", help="Allow initial test to not satisfy criteria (search for test).", "--search", action="store_true", help="Allow initial test to not satisfy criteria (search for test).",
default=None) default=None)
parser.add_argument( parser.add_argument(
"--timeout", type=int, help="After this amount of time (in seconds), give up on reduction.", "--timeout", type=int, help="After this amount of time (in seconds), give up on reduction.",
default=1200) default=1200)
parser.add_argument(
"--maxByteRange", type=int, help="Maximum size of byte chunk to try in range removals.",
default=16)
parser.add_argument( parser.add_argument(
"--fast", action='store_true', "--fast", action='store_true',
help="Faster, less complete, reduction (no range or byte pattern attempts).") help="Faster, less complete, reduction (no byte range removal pass).")
parser.add_argument(
"--slow", action='store_true',
help="Slower, more complete, reduction (byte pattern pass).")
parser.add_argument(
"--slowest", action='store_true',
help="Slowest, most complete, reduction (byte pattern pass, tries all byte ranges).")
parser.add_argument( parser.add_argument(
"--verbose", action='store_true', "--verbose", action='store_true',
help="Verbose reduction.") help="Verbose reduction.")
parser.add_argument( parser.add_argument(
"--fork", action='store_true', "--fork", action='store_true',
help="Fork when running.") help="Fork when running.")
@ -66,6 +67,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
maxByteRange = args.maxByteRange
deepstate = args.binary deepstate = args.binary
test = args.input_test test = args.input_test
out = args.output_test out = args.output_test
@ -77,6 +79,7 @@ def main():
def runCandidate(candidate): def runCandidate(candidate):
global candidateRuns global candidateRuns
candidateRuns += 1 candidateRuns += 1
if (time.time() - start) > args.timeout: if (time.time() - start) > args.timeout:
raise TimeoutException raise TimeoutException
@ -147,7 +150,7 @@ def main():
conversions.append((currentMulti, int(line.split()[-1]))) conversions.append((currentMulti, int(line.split()[-1])))
return conversions return conversions
def fixUp(test, conversions): def fixRangeConversions(test, conversions):
numConversions = 0 numConversions = 0
for (pos, value) in conversions: for (pos, value) in conversions:
if pos[1] >= len(test): if pos[1] >= len(test):
@ -158,262 +161,338 @@ def main():
test[b] = 0 test[b] = 0
test[pos[1]] = value test[pos[1]] = value
if numConversions > 0: if numConversions > 0:
print("APPLIED", numConversions, "RANGE CONVERSIONS") print("Applied", numConversions, "range conversions")
initial = runCandidate(test) initial = runCandidate(test)
if (not args.search) and (not checks(initial)): if (not args.search) and (not checks(initial)):
print("STARTING TEST DOES NOT SATISFY REDUCTION CRITERIA") print("STARTING TEST DOES NOT SATISFY REDUCTION CRITERIA!")
return 1 return 1
with open(test, 'rb') as test: with open(test, 'rb') as test:
currentTest = bytearray(test.read()) currentTest = bytearray(test.read())
original = bytearray(currentTest) original = bytearray(currentTest)
print("ORIGINAL TEST HAS", len(currentTest), "BYTES") print("Original test has", len(currentTest), "bytes")
if args.slowest:
maxByteRange = len(currentTest)
fixUp(currentTest, rangeConversions(initial)) fixRangeConversions(currentTest, rangeConversions(initial))
r = writeAndRunCandidate(currentTest) r = writeAndRunCandidate(currentTest)
assert(checks(r)) assert(checks(r))
s = structure(initial) s = structure(initial)
if (s[1]+1) < len(currentTest): if (s[1]+1) < len(currentTest):
print("LAST BYTE READ IS", s[1]) print("Last byte read:", s[1])
print("SHRINKING TO IGNORE UNREAD BYTES") print("Shrinking to ignore unread bytes")
currentTest = currentTest[:s[1]+1] currentTest = currentTest[:s[1]+1]
if currentTest != original: if currentTest != original:
print("WRITING REDUCED TEST WITH", len(currentTest), "BYTES TO", out) print("Writing reduced test with", len(currentTest), "bytes to", out)
with open(out, 'wb') as outf: with open(out, 'wb') as outf:
outf.write(currentTest) outf.write(currentTest)
initialSize = float(len(currentTest)) initialSize = float(len(currentTest))
iteration = 0 iteration = 0
changed = True
rangeRemovePos = 0 def updateCurrent(newTest):
byteReducePos = 0 global currentTest, s
currentTest = newTest
fixRangeConversions(currentTest, rangeConversions(r))
print("Writing reduced test with", len(currentTest), "bytes to", out)
with open(out, 'wb') as outf:
outf.write(currentTest)
s = structure(r)
percent = 100.0 * ((initialSize - len(currentTest)) / initialSize)
print(round(time.time()-start, 2), "secs /",
candidateRuns, "execs /", str(round(percent, 2)) + "% reduction")
print("="*80)
sys.stdout.flush()
def passInfo():
global passStart
percent = 100.0 * ((initialSize - len(currentTest)) / initialSize)
print("PASS FINISHED IN", round(time.time() - passStart, 2), "SECONDS, RUN:", round(time.time()-start, 2),
"secs /", candidateRuns, "execs /", str(round(percent, 2)) + "% reduction")
passStart = time.time()
oldTest = []
lastOneOfRemovalTest = []
lastChunkRemovalTest = {}
lastChunkRemovalTest[1] = []
lastChunkRemovalTest[4] = []
lastChunkRemovalTest[8] = []
lastReduceAndDeleteTest = {}
lastReduceAndDeleteTest[1] = []
lastReduceAndDeleteTest[4] = []
lastReduceAndDeleteTest[8] = []
lastAllRangeTest = []
lastOneOfSwapTest = []
lastByteReduceTest = []
lastPatternSearchTest = []
passStart = time.time()
try: try:
while changed: while oldTest != currentTest:
changed = False oldTest = bytearray(currentTest)
iteration += 1 iteration += 1
percent = 100.0 * ((initialSize - len(currentTest)) / initialSize) percent = 100.0 * ((initialSize - len(currentTest)) / initialSize)
print("=" * 80) print("=" * 80)
print("STARTING ITERATION #" + str(iteration), round(time.time()-start, 2), "SECONDS /", print("Iteration #" + str(iteration), round(time.time()-start, 2), "secs /",
candidateRuns, "EXECUTIONS /", str(round(percent, 2)) + "% REDUCTION") candidateRuns, "execs /", str(round(percent, 2)) + "% reduction")
if args.verbose:
print("TRYING ONEOF REMOVALS...") if currentTest != lastOneOfRemovalTest:
cuts = s[0] if args.verbose:
for c in cuts: print("*"*80+"\nPASS: removing OneOfs...")
newTest = currentTest[:c[0]] + currentTest[c[1]+1:] changed = True
if len(newTest) == len(currentTest): while changed:
continue # Ignore non-shrinking reductions changed = False
r = writeAndRunCandidate(newTest) cuts = s[0]
if checks(r): for c in cuts:
print("ONEOF REMOVAL REDUCED TEST TO", len(newTest), "BYTES") newTest = currentTest[:c[0]] + currentTest[c[1]+1:]
if len(newTest) == len(currentTest):
continue # Ignore non-shrinking reductions
r = writeAndRunCandidate(newTest)
if checks(r):
print("OneOf removal reduced test to", len(newTest), "bytes")
changed = True
updateCurrent(newTest)
break
lastOneOfRemovalTest = bytearray(currentTest)
passInfo()
for k in [1, 4, 8]:
if currentTest != lastChunkRemovalTest[k]:
if args.verbose:
print("*"*80+"\nPASS: trying", k, "byte chunk removals...")
changed = True changed = True
rangeRemovePos = 0 startingPos = 0
byteReducePos = 0 while changed:
break changed = False
for b in range(startingPos, len(currentTest)):
if (not args.fast) and (not changed): newTest = currentTest[:b] + currentTest[b+k:]
for b in range(rangeRemovePos, len(currentTest)): r = writeAndRunCandidate(newTest)
if args.verbose: if checks(r):
print("TRYING BYTE RANGE REMOVAL FROM BYTE", str(b) + "...") print("Removed", k, "byte(s) @", str(b) + ": reduced test to", len(newTest), "bytes")
for v in range(b+1, len(currentTest)): changed = True
newTest = currentTest[:b] + currentTest[v:] updateCurrent(newTest)
r = writeAndRunCandidate(newTest) startingPos = b
if checks(r): break
print("BYTE RANGE REMOVAL REDUCED TEST TO", len(newTest), "BYTES") if not changed:
rangeRemovePos = b for b in range(0, startingPos):
byteReducePos = 0 newTest = currentTest[:b] + currentTest[b+k:]
changed = True
break
if changed:
break
if (not args.fast) and (not changed):
for b in range(0, rangeRemovePos):
if args.verbose:
print("TRYING BYTE RANGE REMOVAL FROM BYTE", str(b) + "...")
for v in range(b+1, len(currentTest)):
newTest = currentTest[:b] + currentTest[v:]
r = writeAndRunCandidate(newTest)
if checks(r):
print("BYTE RANGE REMOVAL REDUCED TEST TO", len(newTest), "BYTES")
rangeRemovePos = b
byteReducePos = 0
changed = True
break
if changed:
break
if not changed:
rangeRemovePos = 0
if not changed:
if args.verbose:
print("TRYING ONEOF SWAPPING...")
cuts = s[0]
for i in range(len(cuts)-1):
cuti = cuts[i]
bytesi = currentTest[cuti[0]:cuti[1] + 1]
if args.verbose:
print("TRYING ONEOF SWAPPING FROM BYTE", cuti[0], "[" + " ".join(map(str, bytesi)) + "]")
for j in range(i + 1, len(cuts)):
cutj = cuts[j]
if cutj[0] > cuti[1]:
bytesj = currentTest[cutj[0]:cutj[1] + 1]
if (len(bytesj) > 0) and (bytesi > bytesj):
newTest = currentTest[:cuti[0]] + bytesj + currentTest[cuti[1]+1:cutj[0]]
newTest += bytesi
newTest += currentTest[cutj[1]+1:]
newTest = bytearray(newTest)
r = writeAndRunCandidate(newTest) r = writeAndRunCandidate(newTest)
if checks(r): if checks(r):
print("ONEOF SWAP @ BYTE", cuti[0], "[" + " ".join(map(str, bytesi)) + "]", "WITH", print("Removed", k, "byte(s) @", str(b) + ": reduced test to", len(newTest), "bytes")
cutj[0], "[" + " ".join(map(str, bytesj)) + "]")
changed = True changed = True
byteReducePos = 0 updateCurrent(newTest)
startingPos = b
break break
if changed: lastChunkRemovalTest[k] = bytearray(currentTest)
break passInfo()
if not changed: for k in [1, 4, 8]:
if args.verbose: if currentTest != lastReduceAndDeleteTest[k]:
print("TRYING BYTE REDUCTIONS...")
for b in range(byteReducePos, len(currentTest)):
for v in range(0, currentTest[b]):
newTest = bytearray(currentTest)
newTest[b] = v
r = writeAndRunCandidate(newTest)
if checks(r):
print("BYTE REDUCTION: BYTE", b, "FROM", currentTest[b], "TO", v)
changed = True
byteReducePos = b+1
break
if changed:
break
if not changed:
for b in range(0, byteReducePos):
for v in range(0, currentTest[b]):
newTest = bytearray(currentTest)
newTest[b] = v
r = writeAndRunCandidate(newTest)
if checks(r):
print("BYTE REDUCTION: BYTE", b, "FROM", currentTest[b], "TO", v)
changed = True
byteReducePos = b+1
break
if changed:
break
if not changed:
byteReducePos = 0
if not changed:
if args.verbose:
print("TRYING BYTE REDUCE AND DELETE...")
for b in range(0, len(currentTest)-1):
if currentTest[b] == 0:
continue
newTest = bytearray(currentTest)
newTest[b] = currentTest[b]-1
newTest = newTest[:b+1] + newTest[b+2:]
r = writeAndRunCandidate(newTest)
if checks(r):
print("BYTE REDUCE AND DELETE AT BYTE", b)
changed = True
break
if not changed:
if args.verbose:
print("TRYING BYTE REDUCE AND DELETE 4...")
for b in range(0, len(currentTest)-5):
if currentTest[b] == 0:
continue
newTest = bytearray(currentTest)
newTest[b] = currentTest[b]-1
newTest = newTest[:b+1] + newTest[b+5:]
r = writeAndRunCandidate(newTest)
if checks(r):
print("BYTE REDUCE AND DELETE 4 AT BYTE", b)
changed = True
break
if not changed:
if args.verbose:
print("TRYING BYTE REDUCE AND DELETE 8...")
for b in range(0, len(currentTest)-9):
if currentTest[b] == 0:
continue
newTest = bytearray(currentTest)
newTest[b] = currentTest[b]-1
newTest = newTest[:b+1] + newTest[b+9:]
r = writeAndRunCandidate(newTest)
if checks(r):
print("BYTE REDUCE AND DELETE 8 AT BYTE", b)
changed = True
break
if (not args.fast) and (not changed):
for b1 in range(0, len(currentTest)-4):
if args.verbose: if args.verbose:
print("TRYING BYTE PATTERN SEARCH FROM BYTE", str(b1) + "...") print("*"*80+"\nPASS: byte reduce and delete", str(k) + "...")
for b2 in range(b1+2, len(currentTest)-4): changed = True
v1 = (currentTest[b1], currentTest[b1+1]) while changed:
v2 = (currentTest[b2], currentTest[b2+1]) changed = False
if (v1 == v2): for b in range(0, len(currentTest)-k):
ba = bytearray(v1) if currentTest[b] == 0:
part1 = currentTest[:b1] continue
part2 = currentTest[b1+2:b2] newTest = bytearray(currentTest)
part3 = currentTest[b2+2:] newTest[b] = currentTest[b]-1
banews = [] newTest = newTest[:b+1] + newTest[b+k+1:]
banews.append(ba[0:1]) r = writeAndRunCandidate(newTest)
banews.append(ba[1:2]) if checks(r):
if ba[0] > 0: print("Reduced byte", b, "by 1 and deleted", k, "bytes, reducing test to", len(newTest), "bytes")
for v in range(0, ba[0]): changed = True
banews.append(bytearray([v, ba[1]])) updateCurrent(newTest)
banews.append(bytearray([ba[0]-1])) break
if ba[1] > 0: lastReduceAndDeleteTest[k] = bytearray(currentTest)
for v in range(0, ba[1]): passInfo()
banews.append(bytearray([ba[0], v]))
for banew in banews: if not args.fast:
newTest = part1 + banew + part2 + banew + part3 if currentTest != lastAllRangeTest:
if args.verbose:
print("*"*80+"\nPASS: trying all byte range removals...")
changed = True
startingPos = 0
while changed:
changed = False
for b in range(startingPos, len(currentTest)):
if args.verbose:
print("Trying byte range removal from", str(b) + "...")
for v in range(b+2, min(len(currentTest), b+maxByteRange)):
if (v-b) in [4, 8]:
continue
newTest = currentTest[:b] + currentTest[v:]
r = writeAndRunCandidate(newTest) r = writeAndRunCandidate(newTest)
if checks(r): if checks(r):
print("BYTE PATTERN", tuple(ba), "AT", b1, "AND", b2, "CHANGED TO", tuple(banew)) print("Byte range removal of bytes", str(b) + "-" + str(v-1),
"reduced test to", len(newTest), "bytes")
changed = True changed = True
updateCurrent(newTest)
startingPos = b
break break
if changed: if changed:
break break
if changed: if not changed:
break for b in range(0, startingPos):
if args.verbose:
print("Trying byte range removal from", str(b) + "...")
for v in range(b+2, min(len(currentTest), b+maxByteRange)):
if (v-b) in [4, 8]:
continue
newTest = currentTest[:b] + currentTest[v:]
r = writeAndRunCandidate(newTest)
if checks(r):
print("Byte range removal of bytes", str(b) + "-" + str(v-1),
"reduced test to", len(newTest), "bytes")
changed = True
updateCurrent(newTest)
startingPos = b
break
if changed:
break
lastAllRangeTest = bytearray(currentTest)
passInfo()
if changed: if currentTest != lastOneOfSwapTest:
currentTest = newTest if args.verbose:
print("WRITING REDUCED TEST WITH", len(currentTest), "BYTES TO", out) print("*"*80+"\nPASS: swapping OneOfs...")
with open(out, 'wb') as outf: changed = True
outf.write(currentTest) while changed:
s = structure(r) changed = False
fixUp(currentTest, rangeConversions(r)) cuts = s[0]
else: for i in range(len(cuts)-1):
print("*" * 80) cuti = cuts[i]
print("NO (MORE) REDUCTIONS FOUND") bytesi = currentTest[cuti[0]:cuti[1] + 1]
if args.verbose:
print("Trying OneOf swaps from byte", cuti[0], "[" + " ".join(map(str, bytesi)) + "]")
for j in range(i + 1, len(cuts)):
cutj = cuts[j]
if cutj[0] > cuti[1]:
bytesj = currentTest[cutj[0]:cutj[1] + 1]
if (len(bytesj) > 0) and (bytesi > bytesj):
newTest = currentTest[:cuti[0]] + bytesj + currentTest[cuti[1]+1:cutj[0]]
newTest += bytesi
newTest += currentTest[cutj[1]+1:]
newTest = bytearray(newTest)
r = writeAndRunCandidate(newTest)
if checks(r):
print("OneOf swap @ byte", cuti[0], "[" + " ".join(map(str, bytesi)) + "]", "with",
cutj[0], "[" + " ".join(map(str, bytesj)) + "]")
changed = True
updateCurrent(newTest)
break
if changed:
break
if changed:
break
lastOneOfSwapTest = bytearray(currentTest)
passInfo()
if currentTest != lastByteReduceTest:
if args.verbose:
print("*"*80+"\nPASS: byte reductions...")
changed = True
startingPos = 0
while changed:
changed = False
for b in range(startingPos, len(currentTest)):
for v in range(0, currentTest[b]):
newTest = bytearray(currentTest)
newTest[b] = v
r = writeAndRunCandidate(newTest)
if checks(r):
print("Reduced byte", b, "from", currentTest[b], "to", v)
changed = True
updateCurrent(newTest)
startingPos = b+1
break
if changed:
break
if changed:
continue
for b in range(0, startingPos):
for v in range(0, currentTest[b]):
newTest = bytearray(currentTest)
newTest[b] = v
r = writeAndRunCandidate(newTest)
if checks(r):
print("Reduced byte", b, "from", currentTest[b], "to", v)
changed = True
updateCurrent(newTest)
startingPos = b+1
break
if changed:
break
lastByteReduceTest = bytearray(currentTest)
passInfo()
if (args.slow or args.slowest) and (oldTest == currentTest):
if currentTest != lastPatternSearchTest:
if args.verbose:
print("*"*80+"\nPASS: byte pattern search...")
changed = True
while changed:
changed = False
for b1 in range(0, len(currentTest)-4):
if args.verbose:
print("Trying byte pattern search from byte", str(b1) + "...")
for b2 in range(b1+2, len(currentTest)-4):
v1 = (currentTest[b1], currentTest[b1+1])
v2 = (currentTest[b2], currentTest[b2+1])
if (v1 == v2):
ba = bytearray(v1)
part1 = currentTest[:b1]
part2 = currentTest[b1+2:b2]
part3 = currentTest[b2+2:]
banews = []
banews.append(ba[0:1])
banews.append(ba[1:2])
if ba[0] > 0:
for v in range(0, ba[0]):
banews.append(bytearray([v, ba[1]]))
banews.append(bytearray([ba[0]-1]))
if ba[1] > 0:
for v in range(0, ba[1]):
banews.append(bytearray([ba[0], v]))
for banew in banews:
newTest = part1 + banew + part2 + banew + part3
r = writeAndRunCandidate(newTest)
if checks(r):
print("Byte pattern", tuple(ba), "at", b1, "and", b2, "changed to", tuple(banew))
changed = True
updateCurrent(newTest)
break
if changed:
break
if changed:
break
lastPatternSearchTest = bytearray(currentTest)
passInfo()
if oldTest == currentTest:
print("*" * 80)
print("DONE: NO (MORE) REDUCTIONS FOUND")
except TimeoutException: except TimeoutException:
print("*" * 80) print("*" * 80)
print("REDUCTION TIMED OUT AFTER", args.timeout, "SECONDS") print("DONE: REDUCTION TIMED OUT AFTER", args.timeout, "SECONDS")
print("=" * 80)
percent = 100.0 * ((initialSize - len(currentTest)) / initialSize)
print("Completed", iteration, "iterations:", round(time.time()-start, 2), "secs /",
candidateRuns, "execs /", str(round(percent, 2)) + "% reduction")
if (s[1] + 1) > len(currentTest): if (s[1] + 1) > len(currentTest):
print("PADDING TEST WITH", (s[1] + 1) - len(currentTest), "ZEROS") print("Padding test with", (s[1] + 1) - len(currentTest), "zeroes")
padding = bytearray('\x00' * ((s[1] + 1) - len(currentTest)), 'utf-8') padding = bytearray('\x00' * ((s[1] + 1) - len(currentTest)), 'utf-8')
currentTest = currentTest + padding currentTest = currentTest + padding
print("=" * 80) print("Writing reduced test with", len(currentTest), "bytes to", out)
percent = 100.0 * ((initialSize - len(currentTest)) / initialSize)
print("COMPLETED AFTER", iteration, "ITERATIONS:", round(time.time()-start, 2), "SECONDS /",
candidateRuns, "EXECUTIONS /", str(round(percent, 2)) + "% REDUCTION")
print("WRITING REDUCED TEST WITH", len(currentTest), "BYTES TO", out)
with open(out, 'wb') as outf: with open(out, 'wb') as outf:
outf.write(currentTest) outf.write(currentTest)