Graceful fallback if Z3 doesn't support optimization queries (#135)

* Configure Z3 depending on installed version

* Mocking for subprocess.check_output

* change get_value_fmt to private attribute _get_value_fmt

* Move memoized to utils

* Remove minimal version and simplify

* Re-add invalid versions/sanity checks

* Z3 version format checked on github

* New tests. Version is future proof

* Better logging
This commit is contained in:
feliam 2017-04-18 18:40:06 -03:00 committed by GitHub
parent f6f20b5210
commit a9711cf119
4 changed files with 109 additions and 121 deletions

View File

@ -23,9 +23,10 @@ import logging
import re
import time
from visitors import *
from ...utils.helpers import issymbolic
logger = logging.getLogger("SMT")
from ...utils.helpers import issymbolic, memoized
import collections
logger = logging.getLogger("SMT")
class Z3NotFoundError(EnvironmentError):
pass
@ -104,76 +105,73 @@ class Solver(object):
else:
return x, x
import collections
import functools
class memoized(object):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args, **kwargs):
key = args + tuple(sorted(kwargs.items()))
if not isinstance(key, collections.Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args, **kwargs)
if key in self.cache:
return self.cache[key]
else:
value = self.func(*args, **kwargs)
self.cache[key] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
#FixME move this \/
#FixME move this \/ This configuration should be registered as global config
consider_unknown_as_unsat = True
class SMTSolver(Solver):
Version = collections.namedtuple('Version', 'major minor patch')
class Z3Solver(Solver):
def __init__(self):
''' Build a solver intance.
This is implemented using an external native solver via a subprocess.
Everytime a new symbol or assertion is added a smtlibv2 command is
sent to the solver.
The actual state is also mantained in memory to be able to save and
restore the state.
The analisys may be saved to disk and continued after a while or
forked in memory or even sent over the network.
''' Build a Z3 solver intance.
This is implemented using an external z3 solver (via a subprocess).
'''
super(SMTSolver, self).__init__()
super(Z3Solver, self).__init__()
self._proc = None
self._log = '' #this should be enabled only if we are debugging
self.version = self._solver_version()
self.support_maximize = False
self.support_minimize = False
self.support_reset = True
logger.debug('Z3 version: %s', self.version)
if self.version >= Version(4, 4, 1):
self.support_maximize = True
self.support_minimize = True
else:
logger.debug(' Please install Z3 4.4.1 or newer to get optimization support')
self._command = 'z3 -t:30000 -smt2 -in'
self._init = ['(set-logic QF_AUFBV)', '(set-option :global-decls false)']
self._get_value_fmt = (re.compile('\(\((?P<expr>(.*))\ #x(?P<value>([0-9a-fA-F]*))\)\)'), 16)
@staticmethod
def _solver_version():
'''
If we
fail to parse the version, we assume z3's output has changed, meaning it's a newer
version than what's used now, and therefore ok.
Anticipated version_cmd_output format: 'Z3 version 4.4.2'
'Z3 version 4.4.5 - 64 bit - build hashcode $Z3GITHASH'
'''
their_version = Version(0,0,0)
try:
version_cmd_output = check_output(self.version_cmd.split())
version_cmd_output = check_output('z3 -version'.split())
except OSError:
raise Z3NotFoundError
self._check_solver_version(version_cmd_output)
self._proc = None
self._constraints = None
self._log = ''
def _check_solver_version(self, version_cmd_output):
''' Auxiliary method to check the version of the external solver
This will spawn the external solver with the configured parameters and check that the banner matches the expected version.
'''
pass
try:
version = version_cmd_output.split()[2]
their_version = Version(*map(int, version.split('.')))
except (IndexError, ValueError, TypeError):
pass
return their_version
def _start_proc(self):
''' Auxiliary method to spawn the external solver pocess'''
assert '_proc' not in dir(self) or self._proc is None
try:
self._proc = Popen(self.command.split(' '), stdin=PIPE, stdout=PIPE )
self._proc = Popen(self._command.split(' '), stdin=PIPE, stdout=PIPE )
except OSError:
#Z3 was removed from the system in the middle of operation
raise Z3NotFoundError # TODO(mark) don't catch this exception in two places
#run solver specific initializations
for cfg in self.init:
for cfg in self._init:
self._send(cfg)
def _stop_proc(self):
@ -210,7 +208,8 @@ class SMTSolver(Solver):
else:
if self.support_reset:
self._send("(reset)")
for cfg in self.init:
for cfg in self._init:
self._send(cfg)
else:
self._stop_proc()
@ -254,16 +253,15 @@ class SMTSolver(Solver):
raise Exception("Error in smtlib: {}".format(bufl[0]))
return buf
## UTILS: check-sat get-value simplify
## UTILS: check-sat get-value
def _check(self):
''' Check the satisfiability of the current state '''
logger.debug("!! Solver.check() ")
logger.debug("Solver.check() ")
start = time.time()
self._send('(check-sat)')
_status = self._recv()
logger.debug("Check took %s seconds (%s)", time.time()- start, _status)
if _status not in ('sat','unsat','unknown'):
#print "<"*100 + self._log +">"*100
raise SolverException(_status)
if consider_unknown_as_unsat:
if _status == 'unknown':
@ -297,7 +295,7 @@ class SMTSolver(Solver):
if isinstance(expression, Bool):
return {'true': True, 'false': False}[ret[2:-2].split(' ')[1]]
elif isinstance(expression, BitVec):
pattern, base = self.get_value_fmt
pattern, base = self._get_value_fmt
m = pattern.match(ret)
expr, value = m.group('expr'), m.group('value')
return int(value, base)
@ -450,7 +448,7 @@ class SMTSolver(Solver):
self._send('(get-value (%s))'%var[i].name)
ret = self._recv()
assert ret.startswith('((') and ret.endswith('))')
pattern, base = self.get_value_fmt
pattern, base = self._get_value_fmt
m = pattern.match(ret)
expr, value = m.group('expr'), m.group('value')
result += chr(int(value, base))
@ -470,54 +468,11 @@ class SMTSolver(Solver):
if isinstance(expression, Bool):
return {'true':True, 'false':False}[ret[2:-2].split(' ')[1]]
if isinstance(expression, BitVec):
pattern, base = self.get_value_fmt
pattern, base = self._get_value_fmt
m = pattern.match(ret)
expr, value = m.group('expr'), m.group('value')
return int(value, base)
raise NotImplementedError("get_value only implemented for Bool and BitVec")
def simplify(self):
''' Ask the solver to try to simplify the expression val.
This works only with z3.
:param val: a symbol or expression.
'''
simple_constraints = []
for exp in self._constraints:
new_constraint = exp.simplify()
if not isinstance(new_constraint, Bool):
simple_constraints.append(new_constraint)
self._constraints = set(simple_constraints)
Version = collections.namedtuple('Version', 'major minor patch')
class Z3Solver(SMTSolver):
def __init__(self):
self.command = 'z3 -t:30000 -smt2 -in'
self.init = ['(set-logic QF_AUFBV)', '(set-option :global-decls false)']
self.version_cmd = 'z3 -version'
self.min_version = Version(4, 4, 2)
self.get_value_fmt = (re.compile('\(\((?P<expr>(.*))\ #x(?P<value>([0-9a-fA-F]*))\)\)'), 16)
self.support_simplify = True
self.support_reset = False
self.support_maximize = True
self.support_minimize = True
super(Z3Solver, self).__init__()
def _check_solver_version(self, version_cmd_output):
'''
Check that the z3 version we're using is at least the minimum we need. If we
fail to parse the version, we assume z3's output has changed, meaning it's a newer
version than what's used now, and therefore ok.
Anticipated version_cmd_output format: 'Z3 version 4.4.2'
'''
try:
version = version_cmd_output.split()[2]
their_version = Version(*map(int, version.split('.')))
if their_version < self.min_version:
raise SolverException("Z3 Version >= {}.{}.{} required".format(*self.min_version))
except (IndexError, ValueError, TypeError):
pass
solver = Z3Solver()

View File

@ -1,4 +1,3 @@
import os
import sys
import time

View File

@ -12,3 +12,32 @@ def issymbolic(value):
'''
return isinstance(value, Expression)
import functools
class memoized(object):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args, **kwargs):
key = args + tuple(sorted(kwargs.items()))
if not isinstance(key, collections.Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args, **kwargs)
if key in self.cache:
return self.cache[key]
else:
value = self.func(*args, **kwargs)
self.cache[key] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)

