From a46e06b03b5839bc8276a565726c322703382e86 Mon Sep 17 00:00:00 2001 From: Peter Goodman Date: Wed, 1 Nov 2017 02:42:31 -0400 Subject: [PATCH] Refactored to split common code between Manticore and Angr out into common.py. Implemented the new deferred streaming stuff, it seems to work semi-well for this simple cases I've tested, but there's still work to do. The latest code has some remaining issues. Printing out the final input bytes in Angr shows the wrong thing, although what gets streamed out is right. This is visible when running mctest-angr examples/ArtihmeticProperties. With Manticore, the big issue is that it doesn't properly pickle smt expressions (or something to this effect), so I'm ending up with multiple definitions of the same stuff and that throws exceptions. This is tricky to deal with because the streaming of output needs to be able to save symbolic data. --- CMakeLists.txt | 2 + bin/mctest/__main__.py | 278 +++++++------- bin/mctest/common.py | 381 +++++++++++++++++++ bin/mctest/main_angr.py | 351 ++++++++--------- examples/ArithmeticProperties.cpp | 3 +- examples/CMakeLists.txt | 2 + examples/StreamingAndFormatting.cpp | 44 +++ src/include/mctest/Compiler.h | 6 +- src/include/mctest/Log.h | 51 +++ src/include/mctest/McTest.h | 63 ++-- src/include/mctest/McTest.hpp | 6 +- src/include/mctest/McUnit.hpp | 82 ++-- src/include/mctest/Quantified.hpp | 6 +- src/include/mctest/Stream.h | 79 ++++ src/include/mctest/Stream.hpp | 102 +++++ src/lib/Log.c | 109 ++++++ src/lib/McTest.c | 132 +++---- src/lib/Stream.c | 563 ++++++++++++++++++++++++++++ 18 files changed, 1765 insertions(+), 495 deletions(-) create mode 100644 bin/mctest/common.py create mode 100644 examples/StreamingAndFormatting.cpp create mode 100644 src/include/mctest/Log.h create mode 100644 src/include/mctest/Stream.h create mode 100644 src/include/mctest/Stream.hpp create mode 100644 src/lib/Log.c create mode 100644 src/lib/Stream.c diff --git a/CMakeLists.txt b/CMakeLists.txt index 32cbc20..2026ac7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,8 @@ endif () add_library(${PROJECT_NAME} STATIC src/lib/McTest.c + src/lib/Log.c + src/lib/Stream.c ) target_include_directories(${PROJECT_NAME} diff --git a/bin/mctest/__main__.py b/bin/mctest/__main__.py index 7f65973..70eb293 100644 --- a/bin/mctest/__main__.py +++ b/bin/mctest/__main__.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import argparse import collections import logging @@ -21,86 +20,107 @@ import manticore import multiprocessing import sys import traceback +from .common import McTest from manticore.core.state import TerminateState from manticore.utils.helpers import issymbolic -L = logging.getLogger("mctest") +L = logging.getLogger("mctest.mcore") L.setLevel(logging.INFO) +OUR_TERMINATION_REASON = "I McTest'd it" -def read_c_string(state, ea): - """Read a concrete NUL-terminated string from `ea`.""" - return state.cpu.read_string(ea) +class MCoreTest(McTest): + def __init__(self, state): + super(MCoreTest, self).__init__() + self.state = state + def __del__(self): + self.state = None -def read_uintptr_t(state, ea): - """Read a uintptr_t value from memory.""" - addr_size_bits = state.cpu.address_bit_size - next_ea = ea + (addr_size_bits // 8) - val = state.cpu.read_int(ea, size=addr_size_bits) - if issymbolic(val): - val = state.solve_one(val) - return val, next_ea + def get_context(self): + return self.state.context + def is_symbolic(self, val): + return manticore.utils.helpers.issymbolic(val) -def read_uint32_t(state, ea): - """Read a uint32_t value from memory.""" - next_ea = ea + 4 - val = state.cpu.read_int(ea, size=32) - if issymbolic(val): - val = state.solve_one(val) - return val, next_ea + def read_uintptr_t(self, ea, concretize=True, constrain=False): + addr_size_bits = self.state.cpu.address_bit_size + next_ea = ea + (addr_size_bits // 8) + val = self.state.cpu.read_int(ea, size=addr_size_bits) + if concretize: + val = self.concretize(val, constrain=constrain) + return val, next_ea + def read_uint64_t(self, ea, concretize=True, constrain=False): + val = self.state.cpu.read_int(ea, size=64) + if concretize: + val = self.concretize(val, constrain=constrain) + return val, ea + 8 -TestInfo = collections.namedtuple( - 'TestInfo', 'ea name file_name line_number') + def read_uint32_t(self, ea, concretize=True, constrain=False): + val = self.state.cpu.read_int(ea, size=32) + if concretize: + val = self.concretize(val, constrain=constrain) + return val, ea + 4 + def read_uint8_t(self, ea, concretize=True, constrain=False): + val = self.state.cpu.read_int(ea, size=8) + if concretize: + val = self.concretize(val, constrain=constrain) + if isinstance(val, str): + assert len(val) == 1 + val = ord(val) + return val, ea + 1 -def read_test_info(state, ea): - """Read in a `McTest_TestInfo` info structure from memory.""" - prev_test_ea, ea = read_uintptr_t(state, ea) - test_func_ea, ea = read_uintptr_t(state, ea) - test_name_ea, ea = read_uintptr_t(state, ea) - file_name_ea, ea = read_uintptr_t(state, ea) - file_line_num, _ = read_uint32_t(state, ea) + def concretize(self, val, constrain=False): + if isinstance(val, (int, long)): + return val + elif isinstance(val, str): + assert len(val) == 1 + return ord(val[0]) - if not test_func_ea or \ - not test_name_ea or \ - not file_name_ea or \ - not file_line_num: # `__LINE__` in C always starts at `1` ;-) - return None, prev_test_ea + assert self.is_symbolic(val) + concrete_val = self.state.solve_one(val) + if isinstance(concrete_val, str): + assert len(concrete_val) == 1 + concrete_val = ord(concrete_val[0]) + if constrain: + self.add_constraint(val == concrete_val) + return concrete_val - test_name = read_c_string(state, test_name_ea) - file_name = read_c_string(state, file_name_ea) - info = TestInfo(test_func_ea, test_name, file_name, file_line_num) - return info, prev_test_ea + def concretize_min(self, val, constrain=False): + if isinstance(val, (int, long)): + return val + concrete_val = self.state.solve_n(val) + if constrain: + self.add_constraint(val == concrete_val) + return concrete_val + def concretize_many(self, val, max_num): + assert 0 < max_num + if isinstance(val, (int, long)): + return [val] + return self.state.solver.eval_upto(val, max_num) -def find_test_cases(state, info_ea): - """Find the test case descriptors.""" - tests = [] - info_ea, _ = read_uintptr_t(state, info_ea) - while info_ea: - test, info_ea = read_test_info(state, info_ea) - if test: - tests.append(test) - tests.sort(key=lambda t: (t.file_name, t.line_number)) - return tests + def add_constraint(self, expr): + if self.is_symbolic(expr): + self.state.constrain(expr) + # TODO(pag): How to check satisfiability? + return True + def pass_test(self): + super(MCoreTest, self).pass_test() + raise TerminateState(OUR_TERMINATION_REASON, testcase=False) -def read_api_table(state, ea): - """Reads in the API table.""" - apis = {} - while True: - api_name_ea, ea = read_uintptr_t(state, ea) - api_ea, ea = read_uintptr_t(state, ea) - if not api_name_ea or not api_ea: - break - api_name = read_c_string(state, api_name_ea) - apis[api_name] = api_ea - return apis + def fail_test(self): + super(MCoreTest, self).fail_test() + raise TerminateState(OUR_TERMINATION_REASON, testcase=False) + + def abandon_test(self): + super(MCoreTest, self).abandon_test() + raise TerminateState(OUR_TERMINATION_REASON, testcase=False) def make_symbolic_input(state, input_begin_ea, input_end_ea): @@ -118,78 +138,63 @@ def make_symbolic_input(state, input_begin_ea, input_end_ea): def hook_IsSymbolicUInt(state, arg): """Implements McTest_IsSymblicUInt, which returns 1 if its input argument has more then one solutions, and zero otherwise.""" - solutions = state.solve_n(arg, 2) - if not solutions: - return 0 - elif 1 == len(solutions): - if issymbolic(arg): - state.constrain(arg == solutions[0]) - return 0 - else: - return 1 + return MCoreTest(state).api_is_symbolic_uint(arg) def hook_Assume(state, arg): """Implements _McTest_Assume, which tries to inject a constraint.""" - constraint = arg != 0 - if issymbolic(constraint): - state.constrain(constraint) + MCoreTest(state).api_assume(arg) -OUR_TERMINATION_REASON = "I McTest'd it" +def hook_StreamInt(state, level, format_ea, unpack_ea, uint64_ea): + """Implements _McTest_StreamInt, which gives us an integer to stream, and + the format to use for streaming.""" + MCoreTest(state).api_stream_int(level, format_ea, unpack_ea, uint64_ea) -def report_state(state): - test = state.context['test'] - if state.context['failed']: - message = (3, "Failed: {}".format(test.name)) - else: - message = (1, "Passed: {}".format(test.name)) - state.context['log_messages'].append(message) - raise TerminateState(OUR_TERMINATION_REASON, testcase=False) +def hook_StreamFloat(state, level, format_ea, unpack_ea, double_ea): + """Implements _McTest_StreamFloat, which gives us an double to stream, and + the format to use for streaming.""" + MCoreTest(state).api_stream_float(level, format_ea, unpack_ea, double_ea) + + +def hook_StreamString(state, level, format_ea, str_ea): + """Implements _McTest_StreamString, which gives us an double to stream, and + the format to use for streaming.""" + MCoreTest(state).api_stream_string(level, format_ea, str_ea) + + +def hook_LogStream(state, level): + """Implements McTest_LogStream, which converts the contents of a stream for + level `level` into a log for level `level`.""" + MCoreTest(state).api_log_stream(level) def hook_Pass(state): """Implements McTest_Pass, which notifies us of a passing test.""" - report_state(state) + MCoreTest(state).api_pass() def hook_Fail(state): """Implements McTest_Fail, which notifies us of a passing test.""" - state.context['failed'] = 1 - report_state(state) + MCoreTest(state).api_fail() + + +def hook_Abandon(state, reason): + """Implements McTest_Abandon, which notifies us that a problem happened + in McTest.""" + MCoreTest(state).api_abandon(reason) def hook_SoftFail(state): """Implements McTest_Fail, which notifies us of a passing test.""" - state.context['failed'] = 1 + MCoreTest(state).api_soft_fail() -LEVEL_TO_LOGGER = { - 0: L.debug, - 1: L.info, - 2: L.warning, - 3: L.error, - 4: L.critical -} - - -def hook_Log(state, level, begin_ea): +def hook_Log(state, level, ea): """Implements McTest_Log, which lets Manticore intercept and handle the printing of log messages from the simulated tests.""" - level = state.solve_one(level) - assert level in LEVEL_TO_LOGGER - - begin_ea = state.solve_one(begin_ea) - - message_bytes = [] - for i in xrange(4096): - b = state.cpu.read_int(begin_ea + i, 8) - if not issymbolic(b) and b == 0: - break - message_bytes.append(b) - - state.context['log_messages'].append((level, message_bytes)) + MCoreTest(state).api_log(level, ea) def hook(func): @@ -202,36 +207,8 @@ def done_test(_, state, state_id, reason): L.error("State {} terminated for unknown reason: {}".format( state_id, reason)) return - - test = state.context['test'] - input_length, _ = read_uint32_t(state, state.context['InputIndex']) - - # Dump out any pending log messages reported by `McTest_Log`. - for level, message_bytes in state.context['log_messages']: - message = [] - for b in message_bytes: - b_ord = state.solve_one(b) - if b_ord == 0: - break - else: - message.append(chr(b_ord)) - - LEVEL_TO_LOGGER[level]("".join(message)) - - max_length = state.context['InputEnd'] - state.context['InputBegin'] - if input_length > max_length: - L.critical("Test used too many input bytes ({} vs. {})".format( - input_length, max_length)) - return - - # Solve for the input bytes. - output = [] - for i in xrange(input_length): - b = state.cpu.read_int(state.context['InputBegin'] + i, 8) - b = state.solve_one(b) - output.append("{:2x}".format(b)) - - L.info("Input: {}".format(" ".join(output))) + mc = MCoreTest(state) + mc.report() def do_run_test(state, apis, test): @@ -241,15 +218,9 @@ def do_run_test(state, apis, test): m.verbosity(1) state = m.initial_state - messages = [(1, "Running {} from {}({})".format( - test.name, test.file_name, test.line_number))] - - state.context['InputBegin'] = apis['InputBegin'] - state.context['InputEnd'] = apis['InputEnd'] - state.context['InputIndex'] = apis['InputIndex'] - state.context['test'] = test - state.context['failed'] = 0 - state.context['log_messages'] = messages + mc = MCoreTest(state) + mc.begin_test(test) + del mc make_symbolic_input(state, apis['InputBegin'], apis['InputEnd']) @@ -258,7 +229,13 @@ def do_run_test(state, apis, test): m.add_hook(apis['Pass'], hook(hook_Pass)) m.add_hook(apis['Fail'], hook(hook_Fail)) m.add_hook(apis['SoftFail'], hook(hook_SoftFail)) + m.add_hook(apis['Abandon'], hook(hook_Abandon)) m.add_hook(apis['Log'], hook(hook_Log)) + m.add_hook(apis['StreamInt'], hook(hook_StreamInt)) + m.add_hook(apis['StreamFloat'], hook(hook_StreamFloat)) + m.add_hook(apis['StreamString'], hook(hook_StreamString)) + m.add_hook(apis['LogStream'], hook(hook_LogStream)) + m.subscribe('will_terminate_state', done_test) m.run() @@ -275,7 +252,8 @@ def run_tests(args, state, apis): """Run all of the test cases.""" pool = multiprocessing.Pool(processes=max(1, args.num_workers)) results = [] - tests = find_test_cases(state, apis['LastTestInfo']) + mc = MCoreTest(state) + tests = mc.find_test_cases() L.info("Running {} tests across {} workers".format( len(tests), args.num_workers)) @@ -313,9 +291,11 @@ def main(): setup_ea = m._get_symbol_address('McTest_Setup') setup_state = m._initial_state - ea_of_api_table = m._get_symbol_address('McTest_API') - apis = read_api_table(setup_state, ea_of_api_table) + mc = MCoreTest(setup_state) + ea_of_api_table = m._get_symbol_address('McTest_API') + apis = mc.read_api_table(ea_of_api_table) + del mc m.add_hook(setup_ea, lambda state: run_tests(args, state, apis)) m.run() diff --git a/bin/mctest/common.py b/bin/mctest/common.py new file mode 100644 index 0000000..df807cd --- /dev/null +++ b/bin/mctest/common.py @@ -0,0 +1,381 @@ +# Copyright (c) 2017 Trail of Bits, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import logging +import struct + +# Represents a TestInfo data structure (the information we know about a test.) +TestInfo = collections.namedtuple( + 'TestInfo', 'ea name file_name line_number') + + +LOG_LEVEL_DEBUG = 0 +LOG_LEVEL_INFO = 1 +LOG_LEVEL_WARNING = 2 +LOG_LEVEL_ERROR = 3 +LOG_LEVEL_FATAL = 4 + + +LOGGER = logging.getLogger("mctest") +LOGGER.setLevel(logging.DEBUG) + + +LOG_LEVEL_TO_LOGGER = { + LOG_LEVEL_DEBUG: LOGGER.debug, + LOG_LEVEL_INFO: LOGGER.info, + LOG_LEVEL_WARNING: LOGGER.warning, + LOG_LEVEL_ERROR: LOGGER.error, + LOG_LEVEL_FATAL: LOGGER.critical +} + + +class Stream(object): + def __init__(self, entries): + self.entries = entries + + +class McTest(object): + """Wrapper around a symbolic executor for making it easy to do common McTest- + specific things.""" + def __init__(self): + pass + + def get_context(self): + raise NotImplementedError("Must be implemented by engine.") + + @property + def context(self): + return self.get_context() + + def is_symbolic(self, val): + raise NotImplementedError("Must be implemented by engine.") + + def read_uintptr_t(self, ea, concretize=True, constrain=False): + raise NotImplementedError("Must be implemented by engine.") + + def read_uint64_t(self, ea, concretize=True, constrain=False): + raise NotImplementedError("Must be implemented by engine.") + + def read_uint32_t(self, ea, concretize=True, constrain=False): + raise NotImplementedError("Must be implemented by engine.") + + def read_uint8_t(self, ea, concretize=True, constrain=False): + raise NotImplementedError("Must be implemented by engine.") + + def concretize(self, val, constrain=False): + raise NotImplementedError("Must be implemented by engine.") + + def concretize_min(self, val, constrain=False): + raise NotImplementedError("Must be implemented by engine.") + + def concretize_many(self, val, max_num): + raise NotImplementedError("Must be implemented by engine.") + + def add_constraint(self, expr): + raise NotImplementedError("Must be implemented by engine.") + + def read_c_string(self, ea, concretize=True): + """Read a NUL-terminated string from `ea`.""" + assert isinstance(ea, (int, long)) + chars = [] + i = 0 + while True: + b, ea = self.read_uint8_t(ea, concretize=concretize) + if self.is_symbolic(b): + concrete_b = self.concretize_min(b) # Find the NUL byte sooner. + if not concrete_b: + break + + if concretize: + chars.append(chr(concrete_b)) + else: + chars.append(b) + + continue + + # Concretize if it's not symbolic; we might have a concrete bitvector. + b = self.concretize(b) + if not b: + break + + else: + chars.append(chr(b)) + + next_ea = ea + len(chars) + 1 + if concretize: + return "".join(chars), next_ea + else: + return chars, next_ea + + def read_test_info(self, ea): + """Read in a `McTest_TestInfo` info structure from memory.""" + prev_test_ea, ea = self.read_uintptr_t(ea) + test_func_ea, ea = self.read_uintptr_t(ea) + test_name_ea, ea = self.read_uintptr_t(ea) + file_name_ea, ea = self.read_uintptr_t(ea) + file_line_num, _ = self.read_uint32_t(ea) + + if not test_func_ea or \ + not test_name_ea or \ + not file_name_ea or \ + not file_line_num: # `__LINE__` in C always starts at `1` ;-) + return None, prev_test_ea + + test_name, _ = self.read_c_string(test_name_ea) + file_name, _ = self.read_c_string(file_name_ea) + info = TestInfo(test_func_ea, test_name, file_name, file_line_num) + return info, prev_test_ea + + def find_test_cases(self): + """Find the test case descriptors.""" + tests = [] + info_ea, _ = self.read_uintptr_t(self.context['apis']['LastTestInfo']) + while info_ea: + test, info_ea = self.read_test_info(info_ea) + if test: + tests.append(test) + tests.sort(key=lambda t: (t.file_name, t.line_number)) + return tests + + def read_api_table(self, ea): + """Reads in the API table.""" + apis = {} + while True: + api_name_ea, ea = self.read_uintptr_t(ea) + api_ea, ea = self.read_uintptr_t(ea) + if not api_name_ea or not api_ea: + break + api_name, _ = self.read_c_string(api_name_ea) + apis[api_name] = api_ea + self.context['apis'] = apis + return apis + + def begin_test(self, info): + self.context['failed'] = False + self.context['abandoned'] = False + self.context['log'] = [] + for level in LOG_LEVEL_TO_LOGGER: + self.context['stream_{}'.format(level)] = [] + self.context['info'] = info + self.log_message(LOG_LEVEL_INFO, "Running {} from {}({})".format( + info.name, info.file_name, info.line_number)) + + def log_message(self, level, message): + """Log a message.""" + assert level in LOG_LEVEL_TO_LOGGER + log = list(self.context['log']) + if isinstance(message, (str, list, tuple)): + log.append((level, Stream([(str, "%s", None, message)]))) + else: + assert isinstance(message, Stream) + log.append((level, message)) + + self.context['log'] = log + + def _concretize_bytes(self, byte_str): + new_bytes = [] + for b in byte_str: + if isinstance(b, str): + assert len(b) == 1 + new_bytes.append(ord(b)) + elif isinstance(b, (int, long)): + new_bytes.append(b) + else: + new_bytes.append(self.concretize(b, constrain=True)) + return new_bytes + + def _stream_to_message(self, stream): + assert isinstance(stream, Stream) + message = [] + for val_type, format_str, unpack_str, val_bytes in stream.entries: + val_bytes = self._concretize_bytes(val_bytes) + if val_type == str: + val = "".join(chr(b) for b in val_bytes) + elif val_type == float: + data = struct.pack('BBBBBBBB', *val_bytes) + val = struct.unpack(unpack_str, data)[0] + else: + assert val_type == int + + # TODO(pag): I am pretty sure that this is wrong for big-endian. + data = struct.pack('BBBBBBBB', *val_bytes) + val = struct.unpack(unpack_str, data[:struct.calcsize(unpack_str)])[0] + + # Remove length specifiers that are not supported. + format_str = format_str.replace('l', '') + format_str = format_str.replace('h', '') + format_str = format_str.replace('z', '') + format_str = format_str.replace('t', '') + + message.extend(format_str % val) + + return "".join(message) + + def report(self): + info = self.context['info'] + apis = self.context['apis'] + input_length, _ = self.read_uint32_t(apis['InputIndex']) + input_bytes = [] + for i in xrange(input_length): + ea = apis['InputBegin'] + i + b, _ = self.read_uint8_t(ea + i, concretize=True, constrain=True) + input_bytes.append("{:02x}".format(b)) + + for level, stream in self.context['log']: + logger = LOG_LEVEL_TO_LOGGER[level] + logger(self._stream_to_message(stream)) + + LOGGER.info("Input: {}".format(" ".join(input_bytes))) + + def pass_test(self): + pass + + def fail_test(self): + self.context['failed'] = True + + def abandon_test(self): + self.context['abandoned'] = True + + def api_is_symbolic_uint(self, arg): + """Implements the `McTest_IsSymbolicUInt` API, which returns whether or + not a given value is symbolic.""" + solutions = self.concretize_many(arg, 2) + if not solutions: + return 0 + elif 1 == len(solutions): + if self.is_symbolic(arg): + self.add_constraint(arg == solutions[0]) + return 0 + else: + return 1 + + def api_assume(self, arg): + """Implements the `McTest_Assume` API function, which injects a constraint + into the solver.""" + constraint = arg != 0 + if not self.add_constraint(constraint): + self.log_message(LOG_LEVEL_FATAL, + "Failed to add assumption {}".format(constraint)) + self.abandon_test() + + def api_pass(self): + """Implements the `McTest_Pass` API function, which marks this test as + having passed, and stops further execution.""" + if self.context['failed']: + self.api_fail() + else: + info = self.context['info'] + self.log_message(LOG_LEVEL_INFO, "Passed: {}".format(info.name)) + self.pass_test() + + def api_fail(self): + """Implements the `McTest_Fail` API function, which marks this test as + having failed, and stops further execution.""" + self.context['failed'] = True + info = self.context['info'] + self.log_message(LOG_LEVEL_ERROR, "Failed: {}".format(info.name)) + self.fail_test() + + def api_soft_fail(self): + """Implements the `McTest_SoftFail` API function, which marks this test + as having failed, but lets execution continue.""" + self.context['failed'] = True + + def api_abandon(self, arg): + """Implements the `McTest_Abandon` API function, which marks this test + as having aborted due to some unrecoverable error.""" + info = self.context['info'] + ea = self.concretize(arg, constrain=True) + self.log_message(LOG_LEVEL_FATAL, self.read_c_string(ea)[0]) + self.log_message(LOG_LEVEL_FATAL, "Abandoned: {}".format(info.name)) + self.abandon_test() + + def api_log(self, level, ea): + """Implements the `McTest_Log` API function, which prints a C string + to a specific log level.""" + self.api_log_stream(level) + + level = self.concretize(level, constrain=True) + ea = self.concretize(ea, constrain=True) + assert level in LOG_LEVEL_TO_LOGGER + self.log_message(level, self.read_c_string(ea, concretize=False)) + + if level == LOG_LEVEL_FATAL: + self.api_fail() + elif level == LOG_LEVEL_ERROR: + self.api_soft_fail() + + def _api_stream_int_float(self, level, format_ea, unpack_ea, uint64_ea, + val_type): + level = self.concretize(level, constrain=True) + assert level in LOG_LEVEL_TO_LOGGER + + format_ea = self.concretize(format_ea, constrain=True) + unpack_ea = self.concretize(unpack_ea, constrain=True) + uint64_ea = self.concretize(uint64_ea, constrain=True) + + format_str, _ = self.read_c_string(format_ea) + unpack_str, _ = self.read_c_string(unpack_ea) + uint64_bytes = [] + for i in xrange(8): + b, _ = self.read_uint8_t(uint64_ea + i, concretize=False) + uint64_bytes.append(b) + + stream_id = 'stream_{}'.format(level) + stream = list(self.context[stream_id]) + stream.append((val_type, format_str, unpack_str, uint64_bytes)) + self.context[stream_id] = stream + + def api_stream_int(self, level, format_ea, unpack_ea, uint64_ea): + """Implements the `_McTest_StreamInt`, which streams an integer into a + holding buffer for the log.""" + return self._api_stream_int_float(level, format_ea, unpack_ea, + uint64_ea, int) + + def api_stream_float(self, level, format_ea, unpack_ea, double_ea): + """Implements the `_McTest_StreamFloat`, which streams an integer into a + holding buffer for the log.""" + return self._api_stream_int_float(level, format_ea, unpack_ea, + double_ea, float) + + def api_stream_string(self, level, format_ea, str_ea): + """Implements the `_McTest_StreamString`, which streams a C-string into a + holding buffer for the log.""" + level = self.concretize(level, constrain=True) + assert level in LOG_LEVEL_TO_LOGGER + + format_ea = self.concretize(format_ea, constrain=True) + str_ea = self.concretize(str_ea, constrain=True) + format_str, _ = self.read_c_string(format_ea) + print_str, _ = self.read_c_string(str_ea, concretize=False) + + stream_id = 'stream_{}'.format(level) + stream = list(self.context[stream_id]) + stream.append((str, format_str, None, print_str)) + self.context[stream_id] = stream + + def api_log_stream(self, level): + level = self.concretize(level, constrain=True) + assert level in LOG_LEVEL_TO_LOGGER + stream_id = 'stream_{}'.format(level) + stream = self.context[stream_id] + if len(stream): + self.context[stream_id] = [] + self.log_message(level, Stream(stream)) + + if level == LOG_LEVEL_FATAL: + self.api_fail() + elif level == LOG_LEVEL_ERROR: + self.api_soft_fail() diff --git a/bin/mctest/main_angr.py b/bin/mctest/main_angr.py index b8ef318..5294e67 100644 --- a/bin/mctest/main_angr.py +++ b/bin/mctest/main_angr.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import - import angr import argparse import collections @@ -22,95 +20,114 @@ import logging import multiprocessing import sys import traceback +from .common import McTest - -L = logging.getLogger("mctest") +L = logging.getLogger("mctest.angr") L.setLevel(logging.INFO) +class AngrTest(McTest): + def __init__(self, state=None, procedure=None): + super(AngrTest, self).__init__() + if procedure: + self.procedure = procedure + self.state = procedure.state + elif state: + self.procedure = None + self.state = state + + def __del__(self): + self.procedure = None + self.state = None + + def get_context(self): + return self.state.globals + + def is_symbolic(self, val): + if isinstance(val, (int, long)): + return False + return self.state.se.symbolic(val) + + def read_uintptr_t(self, ea, concretize=True, constrain=False): + next_ea = ea + (self.state.arch.bits // 8) + val = self.state.mem[ea].uintptr_t.resolved + if concretize: + val = self.concretize(val, constrain=constrain) + return val, next_ea + + def read_uint64_t(self, ea, concretize=True, constrain=False): + val = self.state.mem[ea].uint64_t.resolved + if concretize: + val = self.concretize(val, constrain=constrain) + return val, ea + 8 + + def read_uint32_t(self, ea, concretize=True, constrain=False): + val = self.state.mem[ea].uint32_t.resolved + if concretize: + val = self.concretize(val, constrain=constrain) + return val, ea + 4 + + def read_uint8_t(self, ea, concretize=True, constrain=False): + val = self.state.mem[ea].uint8_t.resolved + if concretize: + val = self.concretize(val, constrain=constrain) + if isinstance(val, str): + assert len(val) == 1 + val = ord(val) + return val, ea + 1 + + def concretize(self, val, constrain=False): + if isinstance(val, (int, long)): + return val + elif isinstance(val, str): + assert len(val) == 1 + return ord(val[0]) + + concrete_val = self.state.solver.eval(val, cast_to=int) + if constrain: + self.add_constraint(val == concrete_val) + + return concrete_val + + def concretize_min(self, val, constrain=False): + if isinstance(val, (int, long)): + return val + concrete_val = self.state.solver.min(val, cast_to=int) + if constrain: + self.add_constraint(val == concrete_val) + return concrete_val + + def concretize_many(self, val, max_num): + assert 0 < max_num + if isinstance(val, (int, long)): + return [val] + return self.state.solver.eval_upto(val, max_num, cast_to=int) + + def add_constraint(self, expr): + if self.is_symbolic(expr): + self.state.solver.add(expr) + return self.state.solver.satisfiable() + else: + return True + + def pass_test(self): + super(AngrTest, self).pass_test() + self.procedure.exit(0) + + def fail_test(self): + super(AngrTest, self).fail_test() + self.procedure.exit(1) + + def abandon_test(self): + super(AngrTest, self).abandon_test() + self.procedure.exit(1) + + def hook_function(project, ea, cls): """Hook the function `ea` with the SimProcedure `cls`.""" project.hook(ea, cls(project=project)) -def read_c_string(state, ea): - """Read a concrete NUL-terminated string from `ea`.""" - assert isinstance(ea, (int, long)) - chars = [] - i = 0 - while True: - char = state.mem[ea + i].char.resolved - char = state.solver.eval(char, cast_to=str) - if not ord(char[0]): - break - chars.append(char) - i += 1 - return "".join(chars) - - -def read_uintptr_t(state, ea): - """Read a uintptr_t value from memory.""" - next_ea = ea + (state.arch.bits // 8) - val = state.solver.eval(state.mem[ea].uintptr_t.resolved, cast_to=int) - return val, next_ea - - -def read_uint32_t(state, ea): - """Read a uint32_t value from memory.""" - next_ea = ea + 4 - val = state.solver.eval(state.mem[ea].uint32_t.resolved, cast_to=int) - return val, next_ea - - -TestInfo = collections.namedtuple( - 'TestInfo', 'ea name file_name line_number') - - -def read_test_info(state, ea): - """Read in a `McTest_TestInfo` info structure from memory.""" - prev_test_ea, ea = read_uintptr_t(state, ea) - test_func_ea, ea = read_uintptr_t(state, ea) - test_name_ea, ea = read_uintptr_t(state, ea) - file_name_ea, ea = read_uintptr_t(state, ea) - file_line_num, _ = read_uint32_t(state, ea) - - if not test_func_ea or \ - not test_name_ea or \ - not file_name_ea or \ - not file_line_num: # `__LINE__` in C always starts at `1` ;-) - return None, prev_test_ea - - test_name = read_c_string(state, test_name_ea) - file_name = read_c_string(state, file_name_ea) - info = TestInfo(test_func_ea, test_name, file_name, file_line_num) - return info, prev_test_ea - - -def find_test_cases(state, info_ea): - """Find the test case descriptors.""" - tests = [] - info_ea, _ = read_uintptr_t(state, info_ea) - while info_ea: - test, info_ea = read_test_info(state, info_ea) - if test: - tests.append(test) - tests.sort(key=lambda t: (t.file_name, t.line_number)) - return tests - - -def read_api_table(state, ea): - """Reads in the API table.""" - apis = {} - while True: - api_name_ea, ea = read_uintptr_t(state, ea) - api_ea, ea = read_uintptr_t(state, ea) - if not api_name_ea or not api_ea: - break - api_name = read_c_string(state, api_name_ea) - apis[api_name] = api_ea - return apis - - def make_symbolic_input(state, input_begin_ea, input_end_ea): """Fill in the input data array with symbolic data.""" input_size = input_end_ea - input_begin_ea @@ -123,132 +140,73 @@ class IsSymbolicUInt(angr.SimProcedure): """Implements McTest_IsSymblicUInt, which returns 1 if its input argument has more then one solutions, and zero otherwise.""" def run(self, arg): - solutions = self.state.solver.eval_upto(arg, 2) - if not solutions: - return 0 - elif 1 == len(solutions): - if self.state.se.symbolic(arg): - self.state.solver.add(arg == solutions[0]) - return 0 - else: - return 1 + return AngrTest(procedure=self).api_is_symbolic_uint(arg) class Assume(angr.SimProcedure): """Implements _McTest_Assume, which tries to inject a constraint.""" def run(self, arg): - constraint = arg != 0 - self.state.solver.add(constraint) - if not self.state.solver.satisfiable(): - L.error("Failed to assert assumption {}".format(constraint)) - self.exit(2) - - -def report_state(state): - test = state.globals['test'] - if state.globals['failed']: - add_log_message(state, 3, "Failed: {}".format(test.name)) - else: - add_log_message(state, 1, "Passed: {}".format(test.name)) + AngrTest(procedure=self).api_assume(arg) class Pass(angr.SimProcedure): """Implements McTest_Pass, which notifies us of a passing test.""" def run(self): - report_state(self.state) - self.exit(self.state.globals['failed']) + AngrTest(procedure=self).api_pass() class Fail(angr.SimProcedure): - """Implements McTest_Fail, which notifies us of a passing test.""" + """Implements McTest_Fail, which notifies us of a failing test.""" def run(self): - self.state.globals['failed'] = 1 - report_state(self.state) - self.exit(1) + AngrTest(procedure=self).api_fail() + + +class Abandon(angr.SimProcedure): + """Implements McTest_Fail, which notifies us of a failing test.""" + def run(self, reason): + AngrTest(procedure=self).api_abandon(reason) class SoftFail(angr.SimProcedure): - """Implements McTest_SoftFail, which notifies us of a passing test.""" + """Implements McTest_SoftFail, which notifies us of a failing test.""" def run(self): - self.state.globals['failed'] = 1 + AngrTest(procedure=self).api_soft_fail() -def add_log_message(state, level, message): - """Add a log message to a state.""" - messages = list(state.globals['log_messages']) - messages.append((level, message)) - state.globals['log_messages'] = messages +class StreamInt(angr.SimProcedure): + """Implements _McTest_StreamInt, which gives us an integer to stream, and + the format to use for streaming.""" + def run(self, level, format_ea, unpack_ea, uint64_ea): + AngrTest(procedure=self).api_stream_int(level, format_ea, unpack_ea, + uint64_ea) + +class StreamFloat(angr.SimProcedure): + """Implements _McTest_StreamFloat, which gives us an double to stream, and + the format to use for streaming.""" + def run(self, level, format_ea, unpack_ea, double_ea): + AngrTest(procedure=self).api_stream_float(level, format_ea, unpack_ea, + double_ea) -LEVEL_TO_LOGGER = { - 0: L.debug, - 1: L.info, - 2: L.warning, - 3: L.error, - 4: L.critical -} +class StreamString(angr.SimProcedure): + """Implements _McTest_StreamString, which gives us an double to stream, and + the format to use for streaming.""" + def run(self, level, format_ea, str_ea): + AngrTest(procedure=self).api_stream_string(level, format_ea, str_ea) + + +class LogStream(angr.SimProcedure): + """Implements McTest_LogStream, which converts the contents of a stream for + level `level` into a log for level `level`.""" + def run(self, level): + AngrTest(procedure=self).api_log_stream(level) class Log(angr.SimProcedure): """Implements McTest_Log, which lets Angr intercept and handle the printing of log messages from the simulated tests.""" - def run(self, level, begin_ea_): - level = self.state.solver.eval(level, cast_to=int) - assert level in LEVEL_TO_LOGGER - - begin_ea = self.state.solver.eval(begin_ea_, cast_to=int) - if self.state.se.symbolic(begin_ea_): - self.state.solver.add(begin_ea_ == begin_ea) - - data = [] - for i in xrange(4096): - b = self.state.memory.load(begin_ea + i, size=1) - solutions = self.state.solver.eval_upto(b, 2, cast_to=int) - if 1 == len(solutions) and solutions[0] == 0: - break - data.append(b) - - # Deep copy the message. - add_log_message(self.state, level, data) - - if 3 == level: - self.state.globals['failed'] = 1 # Soft failure on an error log message. - elif 4 == level: - self.state.globals['failed'] = 1 - report_state(self.state) - self.exit(1) - - -def done_test(state): - test = state.globals['test'] - input_length, _ = read_uint32_t(state, state.globals['InputIndex']) - - # Dump out any pending log messages reported by `McTest_Log`. - for level, message in state.globals['log_messages']: - data = [] - for b in message: - if not isinstance(b, str): - b = state.solver.eval(b, cast_to=str) - if not ord(b): - break - data.append(b) - - LEVEL_TO_LOGGER[level]("".join(data)) - - max_length = state.globals['InputEnd'] - state.globals['InputBegin'] - if input_length > max_length: - L.critical("Test used too many input bytes ({} vs. {})".format( - input_length, max_length)) - return - - # Solve for the input bytes. - output = [] - data = state.memory.load(state.globals['InputBegin'], size=input_length) - data = state.solver.eval(data, cast_to=str) - for i in xrange(input_length): - output.append("{:2x}".format(ord(data[i]))) - - L.info("Input: {}".format(" ".join(output))) + def run(self, level, ea): + AngrTest(procedure=self).api_log(level, ea) def do_run_test(project, test, apis, run_state): @@ -258,15 +216,9 @@ def do_run_test(project, test, apis, run_state): test.ea, base_state=run_state) - messages = [(1, "Running {} from {}({})".format( - test.name, test.file_name, test.line_number))] - - test_state.globals['InputBegin'] = apis['InputBegin'] - test_state.globals['InputEnd'] = apis['InputEnd'] - test_state.globals['InputIndex'] = apis['InputIndex'] - test_state.globals['test'] = test - test_state.globals['failed'] = 0 - test_state.globals['log_messages'] = messages + mc = AngrTest(state=test_state) + mc.begin_test(test) + del mc make_symbolic_input(test_state, apis['InputBegin'], apis['InputEnd']) @@ -283,8 +235,10 @@ def do_run_test(project, test, apis, run_state): sys.exc_info()[0], traceback.format_exc())) for state in test_manager.deadended: - done_test(state) + AngrTest(state=state).report() + for state in test_manager.errored: + print "ErrorL", state.error def run_test(project, test, apis, run_state): """Symbolically executes a single test function.""" @@ -297,7 +251,6 @@ def run_test(project, test, apis, run_state): def main(): """Run McTest.""" - parser = argparse.ArgumentParser( description="Symbolically execute unit tests with Angr") @@ -330,24 +283,32 @@ def main(): setup_ea = project.kb.labels.lookup('McTest_Setup') concrete_manager.explore(find=setup_ea) run_state = concrete_manager.found[0] - + # Read the API table, which will tell us about the location of various # symbols. Technically we can look these up with the `labels.lookup` API, # but we have the API table for Manticore-compatibility, so we may as well # use it. ea_of_api_table = project.kb.labels.lookup('McTest_API') - apis = read_api_table(run_state, ea_of_api_table) + + mc = AngrTest(state=run_state) + apis = mc.read_api_table(ea_of_api_table) # Hook various functions. hook_function(project, apis['IsSymbolicUInt'], IsSymbolicUInt) hook_function(project, apis['Assume'], Assume) hook_function(project, apis['Pass'], Pass) hook_function(project, apis['Fail'], Fail) + hook_function(project, apis['Abandon'], Abandon) hook_function(project, apis['SoftFail'], SoftFail) hook_function(project, apis['Log'], Log) + hook_function(project, apis['StreamInt'], StreamInt) + hook_function(project, apis['StreamFloat'], StreamFloat) + hook_function(project, apis['StreamString'], StreamString) + hook_function(project, apis['LogStream'], LogStream) # Find the test cases that we want to run. - tests = find_test_cases(run_state, apis['LastTestInfo']) + tests = mc.find_test_cases() + del mc L.info("Running {} tests across {} workers".format( len(tests), args.num_workers)) diff --git a/examples/ArithmeticProperties.cpp b/examples/ArithmeticProperties.cpp index 59f3557..e65902c 100644 --- a/examples/ArithmeticProperties.cpp +++ b/examples/ArithmeticProperties.cpp @@ -39,7 +39,8 @@ TEST(Arithmetic, AdditionIsAssociative) { TEST(Arithmetic, InvertibleMultiplication_CanFail) { ForAll([] (int x, int y) { - ASSUME_NE(y, 0); + ASSUME_NE(y, 0) + << "Assumed non-zero value for y: " << y; ASSERT_EQ(x, (x / y) * y) << x << " != (" << x << " / " << y << ") * " << y; }); diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e9cabff..09ce33c 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -21,3 +21,5 @@ target_link_libraries(ArithmeticProperties mctest) add_executable(Lists Lists.cpp) target_link_libraries(Lists mctest) +add_executable(StreamingAndFormatting StreamingAndFormatting.cpp) +target_link_libraries(StreamingAndFormatting mctest) diff --git a/examples/StreamingAndFormatting.cpp b/examples/StreamingAndFormatting.cpp new file mode 100644 index 0000000..3698ea0 --- /dev/null +++ b/examples/StreamingAndFormatting.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2017 Trail of Bits, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +TEST(Streaming, BasicLevels) { + LOG(DEBUG) << "This is a debug message"; + LOG(INFO) << "This is an info message"; + LOG(WARNING) << "This is a warning message"; + LOG(ERROR) << "This is a error message"; + LOG(INFO) << "This is a info message again"; + ASSERT(true) << "This should not be printed."; +} + +TEST(Streaming, BasicTypes) { + LOG(INFO) << 'a'; + LOG(INFO) << 1; + LOG(INFO) << 1.0; + LOG(INFO) << "string"; + LOG(INFO) << nullptr; +} + +TEST(Formatting, OverridePrintf) { + printf("hello string=%s hex_lower=%x hex_upper=%X octal=%o char=%c dec=%d" + "double=%f sci=%e SCI=%E pointer=%p", + "world", 999, 999, 999, 'a', 999, 999.0, 999.0, 999.0, "world"); +} + +int main(void) { + return McTest_Run(); +} diff --git a/src/include/mctest/Compiler.h b/src/include/mctest/Compiler.h index 5e420b1..210cab1 100644 --- a/src/include/mctest/Compiler.h +++ b/src/include/mctest/Compiler.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef INCLUDE_MCTEST_COMPILER_H_ -#define INCLUDE_MCTEST_COMPILER_H_ +#ifndef SRC_INCLUDE_MCTEST_COMPILER_H_ +#define SRC_INCLUDE_MCTEST_COMPILER_H_ #include #include @@ -93,4 +93,4 @@ static void f(void) #endif -#endif /* INCLUDE_MCTEST_COMPILER_H_ */ +#endif /* SRC_INCLUDE_MCTEST_COMPILER_H_ */ diff --git a/src/include/mctest/Log.h b/src/include/mctest/Log.h new file mode 100644 index 0000000..43b8888 --- /dev/null +++ b/src/include/mctest/Log.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2017 Trail of Bits, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SRC_INCLUDE_MCTEST_LOG_H_ +#define SRC_INCLUDE_MCTEST_LOG_H_ + +#include + +#include + +MCTEST_BEGIN_EXTERN_C + +struct McTest_Stream; + +enum McTest_LogLevel { + McTest_LogDebug = 0, + McTest_LogInfo = 1, + McTest_LogWarning = 2, + McTest_LogWarn = McTest_LogWarning, + McTest_LogError = 3, + McTest_LogFatal = 4, + McTest_LogCritical = McTest_LogFatal +}; + +/* Log a C string. */ +extern void McTest_Log(enum McTest_LogLevel level, const char *str); + +/* Log some formatted output. */ +extern void McTest_LogFormat(enum McTest_LogLevel level, + const char *format, ...); + +/* Log some formatted output. */ +extern void McTest_LogVFormat(enum McTest_LogLevel level, + const char *format, va_list args); + +MCTEST_END_EXTERN_C + +#endif /* SRC_INCLUDE_MCTEST_LOG_H_ */ diff --git a/src/include/mctest/McTest.h b/src/include/mctest/McTest.h index 8cdc192..06db23f 100644 --- a/src/include/mctest/McTest.h +++ b/src/include/mctest/McTest.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef INCLUDE_MCTEST_MCTEST_H_ -#define INCLUDE_MCTEST_MCTEST_H_ +#ifndef SRC_INCLUDE_MCTEST_MCTEST_H_ +#define SRC_INCLUDE_MCTEST_MCTEST_H_ #include #include @@ -25,7 +25,9 @@ #include #include +#include #include +#include #ifdef assert # undef assert @@ -54,6 +56,12 @@ extern int McTest_IsTrue(int expr); /* Symbolize the data in the range `[begin, end)`. */ extern void McTest_SymbolizeData(void *begin, void *end); +/* Concretize some data i nthe range `[begin, end)`. */ +extern void McTest_ConcretizeData(void *begin, void *end); + +/* Concretize a C string */ +extern void McTest_ConcretizeCStr(const char *begin); + MCTEST_INLINE static void *McTest_Malloc(size_t num_bytes) { void *data = malloc(num_bytes); uintptr_t data_end = ((uintptr_t) data) + num_bytes; @@ -96,6 +104,10 @@ extern void _McTest_Assume(int expr); #define McTest_Assume(x) _McTest_Assume(!!(x)) +/* Abandon this test. We've hit some kind of internal problem. */ +MCTEST_NORETURN +extern void McTest_Abandon(const char *reason); + MCTEST_NORETURN extern void McTest_Fail(void); @@ -121,20 +133,6 @@ MCTEST_INLINE static void McTest_Check(int expr) { } } -enum McTest_LogLevel { - McTest_LogDebug = 0, - McTest_LogInfo = 1, - McTest_LogWarning = 2, - McTest_LogWarn = McTest_LogWarning, - McTest_LogError = 3, - McTest_LogFatal = 4, - McTest_LogCritical = McTest_LogFatal -}; - -/* Outputs information to a log, using a specific log level. */ -extern void McTest_Log(enum McTest_LogLevel level, const char *begin, - const char *end); - /* Return a symbolic value in a the range `[low_inc, high_inc]`. */ #define MCTEST_MAKE_SYMBOLIC_RANGE(Tname, tname) \ MCTEST_INLINE static tname McTest_ ## Tname ## InRange( \ @@ -248,12 +246,21 @@ extern struct McTest_TestInfo *McTest_LastTestInfo; /* Set up McTest. */ extern void McTest_Setup(void); +/* Tear down McTest. */ +extern void McTest_Teardown(void); + +/* Notify that we're about to begin a test. */ +extern void McTest_Begin(struct McTest_TestInfo *info); + /* Return the first test case to run. */ extern struct McTest_TestInfo *McTest_FirstTest(void); /* Returns 1 if a failure was caught, otherwise 0. */ extern int McTest_CatchFail(void); +/* Returns 1 if this test case was abandoned. */ +extern int McTest_CatchAbandoned(void); + /* Jump buffer for returning to `McTest_Run`. */ extern jmp_buf McTest_ReturnToRun; @@ -261,18 +268,11 @@ extern jmp_buf McTest_ReturnToRun; static int McTest_Run(void) { int num_failed_tests = 0; struct McTest_TestInfo *test = NULL; - char buff[1024]; - int num_buff_bytes_used = 0; McTest_Setup(); for (test = McTest_FirstTest(); test != NULL; test = test->prev) { - - /* Print the test that we're going to run. */ - num_buff_bytes_used = sprintf(buff, "Running: %s from %s(%u)", - test->test_name, test->file_name, - test->line_number); - McTest_Log(McTest_LogInfo, buff, &(buff[num_buff_bytes_used])); + McTest_Begin(test); /* Run the test. */ if (!setjmp(McTest_ReturnToRun)) { @@ -291,23 +291,28 @@ static int McTest_Run(void) { } #endif /* __cplusplus */ + /* We caught a failure when running the test. */ } else if (McTest_CatchFail()) { ++num_failed_tests; + McTest_LogFormat(McTest_LogError, "Failed: %s", test->test_name); - num_buff_bytes_used = sprintf(buff, "Failed: %s", test->test_name); - McTest_Log(McTest_LogInfo, buff, &(buff[num_buff_bytes_used])); + /* The test was abandoned. We may have gotten soft failures before + * abandoning, so we prefer to catch those first. */ + } else if (McTest_CatchAbandoned()) { + McTest_LogFormat(McTest_LogFatal, "Abandoned: %s", test->test_name); /* The test passed. */ } else { - num_buff_bytes_used = sprintf(buff, "Passed: %s", test->test_name); - McTest_Log(McTest_LogInfo, buff, &(buff[num_buff_bytes_used])); + McTest_LogFormat(McTest_LogInfo, "Passed: %s", test->test_name); } } + McTest_Teardown(); + return num_failed_tests; } MCTEST_END_EXTERN_C -#endif /* INCLUDE_MCTEST_MCTEST_H_ */ +#endif /* SRC_INCLUDE_MCTEST_MCTEST_H_ */ diff --git a/src/include/mctest/McTest.hpp b/src/include/mctest/McTest.hpp index 2fb96dc..ab23539 100644 --- a/src/include/mctest/McTest.hpp +++ b/src/include/mctest/McTest.hpp @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef INCLUDE_MCTEST_MCTEST_HPP_ -#define INCLUDE_MCTEST_MCTEST_HPP_ +#ifndef SRC_INCLUDE_MCTEST_MCTEST_HPP_ +#define SRC_INCLUDE_MCTEST_MCTEST_HPP_ #include @@ -191,4 +191,4 @@ MAKE_SYMBOL_SPECIALIZATION(Char, int8_t) } // namespace mctest -#endif // INCLUDE_MCTEST_MCTEST_HPP_ +#endif // SRC_INCLUDE_MCTEST_MCTEST_HPP_ diff --git a/src/include/mctest/McUnit.hpp b/src/include/mctest/McUnit.hpp index c0a96b2..8a605d5 100644 --- a/src/include/mctest/McUnit.hpp +++ b/src/include/mctest/McUnit.hpp @@ -14,59 +14,44 @@ * limitations under the License. */ -#ifndef INCLUDE_MCTEST_MCUNIT_HPP_ -#define INCLUDE_MCTEST_MCUNIT_HPP_ +#ifndef SRC_INCLUDE_MCTEST_MCUNIT_HPP_ +#define SRC_INCLUDE_MCTEST_MCUNIT_HPP_ #include - -#include +#include #define TEST(category, name) \ McTest_EntryPoint(category ## _ ## name) -namespace mctest { +#define LOG_DEBUG(cond) \ + ::mctest::Stream(McTest_LogDebug, (cond), __FILE__, __LINE__) -/* Base logger */ -class Logger { - public: - MCTEST_INLINE Logger(McTest_LogLevel level_, bool expr_, - const char *file_, unsigned line_) - : level(level_), - expr(!!McTest_IsTrue(expr_)), - file(file_), - line(line_) {} +#define LOG_INFO(cond) \ + ::mctest::Stream(McTest_LogInfo, (cond), __FILE__, __LINE__) - MCTEST_INLINE ~Logger(void) { - if (!expr) { - std::stringstream report_ss; - report_ss << file << "(" << line << "): " << ss.str(); - auto report_str = report_ss.str(); - auto report_c_str = report_str.c_str(); - McTest_Log(level, report_c_str, report_c_str + report_str.size()); - } - } +#define LOG_WARNING(cond) \ + ::mctest::Stream(McTest_LogWarning, (cond), __FILE__, __LINE__) - MCTEST_INLINE std::stringstream &stream(void) { - return ss; - } +#define LOG_WARN(cond) \ + ::mctest::Stream(McTest_LogWarning, (cond), __FILE__, __LINE__) - private: - Logger(void) = delete; - Logger(const Logger &) = delete; - Logger &operator=(const Logger &) = delete; +#define LOG_ERROR(cond) \ + ::mctest::Stream(McTest_LogError, (cond), __FILE__, __LINE__) - const McTest_LogLevel level; - const bool expr; - const char * const file; - const unsigned line; - std::stringstream ss; -}; +#define LOG_FATAL(cond) \ + ::mctest::Stream(McTest_LogFatal, (cond), __FILE__, __LINE__) + +#define LOG_CRITICAl(cond) \ + ::mctest::Stream(McTest_LogFatal, (cond), __FILE__, __LINE__) + +#define LOG(LEVEL) LOG_ ## LEVEL(true) + +#define LOG_IF(LEVEL, cond) LOG_ ## LEVEL(cond) -} // namespace mctest #define MCTEST_LOG_BINOP(a, b, op, level) \ - ::mctest::Logger( \ - level, ((a) op (b)), __FILE__, __LINE__).stream() + ::mctest::Stream( \ + level, !((a) op (b)), __FILE__, __LINE__) #define ASSERT_EQ(a, b) MCTEST_LOG_BINOP(a, b, ==, McTest_LogFatal) #define ASSERT_NE(a, b) MCTEST_LOG_BINOP(a, b, !=, McTest_LogFatal) @@ -83,27 +68,26 @@ class Logger { #define CHECK_GE(a, b) MCTEST_LOG_BINOP(a, b, >=, McTest_LogError) #define ASSERT(expr) \ - ::mctest::Logger( \ - McTest_LogFatal, !!(expr), __FILE__, __LINE__).stream() + ::mctest::Stream( \ + McTest_LogFatal, !(expr), __FILE__, __LINE__) #define ASSERT_TRUE ASSERT #define ASSERT_FALSE(expr) ASSERT(!(expr)) #define CHECK(expr) \ - ::mctest::Logger( \ - McTest_LogError, !!(expr), __FILE__, __LINE__).stream() + ::mctest::Stream( \ + McTest_LogError, !(expr), __FILE__, __LINE__) #define CHECK_TRUE CHECK #define CHECK_FALSE(expr) CHECK(!(expr)) #define ASSUME(expr) \ - McTest_Assume(expr), ::mctest::Logger( \ - McTest_LogInfo, false, __FILE__, __LINE__).stream() - + McTest_Assume(expr), ::mctest::Stream( \ + McTest_LogInfo, true, __FILE__, __LINE__) #define MCTEST_ASSUME_BINOP(a, b, op) \ - McTest_Assume(((a) op (b))), ::mctest::Logger( \ - McTest_LogInfo, false, __FILE__, __LINE__).stream() + McTest_Assume(((a) op (b))), ::mctest::Stream( \ + McTest_LogInfo, true, __FILE__, __LINE__) #define ASSUME_EQ(a, b) MCTEST_ASSUME_BINOP(a, b, ==) #define ASSUME_NE(a, b) MCTEST_ASSUME_BINOP(a, b, !=) @@ -112,4 +96,4 @@ class Logger { #define ASSUME_GT(a, b) MCTEST_ASSUME_BINOP(a, b, >) #define ASSUME_GE(a, b) MCTEST_ASSUME_BINOP(a, b, >=) -#endif // INCLUDE_MCTEST_MCUNIT_HPP_ +#endif // SRC_INCLUDE_MCTEST_MCUNIT_HPP_ diff --git a/src/include/mctest/Quantified.hpp b/src/include/mctest/Quantified.hpp index 543fbf6..bc5d372 100644 --- a/src/include/mctest/Quantified.hpp +++ b/src/include/mctest/Quantified.hpp @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef INCLUDE_MCTEST_QUANTIFIED_HPP_ -#define INCLUDE_MCTEST_QUANTIFIED_HPP_ +#ifndef SRC_INCLUDE_MCTEST_QUANTIFIED_HPP_ +#define SRC_INCLUDE_MCTEST_QUANTIFIED_HPP_ #include @@ -35,4 +35,4 @@ inline static void ForAll(Closure func) { } // namespace mctest -#endif // INCLUDE_MCTEST_QUANTIFIED_HPP_ +#endif // SRC_INCLUDE_MCTEST_QUANTIFIED_HPP_ diff --git a/src/include/mctest/Stream.h b/src/include/mctest/Stream.h new file mode 100644 index 0000000..52560ac --- /dev/null +++ b/src/include/mctest/Stream.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2017 Trail of Bits, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SRC_INCLUDE_MCTEST_STREAM_H_ +#define SRC_INCLUDE_MCTEST_STREAM_H_ + +#include +#include + +#include +#include + +MCTEST_BEGIN_EXTERN_C + +/* Flush the contents of the stream to a log. */ +extern void McTest_LogStream(enum McTest_LogLevel level); + +/* Stream a C string into the stream's message. */ +extern void McTest_StreamCStr(enum McTest_LogLevel level, const char *begin); + +/* TODO(pag): Implement `McTest_StreamWCStr` with `wchar_t`. */ + +/* Stream a some data in the inclusive range `[begin, end]` into the + * stream's message. */ +/*extern void McTest_StreamData(enum McTest_LogLevel level, const void *begin, + const void *end);*/ + +/* Stream some formatted input */ +extern void McTest_StreamFormat(enum McTest_LogLevel level, const char *format, + ...); + +/* Stream some formatted input */ +extern void McTest_StreamVFormat(enum McTest_LogLevel level, const char *format, + va_list args); + +#define MCTEST_DECLARE_STREAMER(Type, type) \ + extern void McTest_Stream ## Type(enum McTest_LogLevel level, type val); + +MCTEST_DECLARE_STREAMER(Double, double); +MCTEST_DECLARE_STREAMER(Pointer, void *); + +MCTEST_DECLARE_STREAMER(UInt64, uint64_t) +MCTEST_DECLARE_STREAMER(Int64, int64_t) + +MCTEST_DECLARE_STREAMER(UInt32, uint32_t) +MCTEST_DECLARE_STREAMER(Int32, int32_t) + +MCTEST_DECLARE_STREAMER(UInt16, uint16_t) +MCTEST_DECLARE_STREAMER(Int16, int16_t) + +MCTEST_DECLARE_STREAMER(UInt8, uint8_t) +MCTEST_DECLARE_STREAMER(Int8, int8_t) + +#undef MCTEST_DECLARE_STREAMER + +MCTEST_INLINE static void McTest_StreamFloat(enum McTest_LogLevel level, + float val) { + McTest_StreamDouble(level, (double) val); +} + +/* Reset the formatting in a stream. */ +extern void McTest_StreamResetFormatting(enum McTest_LogLevel level); + +MCTEST_END_EXTERN_C + +#endif /* SRC_INCLUDE_MCTEST_STREAM_H_ */ diff --git a/src/include/mctest/Stream.hpp b/src/include/mctest/Stream.hpp new file mode 100644 index 0000000..54213c6 --- /dev/null +++ b/src/include/mctest/Stream.hpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2017 Trail of Bits, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SRC_INCLUDE_MCTEST_STREAM_HPP_ +#define SRC_INCLUDE_MCTEST_STREAM_HPP_ + +#include +#include + +#include +#include + +namespace mctest { + +/* Conditionally stream output to a log using the streaming APIs. */ +class Stream { + public: + MCTEST_INLINE Stream(McTest_LogLevel level_, bool do_log_, + const char *file, unsigned line) + : level(level_), + do_log(McTest_IsTrue(do_log_)) { + McTest_LogStream(level); + if (do_log) { + McTest_StreamFormat(level, "%s(%u): ", file, line); + } + } + + MCTEST_INLINE ~Stream(void) { + if (do_log) { + McTest_LogStream(level); + } + } + +#define MCTEST_DEFINE_STREAMER(Type, type, expr) \ + MCTEST_INLINE const Stream &operator<<(type val) const { \ + if (do_log) { \ + McTest_Stream ## Type(level, expr); \ + } \ + return *this; \ + } + + MCTEST_DEFINE_STREAMER(UInt64, uint64_t, val) + MCTEST_DEFINE_STREAMER(Int64, int64_t, val) + + MCTEST_DEFINE_STREAMER(UInt32, uint32_t, val) + MCTEST_DEFINE_STREAMER(Int32, int32_t, val) + + MCTEST_DEFINE_STREAMER(UInt16, uint16_t, val) + MCTEST_DEFINE_STREAMER(Int16, int16_t, val) + + MCTEST_DEFINE_STREAMER(UInt8, uint8_t, val) + MCTEST_DEFINE_STREAMER(Int8, int8_t, val) + + MCTEST_DEFINE_STREAMER(Float, float, val) + MCTEST_DEFINE_STREAMER(Double, double, val) + + MCTEST_DEFINE_STREAMER(CStr, const char *, val) + MCTEST_DEFINE_STREAMER(CStr, char *, const_cast(val)) + + MCTEST_DEFINE_STREAMER(Pointer, nullptr_t, nullptr) + + template + MCTEST_DEFINE_STREAMER(Pointer, T *, val); + + template + MCTEST_DEFINE_STREAMER(Pointer, const T *, const_cast(val)); + +#undef MCTEST_DEFINE_INT_STREAMER + + MCTEST_INLINE const Stream &operator<<(const std::string &str) const { + if (do_log && !str.empty()) { + McTest_StreamCStr(level, str.c_str()); + } + return *this; + } + + // TODO(pag): Implement a `std::wstring` streamer. + + private: + Stream(void) = delete; + Stream(const Stream &) = delete; + Stream &operator=(const Stream &) = delete; + + const McTest_LogLevel level; + const int do_log; +}; + +} // namespace mctest + +#endif // SRC_INCLUDE_MCTEST_STREAM_HPP_ diff --git a/src/lib/Log.c b/src/lib/Log.c new file mode 100644 index 0000000..aa739ae --- /dev/null +++ b/src/lib/Log.c @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2017 Trail of Bits, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include + +MCTEST_BEGIN_EXTERN_C + +/* Returns a printable string version of the log level. */ +static const char *McTest_LogLevelStr(enum McTest_LogLevel level) { + switch (level) { + case McTest_LogDebug: + return "DEBUG"; + case McTest_LogInfo: + return "INFO"; + case McTest_LogWarning: + return "WARNING"; + case McTest_LogError: + return "ERROR"; + case McTest_LogFatal: + return "FATAL"; + default: + return "UNKNOWN"; + } +} + +enum { + McTest_LogBufSize = 4096 +}; + +char McTest_LogBuf[McTest_LogBufSize + 1] = {}; + +/* Log a C string. */ +void McTest_Log(enum McTest_LogLevel level, const char *str) { + memset(McTest_LogBuf, 0, McTest_LogBufSize); + snprintf(McTest_LogBuf, McTest_LogBufSize, "%s: %s", + McTest_LogLevelStr(level), str); + fputs(McTest_LogBuf, stderr); + + if (McTest_LogError == level) { + McTest_SoftFail(); + } else if (McTest_LogFatal == level) { + McTest_Fail(); + } +} + +/* Log some formatted output. */ +void McTest_LogFormat(enum McTest_LogLevel level, const char *format, ...) { + McTest_LogStream(level); + va_list args; + va_start(args, format); + McTest_StreamVFormat(level, format, args); + va_end(args); + McTest_LogStream(level); +} + +/* Log some formatted output. */ +void McTest_LogVFormat(enum McTest_LogLevel level, + const char *format, va_list args) { + McTest_LogStream(level); + McTest_StreamVFormat(level, format, args); + McTest_LogStream(level); +} + +/* Override libc! */ +int printf(const char *format, ...) { + McTest_LogStream(McTest_LogInfo); + va_list args; + va_start(args, format); + McTest_StreamVFormat(McTest_LogInfo, format, args); + va_end(args); + McTest_LogStream(McTest_LogInfo); + return 0; +} + +int fprintf(FILE *file, const char *format, ...) { + enum McTest_LogLevel level = McTest_LogInfo; + if (stderr == file) { + level = McTest_LogDebug; + } else if (stdout != file) { + return 0; /* TODO(pag): This is probably evil. */ + } + + McTest_LogStream(level); + va_list args; + va_start(args, format); + McTest_StreamVFormat(level, format, args); + va_end(args); + McTest_LogStream(level); + return 0; +} + +MCTEST_END_EXTERN_C diff --git a/src/lib/McTest.c b/src/lib/McTest.c index a19907c..d25343e 100644 --- a/src/lib/McTest.c +++ b/src/lib/McTest.c @@ -14,17 +14,13 @@ * limitations under the License. */ +#include #include #include #include #include -#if defined(unix) || defined(__unix) || defined(__unix__) -# define _GNU_SOURCE -# include /* For `syscall` */ -#endif - MCTEST_BEGIN_EXTERN_C /* Pointer to the last registers McTest_TestInfo data structure */ @@ -45,8 +41,17 @@ static uint32_t McTest_InputIndex = 0; /* Jump buffer for returning to `McTest_Run`. */ jmp_buf McTest_ReturnToRun = {}; +static const char *McTest_TestAbandoned = NULL; static int McTest_TestFailed = 0; +/* Abandon this test. We've hit some kind of internal problem. */ +MCTEST_NORETURN +void McTest_Abandon(const char *reason) { + McTest_Log(McTest_LogFatal, reason); + McTest_TestAbandoned = reason; + longjmp(McTest_ReturnToRun, 1); +} + /* Mark this test as failing. */ MCTEST_NORETURN void McTest_Fail(void) { @@ -80,6 +85,16 @@ void McTest_SymbolizeData(void *begin, void *end) { } } +/* Concretize some data i nthe range `[begin, end)`. */ +void McTest_ConcretizeData(void *begin, void *end) { + (void) begin; +} + +/* Concretize a C string */ +void McTest_ConcretizeCStr(const char *begin) { + (void) begin; +} + MCTEST_NOINLINE int McTest_One(void) { return 1; } @@ -148,62 +163,15 @@ int McTest_IsSymbolicUInt(uint32_t x) { return 0; } -/* Returns a printable string version of the log level. */ -static const char *McTest_LogLevelStr(enum McTest_LogLevel level) { - switch (level) { - case McTest_LogDebug: - return "DEBUG"; - case McTest_LogInfo: - return "INFO"; - case McTest_LogWarning: - return "WARNING"; - case McTest_LogError: - return "ERROR"; - case McTest_LogFatal: - return "FATAL"; - default: - return "UNKNOWN"; - } -} +/* Defined in Stream.c */ +extern void _McTest_StreamInt(enum McTest_LogLevel level, const char *format, + const char *unpack, uint64_t *val); -enum { - McTest_LogBufSize = 4096 -}; +extern void _McTest_StreamFloat(enum McTest_LogLevel level, const char *format, + const char *unpack, double *val); -char McTest_LogBuf[McTest_LogBufSize + 1] = {}; - - -void _McTest_Log(enum McTest_LogLevel level, const char *message) { - fprintf(stderr, "%s: %s\n", McTest_LogLevelStr(level), message); - if (McTest_LogError == level) { - McTest_SoftFail(); - - } else if (McTest_LogFatal == level) { - McTest_Fail(); - } -} - -/* Outputs information to a log, using a specific log level. */ -void McTest_Log(enum McTest_LogLevel level, const char *begin, - const char *end) { - if (end <= begin) { - return; - } - - size_t size = (size_t) (end - begin); - if (size > McTest_LogBufSize) { - size = McTest_LogBufSize; - } - - /* When we interpose on _McTest_Log, we are looking for the first non-symbolic - * zero byte as our end of string character, so we want to guarantee that we - * have a bunch of those */ - memset(McTest_LogBuf, 0, McTest_LogBufSize); - memcpy(McTest_LogBuf, begin, size); - McTest_LogBuf[McTest_LogBufSize] = '\0'; - - _McTest_Log(level, McTest_LogBuf); -} +extern void _McTest_StreamString(enum McTest_LogLevel level, const char *format, + const char *str); /* A McTest-specific symbol that is needed for hooking. */ struct McTest_IndexEntry { @@ -214,16 +182,36 @@ struct McTest_IndexEntry { /* An index of symbols that the symbolic executors will hook or * need access to. */ const struct McTest_IndexEntry McTest_API[] = { + + /* Control-flow during the test. */ {"Pass", (void *) McTest_Pass}, {"Fail", (void *) McTest_Fail}, {"SoftFail", (void *) McTest_SoftFail}, - {"Log", (void *) _McTest_Log}, - {"Assume", (void *) _McTest_Assume}, - {"IsSymbolicUInt", (void *) McTest_IsSymbolicUInt}, + {"Abandon", (void *) McTest_Abandon}, + + /* Locating the tests. */ + {"LastTestInfo", (void *) &McTest_LastTestInfo}, + + /* Source of symbolic bytes. */ {"InputBegin", (void *) &(McTest_Input[0])}, {"InputEnd", (void *) &(McTest_Input[McTest_InputLength])}, {"InputIndex", (void *) &McTest_InputIndex}, - {"LastTestInfo", (void *) &McTest_LastTestInfo}, + + /* Solver APIs. */ + {"Assume", (void *) _McTest_Assume}, + {"IsSymbolicUInt", (void *) McTest_IsSymbolicUInt}, + {"ConcretizeData", (void *) McTest_ConcretizeData}, + {"ConcretizeCStr", (void *) McTest_ConcretizeCStr}, + + /* Logging API. */ + {"Log", (void *) McTest_Log}, + + /* Streaming API for deferred logging. */ + {"LogStream", (void *) McTest_LogStream}, + {"StreamInt", (void *) _McTest_StreamInt}, + {"StreamFloat", (void *) _McTest_StreamFloat}, + {"StreamString", (void *) _McTest_StreamString}, + {NULL, NULL}, }; @@ -232,6 +220,19 @@ void McTest_Setup(void) { /* TODO(pag): Sort the test cases by file name and line number. */ } +/* Tear down McTest. */ +void McTest_Teardown(void) { + +} + +/* Notify that we're about to begin a test. */ +void McTest_Begin(struct McTest_TestInfo *info) { + McTest_TestFailed = 0; + McTest_TestAbandoned = NULL; + McTest_LogFormat(McTest_LogInfo, "Running: %s from %s(%u)", + info->test_name, info->file_name, info->line_number); +} + /* Return the first test case to run. */ struct McTest_TestInfo *McTest_FirstTest(void) { return McTest_LastTestInfo; @@ -242,4 +243,9 @@ int McTest_CatchFail(void) { return McTest_TestFailed; } +/* Returns 1 if this test case was abandoned. */ +int McTest_CatchAbandoned(void) { + return McTest_TestAbandoned != NULL; +} + MCTEST_END_EXTERN_C diff --git a/src/lib/Stream.c b/src/lib/Stream.c new file mode 100644 index 0000000..787d3ad --- /dev/null +++ b/src/lib/Stream.c @@ -0,0 +1,563 @@ +/* + * Copyright (c) 2017 Trail of Bits, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +MCTEST_BEGIN_EXTERN_C + +enum { + McTest_StreamSize = 4096 +}; + +/* Formatting options availale to the streaming API. */ +struct McTest_StreamFormatOptions { + /* int radix; */ + int hex; + int oct; + int show_base; + /*int width; */ + int left_justify; + int add_sign; + char fill; +}; + +/* Stream type that accumulates formatted data to be printed. This loosely + * mirrors C++ I/O streams, not because I/O streams are good, but instead + * because the ability to stream in data to things like the C++-backed + * `ASSERT` and `CHECK` macros is really nice. */ +struct McTest_Stream { + int size; + struct McTest_StreamFormatOptions options; + char message[McTest_StreamSize + 2]; + char staging[32]; + char format[32]; + char unpack[32]; + union { + uint64_t as_uint64; + double as_fp64; + } value; +}; + +/* Hard-coded streams for each log level. */ +static struct McTest_Stream McTest_Streams[McTest_LogFatal + 1] = {}; + +static char McTest_EndianSpecifier = '='; + +/* Figure out what the Python `struct` endianness specifier should be. */ +MCTEST_INITIALIZER(DetectEndianness) { + static const int one = 1; + if ((const char *) &one) { + McTest_EndianSpecifier = '<'; /* Little endian. */ + } else { + McTest_EndianSpecifier = '>'; /* Big endian. */ + } +} + +static void McTest_StreamUnpack(struct McTest_Stream *stream, char type) { + stream->unpack[0] = McTest_EndianSpecifier; + stream->unpack[1] = type; + stream->unpack[2] = '\0'; +} + +/* Fill in the format for when we want to stream an integer. */ +static void McTest_StreamIntFormat(struct McTest_Stream *stream, + size_t val_size, int is_unsigned) { + char *format = stream->format; + int i = 0; + + format[i++] = '%'; + if(stream->options.left_justify) { + format[i++] = '-'; + } + + if(stream->options.add_sign) { + format[i++] = '+'; + } + + if (stream->options.fill) { + format[i++] = stream->options.fill; + } + + if (stream->options.show_base) { + format[i++] = '#'; /* Show the radix. */ + } + + if (8 == val_size) { + format[i++] = 'l'; + format[i++] = 'l'; + + } else if (2 == val_size) { + format[i++] = 'h'; + + } else if (1 == val_size) { + if (is_unsigned) { + format[i++] = 'h'; + format[i++] = 'h'; + } + } + + if (stream->options.hex) { + format[i++] = 'x'; + } else if (stream->options.oct) { + format[i++] = 'o'; + } else if (is_unsigned) { + format[i++] = 'u'; + } else { + if (1 == val_size) { + format[i++] = 'c'; + } else { + format[i++] = 'd'; + } + } + + format[i++] = '\0'; +} + +static void CheckCapacity(struct McTest_Stream *stream, int num_chars_to_add) { + if (0 > num_chars_to_add) { + McTest_Abandon("Can't add a negative number of characters to a stream."); + } else if ((stream->size + num_chars_to_add) >= McTest_StreamSize) { + McTest_Abandon("Exceeded capacity of stream buffer."); + } +} + +/* Stream an integer into the stream's message. This function is designed to + * be hooked by the symbolic executor, so that it can easily pull out the + * relevant data from `*val`, which may be symbolic, and defer the actual + * formatting until later. */ +MCTEST_NOINLINE +void _McTest_StreamInt(enum McTest_LogLevel level, const char *format, + const char *unpack, uint64_t *val) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + int size = 0; + int remaining_size = McTest_StreamSize - stream->size; + if (unpack[1] == 'Q' || unpack[1] == 'q') { + size = snprintf(&(stream->message[stream->size]), + remaining_size, format, *val); + } else { + size = snprintf(&(stream->message[stream->size]), + remaining_size, format, (uint32_t) *val); + } + CheckCapacity(stream, size); + stream->size += size; +} + +MCTEST_NOINLINE +void _McTest_StreamFloat(enum McTest_LogLevel level, const char *format, + const char *unpack, double *val) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + int remaining_size = McTest_StreamSize - stream->size; + int size = snprintf(&(stream->message[stream->size]), + remaining_size, format, *val); + CheckCapacity(stream, size); + stream->size += size; +} + +MCTEST_NOINLINE +void _McTest_StreamString(enum McTest_LogLevel level, const char *format, + const char *str) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + int remaining_size = McTest_StreamSize - stream->size; + int size = snprintf(&(stream->message[stream->size]), + remaining_size, format, str); + CheckCapacity(stream, size); + stream->size += size; +} + +#define MAKE_INT_STREAMER(Type, type, is_unsigned, pack_kind) \ + void McTest_Stream ## Type(enum McTest_LogLevel level, type val) { \ + struct McTest_Stream *stream = &(McTest_Streams[level]); \ + McTest_StreamIntFormat(stream, sizeof(val), is_unsigned); \ + McTest_StreamUnpack(stream, pack_kind); \ + stream->value.as_uint64 = (uint64_t) val; \ + _McTest_StreamInt(level, stream->format, stream->unpack, \ + &(stream->value.as_uint64)); \ + } + +MAKE_INT_STREAMER(Pointer, void *, 1, (sizeof(void *) == 8 ? 'Q' : 'I')) + +MAKE_INT_STREAMER(UInt64, uint64_t, 1, 'Q') +MAKE_INT_STREAMER(Int64, int64_t, 0, 'q') + +MAKE_INT_STREAMER(UInt32, uint32_t, 1, 'I') +MAKE_INT_STREAMER(Int32, int32_t, 0, 'i') + +MAKE_INT_STREAMER(UInt16, uint16_t, 1, 'h') +MAKE_INT_STREAMER(Int16, int16_t, 0, 'H') + +MAKE_INT_STREAMER(UInt8, uint8_t, 1, 'B') +MAKE_INT_STREAMER(Int8, int8_t, 0, 'c') + +#undef MAKE_INT_STREAMER + +/* Stream a C string into the stream's message. */ +void McTest_StreamCStr(enum McTest_LogLevel level, const char *begin) { + _McTest_StreamString(level, "%s", begin); +} + +/* Stream a some data in the inclusive range `[begin, end]` into the + * stream's message. */ +/*void McTest_StreamData(enum McTest_LogLevel level, const void *begin, + const void *end) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + int remaining_size = McTest_StreamSize - stream->size; + int input_size = (int) ((uintptr_t) end - (uintptr_t) begin) + 1; + CheckCapacity(stream, input_size); + memcpy(&(stream->message[stream->size]), begin, (size_t) input_size); + stream->size += input_size; +}*/ + +/* Stream a `double` into the stream's message. This function is designed to + * be hooked by the symbolic executor, so that it can easily pull out the + * relevant data from `*val`, which may be symbolic, and defer the actual + * formatting until later. */ +void McTest_StreamDouble(enum McTest_LogLevel level, double val) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + const char *format = "%f"; /* TODO(pag): Support more? */ + stream->value.as_fp64 = val; + McTest_StreamUnpack(stream, 'd'); + _McTest_StreamFloat(level, format, stream->unpack, &(stream->value.as_fp64)); +} + +/* Flush the contents of the stream to a log. */ +void McTest_LogStream(enum McTest_LogLevel level) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + if (stream->size) { + stream->message[stream->size] = '\n'; + stream->message[stream->size + 1] = '\0'; + stream->message[McTest_StreamSize] = '\0'; + McTest_Log(level, stream->message); + memset(stream->message, 0, McTest_StreamSize); + stream->size = 0; + } +} + +/* Reset the formatting in a stream. */ +void McTest_StreamResetFormatting(enum McTest_LogLevel level) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + memset(&(stream->options), 0, sizeof(stream->options)); +} + +/* Approximately do string format parsing and convert it into calls into our + * streaming API. */ +static int McTest_StreamFormatValue(enum McTest_LogLevel level, + const char *format, va_list args) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + char format_buf[32] = {'\0'}; + int i = 0; + int k = 0; + int length = 4; + char ch = '\0'; + int is_string = 0; + int is_unsigned = 0; + int is_float = 0; + int long_double = 0; + char extract = '\0'; + +#define READ_FORMAT_CHAR \ + ch = format[i]; \ + format_buf[i - k] = ch; \ + i++ + + READ_FORMAT_CHAR; /* Read the '%' */ + + if ('%' != ch) { + McTest_Abandon("Invalid format."); + return 0; + } + + /* Flags */ +get_flag_char: + READ_FORMAT_CHAR; + switch (ch) { + case '\0': + McTest_Abandon("Incomplete format (flags)."); + return 0; + case '-': + case '+': + case ' ': + case '#': + case '0': + goto get_flag_char; + default: + break; + } + + /* Width */ +get_width_char: + switch (ch) { + case '\0': + McTest_Abandon("Incomplete format (width)."); + return 0; + case '*': + McTest_Abandon("Variable width printing not supported."); + return 0; + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + READ_FORMAT_CHAR; + goto get_width_char; + default: + break; + } + + /* Precision */ + if ('.' == ch) { + get_precision_char: + READ_FORMAT_CHAR; + switch (ch) { + case '\0': + McTest_Abandon("Incomplete format (precision)."); + return 0; + case '*': + McTest_Abandon("Variable precision printing not supported."); + break; + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + goto get_precision_char; + default: + break; + } + } + + /* Length */ +get_length_char: + switch (ch) { + case '\0': + McTest_Abandon("Incomplete format (length)."); + return 0; + case 'L': + long_double = 1; + k += 1; /* Overwrite the `L`. */ + READ_FORMAT_CHAR; + break; + case 'h': + length /= 2; + READ_FORMAT_CHAR; + goto get_length_char; + case 'l': + length *= 2; + READ_FORMAT_CHAR; + goto get_length_char; + case 'j': + length = (int) sizeof(intmax_t); + READ_FORMAT_CHAR; + break; + case 'z': + length = (int) sizeof(size_t); + READ_FORMAT_CHAR; + break; + case 't': + length = (int) sizeof(ptrdiff_t); + READ_FORMAT_CHAR; + break; + default: + break; + } + + if (!length) { + length = 1; + } else if (8 < length) { + length = 8; + } + + format_buf[i] = '\0'; + + /* Specifier */ + switch(ch) { + case '\0': + McTest_Abandon("Incomplete format (specifier)."); + return 0; + + case 'n': + return i; /* Nothing printed. */ + + /* Print a character. */ + case 'c': + stream->value.as_uint64 = (uint64_t) (char) va_arg(args, int); + extract = 'c'; + goto common_stream_int; + + /* Signed integer. */ + case 'd': + case 'i': + if (1 == length) { + stream->value.as_uint64 = (uint64_t) (int8_t) va_arg(args, int); + extract = 'b'; + } else if (2 == length) { + stream->value.as_uint64 = (uint64_t) (int16_t) va_arg(args, int); + extract = 'h'; + } else if (4 == length) { + stream->value.as_uint64 = (uint64_t) (int32_t) va_arg(args, int); + extract = 'i'; + } else if (8 == length) { + stream->value.as_uint64 = (uint64_t) va_arg(args, int64_t); + extract = 'q'; + } else { + McTest_Abandon("Unsupported integer length."); + } + goto common_stream_int; + + /* Pointer. */ + case 'p': + length = (int) sizeof(void *); + format_buf[i - 1] = 'x'; + /* Note: Falls through. */ + + /* Unsigned, hex, octal */ + case 'u': + case 'o': + case 'x': + case 'X': + if (1 == length) { + stream->value.as_uint64 = (uint64_t) (uint8_t) va_arg(args, int); + extract = 'B'; + } else if (2 == length) { + stream->value.as_uint64 = (uint64_t) (uint16_t) va_arg(args, int); + extract = 'H'; + } else if (4 == length) { + stream->value.as_uint64 = (uint64_t) (uint32_t) va_arg(args, int); + extract = 'I'; + } else if (8 == length) { + stream->value.as_uint64 = (uint64_t) va_arg(args, uint64_t); + extract = 'Q'; + } else { + McTest_Abandon("Unsupported integer length."); + } + + common_stream_int: + McTest_StreamUnpack(stream, extract); + _McTest_StreamInt(level, format_buf, stream->unpack, + &(stream->value.as_uint64)); + break; + + /* Floating point, scientific notation, etc. */ + case 'f': + case 'F': + case 'e': + case 'E': + case 'g': + case 'G': + case 'a': + case 'A': + if (long_double) { + stream->value.as_fp64 = (double) va_arg(args, long double); + } else { + stream->value.as_fp64 = va_arg(args, double); + } + McTest_StreamUnpack(stream, 'd'); + break; + + case 's': + _McTest_StreamString(level, format_buf, va_arg(args, const char *)); + break; + + default: + McTest_Abandon("Unsupported format specifier."); + return 0; + } + + return i; +} + +static char McTest_Format[McTest_StreamSize + 1]; + +/* Stream some formatted input */ +void McTest_StreamVFormat(enum McTest_LogLevel level, const char *format_, + va_list args) { + char *begin = NULL; + char *end = NULL; + char *format = McTest_Format; + int i = 0; + char ch = '\0'; + char next_ch = '\0'; + + strncpy(format, format_, McTest_StreamSize); + format[McTest_StreamSize] = '\0'; + + McTest_ConcretizeCStr(format); + + for (i = 0; '\0' != (ch = format[i]); ) { + if (!begin) { + begin = &(format[i]); + } + + if ('%' == ch) { + if ('%' == format[i + 1]) { + end = &(format[i]); + next_ch = end[1]; + end[1] = '\0'; + McTest_StreamCStr(level, begin); + end[1] = next_ch; + begin = NULL; + end = NULL; + i += 2; + + } else { + if (end) { + next_ch = end[1]; + end[1] = '\0'; + McTest_StreamCStr(level, begin); + end[1] = next_ch; + } + begin = NULL; + end = NULL; + i += McTest_StreamFormatValue(level, &(format[i]), args); + } + } else { + end = &(format[i]); + i += 1; + } + } + + if (begin && begin[0]) { + McTest_StreamCStr(level, begin); + } +} + +/* Stream some formatted input */ +void McTest_StreamFormat(enum McTest_LogLevel level, const char *format, ...) { + va_list args; + va_start(args, format); + McTest_StreamVFormat(level, format, args); + va_end(args); +} + +MCTEST_END_EXTERN_C