diff --git a/CMakeLists.txt b/CMakeLists.txt index 32cbc20..2026ac7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,8 @@ endif () add_library(${PROJECT_NAME} STATIC src/lib/McTest.c + src/lib/Log.c + src/lib/Stream.c ) target_include_directories(${PROJECT_NAME} diff --git a/bin/mctest/__main__.py b/bin/mctest/__main__.py index 7f65973..70eb293 100644 --- a/bin/mctest/__main__.py +++ b/bin/mctest/__main__.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import argparse import collections import logging @@ -21,86 +20,107 @@ import manticore import multiprocessing import sys import traceback +from .common import McTest from manticore.core.state import TerminateState from manticore.utils.helpers import issymbolic -L = logging.getLogger("mctest") +L = logging.getLogger("mctest.mcore") L.setLevel(logging.INFO) +OUR_TERMINATION_REASON = "I McTest'd it" -def read_c_string(state, ea): - """Read a concrete NUL-terminated string from `ea`.""" - return state.cpu.read_string(ea) +class MCoreTest(McTest): + def __init__(self, state): + super(MCoreTest, self).__init__() + self.state = state + def __del__(self): + self.state = None -def read_uintptr_t(state, ea): - """Read a uintptr_t value from memory.""" - addr_size_bits = state.cpu.address_bit_size - next_ea = ea + (addr_size_bits // 8) - val = state.cpu.read_int(ea, size=addr_size_bits) - if issymbolic(val): - val = state.solve_one(val) - return val, next_ea + def get_context(self): + return self.state.context + def is_symbolic(self, val): + return manticore.utils.helpers.issymbolic(val) -def read_uint32_t(state, ea): - """Read a uint32_t value from memory.""" - next_ea = ea + 4 - val = state.cpu.read_int(ea, size=32) - if issymbolic(val): - val = state.solve_one(val) - return val, next_ea + def read_uintptr_t(self, ea, concretize=True, constrain=False): + addr_size_bits = self.state.cpu.address_bit_size + next_ea = ea + (addr_size_bits // 8) + val = self.state.cpu.read_int(ea, size=addr_size_bits) + if concretize: + val = self.concretize(val, constrain=constrain) + return val, next_ea + def read_uint64_t(self, ea, concretize=True, constrain=False): + val = self.state.cpu.read_int(ea, size=64) + if concretize: + val = self.concretize(val, constrain=constrain) + return val, ea + 8 -TestInfo = collections.namedtuple( - 'TestInfo', 'ea name file_name line_number') + def read_uint32_t(self, ea, concretize=True, constrain=False): + val = self.state.cpu.read_int(ea, size=32) + if concretize: + val = self.concretize(val, constrain=constrain) + return val, ea + 4 + def read_uint8_t(self, ea, concretize=True, constrain=False): + val = self.state.cpu.read_int(ea, size=8) + if concretize: + val = self.concretize(val, constrain=constrain) + if isinstance(val, str): + assert len(val) == 1 + val = ord(val) + return val, ea + 1 -def read_test_info(state, ea): - """Read in a `McTest_TestInfo` info structure from memory.""" - prev_test_ea, ea = read_uintptr_t(state, ea) - test_func_ea, ea = read_uintptr_t(state, ea) - test_name_ea, ea = read_uintptr_t(state, ea) - file_name_ea, ea = read_uintptr_t(state, ea) - file_line_num, _ = read_uint32_t(state, ea) + def concretize(self, val, constrain=False): + if isinstance(val, (int, long)): + return val + elif isinstance(val, str): + assert len(val) == 1 + return ord(val[0]) - if not test_func_ea or \ - not test_name_ea or \ - not file_name_ea or \ - not file_line_num: # `__LINE__` in C always starts at `1` ;-) - return None, prev_test_ea + assert self.is_symbolic(val) + concrete_val = self.state.solve_one(val) + if isinstance(concrete_val, str): + assert len(concrete_val) == 1 + concrete_val = ord(concrete_val[0]) + if constrain: + self.add_constraint(val == concrete_val) + return concrete_val - test_name = read_c_string(state, test_name_ea) - file_name = read_c_string(state, file_name_ea) - info = TestInfo(test_func_ea, test_name, file_name, file_line_num) - return info, prev_test_ea + def concretize_min(self, val, constrain=False): + if isinstance(val, (int, long)): + return val + concrete_val = self.state.solve_n(val) + if constrain: + self.add_constraint(val == concrete_val) + return concrete_val + def concretize_many(self, val, max_num): + assert 0 < max_num + if isinstance(val, (int, long)): + return [val] + return self.state.solver.eval_upto(val, max_num) -def find_test_cases(state, info_ea): - """Find the test case descriptors.""" - tests = [] - info_ea, _ = read_uintptr_t(state, info_ea) - while info_ea: - test, info_ea = read_test_info(state, info_ea) - if test: - tests.append(test) - tests.sort(key=lambda t: (t.file_name, t.line_number)) - return tests + def add_constraint(self, expr): + if self.is_symbolic(expr): + self.state.constrain(expr) + # TODO(pag): How to check satisfiability? + return True + def pass_test(self): + super(MCoreTest, self).pass_test() + raise TerminateState(OUR_TERMINATION_REASON, testcase=False) -def read_api_table(state, ea): - """Reads in the API table.""" - apis = {} - while True: - api_name_ea, ea = read_uintptr_t(state, ea) - api_ea, ea = read_uintptr_t(state, ea) - if not api_name_ea or not api_ea: - break - api_name = read_c_string(state, api_name_ea) - apis[api_name] = api_ea - return apis + def fail_test(self): + super(MCoreTest, self).fail_test() + raise TerminateState(OUR_TERMINATION_REASON, testcase=False) + + def abandon_test(self): + super(MCoreTest, self).abandon_test() + raise TerminateState(OUR_TERMINATION_REASON, testcase=False) def make_symbolic_input(state, input_begin_ea, input_end_ea): @@ -118,78 +138,63 @@ def make_symbolic_input(state, input_begin_ea, input_end_ea): def hook_IsSymbolicUInt(state, arg): """Implements McTest_IsSymblicUInt, which returns 1 if its input argument has more then one solutions, and zero otherwise.""" - solutions = state.solve_n(arg, 2) - if not solutions: - return 0 - elif 1 == len(solutions): - if issymbolic(arg): - state.constrain(arg == solutions[0]) - return 0 - else: - return 1 + return MCoreTest(state).api_is_symbolic_uint(arg) def hook_Assume(state, arg): """Implements _McTest_Assume, which tries to inject a constraint.""" - constraint = arg != 0 - if issymbolic(constraint): - state.constrain(constraint) + MCoreTest(state).api_assume(arg) -OUR_TERMINATION_REASON = "I McTest'd it" +def hook_StreamInt(state, level, format_ea, unpack_ea, uint64_ea): + """Implements _McTest_StreamInt, which gives us an integer to stream, and + the format to use for streaming.""" + MCoreTest(state).api_stream_int(level, format_ea, unpack_ea, uint64_ea) -def report_state(state): - test = state.context['test'] - if state.context['failed']: - message = (3, "Failed: {}".format(test.name)) - else: - message = (1, "Passed: {}".format(test.name)) - state.context['log_messages'].append(message) - raise TerminateState(OUR_TERMINATION_REASON, testcase=False) +def hook_StreamFloat(state, level, format_ea, unpack_ea, double_ea): + """Implements _McTest_StreamFloat, which gives us an double to stream, and + the format to use for streaming.""" + MCoreTest(state).api_stream_float(level, format_ea, unpack_ea, double_ea) + + +def hook_StreamString(state, level, format_ea, str_ea): + """Implements _McTest_StreamString, which gives us an double to stream, and + the format to use for streaming.""" + MCoreTest(state).api_stream_string(level, format_ea, str_ea) + + +def hook_LogStream(state, level): + """Implements McTest_LogStream, which converts the contents of a stream for + level `level` into a log for level `level`.""" + MCoreTest(state).api_log_stream(level) def hook_Pass(state): """Implements McTest_Pass, which notifies us of a passing test.""" - report_state(state) + MCoreTest(state).api_pass() def hook_Fail(state): """Implements McTest_Fail, which notifies us of a passing test.""" - state.context['failed'] = 1 - report_state(state) + MCoreTest(state).api_fail() + + +def hook_Abandon(state, reason): + """Implements McTest_Abandon, which notifies us that a problem happened + in McTest.""" + MCoreTest(state).api_abandon(reason) def hook_SoftFail(state): """Implements McTest_Fail, which notifies us of a passing test.""" - state.context['failed'] = 1 + MCoreTest(state).api_soft_fail() -LEVEL_TO_LOGGER = { - 0: L.debug, - 1: L.info, - 2: L.warning, - 3: L.error, - 4: L.critical -} - - -def hook_Log(state, level, begin_ea): +def hook_Log(state, level, ea): """Implements McTest_Log, which lets Manticore intercept and handle the printing of log messages from the simulated tests.""" - level = state.solve_one(level) - assert level in LEVEL_TO_LOGGER - - begin_ea = state.solve_one(begin_ea) - - message_bytes = [] - for i in xrange(4096): - b = state.cpu.read_int(begin_ea + i, 8) - if not issymbolic(b) and b == 0: - break - message_bytes.append(b) - - state.context['log_messages'].append((level, message_bytes)) + MCoreTest(state).api_log(level, ea) def hook(func): @@ -202,36 +207,8 @@ def done_test(_, state, state_id, reason): L.error("State {} terminated for unknown reason: {}".format( state_id, reason)) return - - test = state.context['test'] - input_length, _ = read_uint32_t(state, state.context['InputIndex']) - - # Dump out any pending log messages reported by `McTest_Log`. - for level, message_bytes in state.context['log_messages']: - message = [] - for b in message_bytes: - b_ord = state.solve_one(b) - if b_ord == 0: - break - else: - message.append(chr(b_ord)) - - LEVEL_TO_LOGGER[level]("".join(message)) - - max_length = state.context['InputEnd'] - state.context['InputBegin'] - if input_length > max_length: - L.critical("Test used too many input bytes ({} vs. {})".format( - input_length, max_length)) - return - - # Solve for the input bytes. - output = [] - for i in xrange(input_length): - b = state.cpu.read_int(state.context['InputBegin'] + i, 8) - b = state.solve_one(b) - output.append("{:2x}".format(b)) - - L.info("Input: {}".format(" ".join(output))) + mc = MCoreTest(state) + mc.report() def do_run_test(state, apis, test): @@ -241,15 +218,9 @@ def do_run_test(state, apis, test): m.verbosity(1) state = m.initial_state - messages = [(1, "Running {} from {}({})".format( - test.name, test.file_name, test.line_number))] - - state.context['InputBegin'] = apis['InputBegin'] - state.context['InputEnd'] = apis['InputEnd'] - state.context['InputIndex'] = apis['InputIndex'] - state.context['test'] = test - state.context['failed'] = 0 - state.context['log_messages'] = messages + mc = MCoreTest(state) + mc.begin_test(test) + del mc make_symbolic_input(state, apis['InputBegin'], apis['InputEnd']) @@ -258,7 +229,13 @@ def do_run_test(state, apis, test): m.add_hook(apis['Pass'], hook(hook_Pass)) m.add_hook(apis['Fail'], hook(hook_Fail)) m.add_hook(apis['SoftFail'], hook(hook_SoftFail)) + m.add_hook(apis['Abandon'], hook(hook_Abandon)) m.add_hook(apis['Log'], hook(hook_Log)) + m.add_hook(apis['StreamInt'], hook(hook_StreamInt)) + m.add_hook(apis['StreamFloat'], hook(hook_StreamFloat)) + m.add_hook(apis['StreamString'], hook(hook_StreamString)) + m.add_hook(apis['LogStream'], hook(hook_LogStream)) + m.subscribe('will_terminate_state', done_test) m.run() @@ -275,7 +252,8 @@ def run_tests(args, state, apis): """Run all of the test cases.""" pool = multiprocessing.Pool(processes=max(1, args.num_workers)) results = [] - tests = find_test_cases(state, apis['LastTestInfo']) + mc = MCoreTest(state) + tests = mc.find_test_cases() L.info("Running {} tests across {} workers".format( len(tests), args.num_workers)) @@ -313,9 +291,11 @@ def main(): setup_ea = m._get_symbol_address('McTest_Setup') setup_state = m._initial_state - ea_of_api_table = m._get_symbol_address('McTest_API') - apis = read_api_table(setup_state, ea_of_api_table) + mc = MCoreTest(setup_state) + ea_of_api_table = m._get_symbol_address('McTest_API') + apis = mc.read_api_table(ea_of_api_table) + del mc m.add_hook(setup_ea, lambda state: run_tests(args, state, apis)) m.run() diff --git a/bin/mctest/common.py b/bin/mctest/common.py new file mode 100644 index 0000000..df807cd --- /dev/null +++ b/bin/mctest/common.py @@ -0,0 +1,381 @@ +# Copyright (c) 2017 Trail of Bits, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import logging +import struct + +# Represents a TestInfo data structure (the information we know about a test.) +TestInfo = collections.namedtuple( + 'TestInfo', 'ea name file_name line_number') + + +LOG_LEVEL_DEBUG = 0 +LOG_LEVEL_INFO = 1 +LOG_LEVEL_WARNING = 2 +LOG_LEVEL_ERROR = 3 +LOG_LEVEL_FATAL = 4 + + +LOGGER = logging.getLogger("mctest") +LOGGER.setLevel(logging.DEBUG) + + +LOG_LEVEL_TO_LOGGER = { + LOG_LEVEL_DEBUG: LOGGER.debug, + LOG_LEVEL_INFO: LOGGER.info, + LOG_LEVEL_WARNING: LOGGER.warning, + LOG_LEVEL_ERROR: LOGGER.error, + LOG_LEVEL_FATAL: LOGGER.critical +} + + +class Stream(object): + def __init__(self, entries): + self.entries = entries + + +class McTest(object): + """Wrapper around a symbolic executor for making it easy to do common McTest- + specific things.""" + def __init__(self): + pass + + def get_context(self): + raise NotImplementedError("Must be implemented by engine.") + + @property + def context(self): + return self.get_context() + + def is_symbolic(self, val): + raise NotImplementedError("Must be implemented by engine.") + + def read_uintptr_t(self, ea, concretize=True, constrain=False): + raise NotImplementedError("Must be implemented by engine.") + + def read_uint64_t(self, ea, concretize=True, constrain=False): + raise NotImplementedError("Must be implemented by engine.") + + def read_uint32_t(self, ea, concretize=True, constrain=False): + raise NotImplementedError("Must be implemented by engine.") + + def read_uint8_t(self, ea, concretize=True, constrain=False): + raise NotImplementedError("Must be implemented by engine.") + + def concretize(self, val, constrain=False): + raise NotImplementedError("Must be implemented by engine.") + + def concretize_min(self, val, constrain=False): + raise NotImplementedError("Must be implemented by engine.") + + def concretize_many(self, val, max_num): + raise NotImplementedError("Must be implemented by engine.") + + def add_constraint(self, expr): + raise NotImplementedError("Must be implemented by engine.") + + def read_c_string(self, ea, concretize=True): + """Read a NUL-terminated string from `ea`.""" + assert isinstance(ea, (int, long)) + chars = [] + i = 0 + while True: + b, ea = self.read_uint8_t(ea, concretize=concretize) + if self.is_symbolic(b): + concrete_b = self.concretize_min(b) # Find the NUL byte sooner. + if not concrete_b: + break + + if concretize: + chars.append(chr(concrete_b)) + else: + chars.append(b) + + continue + + # Concretize if it's not symbolic; we might have a concrete bitvector. + b = self.concretize(b) + if not b: + break + + else: + chars.append(chr(b)) + + next_ea = ea + len(chars) + 1 + if concretize: + return "".join(chars), next_ea + else: + return chars, next_ea + + def read_test_info(self, ea): + """Read in a `McTest_TestInfo` info structure from memory.""" + prev_test_ea, ea = self.read_uintptr_t(ea) + test_func_ea, ea = self.read_uintptr_t(ea) + test_name_ea, ea = self.read_uintptr_t(ea) + file_name_ea, ea = self.read_uintptr_t(ea) + file_line_num, _ = self.read_uint32_t(ea) + + if not test_func_ea or \ + not test_name_ea or \ + not file_name_ea or \ + not file_line_num: # `__LINE__` in C always starts at `1` ;-) + return None, prev_test_ea + + test_name, _ = self.read_c_string(test_name_ea) + file_name, _ = self.read_c_string(file_name_ea) + info = TestInfo(test_func_ea, test_name, file_name, file_line_num) + return info, prev_test_ea + + def find_test_cases(self): + """Find the test case descriptors.""" + tests = [] + info_ea, _ = self.read_uintptr_t(self.context['apis']['LastTestInfo']) + while info_ea: + test, info_ea = self.read_test_info(info_ea) + if test: + tests.append(test) + tests.sort(key=lambda t: (t.file_name, t.line_number)) + return tests + + def read_api_table(self, ea): + """Reads in the API table.""" + apis = {} + while True: + api_name_ea, ea = self.read_uintptr_t(ea) + api_ea, ea = self.read_uintptr_t(ea) + if not api_name_ea or not api_ea: + break + api_name, _ = self.read_c_string(api_name_ea) + apis[api_name] = api_ea + self.context['apis'] = apis + return apis + + def begin_test(self, info): + self.context['failed'] = False + self.context['abandoned'] = False + self.context['log'] = [] + for level in LOG_LEVEL_TO_LOGGER: + self.context['stream_{}'.format(level)] = [] + self.context['info'] = info + self.log_message(LOG_LEVEL_INFO, "Running {} from {}({})".format( + info.name, info.file_name, info.line_number)) + + def log_message(self, level, message): + """Log a message.""" + assert level in LOG_LEVEL_TO_LOGGER + log = list(self.context['log']) + if isinstance(message, (str, list, tuple)): + log.append((level, Stream([(str, "%s", None, message)]))) + else: + assert isinstance(message, Stream) + log.append((level, message)) + + self.context['log'] = log + + def _concretize_bytes(self, byte_str): + new_bytes = [] + for b in byte_str: + if isinstance(b, str): + assert len(b) == 1 + new_bytes.append(ord(b)) + elif isinstance(b, (int, long)): + new_bytes.append(b) + else: + new_bytes.append(self.concretize(b, constrain=True)) + return new_bytes + + def _stream_to_message(self, stream): + assert isinstance(stream, Stream) + message = [] + for val_type, format_str, unpack_str, val_bytes in stream.entries: + val_bytes = self._concretize_bytes(val_bytes) + if val_type == str: + val = "".join(chr(b) for b in val_bytes) + elif val_type == float: + data = struct.pack('BBBBBBBB', *val_bytes) + val = struct.unpack(unpack_str, data)[0] + else: + assert val_type == int + + # TODO(pag): I am pretty sure that this is wrong for big-endian. + data = struct.pack('BBBBBBBB', *val_bytes) + val = struct.unpack(unpack_str, data[:struct.calcsize(unpack_str)])[0] + + # Remove length specifiers that are not supported. + format_str = format_str.replace('l', '') + format_str = format_str.replace('h', '') + format_str = format_str.replace('z', '') + format_str = format_str.replace('t', '') + + message.extend(format_str % val) + + return "".join(message) + + def report(self): + info = self.context['info'] + apis = self.context['apis'] + input_length, _ = self.read_uint32_t(apis['InputIndex']) + input_bytes = [] + for i in xrange(input_length): + ea = apis['InputBegin'] + i + b, _ = self.read_uint8_t(ea + i, concretize=True, constrain=True) + input_bytes.append("{:02x}".format(b)) + + for level, stream in self.context['log']: + logger = LOG_LEVEL_TO_LOGGER[level] + logger(self._stream_to_message(stream)) + + LOGGER.info("Input: {}".format(" ".join(input_bytes))) + + def pass_test(self): + pass + + def fail_test(self): + self.context['failed'] = True + + def abandon_test(self): + self.context['abandoned'] = True + + def api_is_symbolic_uint(self, arg): + """Implements the `McTest_IsSymbolicUInt` API, which returns whether or + not a given value is symbolic.""" + solutions = self.concretize_many(arg, 2) + if not solutions: + return 0 + elif 1 == len(solutions): + if self.is_symbolic(arg): + self.add_constraint(arg == solutions[0]) + return 0 + else: + return 1 + + def api_assume(self, arg): + """Implements the `McTest_Assume` API function, which injects a constraint + into the solver.""" + constraint = arg != 0 + if not self.add_constraint(constraint): + self.log_message(LOG_LEVEL_FATAL, + "Failed to add assumption {}".format(constraint)) + self.abandon_test() + + def api_pass(self): + """Implements the `McTest_Pass` API function, which marks this test as + having passed, and stops further execution.""" + if self.context['failed']: + self.api_fail() + else: + info = self.context['info'] + self.log_message(LOG_LEVEL_INFO, "Passed: {}".format(info.name)) + self.pass_test() + + def api_fail(self): + """Implements the `McTest_Fail` API function, which marks this test as + having failed, and stops further execution.""" + self.context['failed'] = True + info = self.context['info'] + self.log_message(LOG_LEVEL_ERROR, "Failed: {}".format(info.name)) + self.fail_test() + + def api_soft_fail(self): + """Implements the `McTest_SoftFail` API function, which marks this test + as having failed, but lets execution continue.""" + self.context['failed'] = True + + def api_abandon(self, arg): + """Implements the `McTest_Abandon` API function, which marks this test + as having aborted due to some unrecoverable error.""" + info = self.context['info'] + ea = self.concretize(arg, constrain=True) + self.log_message(LOG_LEVEL_FATAL, self.read_c_string(ea)[0]) + self.log_message(LOG_LEVEL_FATAL, "Abandoned: {}".format(info.name)) + self.abandon_test() + + def api_log(self, level, ea): + """Implements the `McTest_Log` API function, which prints a C string + to a specific log level.""" + self.api_log_stream(level) + + level = self.concretize(level, constrain=True) + ea = self.concretize(ea, constrain=True) + assert level in LOG_LEVEL_TO_LOGGER + self.log_message(level, self.read_c_string(ea, concretize=False)) + + if level == LOG_LEVEL_FATAL: + self.api_fail() + elif level == LOG_LEVEL_ERROR: + self.api_soft_fail() + + def _api_stream_int_float(self, level, format_ea, unpack_ea, uint64_ea, + val_type): + level = self.concretize(level, constrain=True) + assert level in LOG_LEVEL_TO_LOGGER + + format_ea = self.concretize(format_ea, constrain=True) + unpack_ea = self.concretize(unpack_ea, constrain=True) + uint64_ea = self.concretize(uint64_ea, constrain=True) + + format_str, _ = self.read_c_string(format_ea) + unpack_str, _ = self.read_c_string(unpack_ea) + uint64_bytes = [] + for i in xrange(8): + b, _ = self.read_uint8_t(uint64_ea + i, concretize=False) + uint64_bytes.append(b) + + stream_id = 'stream_{}'.format(level) + stream = list(self.context[stream_id]) + stream.append((val_type, format_str, unpack_str, uint64_bytes)) + self.context[stream_id] = stream + + def api_stream_int(self, level, format_ea, unpack_ea, uint64_ea): + """Implements the `_McTest_StreamInt`, which streams an integer into a + holding buffer for the log.""" + return self._api_stream_int_float(level, format_ea, unpack_ea, + uint64_ea, int) + + def api_stream_float(self, level, format_ea, unpack_ea, double_ea): + """Implements the `_McTest_StreamFloat`, which streams an integer into a + holding buffer for the log.""" + return self._api_stream_int_float(level, format_ea, unpack_ea, + double_ea, float) + + def api_stream_string(self, level, format_ea, str_ea): + """Implements the `_McTest_StreamString`, which streams a C-string into a + holding buffer for the log.""" + level = self.concretize(level, constrain=True) + assert level in LOG_LEVEL_TO_LOGGER + + format_ea = self.concretize(format_ea, constrain=True) + str_ea = self.concretize(str_ea, constrain=True) + format_str, _ = self.read_c_string(format_ea) + print_str, _ = self.read_c_string(str_ea, concretize=False) + + stream_id = 'stream_{}'.format(level) + stream = list(self.context[stream_id]) + stream.append((str, format_str, None, print_str)) + self.context[stream_id] = stream + + def api_log_stream(self, level): + level = self.concretize(level, constrain=True) + assert level in LOG_LEVEL_TO_LOGGER + stream_id = 'stream_{}'.format(level) + stream = self.context[stream_id] + if len(stream): + self.context[stream_id] = [] + self.log_message(level, Stream(stream)) + + if level == LOG_LEVEL_FATAL: + self.api_fail() + elif level == LOG_LEVEL_ERROR: + self.api_soft_fail() diff --git a/bin/mctest/main_angr.py b/bin/mctest/main_angr.py index b8ef318..5294e67 100644 --- a/bin/mctest/main_angr.py +++ b/bin/mctest/main_angr.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import - import angr import argparse import collections @@ -22,95 +20,114 @@ import logging import multiprocessing import sys import traceback +from .common import McTest - -L = logging.getLogger("mctest") +L = logging.getLogger("mctest.angr") L.setLevel(logging.INFO) +class AngrTest(McTest): + def __init__(self, state=None, procedure=None): + super(AngrTest, self).__init__() + if procedure: + self.procedure = procedure + self.state = procedure.state + elif state: + self.procedure = None + self.state = state + + def __del__(self): + self.procedure = None + self.state = None + + def get_context(self): + return self.state.globals + + def is_symbolic(self, val): + if isinstance(val, (int, long)): + return False + return self.state.se.symbolic(val) + + def read_uintptr_t(self, ea, concretize=True, constrain=False): + next_ea = ea + (self.state.arch.bits // 8) + val = self.state.mem[ea].uintptr_t.resolved + if concretize: + val = self.concretize(val, constrain=constrain) + return val, next_ea + + def read_uint64_t(self, ea, concretize=True, constrain=False): + val = self.state.mem[ea].uint64_t.resolved + if concretize: + val = self.concretize(val, constrain=constrain) + return val, ea + 8 + + def read_uint32_t(self, ea, concretize=True, constrain=False): + val = self.state.mem[ea].uint32_t.resolved + if concretize: + val = self.concretize(val, constrain=constrain) + return val, ea + 4 + + def read_uint8_t(self, ea, concretize=True, constrain=False): + val = self.state.mem[ea].uint8_t.resolved + if concretize: + val = self.concretize(val, constrain=constrain) + if isinstance(val, str): + assert len(val) == 1 + val = ord(val) + return val, ea + 1 + + def concretize(self, val, constrain=False): + if isinstance(val, (int, long)): + return val + elif isinstance(val, str): + assert len(val) == 1 + return ord(val[0]) + + concrete_val = self.state.solver.eval(val, cast_to=int) + if constrain: + self.add_constraint(val == concrete_val) + + return concrete_val + + def concretize_min(self, val, constrain=False): + if isinstance(val, (int, long)): + return val + concrete_val = self.state.solver.min(val, cast_to=int) + if constrain: + self.add_constraint(val == concrete_val) + return concrete_val + + def concretize_many(self, val, max_num): + assert 0 < max_num + if isinstance(val, (int, long)): + return [val] + return self.state.solver.eval_upto(val, max_num, cast_to=int) + + def add_constraint(self, expr): + if self.is_symbolic(expr): + self.state.solver.add(expr) + return self.state.solver.satisfiable() + else: + return True + + def pass_test(self): + super(AngrTest, self).pass_test() + self.procedure.exit(0) + + def fail_test(self): + super(AngrTest, self).fail_test() + self.procedure.exit(1) + + def abandon_test(self): + super(AngrTest, self).abandon_test() + self.procedure.exit(1) + + def hook_function(project, ea, cls): """Hook the function `ea` with the SimProcedure `cls`.""" project.hook(ea, cls(project=project)) -def read_c_string(state, ea): - """Read a concrete NUL-terminated string from `ea`.""" - assert isinstance(ea, (int, long)) - chars = [] - i = 0 - while True: - char = state.mem[ea + i].char.resolved - char = state.solver.eval(char, cast_to=str) - if not ord(char[0]): - break - chars.append(char) - i += 1 - return "".join(chars) - - -def read_uintptr_t(state, ea): - """Read a uintptr_t value from memory.""" - next_ea = ea + (state.arch.bits // 8) - val = state.solver.eval(state.mem[ea].uintptr_t.resolved, cast_to=int) - return val, next_ea - - -def read_uint32_t(state, ea): - """Read a uint32_t value from memory.""" - next_ea = ea + 4 - val = state.solver.eval(state.mem[ea].uint32_t.resolved, cast_to=int) - return val, next_ea - - -TestInfo = collections.namedtuple( - 'TestInfo', 'ea name file_name line_number') - - -def read_test_info(state, ea): - """Read in a `McTest_TestInfo` info structure from memory.""" - prev_test_ea, ea = read_uintptr_t(state, ea) - test_func_ea, ea = read_uintptr_t(state, ea) - test_name_ea, ea = read_uintptr_t(state, ea) - file_name_ea, ea = read_uintptr_t(state, ea) - file_line_num, _ = read_uint32_t(state, ea) - - if not test_func_ea or \ - not test_name_ea or \ - not file_name_ea or \ - not file_line_num: # `__LINE__` in C always starts at `1` ;-) - return None, prev_test_ea - - test_name = read_c_string(state, test_name_ea) - file_name = read_c_string(state, file_name_ea) - info = TestInfo(test_func_ea, test_name, file_name, file_line_num) - return info, prev_test_ea - - -def find_test_cases(state, info_ea): - """Find the test case descriptors.""" - tests = [] - info_ea, _ = read_uintptr_t(state, info_ea) - while info_ea: - test, info_ea = read_test_info(state, info_ea) - if test: - tests.append(test) - tests.sort(key=lambda t: (t.file_name, t.line_number)) - return tests - - -def read_api_table(state, ea): - """Reads in the API table.""" - apis = {} - while True: - api_name_ea, ea = read_uintptr_t(state, ea) - api_ea, ea = read_uintptr_t(state, ea) - if not api_name_ea or not api_ea: - break - api_name = read_c_string(state, api_name_ea) - apis[api_name] = api_ea - return apis - - def make_symbolic_input(state, input_begin_ea, input_end_ea): """Fill in the input data array with symbolic data.""" input_size = input_end_ea - input_begin_ea @@ -123,132 +140,73 @@ class IsSymbolicUInt(angr.SimProcedure): """Implements McTest_IsSymblicUInt, which returns 1 if its input argument has more then one solutions, and zero otherwise.""" def run(self, arg): - solutions = self.state.solver.eval_upto(arg, 2) - if not solutions: - return 0 - elif 1 == len(solutions): - if self.state.se.symbolic(arg): - self.state.solver.add(arg == solutions[0]) - return 0 - else: - return 1 + return AngrTest(procedure=self).api_is_symbolic_uint(arg) class Assume(angr.SimProcedure): """Implements _McTest_Assume, which tries to inject a constraint.""" def run(self, arg): - constraint = arg != 0 - self.state.solver.add(constraint) - if not self.state.solver.satisfiable(): - L.error("Failed to assert assumption {}".format(constraint)) - self.exit(2) - - -def report_state(state): - test = state.globals['test'] - if state.globals['failed']: - add_log_message(state, 3, "Failed: {}".format(test.name)) - else: - add_log_message(state, 1, "Passed: {}".format(test.name)) + AngrTest(procedure=self).api_assume(arg) class Pass(angr.SimProcedure): """Implements McTest_Pass, which notifies us of a passing test.""" def run(self): - report_state(self.state) - self.exit(self.state.globals['failed']) + AngrTest(procedure=self).api_pass() class Fail(angr.SimProcedure): - """Implements McTest_Fail, which notifies us of a passing test.""" + """Implements McTest_Fail, which notifies us of a failing test.""" def run(self): - self.state.globals['failed'] = 1 - report_state(self.state) - self.exit(1) + AngrTest(procedure=self).api_fail() + + +class Abandon(angr.SimProcedure): + """Implements McTest_Fail, which notifies us of a failing test.""" + def run(self, reason): + AngrTest(procedure=self).api_abandon(reason) class SoftFail(angr.SimProcedure): - """Implements McTest_SoftFail, which notifies us of a passing test.""" + """Implements McTest_SoftFail, which notifies us of a failing test.""" def run(self): - self.state.globals['failed'] = 1 + AngrTest(procedure=self).api_soft_fail() -def add_log_message(state, level, message): - """Add a log message to a state.""" - messages = list(state.globals['log_messages']) - messages.append((level, message)) - state.globals['log_messages'] = messages +class StreamInt(angr.SimProcedure): + """Implements _McTest_StreamInt, which gives us an integer to stream, and + the format to use for streaming.""" + def run(self, level, format_ea, unpack_ea, uint64_ea): + AngrTest(procedure=self).api_stream_int(level, format_ea, unpack_ea, + uint64_ea) + +class StreamFloat(angr.SimProcedure): + """Implements _McTest_StreamFloat, which gives us an double to stream, and + the format to use for streaming.""" + def run(self, level, format_ea, unpack_ea, double_ea): + AngrTest(procedure=self).api_stream_float(level, format_ea, unpack_ea, + double_ea) -LEVEL_TO_LOGGER = { - 0: L.debug, - 1: L.info, - 2: L.warning, - 3: L.error, - 4: L.critical -} +class StreamString(angr.SimProcedure): + """Implements _McTest_StreamString, which gives us an double to stream, and + the format to use for streaming.""" + def run(self, level, format_ea, str_ea): + AngrTest(procedure=self).api_stream_string(level, format_ea, str_ea) + + +class LogStream(angr.SimProcedure): + """Implements McTest_LogStream, which converts the contents of a stream for + level `level` into a log for level `level`.""" + def run(self, level): + AngrTest(procedure=self).api_log_stream(level) class Log(angr.SimProcedure): """Implements McTest_Log, which lets Angr intercept and handle the printing of log messages from the simulated tests.""" - def run(self, level, begin_ea_): - level = self.state.solver.eval(level, cast_to=int) - assert level in LEVEL_TO_LOGGER - - begin_ea = self.state.solver.eval(begin_ea_, cast_to=int) - if self.state.se.symbolic(begin_ea_): - self.state.solver.add(begin_ea_ == begin_ea) - - data = [] - for i in xrange(4096): - b = self.state.memory.load(begin_ea + i, size=1) - solutions = self.state.solver.eval_upto(b, 2, cast_to=int) - if 1 == len(solutions) and solutions[0] == 0: - break - data.append(b) - - # Deep copy the message. - add_log_message(self.state, level, data) - - if 3 == level: - self.state.globals['failed'] = 1 # Soft failure on an error log message. - elif 4 == level: - self.state.globals['failed'] = 1 - report_state(self.state) - self.exit(1) - - -def done_test(state): - test = state.globals['test'] - input_length, _ = read_uint32_t(state, state.globals['InputIndex']) - - # Dump out any pending log messages reported by `McTest_Log`. - for level, message in state.globals['log_messages']: - data = [] - for b in message: - if not isinstance(b, str): - b = state.solver.eval(b, cast_to=str) - if not ord(b): - break - data.append(b) - - LEVEL_TO_LOGGER[level]("".join(data)) - - max_length = state.globals['InputEnd'] - state.globals['InputBegin'] - if input_length > max_length: - L.critical("Test used too many input bytes ({} vs. {})".format( - input_length, max_length)) - return - - # Solve for the input bytes. - output = [] - data = state.memory.load(state.globals['InputBegin'], size=input_length) - data = state.solver.eval(data, cast_to=str) - for i in xrange(input_length): - output.append("{:2x}".format(ord(data[i]))) - - L.info("Input: {}".format(" ".join(output))) + def run(self, level, ea): + AngrTest(procedure=self).api_log(level, ea) def do_run_test(project, test, apis, run_state): @@ -258,15 +216,9 @@ def do_run_test(project, test, apis, run_state): test.ea, base_state=run_state) - messages = [(1, "Running {} from {}({})".format( - test.name, test.file_name, test.line_number))] - - test_state.globals['InputBegin'] = apis['InputBegin'] - test_state.globals['InputEnd'] = apis['InputEnd'] - test_state.globals['InputIndex'] = apis['InputIndex'] - test_state.globals['test'] = test - test_state.globals['failed'] = 0 - test_state.globals['log_messages'] = messages + mc = AngrTest(state=test_state) + mc.begin_test(test) + del mc make_symbolic_input(test_state, apis['InputBegin'], apis['InputEnd']) @@ -283,8 +235,10 @@ def do_run_test(project, test, apis, run_state): sys.exc_info()[0], traceback.format_exc())) for state in test_manager.deadended: - done_test(state) + AngrTest(state=state).report() + for state in test_manager.errored: + print "ErrorL", state.error def run_test(project, test, apis, run_state): """Symbolically executes a single test function.""" @@ -297,7 +251,6 @@ def run_test(project, test, apis, run_state): def main(): """Run McTest.""" - parser = argparse.ArgumentParser( description="Symbolically execute unit tests with Angr") @@ -330,24 +283,32 @@ def main(): setup_ea = project.kb.labels.lookup('McTest_Setup') concrete_manager.explore(find=setup_ea) run_state = concrete_manager.found[0] - + # Read the API table, which will tell us about the location of various # symbols. Technically we can look these up with the `labels.lookup` API, # but we have the API table for Manticore-compatibility, so we may as well # use it. ea_of_api_table = project.kb.labels.lookup('McTest_API') - apis = read_api_table(run_state, ea_of_api_table) + + mc = AngrTest(state=run_state) + apis = mc.read_api_table(ea_of_api_table) # Hook various functions. hook_function(project, apis['IsSymbolicUInt'], IsSymbolicUInt) hook_function(project, apis['Assume'], Assume) hook_function(project, apis['Pass'], Pass) hook_function(project, apis['Fail'], Fail) + hook_function(project, apis['Abandon'], Abandon) hook_function(project, apis['SoftFail'], SoftFail) hook_function(project, apis['Log'], Log) + hook_function(project, apis['StreamInt'], StreamInt) + hook_function(project, apis['StreamFloat'], StreamFloat) + hook_function(project, apis['StreamString'], StreamString) + hook_function(project, apis['LogStream'], LogStream) # Find the test cases that we want to run. - tests = find_test_cases(run_state, apis['LastTestInfo']) + tests = mc.find_test_cases() + del mc L.info("Running {} tests across {} workers".format( len(tests), args.num_workers)) diff --git a/examples/ArithmeticProperties.cpp b/examples/ArithmeticProperties.cpp index 59f3557..e65902c 100644 --- a/examples/ArithmeticProperties.cpp +++ b/examples/ArithmeticProperties.cpp @@ -39,7 +39,8 @@ TEST(Arithmetic, AdditionIsAssociative) { TEST(Arithmetic, InvertibleMultiplication_CanFail) { ForAll([] (int x, int y) { - ASSUME_NE(y, 0); + ASSUME_NE(y, 0) + << "Assumed non-zero value for y: " << y; ASSERT_EQ(x, (x / y) * y) << x << " != (" << x << " / " << y << ") * " << y; }); diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e9cabff..09ce33c 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -21,3 +21,5 @@ target_link_libraries(ArithmeticProperties mctest) add_executable(Lists Lists.cpp) target_link_libraries(Lists mctest) +add_executable(StreamingAndFormatting StreamingAndFormatting.cpp) +target_link_libraries(StreamingAndFormatting mctest) diff --git a/examples/StreamingAndFormatting.cpp b/examples/StreamingAndFormatting.cpp new file mode 100644 index 0000000..3698ea0 --- /dev/null +++ b/examples/StreamingAndFormatting.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2017 Trail of Bits, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +TEST(Streaming, BasicLevels) { + LOG(DEBUG) << "This is a debug message"; + LOG(INFO) << "This is an info message"; + LOG(WARNING) << "This is a warning message"; + LOG(ERROR) << "This is a error message"; + LOG(INFO) << "This is a info message again"; + ASSERT(true) << "This should not be printed."; +} + +TEST(Streaming, BasicTypes) { + LOG(INFO) << 'a'; + LOG(INFO) << 1; + LOG(INFO) << 1.0; + LOG(INFO) << "string"; + LOG(INFO) << nullptr; +} + +TEST(Formatting, OverridePrintf) { + printf("hello string=%s hex_lower=%x hex_upper=%X octal=%o char=%c dec=%d" + "double=%f sci=%e SCI=%E pointer=%p", + "world", 999, 999, 999, 'a', 999, 999.0, 999.0, 999.0, "world"); +} + +int main(void) { + return McTest_Run(); +} diff --git a/src/include/mctest/Compiler.h b/src/include/mctest/Compiler.h index 5e420b1..210cab1 100644 --- a/src/include/mctest/Compiler.h +++ b/src/include/mctest/Compiler.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef INCLUDE_MCTEST_COMPILER_H_ -#define INCLUDE_MCTEST_COMPILER_H_ +#ifndef SRC_INCLUDE_MCTEST_COMPILER_H_ +#define SRC_INCLUDE_MCTEST_COMPILER_H_ #include #include @@ -93,4 +93,4 @@ static void f(void) #endif -#endif /* INCLUDE_MCTEST_COMPILER_H_ */ +#endif /* SRC_INCLUDE_MCTEST_COMPILER_H_ */ diff --git a/src/include/mctest/Log.h b/src/include/mctest/Log.h new file mode 100644 index 0000000..43b8888 --- /dev/null +++ b/src/include/mctest/Log.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2017 Trail of Bits, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SRC_INCLUDE_MCTEST_LOG_H_ +#define SRC_INCLUDE_MCTEST_LOG_H_ + +#include + +#include + +MCTEST_BEGIN_EXTERN_C + +struct McTest_Stream; + +enum McTest_LogLevel { + McTest_LogDebug = 0, + McTest_LogInfo = 1, + McTest_LogWarning = 2, + McTest_LogWarn = McTest_LogWarning, + McTest_LogError = 3, + McTest_LogFatal = 4, + McTest_LogCritical = McTest_LogFatal +}; + +/* Log a C string. */ +extern void McTest_Log(enum McTest_LogLevel level, const char *str); + +/* Log some formatted output. */ +extern void McTest_LogFormat(enum McTest_LogLevel level, + const char *format, ...); + +/* Log some formatted output. */ +extern void McTest_LogVFormat(enum McTest_LogLevel level, + const char *format, va_list args); + +MCTEST_END_EXTERN_C + +#endif /* SRC_INCLUDE_MCTEST_LOG_H_ */ diff --git a/src/include/mctest/McTest.h b/src/include/mctest/McTest.h index 8cdc192..06db23f 100644 --- a/src/include/mctest/McTest.h +++ b/src/include/mctest/McTest.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef INCLUDE_MCTEST_MCTEST_H_ -#define INCLUDE_MCTEST_MCTEST_H_ +#ifndef SRC_INCLUDE_MCTEST_MCTEST_H_ +#define SRC_INCLUDE_MCTEST_MCTEST_H_ #include #include @@ -25,7 +25,9 @@ #include #include +#include #include +#include #ifdef assert # undef assert @@ -54,6 +56,12 @@ extern int McTest_IsTrue(int expr); /* Symbolize the data in the range `[begin, end)`. */ extern void McTest_SymbolizeData(void *begin, void *end); +/* Concretize some data i nthe range `[begin, end)`. */ +extern void McTest_ConcretizeData(void *begin, void *end); + +/* Concretize a C string */ +extern void McTest_ConcretizeCStr(const char *begin); + MCTEST_INLINE static void *McTest_Malloc(size_t num_bytes) { void *data = malloc(num_bytes); uintptr_t data_end = ((uintptr_t) data) + num_bytes; @@ -96,6 +104,10 @@ extern void _McTest_Assume(int expr); #define McTest_Assume(x) _McTest_Assume(!!(x)) +/* Abandon this test. We've hit some kind of internal problem. */ +MCTEST_NORETURN +extern void McTest_Abandon(const char *reason); + MCTEST_NORETURN extern void McTest_Fail(void); @@ -121,20 +133,6 @@ MCTEST_INLINE static void McTest_Check(int expr) { } } -enum McTest_LogLevel { - McTest_LogDebug = 0, - McTest_LogInfo = 1, - McTest_LogWarning = 2, - McTest_LogWarn = McTest_LogWarning, - McTest_LogError = 3, - McTest_LogFatal = 4, - McTest_LogCritical = McTest_LogFatal -}; - -/* Outputs information to a log, using a specific log level. */ -extern void McTest_Log(enum McTest_LogLevel level, const char *begin, - const char *end); - /* Return a symbolic value in a the range `[low_inc, high_inc]`. */ #define MCTEST_MAKE_SYMBOLIC_RANGE(Tname, tname) \ MCTEST_INLINE static tname McTest_ ## Tname ## InRange( \ @@ -248,12 +246,21 @@ extern struct McTest_TestInfo *McTest_LastTestInfo; /* Set up McTest. */ extern void McTest_Setup(void); +/* Tear down McTest. */ +extern void McTest_Teardown(void); + +/* Notify that we're about to begin a test. */ +extern void McTest_Begin(struct McTest_TestInfo *info); + /* Return the first test case to run. */ extern struct McTest_TestInfo *McTest_FirstTest(void); /* Returns 1 if a failure was caught, otherwise 0. */ extern int McTest_CatchFail(void); +/* Returns 1 if this test case was abandoned. */ +extern int McTest_CatchAbandoned(void); + /* Jump buffer for returning to `McTest_Run`. */ extern jmp_buf McTest_ReturnToRun; @@ -261,18 +268,11 @@ extern jmp_buf McTest_ReturnToRun; static int McTest_Run(void) { int num_failed_tests = 0; struct McTest_TestInfo *test = NULL; - char buff[1024]; - int num_buff_bytes_used = 0; McTest_Setup(); for (test = McTest_FirstTest(); test != NULL; test = test->prev) { - - /* Print the test that we're going to run. */ - num_buff_bytes_used = sprintf(buff, "Running: %s from %s(%u)", - test->test_name, test->file_name, - test->line_number); - McTest_Log(McTest_LogInfo, buff, &(buff[num_buff_bytes_used])); + McTest_Begin(test); /* Run the test. */ if (!setjmp(McTest_ReturnToRun)) { @@ -291,23 +291,28 @@ static int McTest_Run(void) { } #endif /* __cplusplus */ + /* We caught a failure when running the test. */ } else if (McTest_CatchFail()) { ++num_failed_tests; + McTest_LogFormat(McTest_LogError, "Failed: %s", test->test_name); - num_buff_bytes_used = sprintf(buff, "Failed: %s", test->test_name); - McTest_Log(McTest_LogInfo, buff, &(buff[num_buff_bytes_used])); + /* The test was abandoned. We may have gotten soft failures before + * abandoning, so we prefer to catch those first. */ + } else if (McTest_CatchAbandoned()) { + McTest_LogFormat(McTest_LogFatal, "Abandoned: %s", test->test_name); /* The test passed. */ } else { - num_buff_bytes_used = sprintf(buff, "Passed: %s", test->test_name); - McTest_Log(McTest_LogInfo, buff, &(buff[num_buff_bytes_used])); + McTest_LogFormat(McTest_LogInfo, "Passed: %s", test->test_name); } } + McTest_Teardown(); + return num_failed_tests; } MCTEST_END_EXTERN_C -#endif /* INCLUDE_MCTEST_MCTEST_H_ */ +#endif /* SRC_INCLUDE_MCTEST_MCTEST_H_ */ diff --git a/src/include/mctest/McTest.hpp b/src/include/mctest/McTest.hpp index 2fb96dc..ab23539 100644 --- a/src/include/mctest/McTest.hpp +++ b/src/include/mctest/McTest.hpp @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef INCLUDE_MCTEST_MCTEST_HPP_ -#define INCLUDE_MCTEST_MCTEST_HPP_ +#ifndef SRC_INCLUDE_MCTEST_MCTEST_HPP_ +#define SRC_INCLUDE_MCTEST_MCTEST_HPP_ #include @@ -191,4 +191,4 @@ MAKE_SYMBOL_SPECIALIZATION(Char, int8_t) } // namespace mctest -#endif // INCLUDE_MCTEST_MCTEST_HPP_ +#endif // SRC_INCLUDE_MCTEST_MCTEST_HPP_ diff --git a/src/include/mctest/McUnit.hpp b/src/include/mctest/McUnit.hpp index c0a96b2..8a605d5 100644 --- a/src/include/mctest/McUnit.hpp +++ b/src/include/mctest/McUnit.hpp @@ -14,59 +14,44 @@ * limitations under the License. */ -#ifndef INCLUDE_MCTEST_MCUNIT_HPP_ -#define INCLUDE_MCTEST_MCUNIT_HPP_ +#ifndef SRC_INCLUDE_MCTEST_MCUNIT_HPP_ +#define SRC_INCLUDE_MCTEST_MCUNIT_HPP_ #include - -#include +#include #define TEST(category, name) \ McTest_EntryPoint(category ## _ ## name) -namespace mctest { +#define LOG_DEBUG(cond) \ + ::mctest::Stream(McTest_LogDebug, (cond), __FILE__, __LINE__) -/* Base logger */ -class Logger { - public: - MCTEST_INLINE Logger(McTest_LogLevel level_, bool expr_, - const char *file_, unsigned line_) - : level(level_), - expr(!!McTest_IsTrue(expr_)), - file(file_), - line(line_) {} +#define LOG_INFO(cond) \ + ::mctest::Stream(McTest_LogInfo, (cond), __FILE__, __LINE__) - MCTEST_INLINE ~Logger(void) { - if (!expr) { - std::stringstream report_ss; - report_ss << file << "(" << line << "): " << ss.str(); - auto report_str = report_ss.str(); - auto report_c_str = report_str.c_str(); - McTest_Log(level, report_c_str, report_c_str + report_str.size()); - } - } +#define LOG_WARNING(cond) \ + ::mctest::Stream(McTest_LogWarning, (cond), __FILE__, __LINE__) - MCTEST_INLINE std::stringstream &stream(void) { - return ss; - } +#define LOG_WARN(cond) \ + ::mctest::Stream(McTest_LogWarning, (cond), __FILE__, __LINE__) - private: - Logger(void) = delete; - Logger(const Logger &) = delete; - Logger &operator=(const Logger &) = delete; +#define LOG_ERROR(cond) \ + ::mctest::Stream(McTest_LogError, (cond), __FILE__, __LINE__) - const McTest_LogLevel level; - const bool expr; - const char * const file; - const unsigned line; - std::stringstream ss; -}; +#define LOG_FATAL(cond) \ + ::mctest::Stream(McTest_LogFatal, (cond), __FILE__, __LINE__) + +#define LOG_CRITICAl(cond) \ + ::mctest::Stream(McTest_LogFatal, (cond), __FILE__, __LINE__) + +#define LOG(LEVEL) LOG_ ## LEVEL(true) + +#define LOG_IF(LEVEL, cond) LOG_ ## LEVEL(cond) -} // namespace mctest #define MCTEST_LOG_BINOP(a, b, op, level) \ - ::mctest::Logger( \ - level, ((a) op (b)), __FILE__, __LINE__).stream() + ::mctest::Stream( \ + level, !((a) op (b)), __FILE__, __LINE__) #define ASSERT_EQ(a, b) MCTEST_LOG_BINOP(a, b, ==, McTest_LogFatal) #define ASSERT_NE(a, b) MCTEST_LOG_BINOP(a, b, !=, McTest_LogFatal) @@ -83,27 +68,26 @@ class Logger { #define CHECK_GE(a, b) MCTEST_LOG_BINOP(a, b, >=, McTest_LogError) #define ASSERT(expr) \ - ::mctest::Logger( \ - McTest_LogFatal, !!(expr), __FILE__, __LINE__).stream() + ::mctest::Stream( \ + McTest_LogFatal, !(expr), __FILE__, __LINE__) #define ASSERT_TRUE ASSERT #define ASSERT_FALSE(expr) ASSERT(!(expr)) #define CHECK(expr) \ - ::mctest::Logger( \ - McTest_LogError, !!(expr), __FILE__, __LINE__).stream() + ::mctest::Stream( \ + McTest_LogError, !(expr), __FILE__, __LINE__) #define CHECK_TRUE CHECK #define CHECK_FALSE(expr) CHECK(!(expr)) #define ASSUME(expr) \ - McTest_Assume(expr), ::mctest::Logger( \ - McTest_LogInfo, false, __FILE__, __LINE__).stream() - + McTest_Assume(expr), ::mctest::Stream( \ + McTest_LogInfo, true, __FILE__, __LINE__) #define MCTEST_ASSUME_BINOP(a, b, op) \ - McTest_Assume(((a) op (b))), ::mctest::Logger( \ - McTest_LogInfo, false, __FILE__, __LINE__).stream() + McTest_Assume(((a) op (b))), ::mctest::Stream( \ + McTest_LogInfo, true, __FILE__, __LINE__) #define ASSUME_EQ(a, b) MCTEST_ASSUME_BINOP(a, b, ==) #define ASSUME_NE(a, b) MCTEST_ASSUME_BINOP(a, b, !=) @@ -112,4 +96,4 @@ class Logger { #define ASSUME_GT(a, b) MCTEST_ASSUME_BINOP(a, b, >) #define ASSUME_GE(a, b) MCTEST_ASSUME_BINOP(a, b, >=) -#endif // INCLUDE_MCTEST_MCUNIT_HPP_ +#endif // SRC_INCLUDE_MCTEST_MCUNIT_HPP_ diff --git a/src/include/mctest/Quantified.hpp b/src/include/mctest/Quantified.hpp index 543fbf6..bc5d372 100644 --- a/src/include/mctest/Quantified.hpp +++ b/src/include/mctest/Quantified.hpp @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef INCLUDE_MCTEST_QUANTIFIED_HPP_ -#define INCLUDE_MCTEST_QUANTIFIED_HPP_ +#ifndef SRC_INCLUDE_MCTEST_QUANTIFIED_HPP_ +#define SRC_INCLUDE_MCTEST_QUANTIFIED_HPP_ #include @@ -35,4 +35,4 @@ inline static void ForAll(Closure func) { } // namespace mctest -#endif // INCLUDE_MCTEST_QUANTIFIED_HPP_ +#endif // SRC_INCLUDE_MCTEST_QUANTIFIED_HPP_ diff --git a/src/include/mctest/Stream.h b/src/include/mctest/Stream.h new file mode 100644 index 0000000..52560ac --- /dev/null +++ b/src/include/mctest/Stream.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2017 Trail of Bits, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SRC_INCLUDE_MCTEST_STREAM_H_ +#define SRC_INCLUDE_MCTEST_STREAM_H_ + +#include +#include + +#include +#include + +MCTEST_BEGIN_EXTERN_C + +/* Flush the contents of the stream to a log. */ +extern void McTest_LogStream(enum McTest_LogLevel level); + +/* Stream a C string into the stream's message. */ +extern void McTest_StreamCStr(enum McTest_LogLevel level, const char *begin); + +/* TODO(pag): Implement `McTest_StreamWCStr` with `wchar_t`. */ + +/* Stream a some data in the inclusive range `[begin, end]` into the + * stream's message. */ +/*extern void McTest_StreamData(enum McTest_LogLevel level, const void *begin, + const void *end);*/ + +/* Stream some formatted input */ +extern void McTest_StreamFormat(enum McTest_LogLevel level, const char *format, + ...); + +/* Stream some formatted input */ +extern void McTest_StreamVFormat(enum McTest_LogLevel level, const char *format, + va_list args); + +#define MCTEST_DECLARE_STREAMER(Type, type) \ + extern void McTest_Stream ## Type(enum McTest_LogLevel level, type val); + +MCTEST_DECLARE_STREAMER(Double, double); +MCTEST_DECLARE_STREAMER(Pointer, void *); + +MCTEST_DECLARE_STREAMER(UInt64, uint64_t) +MCTEST_DECLARE_STREAMER(Int64, int64_t) + +MCTEST_DECLARE_STREAMER(UInt32, uint32_t) +MCTEST_DECLARE_STREAMER(Int32, int32_t) + +MCTEST_DECLARE_STREAMER(UInt16, uint16_t) +MCTEST_DECLARE_STREAMER(Int16, int16_t) + +MCTEST_DECLARE_STREAMER(UInt8, uint8_t) +MCTEST_DECLARE_STREAMER(Int8, int8_t) + +#undef MCTEST_DECLARE_STREAMER + +MCTEST_INLINE static void McTest_StreamFloat(enum McTest_LogLevel level, + float val) { + McTest_StreamDouble(level, (double) val); +} + +/* Reset the formatting in a stream. */ +extern void McTest_StreamResetFormatting(enum McTest_LogLevel level); + +MCTEST_END_EXTERN_C + +#endif /* SRC_INCLUDE_MCTEST_STREAM_H_ */ diff --git a/src/include/mctest/Stream.hpp b/src/include/mctest/Stream.hpp new file mode 100644 index 0000000..54213c6 --- /dev/null +++ b/src/include/mctest/Stream.hpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) 2017 Trail of Bits, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SRC_INCLUDE_MCTEST_STREAM_HPP_ +#define SRC_INCLUDE_MCTEST_STREAM_HPP_ + +#include +#include + +#include +#include + +namespace mctest { + +/* Conditionally stream output to a log using the streaming APIs. */ +class Stream { + public: + MCTEST_INLINE Stream(McTest_LogLevel level_, bool do_log_, + const char *file, unsigned line) + : level(level_), + do_log(McTest_IsTrue(do_log_)) { + McTest_LogStream(level); + if (do_log) { + McTest_StreamFormat(level, "%s(%u): ", file, line); + } + } + + MCTEST_INLINE ~Stream(void) { + if (do_log) { + McTest_LogStream(level); + } + } + +#define MCTEST_DEFINE_STREAMER(Type, type, expr) \ + MCTEST_INLINE const Stream &operator<<(type val) const { \ + if (do_log) { \ + McTest_Stream ## Type(level, expr); \ + } \ + return *this; \ + } + + MCTEST_DEFINE_STREAMER(UInt64, uint64_t, val) + MCTEST_DEFINE_STREAMER(Int64, int64_t, val) + + MCTEST_DEFINE_STREAMER(UInt32, uint32_t, val) + MCTEST_DEFINE_STREAMER(Int32, int32_t, val) + + MCTEST_DEFINE_STREAMER(UInt16, uint16_t, val) + MCTEST_DEFINE_STREAMER(Int16, int16_t, val) + + MCTEST_DEFINE_STREAMER(UInt8, uint8_t, val) + MCTEST_DEFINE_STREAMER(Int8, int8_t, val) + + MCTEST_DEFINE_STREAMER(Float, float, val) + MCTEST_DEFINE_STREAMER(Double, double, val) + + MCTEST_DEFINE_STREAMER(CStr, const char *, val) + MCTEST_DEFINE_STREAMER(CStr, char *, const_cast(val)) + + MCTEST_DEFINE_STREAMER(Pointer, nullptr_t, nullptr) + + template + MCTEST_DEFINE_STREAMER(Pointer, T *, val); + + template + MCTEST_DEFINE_STREAMER(Pointer, const T *, const_cast(val)); + +#undef MCTEST_DEFINE_INT_STREAMER + + MCTEST_INLINE const Stream &operator<<(const std::string &str) const { + if (do_log && !str.empty()) { + McTest_StreamCStr(level, str.c_str()); + } + return *this; + } + + // TODO(pag): Implement a `std::wstring` streamer. + + private: + Stream(void) = delete; + Stream(const Stream &) = delete; + Stream &operator=(const Stream &) = delete; + + const McTest_LogLevel level; + const int do_log; +}; + +} // namespace mctest + +#endif // SRC_INCLUDE_MCTEST_STREAM_HPP_ diff --git a/src/lib/Log.c b/src/lib/Log.c new file mode 100644 index 0000000..aa739ae --- /dev/null +++ b/src/lib/Log.c @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2017 Trail of Bits, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include + +MCTEST_BEGIN_EXTERN_C + +/* Returns a printable string version of the log level. */ +static const char *McTest_LogLevelStr(enum McTest_LogLevel level) { + switch (level) { + case McTest_LogDebug: + return "DEBUG"; + case McTest_LogInfo: + return "INFO"; + case McTest_LogWarning: + return "WARNING"; + case McTest_LogError: + return "ERROR"; + case McTest_LogFatal: + return "FATAL"; + default: + return "UNKNOWN"; + } +} + +enum { + McTest_LogBufSize = 4096 +}; + +char McTest_LogBuf[McTest_LogBufSize + 1] = {}; + +/* Log a C string. */ +void McTest_Log(enum McTest_LogLevel level, const char *str) { + memset(McTest_LogBuf, 0, McTest_LogBufSize); + snprintf(McTest_LogBuf, McTest_LogBufSize, "%s: %s", + McTest_LogLevelStr(level), str); + fputs(McTest_LogBuf, stderr); + + if (McTest_LogError == level) { + McTest_SoftFail(); + } else if (McTest_LogFatal == level) { + McTest_Fail(); + } +} + +/* Log some formatted output. */ +void McTest_LogFormat(enum McTest_LogLevel level, const char *format, ...) { + McTest_LogStream(level); + va_list args; + va_start(args, format); + McTest_StreamVFormat(level, format, args); + va_end(args); + McTest_LogStream(level); +} + +/* Log some formatted output. */ +void McTest_LogVFormat(enum McTest_LogLevel level, + const char *format, va_list args) { + McTest_LogStream(level); + McTest_StreamVFormat(level, format, args); + McTest_LogStream(level); +} + +/* Override libc! */ +int printf(const char *format, ...) { + McTest_LogStream(McTest_LogInfo); + va_list args; + va_start(args, format); + McTest_StreamVFormat(McTest_LogInfo, format, args); + va_end(args); + McTest_LogStream(McTest_LogInfo); + return 0; +} + +int fprintf(FILE *file, const char *format, ...) { + enum McTest_LogLevel level = McTest_LogInfo; + if (stderr == file) { + level = McTest_LogDebug; + } else if (stdout != file) { + return 0; /* TODO(pag): This is probably evil. */ + } + + McTest_LogStream(level); + va_list args; + va_start(args, format); + McTest_StreamVFormat(level, format, args); + va_end(args); + McTest_LogStream(level); + return 0; +} + +MCTEST_END_EXTERN_C diff --git a/src/lib/McTest.c b/src/lib/McTest.c index a19907c..d25343e 100644 --- a/src/lib/McTest.c +++ b/src/lib/McTest.c @@ -14,17 +14,13 @@ * limitations under the License. */ +#include #include #include #include #include -#if defined(unix) || defined(__unix) || defined(__unix__) -# define _GNU_SOURCE -# include /* For `syscall` */ -#endif - MCTEST_BEGIN_EXTERN_C /* Pointer to the last registers McTest_TestInfo data structure */ @@ -45,8 +41,17 @@ static uint32_t McTest_InputIndex = 0; /* Jump buffer for returning to `McTest_Run`. */ jmp_buf McTest_ReturnToRun = {}; +static const char *McTest_TestAbandoned = NULL; static int McTest_TestFailed = 0; +/* Abandon this test. We've hit some kind of internal problem. */ +MCTEST_NORETURN +void McTest_Abandon(const char *reason) { + McTest_Log(McTest_LogFatal, reason); + McTest_TestAbandoned = reason; + longjmp(McTest_ReturnToRun, 1); +} + /* Mark this test as failing. */ MCTEST_NORETURN void McTest_Fail(void) { @@ -80,6 +85,16 @@ void McTest_SymbolizeData(void *begin, void *end) { } } +/* Concretize some data i nthe range `[begin, end)`. */ +void McTest_ConcretizeData(void *begin, void *end) { + (void) begin; +} + +/* Concretize a C string */ +void McTest_ConcretizeCStr(const char *begin) { + (void) begin; +} + MCTEST_NOINLINE int McTest_One(void) { return 1; } @@ -148,62 +163,15 @@ int McTest_IsSymbolicUInt(uint32_t x) { return 0; } -/* Returns a printable string version of the log level. */ -static const char *McTest_LogLevelStr(enum McTest_LogLevel level) { - switch (level) { - case McTest_LogDebug: - return "DEBUG"; - case McTest_LogInfo: - return "INFO"; - case McTest_LogWarning: - return "WARNING"; - case McTest_LogError: - return "ERROR"; - case McTest_LogFatal: - return "FATAL"; - default: - return "UNKNOWN"; - } -} +/* Defined in Stream.c */ +extern void _McTest_StreamInt(enum McTest_LogLevel level, const char *format, + const char *unpack, uint64_t *val); -enum { - McTest_LogBufSize = 4096 -}; +extern void _McTest_StreamFloat(enum McTest_LogLevel level, const char *format, + const char *unpack, double *val); -char McTest_LogBuf[McTest_LogBufSize + 1] = {}; - - -void _McTest_Log(enum McTest_LogLevel level, const char *message) { - fprintf(stderr, "%s: %s\n", McTest_LogLevelStr(level), message); - if (McTest_LogError == level) { - McTest_SoftFail(); - - } else if (McTest_LogFatal == level) { - McTest_Fail(); - } -} - -/* Outputs information to a log, using a specific log level. */ -void McTest_Log(enum McTest_LogLevel level, const char *begin, - const char *end) { - if (end <= begin) { - return; - } - - size_t size = (size_t) (end - begin); - if (size > McTest_LogBufSize) { - size = McTest_LogBufSize; - } - - /* When we interpose on _McTest_Log, we are looking for the first non-symbolic - * zero byte as our end of string character, so we want to guarantee that we - * have a bunch of those */ - memset(McTest_LogBuf, 0, McTest_LogBufSize); - memcpy(McTest_LogBuf, begin, size); - McTest_LogBuf[McTest_LogBufSize] = '\0'; - - _McTest_Log(level, McTest_LogBuf); -} +extern void _McTest_StreamString(enum McTest_LogLevel level, const char *format, + const char *str); /* A McTest-specific symbol that is needed for hooking. */ struct McTest_IndexEntry { @@ -214,16 +182,36 @@ struct McTest_IndexEntry { /* An index of symbols that the symbolic executors will hook or * need access to. */ const struct McTest_IndexEntry McTest_API[] = { + + /* Control-flow during the test. */ {"Pass", (void *) McTest_Pass}, {"Fail", (void *) McTest_Fail}, {"SoftFail", (void *) McTest_SoftFail}, - {"Log", (void *) _McTest_Log}, - {"Assume", (void *) _McTest_Assume}, - {"IsSymbolicUInt", (void *) McTest_IsSymbolicUInt}, + {"Abandon", (void *) McTest_Abandon}, + + /* Locating the tests. */ + {"LastTestInfo", (void *) &McTest_LastTestInfo}, + + /* Source of symbolic bytes. */ {"InputBegin", (void *) &(McTest_Input[0])}, {"InputEnd", (void *) &(McTest_Input[McTest_InputLength])}, {"InputIndex", (void *) &McTest_InputIndex}, - {"LastTestInfo", (void *) &McTest_LastTestInfo}, + + /* Solver APIs. */ + {"Assume", (void *) _McTest_Assume}, + {"IsSymbolicUInt", (void *) McTest_IsSymbolicUInt}, + {"ConcretizeData", (void *) McTest_ConcretizeData}, + {"ConcretizeCStr", (void *) McTest_ConcretizeCStr}, + + /* Logging API. */ + {"Log", (void *) McTest_Log}, + + /* Streaming API for deferred logging. */ + {"LogStream", (void *) McTest_LogStream}, + {"StreamInt", (void *) _McTest_StreamInt}, + {"StreamFloat", (void *) _McTest_StreamFloat}, + {"StreamString", (void *) _McTest_StreamString}, + {NULL, NULL}, }; @@ -232,6 +220,19 @@ void McTest_Setup(void) { /* TODO(pag): Sort the test cases by file name and line number. */ } +/* Tear down McTest. */ +void McTest_Teardown(void) { + +} + +/* Notify that we're about to begin a test. */ +void McTest_Begin(struct McTest_TestInfo *info) { + McTest_TestFailed = 0; + McTest_TestAbandoned = NULL; + McTest_LogFormat(McTest_LogInfo, "Running: %s from %s(%u)", + info->test_name, info->file_name, info->line_number); +} + /* Return the first test case to run. */ struct McTest_TestInfo *McTest_FirstTest(void) { return McTest_LastTestInfo; @@ -242,4 +243,9 @@ int McTest_CatchFail(void) { return McTest_TestFailed; } +/* Returns 1 if this test case was abandoned. */ +int McTest_CatchAbandoned(void) { + return McTest_TestAbandoned != NULL; +} + MCTEST_END_EXTERN_C diff --git a/src/lib/Stream.c b/src/lib/Stream.c new file mode 100644 index 0000000..787d3ad --- /dev/null +++ b/src/lib/Stream.c @@ -0,0 +1,563 @@ +/* + * Copyright (c) 2017 Trail of Bits, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +MCTEST_BEGIN_EXTERN_C + +enum { + McTest_StreamSize = 4096 +}; + +/* Formatting options availale to the streaming API. */ +struct McTest_StreamFormatOptions { + /* int radix; */ + int hex; + int oct; + int show_base; + /*int width; */ + int left_justify; + int add_sign; + char fill; +}; + +/* Stream type that accumulates formatted data to be printed. This loosely + * mirrors C++ I/O streams, not because I/O streams are good, but instead + * because the ability to stream in data to things like the C++-backed + * `ASSERT` and `CHECK` macros is really nice. */ +struct McTest_Stream { + int size; + struct McTest_StreamFormatOptions options; + char message[McTest_StreamSize + 2]; + char staging[32]; + char format[32]; + char unpack[32]; + union { + uint64_t as_uint64; + double as_fp64; + } value; +}; + +/* Hard-coded streams for each log level. */ +static struct McTest_Stream McTest_Streams[McTest_LogFatal + 1] = {}; + +static char McTest_EndianSpecifier = '='; + +/* Figure out what the Python `struct` endianness specifier should be. */ +MCTEST_INITIALIZER(DetectEndianness) { + static const int one = 1; + if ((const char *) &one) { + McTest_EndianSpecifier = '<'; /* Little endian. */ + } else { + McTest_EndianSpecifier = '>'; /* Big endian. */ + } +} + +static void McTest_StreamUnpack(struct McTest_Stream *stream, char type) { + stream->unpack[0] = McTest_EndianSpecifier; + stream->unpack[1] = type; + stream->unpack[2] = '\0'; +} + +/* Fill in the format for when we want to stream an integer. */ +static void McTest_StreamIntFormat(struct McTest_Stream *stream, + size_t val_size, int is_unsigned) { + char *format = stream->format; + int i = 0; + + format[i++] = '%'; + if(stream->options.left_justify) { + format[i++] = '-'; + } + + if(stream->options.add_sign) { + format[i++] = '+'; + } + + if (stream->options.fill) { + format[i++] = stream->options.fill; + } + + if (stream->options.show_base) { + format[i++] = '#'; /* Show the radix. */ + } + + if (8 == val_size) { + format[i++] = 'l'; + format[i++] = 'l'; + + } else if (2 == val_size) { + format[i++] = 'h'; + + } else if (1 == val_size) { + if (is_unsigned) { + format[i++] = 'h'; + format[i++] = 'h'; + } + } + + if (stream->options.hex) { + format[i++] = 'x'; + } else if (stream->options.oct) { + format[i++] = 'o'; + } else if (is_unsigned) { + format[i++] = 'u'; + } else { + if (1 == val_size) { + format[i++] = 'c'; + } else { + format[i++] = 'd'; + } + } + + format[i++] = '\0'; +} + +static void CheckCapacity(struct McTest_Stream *stream, int num_chars_to_add) { + if (0 > num_chars_to_add) { + McTest_Abandon("Can't add a negative number of characters to a stream."); + } else if ((stream->size + num_chars_to_add) >= McTest_StreamSize) { + McTest_Abandon("Exceeded capacity of stream buffer."); + } +} + +/* Stream an integer into the stream's message. This function is designed to + * be hooked by the symbolic executor, so that it can easily pull out the + * relevant data from `*val`, which may be symbolic, and defer the actual + * formatting until later. */ +MCTEST_NOINLINE +void _McTest_StreamInt(enum McTest_LogLevel level, const char *format, + const char *unpack, uint64_t *val) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + int size = 0; + int remaining_size = McTest_StreamSize - stream->size; + if (unpack[1] == 'Q' || unpack[1] == 'q') { + size = snprintf(&(stream->message[stream->size]), + remaining_size, format, *val); + } else { + size = snprintf(&(stream->message[stream->size]), + remaining_size, format, (uint32_t) *val); + } + CheckCapacity(stream, size); + stream->size += size; +} + +MCTEST_NOINLINE +void _McTest_StreamFloat(enum McTest_LogLevel level, const char *format, + const char *unpack, double *val) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + int remaining_size = McTest_StreamSize - stream->size; + int size = snprintf(&(stream->message[stream->size]), + remaining_size, format, *val); + CheckCapacity(stream, size); + stream->size += size; +} + +MCTEST_NOINLINE +void _McTest_StreamString(enum McTest_LogLevel level, const char *format, + const char *str) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + int remaining_size = McTest_StreamSize - stream->size; + int size = snprintf(&(stream->message[stream->size]), + remaining_size, format, str); + CheckCapacity(stream, size); + stream->size += size; +} + +#define MAKE_INT_STREAMER(Type, type, is_unsigned, pack_kind) \ + void McTest_Stream ## Type(enum McTest_LogLevel level, type val) { \ + struct McTest_Stream *stream = &(McTest_Streams[level]); \ + McTest_StreamIntFormat(stream, sizeof(val), is_unsigned); \ + McTest_StreamUnpack(stream, pack_kind); \ + stream->value.as_uint64 = (uint64_t) val; \ + _McTest_StreamInt(level, stream->format, stream->unpack, \ + &(stream->value.as_uint64)); \ + } + +MAKE_INT_STREAMER(Pointer, void *, 1, (sizeof(void *) == 8 ? 'Q' : 'I')) + +MAKE_INT_STREAMER(UInt64, uint64_t, 1, 'Q') +MAKE_INT_STREAMER(Int64, int64_t, 0, 'q') + +MAKE_INT_STREAMER(UInt32, uint32_t, 1, 'I') +MAKE_INT_STREAMER(Int32, int32_t, 0, 'i') + +MAKE_INT_STREAMER(UInt16, uint16_t, 1, 'h') +MAKE_INT_STREAMER(Int16, int16_t, 0, 'H') + +MAKE_INT_STREAMER(UInt8, uint8_t, 1, 'B') +MAKE_INT_STREAMER(Int8, int8_t, 0, 'c') + +#undef MAKE_INT_STREAMER + +/* Stream a C string into the stream's message. */ +void McTest_StreamCStr(enum McTest_LogLevel level, const char *begin) { + _McTest_StreamString(level, "%s", begin); +} + +/* Stream a some data in the inclusive range `[begin, end]` into the + * stream's message. */ +/*void McTest_StreamData(enum McTest_LogLevel level, const void *begin, + const void *end) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + int remaining_size = McTest_StreamSize - stream->size; + int input_size = (int) ((uintptr_t) end - (uintptr_t) begin) + 1; + CheckCapacity(stream, input_size); + memcpy(&(stream->message[stream->size]), begin, (size_t) input_size); + stream->size += input_size; +}*/ + +/* Stream a `double` into the stream's message. This function is designed to + * be hooked by the symbolic executor, so that it can easily pull out the + * relevant data from `*val`, which may be symbolic, and defer the actual + * formatting until later. */ +void McTest_StreamDouble(enum McTest_LogLevel level, double val) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + const char *format = "%f"; /* TODO(pag): Support more? */ + stream->value.as_fp64 = val; + McTest_StreamUnpack(stream, 'd'); + _McTest_StreamFloat(level, format, stream->unpack, &(stream->value.as_fp64)); +} + +/* Flush the contents of the stream to a log. */ +void McTest_LogStream(enum McTest_LogLevel level) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + if (stream->size) { + stream->message[stream->size] = '\n'; + stream->message[stream->size + 1] = '\0'; + stream->message[McTest_StreamSize] = '\0'; + McTest_Log(level, stream->message); + memset(stream->message, 0, McTest_StreamSize); + stream->size = 0; + } +} + +/* Reset the formatting in a stream. */ +void McTest_StreamResetFormatting(enum McTest_LogLevel level) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + memset(&(stream->options), 0, sizeof(stream->options)); +} + +/* Approximately do string format parsing and convert it into calls into our + * streaming API. */ +static int McTest_StreamFormatValue(enum McTest_LogLevel level, + const char *format, va_list args) { + struct McTest_Stream *stream = &(McTest_Streams[level]); + char format_buf[32] = {'\0'}; + int i = 0; + int k = 0; + int length = 4; + char ch = '\0'; + int is_string = 0; + int is_unsigned = 0; + int is_float = 0; + int long_double = 0; + char extract = '\0'; + +#define READ_FORMAT_CHAR \ + ch = format[i]; \ + format_buf[i - k] = ch; \ + i++ + + READ_FORMAT_CHAR; /* Read the '%' */ + + if ('%' != ch) { + McTest_Abandon("Invalid format."); + return 0; + } + + /* Flags */ +get_flag_char: + READ_FORMAT_CHAR; + switch (ch) { + case '\0': + McTest_Abandon("Incomplete format (flags)."); + return 0; + case '-': + case '+': + case ' ': + case '#': + case '0': + goto get_flag_char; + default: + break; + } + + /* Width */ +get_width_char: + switch (ch) { + case '\0': + McTest_Abandon("Incomplete format (width)."); + return 0; + case '*': + McTest_Abandon("Variable width printing not supported."); + return 0; + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + READ_FORMAT_CHAR; + goto get_width_char; + default: + break; + } + + /* Precision */ + if ('.' == ch) { + get_precision_char: + READ_FORMAT_CHAR; + switch (ch) { + case '\0': + McTest_Abandon("Incomplete format (precision)."); + return 0; + case '*': + McTest_Abandon("Variable precision printing not supported."); + break; + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + goto get_precision_char; + default: + break; + } + } + + /* Length */ +get_length_char: + switch (ch) { + case '\0': + McTest_Abandon("Incomplete format (length)."); + return 0; + case 'L': + long_double = 1; + k += 1; /* Overwrite the `L`. */ + READ_FORMAT_CHAR; + break; + case 'h': + length /= 2; + READ_FORMAT_CHAR; + goto get_length_char; + case 'l': + length *= 2; + READ_FORMAT_CHAR; + goto get_length_char; + case 'j': + length = (int) sizeof(intmax_t); + READ_FORMAT_CHAR; + break; + case 'z': + length = (int) sizeof(size_t); + READ_FORMAT_CHAR; + break; + case 't': + length = (int) sizeof(ptrdiff_t); + READ_FORMAT_CHAR; + break; + default: + break; + } + + if (!length) { + length = 1; + } else if (8 < length) { + length = 8; + } + + format_buf[i] = '\0'; + + /* Specifier */ + switch(ch) { + case '\0': + McTest_Abandon("Incomplete format (specifier)."); + return 0; + + case 'n': + return i; /* Nothing printed. */ + + /* Print a character. */ + case 'c': + stream->value.as_uint64 = (uint64_t) (char) va_arg(args, int); + extract = 'c'; + goto common_stream_int; + + /* Signed integer. */ + case 'd': + case 'i': + if (1 == length) { + stream->value.as_uint64 = (uint64_t) (int8_t) va_arg(args, int); + extract = 'b'; + } else if (2 == length) { + stream->value.as_uint64 = (uint64_t) (int16_t) va_arg(args, int); + extract = 'h'; + } else if (4 == length) { + stream->value.as_uint64 = (uint64_t) (int32_t) va_arg(args, int); + extract = 'i'; + } else if (8 == length) { + stream->value.as_uint64 = (uint64_t) va_arg(args, int64_t); + extract = 'q'; + } else { + McTest_Abandon("Unsupported integer length."); + } + goto common_stream_int; + + /* Pointer. */ + case 'p': + length = (int) sizeof(void *); + format_buf[i - 1] = 'x'; + /* Note: Falls through. */ + + /* Unsigned, hex, octal */ + case 'u': + case 'o': + case 'x': + case 'X': + if (1 == length) { + stream->value.as_uint64 = (uint64_t) (uint8_t) va_arg(args, int); + extract = 'B'; + } else if (2 == length) { + stream->value.as_uint64 = (uint64_t) (uint16_t) va_arg(args, int); + extract = 'H'; + } else if (4 == length) { + stream->value.as_uint64 = (uint64_t) (uint32_t) va_arg(args, int); + extract = 'I'; + } else if (8 == length) { + stream->value.as_uint64 = (uint64_t) va_arg(args, uint64_t); + extract = 'Q'; + } else { + McTest_Abandon("Unsupported integer length."); + } + + common_stream_int: + McTest_StreamUnpack(stream, extract); + _McTest_StreamInt(level, format_buf, stream->unpack, + &(stream->value.as_uint64)); + break; + + /* Floating point, scientific notation, etc. */ + case 'f': + case 'F': + case 'e': + case 'E': + case 'g': + case 'G': + case 'a': + case 'A': + if (long_double) { + stream->value.as_fp64 = (double) va_arg(args, long double); + } else { + stream->value.as_fp64 = va_arg(args, double); + } + McTest_StreamUnpack(stream, 'd'); + break; + + case 's': + _McTest_StreamString(level, format_buf, va_arg(args, const char *)); + break; + + default: + McTest_Abandon("Unsupported format specifier."); + return 0; + } + + return i; +} + +static char McTest_Format[McTest_StreamSize + 1]; + +/* Stream some formatted input */ +void McTest_StreamVFormat(enum McTest_LogLevel level, const char *format_, + va_list args) { + char *begin = NULL; + char *end = NULL; + char *format = McTest_Format; + int i = 0; + char ch = '\0'; + char next_ch = '\0'; + + strncpy(format, format_, McTest_StreamSize); + format[McTest_StreamSize] = '\0'; + + McTest_ConcretizeCStr(format); + + for (i = 0; '\0' != (ch = format[i]); ) { + if (!begin) { + begin = &(format[i]); + } + + if ('%' == ch) { + if ('%' == format[i + 1]) { + end = &(format[i]); + next_ch = end[1]; + end[1] = '\0'; + McTest_StreamCStr(level, begin); + end[1] = next_ch; + begin = NULL; + end = NULL; + i += 2; + + } else { + if (end) { + next_ch = end[1]; + end[1] = '\0'; + McTest_StreamCStr(level, begin); + end[1] = next_ch; + } + begin = NULL; + end = NULL; + i += McTest_StreamFormatValue(level, &(format[i]), args); + } + } else { + end = &(format[i]); + i += 1; + } + } + + if (begin && begin[0]) { + McTest_StreamCStr(level, begin); + } +} + +/* Stream some formatted input */ +void McTest_StreamFormat(enum McTest_LogLevel level, const char *format, ...) { + va_list args; + va_start(args, format); + McTest_StreamVFormat(level, format, args); + va_end(args); +} + +MCTEST_END_EXTERN_C