Add memory tracing (#203)

* implement memory write tracing

* Comment updates

* Add memory trace tests

* make sure we ignore erroring writes

* Address comments

* remove superfluous return annotation
This commit is contained in:
Yan 2017-05-08 14:32:54 -04:00 committed by GitHub
parent 868bdd80ce
commit 395a40a646
2 changed files with 109 additions and 0 deletions

View File

@ -387,6 +387,7 @@ class Memory(object):
self._maps = set(maps)
self._page2map = WeakValueDictionary() #{page -> ref{MAP}}
self._callbacks = {}
self._recording_stack = []
for m in self._maps:
for i in range(self._page(m.start), self._page(m.end)):
assert i not in self._page2map
@ -771,6 +772,40 @@ class Memory(object):
return result
def push_record_writes(self):
'''
Begin recording all writes. Retrieve all writes with `pop_record_writes()`
'''
self._recording_stack.append([])
def pop_record_writes(self):
'''
Stop recording trace and return a `list[(address, value)]` of all the writes
that occurred, where `value` is of type list[str]. Can be called without
intermediate `pop_record_writes()`.
For example::
mem.push_record_writes()
mem.write(1, 'a')
mem.push_record_writes()
mem.write(2, 'b')
mem.pop_record_writes() # Will return [(2, 'b')]
mem.pop_record_writes() # Will return [(1, 'a'), (2, 'b')]
Multiple writes to the same address will all be included in the trace in the
same order they occurred.
:return: list[tuple]
'''
lst = self._recording_stack.pop()
# Append the current list to a previously-started trace.
if self._recording_stack:
self._recording_stack[-1].extend(lst)
return lst
def write(self, addr, buf):
size = len(buf)
if not self.access_ok(slice(addr, addr + size), 'w'):
@ -778,6 +813,10 @@ class Memory(object):
assert size > 0
stop = addr + size
start = addr
if self._recording_stack:
self._recording_stack[-1].append((addr, buf))
while addr < stop:
m = self.map_containing(addr)
size = min(m.end-addr, stop-addr)

View File

@ -1614,6 +1614,76 @@ class MemoryTest(unittest.TestCase):
m = pickle.loads(pickle.dumps(m))
self.assertItemsEqual(m[0x10000000:0x10003000], 'X'*0x27f0 + 'Y'*0x20 + '\x00'*0x7f0)
def test_mem_basic_trace(self):
cs = ConstraintSet()
mem = SMemory32(cs)
addr = mem.mmap(None, 0x1000, 'rw')
mem.push_record_writes()
mem.write(addr, 'a')
mem.write(addr+1, 'b')
writes = mem.pop_record_writes()
self.assertIn((addr, ['a']), writes)
self.assertIn((addr+1, ['b']), writes)
def test_mem_trace_no_overwrites(self):
cs = ConstraintSet()
mem = SMemory32(cs)
addr = mem.mmap(None, 0x1000, 'rw')
mem.push_record_writes()
mem.write(addr, 'a')
mem.write(addr, 'b')
writes = mem.pop_record_writes()
self.assertIn((addr, ['a']), writes)
self.assertIn((addr, ['b']), writes)
def test_mem_trace_nested(self):
cs = ConstraintSet()
mem = SMemory32(cs)
addr = mem.mmap(None, 0x1000, 'rw')
mem.push_record_writes()
mem.write(addr, 'a')
mem.write(addr+1, 'b')
mem.push_record_writes()
mem.write(addr+2, 'c')
mem.write(addr+3, 'd')
inner_writes = mem.pop_record_writes()
outer_writes = mem.pop_record_writes()
# Make sure writes do not appear in a trace started after them
self.assertNotIn((addr, ['a']), inner_writes)
self.assertNotIn((addr+1, ['b']), inner_writes)
# Make sure the first two are in the outer write
self.assertIn((addr, ['a']), outer_writes)
self.assertIn((addr+1, ['b']), outer_writes)
# Make sure the last two are in the inner write
self.assertIn((addr+2, ['c']), inner_writes)
self.assertIn((addr+3, ['d']), inner_writes)
# Make sure the last two are also in the outer write
self.assertIn((addr+2, ['c']), outer_writes)
self.assertIn((addr+3, ['d']), outer_writes)
def test_mem_trace_ignores_failing(self):
cs = ConstraintSet()
mem = SMemory32(cs)
addr = mem.mmap(None, 0x1000, 'rw')
mem.push_record_writes()
with self.assertRaises(MemoryException):
mem.write(addr-0x5000, 'a')
trace = mem.pop_record_writes()
# Make sure erroring writes don't get recorded
self.assertEqual(len(trace), 0)
if __name__ == '__main__':