Introduce a 'force' parameter to memory access functions (#632)

* whitespace cleanup

* Remove access check from COWMap.__setitem__/__getitem__

 * Access checks happen via read/write, so these checks are unecessary.

* Add force parameter to read/write functions

* Introduce  to AbstractCpu's accessors

* Add mem force tests

* Apply force param to symbolic operations

* Add symbolic force write tests

* Clean up test

* Fix symbolic write behavior; add tests
This commit is contained in:
Yan Ivnitskiy 2018-01-30 12:17:27 -05:00 committed by GitHub
parent faf1d16b99
commit c0068431c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 145 additions and 61 deletions

View File

@ -553,7 +553,7 @@ class Cpu(Eventful):
def memory(self): def memory(self):
return self._memory return self._memory
def write_int(self, where, expression, size=None): def write_int(self, where, expression, size=None, force=False):
''' '''
Writes int to memory Writes int to memory
@ -561,18 +561,20 @@ class Cpu(Eventful):
:param expr: value to write :param expr: value to write
:type expr: int or BitVec :type expr: int or BitVec
:param size: bit size of `expr` :param size: bit size of `expr`
:param force: whether to ignore memory permissions
''' '''
if size is None: if size is None:
size = self.address_bit_size size = self.address_bit_size
assert size in SANE_SIZES assert size in SANE_SIZES
self._publish('will_write_memory', where, expression, size) self._publish('will_write_memory', where, expression, size)
self.memory[where:where+size/8] = [Operators.CHR(Operators.EXTRACT(expression, offset, 8)) for offset in xrange(0, size, 8)] data = [Operators.CHR(Operators.EXTRACT(expression, offset, 8)) for offset in xrange(0, size, 8)]
self._memory.write(where, data, force)
self._publish('did_write_memory', where, expression, size) self._publish('did_write_memory', where, expression, size)
def read_int(self, where, size=None): def read_int(self, where, size=None, force=False):
''' '''
Reads int from memory Reads int from memory
@ -580,13 +582,14 @@ class Cpu(Eventful):
:param size: number of bits to read :param size: number of bits to read
:return: the value read :return: the value read
:rtype: int or BitVec :rtype: int or BitVec
:param force: whether to ignore memory permissions
''' '''
if size is None: if size is None:
size = self.address_bit_size size = self.address_bit_size
assert size in SANE_SIZES assert size in SANE_SIZES
self._publish('will_read_memory', where, size) self._publish('will_read_memory', where, size)
data = self.memory[where:where + size / 8] data = self._memory.read(where, size/8, force)
assert (8 * len(data)) == size assert (8 * len(data)) == size
value = Operators.CONCAT(size, *map(Operators.ORD, reversed(data))) value = Operators.CONCAT(size, *map(Operators.ORD, reversed(data)))
@ -594,32 +597,34 @@ class Cpu(Eventful):
return value return value
def write_bytes(self, where, data): def write_bytes(self, where, data, force=False):
''' '''
Write a concrete or symbolic (or mixed) buffer to memory Write a concrete or symbolic (or mixed) buffer to memory
:param int where: address to write to :param int where: address to write to
:param data: data to write :param data: data to write
:type data: str or list :type data: str or list
:param force: whether to ignore memory permissions
''' '''
for i in xrange(len(data)): for i in xrange(len(data)):
self.write_int(where + i, Operators.ORD(data[i]), 8) self.write_int(where + i, Operators.ORD(data[i]), 8, force)
def read_bytes(self, where, size): def read_bytes(self, where, size, force=False):
''' '''
Read from memory. Read from memory.
:param int where: address to read data from :param int where: address to read data from
:param int size: number of bytes :param int size: number of bytes
:param force: whether to ignore memory permissions
:return: data :return: data
:rtype: list[int or Expression] :rtype: list[int or Expression]
''' '''
result = [] result = []
for i in xrange(size): for i in xrange(size):
result.append(Operators.CHR(self.read_int(where + i, 8))) result.append(Operators.CHR(self.read_int(where + i, 8, force)))
return result return result
def write_string(self, where, string, max_length=None): def write_string(self, where, string, max_length=None, force=False):
''' '''
Writes a string to memory, appending a NULL-terminator at the end. Writes a string to memory, appending a NULL-terminator at the end.
:param int where: Address to write the string to :param int where: Address to write the string to
@ -627,14 +632,15 @@ class Cpu(Eventful):
:param int max_length: :param int max_length:
The size in bytes to cap the string at, or None [default] for no The size in bytes to cap the string at, or None [default] for no
limit. This includes the NULL terminator. limit. This includes the NULL terminator.
:param force: whether to ignore memory permissions
''' '''
if max_length is not None: if max_length is not None:
string = string[:max_length-1] string = string[:max_length-1]
self.write_bytes(where, string + '\x00') self.write_bytes(where, string + '\x00', force)
def read_string(self, where, max_length=None): def read_string(self, where, max_length=None, force=False):
''' '''
Read a NUL-terminated concrete buffer from memory. Stops reading at first symbolic byte. Read a NUL-terminated concrete buffer from memory. Stops reading at first symbolic byte.
@ -642,12 +648,13 @@ class Cpu(Eventful):
:param int max_length: :param int max_length:
The size in bytes to cap the string at, or None [default] for no The size in bytes to cap the string at, or None [default] for no
limit. limit.
:param force: whether to ignore memory permissions
:return: string read :return: string read
:rtype: str :rtype: str
''' '''
s = StringIO.StringIO() s = StringIO.StringIO()
while True: while True:
c = self.read_int(where, 8) c = self.read_int(where, 8, force)
if issymbolic(c) or c == 0: if issymbolic(c) or c == 0:
break break
@ -660,46 +667,50 @@ class Cpu(Eventful):
where += 1 where += 1
return s.getvalue() return s.getvalue()
def push_bytes(self, data): def push_bytes(self, data, force=False):
''' '''
Write `data` to the stack and decrement the stack pointer accordingly. Write `data` to the stack and decrement the stack pointer accordingly.
:param str data: Data to write :param str data: Data to write
:param force: whether to ignore memory permissions
''' '''
self.STACK -= len(data) self.STACK -= len(data)
self.write_bytes(self.STACK, data) self.write_bytes(self.STACK, data, force)
return self.STACK return self.STACK
def pop_bytes(self, nbytes): def pop_bytes(self, nbytes, force=False):
''' '''
Read `nbytes` from the stack, increment the stack pointer, and return Read `nbytes` from the stack, increment the stack pointer, and return
data. data.
:param int nbytes: How many bytes to read :param int nbytes: How many bytes to read
:param force: whether to ignore memory permissions
:return: Data read from the stack :return: Data read from the stack
''' '''
data = self.read_bytes(self.STACK, nbytes) data = self.read_bytes(self.STACK, nbytes, force=force)
self.STACK += nbytes self.STACK += nbytes
return data return data
def push_int(self, value): def push_int(self, value, force=False):
''' '''
Decrement the stack pointer and write `value` to the stack. Decrement the stack pointer and write `value` to the stack.
:param int value: The value to write :param int value: The value to write
:param force: whether to ignore memory permissions
:return: New stack pointer :return: New stack pointer
''' '''
self.STACK -= self.address_bit_size / 8 self.STACK -= self.address_bit_size / 8
self.write_int(self.STACK, value) self.write_int(self.STACK, value, force=force)
return self.STACK return self.STACK
def pop_int(self): def pop_int(self, force=False):
''' '''
Read a value from the stack and increment the stack pointer. Read a value from the stack and increment the stack pointer.
:param force: whether to ignore memory permissions
:return: Value read :return: Value read
''' '''
value = self.read_int(self.STACK) value = self.read_int(self.STACK, force=force)
self.STACK += self.address_bit_size / 8 self.STACK += self.address_bit_size / 8
return value return value

View File

@ -371,7 +371,6 @@ class COWMap(Map):
def __setitem__(self, index, value): def __setitem__(self, index, value):
assert self._in_range(index) assert self._in_range(index)
assert self.access_ok('w')
if isinstance(index, slice): if isinstance(index, slice):
for i in xrange(index.stop-index.start): for i in xrange(index.stop-index.start):
self._cow[index.start+i] = value[i] self._cow[index.start+i] = value[i]
@ -380,7 +379,6 @@ class COWMap(Map):
def __getitem__(self, index): def __getitem__(self, index):
assert self._in_range(index) assert self._in_range(index)
assert self.access_ok('r')
if isinstance(index, slice): if isinstance(index, slice):
result = [] result = []
@ -653,7 +651,6 @@ class Memory(object):
raise MemoryException("Page not mapped", address) raise MemoryException("Page not mapped", address)
return self._page2map[page_offset] return self._page2map[page_offset]
def mappings(self): def mappings(self):
''' '''
Returns a sorted list of all the mappings for this memory. Returns a sorted list of all the mappings for this memory.
@ -690,7 +687,6 @@ class Memory(object):
yield m yield m
addr = m.end addr = m.end
def munmap(self, start, size): def munmap(self, start, size):
''' '''
Deletes the mappings for the specified address range and causes further Deletes the mappings for the specified address range and causes further
@ -736,7 +732,6 @@ class Memory(object):
if tail: if tail:
self._add(tail) self._add(tail)
#Permissions #Permissions
def __contains__(self, address): def __contains__(self, address):
return self._page(address) in self._page2map return self._page(address) in self._page2map
@ -749,7 +744,7 @@ class Memory(object):
else: else:
return self.map_containing(index).perms return self.map_containing(index).perms
def access_ok(self, index, access): def access_ok(self, index, access, force=False):
if isinstance(index, slice): if isinstance(index, slice):
assert index.stop - index.start >= 0 assert index.stop - index.start >= 0
addr = index.start addr = index.start
@ -757,27 +752,27 @@ class Memory(object):
if addr not in self: if addr not in self:
return False return False
m = self.map_containing(addr) m = self.map_containing(addr)
size = min(m.end-addr, index.stop-addr)
if not m.access_ok(access): if not force and not m.access_ok(access):
return False return False
addr += size
until_next_page = min(m.end-addr, index.stop-addr)
addr += until_next_page
assert addr == index.stop assert addr == index.stop
return True return True
else: else:
if index not in self: if index not in self:
return False return False
m = self.map_containing(index) m = self.map_containing(index)
return m.access_ok(access) return force or m.access_ok(access)
#write and read potentially symbolic bytes at symbolic indexes #write and read potentially symbolic bytes at symbolic indexes
def read(self, addr, size): def read(self, addr, size, force=False):
if not self.access_ok(slice(addr, addr+size), 'r'): if not self.access_ok(slice(addr, addr+size), 'r', force):
raise InvalidMemoryAccess(addr, 'r') raise InvalidMemoryAccess(addr, 'r')
assert size > 0 assert size > 0
result = [] result = []
start = addr
stop = addr+size stop = addr+size
p = addr p = addr
while p < stop: while p < stop:
@ -824,9 +819,9 @@ class Memory(object):
self._recording_stack[-1].extend(lst) self._recording_stack[-1].extend(lst)
return lst return lst
def write(self, addr, buf): def write(self, addr, buf, force=False):
size = len(buf) size = len(buf)
if not self.access_ok(slice(addr, addr + size), 'w'): if not self.access_ok(slice(addr, addr + size), 'w', force):
raise InvalidMemoryAccess(addr, 'w') raise InvalidMemoryAccess(addr, 'w')
assert size > 0 assert size > 0
stop = addr + size stop = addr + size
@ -917,13 +912,14 @@ class SMemory(Memory):
del self._symbols[addr] del self._symbols[addr]
super(SMemory, self).munmap(start,size) super(SMemory, self).munmap(start,size)
def read(self, address, size): def read(self, address, size, force=False):
''' '''
Read a stream of potentially symbolic bytes from a potentially symbolic Read a stream of potentially symbolic bytes from a potentially symbolic
address address
:param address: Where to read from :param address: Where to read from
:param size: How many bytes :param size: How many bytes
:param force: Whether to ignore permissions
:rtype: list :rtype: list
''' '''
size = self._get_size(size) size = self._get_size(size)
@ -933,11 +929,13 @@ class SMemory(Memory):
assert solver.check(self.constraints) assert solver.check(self.constraints)
logger.debug('Reading %d bytes from symbolic address %s', size, address) logger.debug('Reading %d bytes from symbolic address %s', size, address)
try: try:
solutions = solver.get_all_values(self.constraints, address, maxcnt=0x1000) #if more than 0x3000 exception solutions = self._try_get_solutions(address, size, 'r', force=force)
assert len(solutions) > 0
except TooManySolutions as e: except TooManySolutions as e:
m, M = solver.minmax(self.constraints, address) m, M = solver.minmax(self.constraints, address)
logger.debug('Got TooManySolutions on a symbolic read. Range [%x, %x]. Not crashing!', m, M) logger.debug('Got TooManySolutions on a symbolic read. Range [%x, %x]. Not crashing!', m, M)
# The force param shouldn't affect this, as this is checking for unmapped reads, not bad perms
crashing_condition = True crashing_condition = True
for start, end, perms, offset, name in self.mappings(): for start, end, perms, offset, name in self.mappings():
if start <= M+size and end >= m : if start <= M+size and end >= m :
@ -957,16 +955,6 @@ class SMemory(Memory):
raise ForkState("Forking state on incomplete result", condition) raise ForkState("Forking state on incomplete result", condition)
#So here we have all potential solutions to address #So here we have all potential solutions to address
assert len(solutions) > 0
crashing_condition = False
for base in solutions:
if any(not self.access_ok(i, 'r') for i in xrange(base, base + size, self.page_size)):
crashing_condition = Operators.OR(address == base, crashing_condition)
if solver.can_be_true(self.constraints, crashing_condition):
raise InvalidSymbolicMemoryAccess(address, 'r', size, crashing_condition)
condition = False condition = False
for base in solutions: for base in solutions:
@ -989,7 +977,7 @@ class SMemory(Memory):
assert len(result) == offset+1 assert len(result) == offset+1
return map(Operators.CHR, result) return map(Operators.CHR, result)
else: else:
result = map(Operators.ORD, super(SMemory, self).read(address, size)) result = map(Operators.ORD, super(SMemory, self).read(address, size, force))
for offset in range(size): for offset in range(size):
if address+offset in self._symbols: if address+offset in self._symbols:
for condition, value in self._symbols[address+offset]: for condition, value in self._symbols[address+offset]:
@ -999,44 +987,62 @@ class SMemory(Memory):
result[offset] = Operators.ITEBV(8, condition, Operators.ORD(value), result[offset]) result[offset] = Operators.ITEBV(8, condition, Operators.ORD(value), result[offset])
return map(Operators.CHR, result) return map(Operators.CHR, result)
def write(self, address, value): def write(self, address, value, force=False):
''' '''
Write a value at address. Write a value at address.
:param address: The address at which to write :param address: The address at which to write
:type address: int or long or Expression :type address: int or long or Expression
:param value: Bytes to write :param value: Bytes to write
:type value: str or list :type value: str or list
:param force: Whether to ignore permissions
''' '''
size = len(value) size = len(value)
if issymbolic(address): if issymbolic(address):
solutions = solver.get_all_values(self.constraints, address, maxcnt=0x1000) #if more than 0x3000 exception solutions = self._try_get_solutions(address, size, 'w', force=force)
crashing_condition = False
for base in solutions:
if any(not self.access_ok(i, 'w') for i in xrange(base, base + size, self.page_size)):
crashing_condition = Operators.OR(address == base, crashing_condition)
if solver.can_be_true(self.constraints, crashing_condition):
raise InvalidSymbolicMemoryAccess(address, 'w', size, crashing_condition)
for offset in xrange(size): for offset in xrange(size):
for base in solutions: for base in solutions:
condition = base == address condition = base == address
self._symbols.setdefault(base+offset, []).append((condition, value[offset])) self._symbols.setdefault(base+offset, []).append((condition, value[offset]))
else: else:
for offset in xrange(size): for offset in xrange(size):
if issymbolic(value[offset]): if issymbolic(value[offset]):
if not self.access_ok(address+offset, 'w'): if not self.access_ok(address+offset, 'w', force):
raise InvalidMemoryAccess(address+offset, 'w') raise InvalidMemoryAccess(address+offset, 'w')
self._symbols[address+offset] = [(True, value[offset])] self._symbols[address+offset] = [(True, value[offset])]
else: else:
# overwrite all previous items # overwrite all previous items
if address+offset in self._symbols: if address+offset in self._symbols:
del self._symbols[address+offset] del self._symbols[address+offset]
super(SMemory, self).write(address+offset, [value[offset]]) super(SMemory, self).write(address+offset, [value[offset]], force)
def _try_get_solutions(self, address, size, access, max_solutions=0x1000, force=False):
'''
Try to solve for a symbolic address, checking permissions when reading/writing size bytes.
:param Expression address: The address to solve for
:param int size: How many bytes to check permissions for
:param str access: 'r' or 'w'
:param int max_solutions: Will raise if more solutions are found
:param force: Whether to ignore permission failure
:rtype: list
'''
assert issymbolic(address)
solutions = solver.get_all_values(self.constraints, address, maxcnt=max_solutions)
crashing_condition = False
for base in solutions:
if not self.access_ok(slice(base,base+size), access, force):
crashing_condition = Operators.OR(address == base, crashing_condition)
if solver.can_be_true(self.constraints, crashing_condition):
raise InvalidSymbolicMemoryAccess(address, access, size, crashing_condition)
return solutions
class Memory32(Memory): class Memory32(Memory):

View File

@ -1693,6 +1693,73 @@ class MemoryTest(unittest.TestCase):
# Make sure erroring writes don't get recorded # Make sure erroring writes don't get recorded
self.assertEqual(len(trace), 0) self.assertEqual(len(trace), 0)
def test_force_access(self):
mem = Memory32()
ro = mem.mmap(0x1000, 0x1000, 'r')
wo = mem.mmap(0x2000, 0x1000, 'w')
xo = mem.mmap(0x3000, 0x1000, 'x')
nul = mem.mmap(0x4000, 0x1000, '')
self.assertEqual(len(mem.mappings()), 4)
self.assertItemsEqual((ro,wo,xo, nul), (0x1000,0x2000,0x3000, 0x4000))
self.assertTrue(mem.access_ok(ro, 'r'))
self.assertFalse(mem.access_ok(ro, 'w'))
with self.assertRaises(InvalidMemoryAccess):
mem.write(ro, 'hello')
mem.write(ro, 'hello', force=True) # Would raise if fails, failing this test
with self.assertRaises(InvalidMemoryAccess):
mem.read(wo, 4)
mem.read(wo, 4, force=True) # Would raise if fails, failing this test
with self.assertRaises(InvalidMemoryAccess):
mem.read(nul, 4)
mem.write(nul, 'hello')
mem.read(nul, 4, force=True)
mem.write(nul, 'hello', force=True)
def test_symbolic_force_access(self):
cs = ConstraintSet()
mem = SMemory32(cs)
msg = 'hello'
ro = mem.mmap(0x1000, 0x1000, 'r')
nul = mem.mmap(0x2000, 0x1000, '')
nul_end = nul + 0x1000
# 1. Should raise if a value is entirely outside of mapped memory
addr1 = cs.new_bitvec(32)
cs.add(addr1 > (ro-16)) # 16 > len(msg)
cs.add(addr1 <= (ro+16))
# Can write to unmapped memory, should raise despite force
with self.assertRaises(InvalidSymbolicMemoryAccess):
mem.write(addr1, msg, force=True)
# 2. Force write to mapped memory, should not raise; no force should
addr2 = cs.new_bitvec(32)
cs.add(addr2 > (nul_end - 16))
cs.add(addr2 <= (nul_end-len(msg)))
mem.write(addr2, msg, force=True)
with self.assertRaises(InvalidSymbolicMemoryAccess):
mem.write(addr2, msg)
# 3. Forced write spans from unmapped to mapped memory, should raise
addr3 = cs.new_bitvec(32)
cs.add(addr3 > (nul_end - 16))
# single byte into unmapped memory
cs.add(addr3 <= (nul_end-len(msg)+1))
with self.assertRaises(InvalidSymbolicMemoryAccess):
mem.write(addr3, msg, force=True)
# 4. Try to force-read a span from mapped, but unreadable memory, should not raise
mem.read(addr2, 5, force=True)
# , but without force should
with self.assertRaises(InvalidSymbolicMemoryAccess):
mem.read(addr2, 5)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()