Kind of feature parity between Manticore and Angr on these tests.

This commit is contained in:
Peter Goodman 2017-10-30 00:45:59 -04:00
parent 4b786adc70
commit e4f4cfe0db
4 changed files with 204 additions and 107 deletions

View File

@ -20,6 +20,7 @@ import logging
import manticore
import multiprocessing
import sys
import traceback
from manticore.core.state import TerminateState
from manticore.utils.helpers import issymbolic
@ -105,8 +106,12 @@ def read_api_table(state, ea):
def make_symbolic_input(state, input_begin_ea, input_end_ea):
"""Fill in the input data array with symbolic data."""
input_size = input_end_ea - input_begin_ea
data = state.new_symbolic_buffer(nbytes=input_size, name='MCTEST_INPUT')
state.cpu.write_bytes(input_begin_ea, data)
data = []
for i in xrange(input_end_ea - input_begin_ea):
input_byte = state.new_symbolic_value(8, "MCTEST_INPUT_{}".format(i))
data.append(input_byte)
state.cpu.write_int(input_begin_ea + i, input_byte, 8)
return data
@ -131,25 +136,32 @@ def hook_Assume(state, arg):
state.constrain(constraint)
OUR_TERMINATION_REASON = "I McTest'd it"
def report_state(state):
test = state.context['test']
if state.context['failed']:
message = (3, "Failed: {}".format(test.name))
else:
message = (1, "Passed: {}".format(test.name))
state.context['log_messages'].append(message)
raise TerminateState(OUR_TERMINATION_REASON, testcase=False)
def hook_Pass(state):
"""Implements McTest_Pass, which notifies us of a passing test."""
L.info("Passed test case")
if state.context['failed']:
raise TerminateState("Got to end of failing test case.")
else:
raise TerminateState("Passed test case")
report_state(state)
def hook_Fail(state):
"""Implements McTest_Fail, which notifies us of a passing test."""
L.error("Failed test case")
state.context['failed'] = 1
raise TerminateState("Failed test case")
report_state(state)
def hook_SoftFail(state):
"""Implements McTest_Fail, which notifies us of a passing test."""
L.error("Soft failure in test case, continuing")
state.context['failed'] = 1
@ -162,41 +174,112 @@ LEVEL_TO_LOGGER = {
}
def hook_Log(state):
def hook_Log(state, level, begin_ea, end_ea):
"""Implements McTest_Log, which lets Manticore intercept and handle the
printing of log messages from the simulated tests."""
pass
level = state.solve_one(level)
assert level in LEVEL_TO_LOGGER
begin_ea = state.solve_one(begin_ea)
end_ea = state.solve_one(end_ea)
assert begin_ea <= end_ea
message_bytes = []
for i in xrange(end_ea - begin_ea):
message_bytes.append(state.cpu.memory[begin_ea + i])
state.context['log_messages'].append((level, message_bytes))
def hook(func):
return lambda state: state.invoke_model(func)
def run_test(state, apis, test):
def done_test(_, state, state_id, reason):
"""Called when a state is terminated."""
if OUR_TERMINATION_REASON not in reason:
L.error("State {} terminated for unknown reason: {}".format(
state_id, reason))
return
test = state.context['test']
input_length, _ = read_uint32_t(state, state.context['InputIndex'])
# Dump out any pending log messages reported by `McTest_Log`.
for level, message_bytes in state.context['log_messages']:
message = []
for b in message_bytes:
if issymbolic(b):
b_ord = state.solve_one(b)
state.constrain(b == b_ord)
message.append(chr(b_ord))
elif isinstance(b, (int, long)):
message.append(chr(b))
else:
message.append(b)
LEVEL_TO_LOGGER[level]("".join(message))
max_length = state.context['InputEnd'] - state.context['InputBegin']
if input_length > max_length:
L.critical("Test used too many input bytes ({} vs. {})".format(
input_length, max_length))
return
# Solve for the input bytes.
output = []
for i in xrange(input_length):
b = state.cpu.read_int(state.context['InputBegin'] + i, 8)
if issymbolic(b):
b = state.solve_one(b)
output.append("{:2x}".format(b))
L.info("Input: {}".format(" ".join(output)))
def do_run_test(state, apis, test):
"""Run an individual test case."""
state.cpu.PC = test.ea
m = manticore.Manticore(state, sys.argv[1:])
m.verbosity(1)
state = m.initial_state
messages = [(1, "Running {} from {}:{}".format(
test.name, test.file_name, test.line_number))]
state.context['InputBegin'] = apis['InputBegin']
state.context['InputEnd'] = apis['InputEnd']
state.context['InputIndex'] = apis['InputIndex']
state.context['test'] = test
state.context['failed'] = 0
state.context['log_messages'] = []
state.context['input'] = make_symbolic_input(
state, apis['InputBegin'], apis['InputEnd'])
state.context['log_messages'] = messages
make_symbolic_input(state, apis['InputBegin'], apis['InputEnd'])
m.add_hook(apis['IsSymbolicUInt'], hook(hook_IsSymbolicUInt))
m.add_hook(apis['Assume'], hook(hook_Assume))
m.add_hook(apis['Pass'], hook(hook_Pass))
m.add_hook(apis['Fail'], hook(hook_Fail))
m.add_hook(apis['SoftFail'], hook(hook_SoftFail))
m.add_hook(apis['Log'], hook(hook_Log))
m.subscribe('will_terminate_state', done_test)
m.run()
def run_test(state, apis, test):
try:
do_run_test(state, apis, test)
except:
L.error("Uncaught exception: {}\n{}".format(
sys.exc_info()[0], traceback.format_exc()))
def run_tests(args, state, apis):
"""Run all of the test cases."""
pool = multiprocessing.Pool(processes=max(1, args.num_workers))
results = []
tests = find_test_cases(state, apis['LastTestInfo'])
for test in tests:
print "Found test", test.name
res = pool.apply_async(run_test, (state, apis, test))
results.append(res)