View File

@ -658,28 +658,33 @@ class ExpressionTest(unittest.TestCase):
self.assertTrue(solver.check(cs))
self.assertEqual(solver.get_value(cs, a), -7&0xFF)
import importlib
class Z3Test(unittest.TestCase):
def setUp(self):
self.z3 = Z3Solver()
#Manual mock for check_output
self.module = importlib.import_module('manticore.core.smtlib.solver')
self.module.check_output = lambda *args, **kwargs: self.version
self.z3 = self.module.Z3Solver
def test_check_solver_min(self):
self.z3._check_solver_version('Z3 version 4.4.2')
def test_check_solver_too_old(self):
with self.assertRaises(SolverException):
self.z3._check_solver_version('Z3 version 4.4.1')
self.version = 'Z3 version 4.4.1'
self.assertTrue(self.z3._solver_version() == Version(major=4, minor=4, patch=1))
def test_check_solver_newer(self):
self.z3._check_solver_version('Z3 version 4.5.0')
self.version = 'Z3 version 4.5.0'
self.assertTrue(self.z3._solver_version() > Version(major=4, minor=4, patch=1))
def test_check_solver_badfmt(self):
self.z3._check_solver_version('Z3 4.5.0')
def test_check_solver_optimize(self):
self.version = 'Z3 version 4.5.0'
solver = self.z3()
self.assertTrue(solver.support_maximize)
self.assertTrue(solver.support_minimize)
def test_check_solver_badfmt2(self):
self.z3._check_solver_version('Z3 version 4.5.0.8')
def test_check_solver_badfmt3(self):
self.z3._check_solver_version('Z3 version 4.5.what')
def test_check_solver_optimize(self):
self.version = 'Z3 version 4.4.0'
solver = self.z3()
self.assertFalse(solver.support_maximize)
self.assertFalse(solver.support_minimize)
if __name__ == '__main__':
unittest.main()