From e4f4cfe0db8033ada9eb864db72dc04be512ee3d Mon Sep 17 00:00:00 2001 From: Peter Goodman Date: Mon, 30 Oct 2017 00:45:59 -0400 Subject: [PATCH] Kind of feature parity between Manticore and Angr on these tests. --- bin/mctest/__main__.py | 117 +++++++++++++++++---- bin/mctest/main_angr.py | 162 ++++++++++++++++-------------- examples/ArithmeticProperties.cpp | 30 +++--- src/include/mctest/McTest.h | 2 + 4 files changed, 204 insertions(+), 107 deletions(-) diff --git a/bin/mctest/__main__.py b/bin/mctest/__main__.py index d0e9600..92d3b17 100644 --- a/bin/mctest/__main__.py +++ b/bin/mctest/__main__.py @@ -20,6 +20,7 @@ import logging import manticore import multiprocessing import sys +import traceback from manticore.core.state import TerminateState from manticore.utils.helpers import issymbolic @@ -105,8 +106,12 @@ def read_api_table(state, ea): def make_symbolic_input(state, input_begin_ea, input_end_ea): """Fill in the input data array with symbolic data.""" input_size = input_end_ea - input_begin_ea - data = state.new_symbolic_buffer(nbytes=input_size, name='MCTEST_INPUT') - state.cpu.write_bytes(input_begin_ea, data) + data = [] + for i in xrange(input_end_ea - input_begin_ea): + input_byte = state.new_symbolic_value(8, "MCTEST_INPUT_{}".format(i)) + data.append(input_byte) + state.cpu.write_int(input_begin_ea + i, input_byte, 8) + return data @@ -131,25 +136,32 @@ def hook_Assume(state, arg): state.constrain(constraint) +OUR_TERMINATION_REASON = "I McTest'd it" + + +def report_state(state): + test = state.context['test'] + if state.context['failed']: + message = (3, "Failed: {}".format(test.name)) + else: + message = (1, "Passed: {}".format(test.name)) + state.context['log_messages'].append(message) + raise TerminateState(OUR_TERMINATION_REASON, testcase=False) + + def hook_Pass(state): """Implements McTest_Pass, which notifies us of a passing test.""" - L.info("Passed test case") - if state.context['failed']: - raise TerminateState("Got to end of failing test case.") - else: - raise TerminateState("Passed test case") + report_state(state) def hook_Fail(state): """Implements McTest_Fail, which notifies us of a passing test.""" - L.error("Failed test case") state.context['failed'] = 1 - raise TerminateState("Failed test case") + report_state(state) def hook_SoftFail(state): """Implements McTest_Fail, which notifies us of a passing test.""" - L.error("Soft failure in test case, continuing") state.context['failed'] = 1 @@ -162,41 +174,112 @@ LEVEL_TO_LOGGER = { } -def hook_Log(state): +def hook_Log(state, level, begin_ea, end_ea): """Implements McTest_Log, which lets Manticore intercept and handle the printing of log messages from the simulated tests.""" - pass + level = state.solve_one(level) + assert level in LEVEL_TO_LOGGER + + begin_ea = state.solve_one(begin_ea) + end_ea = state.solve_one(end_ea) + assert begin_ea <= end_ea + + message_bytes = [] + for i in xrange(end_ea - begin_ea): + message_bytes.append(state.cpu.memory[begin_ea + i]) + + state.context['log_messages'].append((level, message_bytes)) + def hook(func): return lambda state: state.invoke_model(func) -def run_test(state, apis, test): +def done_test(_, state, state_id, reason): + """Called when a state is terminated.""" + if OUR_TERMINATION_REASON not in reason: + L.error("State {} terminated for unknown reason: {}".format( + state_id, reason)) + return + + test = state.context['test'] + input_length, _ = read_uint32_t(state, state.context['InputIndex']) + + # Dump out any pending log messages reported by `McTest_Log`. + for level, message_bytes in state.context['log_messages']: + message = [] + for b in message_bytes: + if issymbolic(b): + b_ord = state.solve_one(b) + state.constrain(b == b_ord) + message.append(chr(b_ord)) + elif isinstance(b, (int, long)): + message.append(chr(b)) + else: + message.append(b) + + LEVEL_TO_LOGGER[level]("".join(message)) + + max_length = state.context['InputEnd'] - state.context['InputBegin'] + if input_length > max_length: + L.critical("Test used too many input bytes ({} vs. {})".format( + input_length, max_length)) + return + + # Solve for the input bytes. + output = [] + for i in xrange(input_length): + b = state.cpu.read_int(state.context['InputBegin'] + i, 8) + if issymbolic(b): + b = state.solve_one(b) + output.append("{:2x}".format(b)) + + L.info("Input: {}".format(" ".join(output))) + + +def do_run_test(state, apis, test): """Run an individual test case.""" state.cpu.PC = test.ea m = manticore.Manticore(state, sys.argv[1:]) + m.verbosity(1) state = m.initial_state + messages = [(1, "Running {} from {}:{}".format( + test.name, test.file_name, test.line_number))] + + state.context['InputBegin'] = apis['InputBegin'] + state.context['InputEnd'] = apis['InputEnd'] + state.context['InputIndex'] = apis['InputIndex'] + state.context['test'] = test state.context['failed'] = 0 - state.context['log_messages'] = [] - state.context['input'] = make_symbolic_input( - state, apis['InputBegin'], apis['InputEnd']) + state.context['log_messages'] = messages + + make_symbolic_input(state, apis['InputBegin'], apis['InputEnd']) m.add_hook(apis['IsSymbolicUInt'], hook(hook_IsSymbolicUInt)) m.add_hook(apis['Assume'], hook(hook_Assume)) m.add_hook(apis['Pass'], hook(hook_Pass)) m.add_hook(apis['Fail'], hook(hook_Fail)) m.add_hook(apis['SoftFail'], hook(hook_SoftFail)) + m.add_hook(apis['Log'], hook(hook_Log)) + m.subscribe('will_terminate_state', done_test) m.run() +def run_test(state, apis, test): + try: + do_run_test(state, apis, test) + except: + L.error("Uncaught exception: {}\n{}".format( + sys.exc_info()[0], traceback.format_exc())) + + def run_tests(args, state, apis): """Run all of the test cases.""" pool = multiprocessing.Pool(processes=max(1, args.num_workers)) results = [] tests = find_test_cases(state, apis['LastTestInfo']) for test in tests: - print "Found test", test.name res = pool.apply_async(run_test, (state, apis, test)) results.append(res) diff --git a/bin/mctest/main_angr.py b/bin/mctest/main_angr.py index e17879e..33340a1 100644 --- a/bin/mctest/main_angr.py +++ b/bin/mctest/main_angr.py @@ -21,6 +21,7 @@ import collections import logging import multiprocessing import sys +import traceback L = logging.getLogger("mctest") @@ -143,120 +144,137 @@ class Assume(angr.SimProcedure): self.exit(2) +def report_state(state): + test = state.globals['test'] + if state.globals['failed']: + message = (3, "Failed: {}".format(test.name)) + else: + message = (1, "Passed: {}".format(test.name)) + state.globals['log_messages'].append(message) + + class Pass(angr.SimProcedure): """Implements McTest_Pass, which notifies us of a passing test.""" def run(self): - L.info("Passed test case") + report_state(self.state) self.exit(self.state.globals['failed']) class Fail(angr.SimProcedure): """Implements McTest_Fail, which notifies us of a passing test.""" def run(self): - L.error("Failed test case") self.state.globals['failed'] = 1 + report_state(self.state) self.exit(1) class SoftFail(angr.SimProcedure): """Implements McTest_SoftFail, which notifies us of a passing test.""" def run(self): - L.error("Soft failure in test case, continuing") self.state.globals['failed'] = 1 +LEVEL_TO_LOGGER = { + 0: L.debug, + 1: L.info, + 2: L.warning, + 3: L.error, + 4: L.critical +} + + class Log(angr.SimProcedure): """Implements McTest_Log, which lets Angr intercept and handle the printing of log messages from the simulated tests.""" - - LEVEL_TO_LOGGER = { - 0: L.debug, - 1: L.info, - 2: L.warning, - 3: L.error, - 4: L.critical - } - def run(self, level, begin_ea, end_ea): - print self.state.regs.rdi, level - print self.state.regs.rsi, begin_ea - print self.state.regs.rdx, end_ea level = self.state.solver.eval(level, cast_to=int) - assert level in self.LEVEL_TO_LOGGER + assert level in LEVEL_TO_LOGGER begin_ea = self.state.solver.eval(begin_ea, cast_to=int) end_ea = self.state.solver.eval(end_ea, cast_to=int) assert begin_ea <= end_ea - print level, begin_ea, end_ea - - # Read the message from memory. - message = "" - if begin_ea < end_ea: - size = end_ea - begin_ea - - # If this is an error message, then concretize the message, adding the - # concretizations to the state so that errors we output will eventually - # match the concrete inputs that we generate. - if 3 <= level: - message_bytes = [] - for i in xrange(size): - byte = self.state.memory.load(begin_ea + i, size=8) - byte_as_ord = self.state.solver.eval(byte, cast_to=int) - if self.state.se.symbolic(byte): - self.state.solver.add(byte == byte_as_ord) - message_bytes.append(chr(byte_as_ord)) - - message = "".join(message_bytes) - - # Warning/informational message, don't assert any new constraints. - else: - data = self.state.memory.load(begin_ea, size=size) - message = self.state.solver.eval(data, cast_to=str) - - # Log the message (produced by the program) through to the Python-based - # logger. - self.LEVEL_TO_LOGGER[level](message) + size = end_ea - begin_ea + data = self.state.memory.load(begin_ea, size=size) + self.state.globals['log_messages'].append((level, data)) if 3 == level: - L.error("Soft failure in test case, continuing") self.state.globals['failed'] = 1 # Soft failure on an error log message. elif 4 == level: - L.error("Failed test case") self.state.globals['failed'] = 1 - self.exit(1) # Hard failure on a fatal/critical log message. + report_state(self.state) + self.exit(1) -def run_test(project, test, run_state): +def done_test(state): + test = state.globals['test'] + input_length, _ = read_uint32_t(state, state.globals['InputIndex']) + + # Dump out any pending log messages reported by `McTest_Log`. + for level, message in state.globals['log_messages']: + if not isinstance(message, str): + message = state.solver.eval(message, cast_to=str) + LEVEL_TO_LOGGER[level]("".join(message)) + + max_length = state.globals['InputEnd'] - state.globals['InputBegin'] + if input_length > max_length: + L.critical("Test used too many input bytes ({} vs. {})".format( + input_length, max_length)) + return + + # Solve for the input bytes. + output = [] + data = state.memory.load(state.globals['InputBegin'], size=input_length) + data = state.solver.eval(data, cast_to=str) + for i in xrange(input_length): + output.append("{:2x}".format(ord(data[i]))) + + L.info("Input: {}".format(" ".join(output))) + + +def do_run_test(project, test, apis, run_state): """Symbolically executes a single test function.""" + test_state = project.factory.call_state( test.ea, base_state=run_state) + messages = [(1, "Running {} from {}:{}".format( + test.name, test.file_name, test.line_number))] + + test_state.globals['InputBegin'] = apis['InputBegin'] + test_state.globals['InputEnd'] = apis['InputEnd'] + test_state.globals['InputIndex'] = apis['InputIndex'] + test_state.globals['test'] = test + test_state.globals['failed'] = 0 + test_state.globals['log_messages'] = messages + + make_symbolic_input(test_state, apis['InputBegin'], apis['InputEnd']) + errored = [] test_manager = angr.SimulationManager( project=project, active_states=[test_state], errored=errored) - L.info("Running test case {} from {}:{}".format( - test.name, test.file_name, test.line_number)) - print 'running...' try: test_manager.run() except Exception as e: - import traceback - print e - print traceback.format_exc() - print test_manager - print 'done running' - print errored + L.error("Uncaught exception: {}\n{}".format( + sys.exc_info()[0], traceback.format_exc())) + for state in test_manager.deadended: - last_event = state.history.events[-1] - if 'terminate' == last_event.type: - code = last_event.objects['exit_code']._model_concrete.value - print '???' - print test_manager + done_test(state) + + +def run_test(project, test, apis, run_state): + """Symbolically executes a single test function.""" + try: + do_run_test(project, test, apis, run_state) + except Exception as e: + L.error("Uncaught exception: {}\n{}".format( + sys.exc_info()[0], traceback.format_exc())) + def main(): """Run McTest.""" @@ -278,9 +296,12 @@ def main(): use_sim_procedures=True, translation_cache=True, support_selfmodifying_code=False, - auto_load_libs=False) + auto_load_libs=True) + + entry_state = project.factory.entry_state( + add_options={angr.options.ZERO_FILL_UNCONSTRAINED_MEMORY, + angr.options.STRICT_PAGE_ACCESS}) - entry_state = project.factory.entry_state() addr_size_bits = entry_state.arch.bits # Concretely execute up until `McTest_InjectAngr`. @@ -309,15 +330,6 @@ def main(): # Find the test cases that we want to run. tests = find_test_cases(run_state, apis['LastTestInfo']) - # This will track whether or not a particular state has failed. Some states - # will soft fail, but continue their execution, and so we want to make sure - # that if they continue to a pass function, that they nonetheless are treated - # as failing. - run_state.globals['failed'] = 0 - run_state.globals['log_messages'] = [] - run_state.globals['input'] = make_symbolic_input( - run_state, apis['InputBegin'], apis['InputEnd']) - pool = multiprocessing.Pool(processes=max(1, args.num_workers)) results = [] @@ -325,7 +337,7 @@ def main(): # the test case function. test_managers = [] for test in tests: - res = pool.apply_async(run_test, (project, test, run_state)) + res = pool.apply_async(run_test, (project, test, apis, run_state)) results.append(res) pool.close() diff --git a/examples/ArithmeticProperties.cpp b/examples/ArithmeticProperties.cpp index 84cc373..59f3557 100644 --- a/examples/ArithmeticProperties.cpp +++ b/examples/ArithmeticProperties.cpp @@ -19,23 +19,23 @@ using namespace mctest; -// MCTEST_NOINLINE int add(int x, int y) { -// return x + y; -// } +MCTEST_NOINLINE int add(int x, int y) { + return x + y; +} -// TEST(Arithmetic, AdditionIsCommutative) { -// ForAll([] (int x, int y) { -// ASSERT_EQ(add(x, y), add(y, x)) -// << "Addition of signed integers must commute."; -// }); -// } +TEST(Arithmetic, AdditionIsCommutative) { + ForAll([] (int x, int y) { + ASSERT_EQ(add(x, y), add(y, x)) + << "Addition of signed integers must commute."; + }); +} -// TEST(Arithmetic, AdditionIsAssociative) { -// ForAll([] (int x, int y, int z) { -// ASSERT_EQ(add(x, add(y, z)), add(add(x, y), z)) -// << "Addition of signed integers must associate."; -// }); -// } +TEST(Arithmetic, AdditionIsAssociative) { + ForAll([] (int x, int y, int z) { + ASSERT_EQ(add(x, add(y, z)), add(add(x, y), z)) + << "Addition of signed integers must associate."; + }); +} TEST(Arithmetic, InvertibleMultiplication_CanFail) { ForAll([] (int x, int y) { diff --git a/src/include/mctest/McTest.h b/src/include/mctest/McTest.h index d02f268..dc058a1 100644 --- a/src/include/mctest/McTest.h +++ b/src/include/mctest/McTest.h @@ -125,8 +125,10 @@ enum McTest_LogLevel { McTest_LogDebug = 0, McTest_LogInfo = 1, McTest_LogWarning = 2, + McTest_LogWarn = McTest_LogWarning, McTest_LogError = 3, McTest_LogFatal = 4, + McTest_LogCritical = McTest_LogFatal }; /* Outputs information to a log, using a specific log level. */