View File

@ -21,6 +21,7 @@ import collections
import logging
import multiprocessing
import sys
import traceback
L = logging.getLogger("mctest")
@ -143,120 +144,137 @@ class Assume(angr.SimProcedure):
self.exit(2)
def report_state(state):
test = state.globals['test']
if state.globals['failed']:
message = (3, "Failed: {}".format(test.name))
else:
message = (1, "Passed: {}".format(test.name))
state.globals['log_messages'].append(message)
class Pass(angr.SimProcedure):
"""Implements McTest_Pass, which notifies us of a passing test."""
def run(self):
L.info("Passed test case")
report_state(self.state)
self.exit(self.state.globals['failed'])
class Fail(angr.SimProcedure):
"""Implements McTest_Fail, which notifies us of a passing test."""
def run(self):
L.error("Failed test case")
self.state.globals['failed'] = 1
report_state(self.state)
self.exit(1)
class SoftFail(angr.SimProcedure):
"""Implements McTest_SoftFail, which notifies us of a passing test."""
def run(self):
L.error("Soft failure in test case, continuing")
self.state.globals['failed'] = 1
LEVEL_TO_LOGGER = {
0: L.debug,
1: L.info,
2: L.warning,
3: L.error,
4: L.critical
}
class Log(angr.SimProcedure):
"""Implements McTest_Log, which lets Angr intercept and handle the
printing of log messages from the simulated tests."""
LEVEL_TO_LOGGER = {
0: L.debug,
1: L.info,
2: L.warning,
3: L.error,
4: L.critical
}
def run(self, level, begin_ea, end_ea):
print self.state.regs.rdi, level
print self.state.regs.rsi, begin_ea
print self.state.regs.rdx, end_ea
level = self.state.solver.eval(level, cast_to=int)
assert level in self.LEVEL_TO_LOGGER
assert level in LEVEL_TO_LOGGER
begin_ea = self.state.solver.eval(begin_ea, cast_to=int)
end_ea = self.state.solver.eval(end_ea, cast_to=int)
assert begin_ea <= end_ea
print level, begin_ea, end_ea
# Read the message from memory.
message = ""
if begin_ea < end_ea:
size = end_ea - begin_ea
# If this is an error message, then concretize the message, adding the
# concretizations to the state so that errors we output will eventually
# match the concrete inputs that we generate.
if 3 <= level:
message_bytes = []
for i in xrange(size):
byte = self.state.memory.load(begin_ea + i, size=8)
byte_as_ord = self.state.solver.eval(byte, cast_to=int)
if self.state.se.symbolic(byte):
self.state.solver.add(byte == byte_as_ord)
message_bytes.append(chr(byte_as_ord))
message = "".join(message_bytes)
# Warning/informational message, don't assert any new constraints.
else:
data = self.state.memory.load(begin_ea, size=size)
message = self.state.solver.eval(data, cast_to=str)
# Log the message (produced by the program) through to the Python-based
# logger.
self.LEVEL_TO_LOGGER[level](message)
size = end_ea - begin_ea
data = self.state.memory.load(begin_ea, size=size)
self.state.globals['log_messages'].append((level, data))
if 3 == level:
L.error("Soft failure in test case, continuing")
self.state.globals['failed'] = 1 # Soft failure on an error log message.
elif 4 == level:
L.error("Failed test case")
self.state.globals['failed'] = 1
self.exit(1) # Hard failure on a fatal/critical log message.
report_state(self.state)
self.exit(1)
def run_test(project, test, run_state):
def done_test(state):
test = state.globals['test']
input_length, _ = read_uint32_t(state, state.globals['InputIndex'])
# Dump out any pending log messages reported by `McTest_Log`.
for level, message in state.globals['log_messages']:
if not isinstance(message, str):
message = state.solver.eval(message, cast_to=str)
LEVEL_TO_LOGGER[level]("".join(message))
max_length = state.globals['InputEnd'] - state.globals['InputBegin']
if input_length > max_length:
L.critical("Test used too many input bytes ({} vs. {})".format(
input_length, max_length))
return
# Solve for the input bytes.
output = []
data = state.memory.load(state.globals['InputBegin'], size=input_length)
data = state.solver.eval(data, cast_to=str)
for i in xrange(input_length):
output.append("{:2x}".format(ord(data[i])))
L.info("Input: {}".format(" ".join(output)))
def do_run_test(project, test, apis, run_state):
"""Symbolically executes a single test function."""
test_state = project.factory.call_state(
test.ea,
base_state=run_state)
messages = [(1, "Running {} from {}:{}".format(
test.name, test.file_name, test.line_number))]
test_state.globals['InputBegin'] = apis['InputBegin']
test_state.globals['InputEnd'] = apis['InputEnd']
test_state.globals['InputIndex'] = apis['InputIndex']
test_state.globals['test'] = test
test_state.globals['failed'] = 0
test_state.globals['log_messages'] = messages
make_symbolic_input(test_state, apis['InputBegin'], apis['InputEnd'])
errored = []
test_manager = angr.SimulationManager(
project=project,
active_states=[test_state],
errored=errored)
L.info("Running test case {} from {}:{}".format(
test.name, test.file_name, test.line_number))
print 'running...'
try:
test_manager.run()
except Exception as e:
import traceback
print e
print traceback.format_exc()
print test_manager
print 'done running'
print errored
L.error("Uncaught exception: {}\n{}".format(
sys.exc_info()[0], traceback.format_exc()))
for state in test_manager.deadended:
last_event = state.history.events[-1]
if 'terminate' == last_event.type:
code = last_event.objects['exit_code']._model_concrete.value
print '???'
print test_manager
done_test(state)
def run_test(project, test, apis, run_state):
"""Symbolically executes a single test function."""
try:
do_run_test(project, test, apis, run_state)
except Exception as e:
L.error("Uncaught exception: {}\n{}".format(
sys.exc_info()[0], traceback.format_exc()))
def main():
"""Run McTest."""
@ -278,9 +296,12 @@ def main():
use_sim_procedures=True,
translation_cache=True,
support_selfmodifying_code=False,
auto_load_libs=False)
auto_load_libs=True)
entry_state = project.factory.entry_state(
add_options={angr.options.ZERO_FILL_UNCONSTRAINED_MEMORY,
angr.options.STRICT_PAGE_ACCESS})
entry_state = project.factory.entry_state()
addr_size_bits = entry_state.arch.bits
# Concretely execute up until `McTest_InjectAngr`.
@ -309,15 +330,6 @@ def main():
# Find the test cases that we want to run.
tests = find_test_cases(run_state, apis['LastTestInfo'])
# This will track whether or not a particular state has failed. Some states
# will soft fail, but continue their execution, and so we want to make sure
# that if they continue to a pass function, that they nonetheless are treated
# as failing.
run_state.globals['failed'] = 0
run_state.globals['log_messages'] = []
run_state.globals['input'] = make_symbolic_input(
run_state, apis['InputBegin'], apis['InputEnd'])
pool = multiprocessing.Pool(processes=max(1, args.num_workers))
results = []
@ -325,7 +337,7 @@ def main():
# the test case function.
test_managers = []
for test in tests:
res = pool.apply_async(run_test, (project, test, run_state))
res = pool.apply_async(run_test, (project, test, apis, run_state))
results.append(res)
pool.close()

View File

@ -19,23 +19,23 @@
using namespace mctest;
// MCTEST_NOINLINE int add(int x, int y) {
// return x + y;
// }
MCTEST_NOINLINE int add(int x, int y) {
return x + y;
}
// TEST(Arithmetic, AdditionIsCommutative) {
// ForAll<int, int>([] (int x, int y) {
// ASSERT_EQ(add(x, y), add(y, x))
// << "Addition of signed integers must commute.";
// });
// }
TEST(Arithmetic, AdditionIsCommutative) {
ForAll<int, int>([] (int x, int y) {
ASSERT_EQ(add(x, y), add(y, x))
<< "Addition of signed integers must commute.";
});
}
// TEST(Arithmetic, AdditionIsAssociative) {
// ForAll<int, int, int>([] (int x, int y, int z) {
// ASSERT_EQ(add(x, add(y, z)), add(add(x, y), z))
// << "Addition of signed integers must associate.";
// });
// }
TEST(Arithmetic, AdditionIsAssociative) {
ForAll<int, int, int>([] (int x, int y, int z) {
ASSERT_EQ(add(x, add(y, z)), add(add(x, y), z))
<< "Addition of signed integers must associate.";
});
}
TEST(Arithmetic, InvertibleMultiplication_CanFail) {
ForAll<int, int>([] (int x, int y) {

View File

@ -125,8 +125,10 @@ enum McTest_LogLevel {
McTest_LogDebug = 0,
McTest_LogInfo = 1,
McTest_LogWarning = 2,
McTest_LogWarn = McTest_LogWarning,
McTest_LogError = 3,
McTest_LogFatal = 4,
McTest_LogCritical = McTest_LogFatal
};
/* Outputs information to a log, using a specific log level. */