Optimizations and bug fixes in smt formulas (#249)

* Remove the use of the incremental mode in get_all_values

* Improvement and bug fixes in visitor.py
* PretyPrinter: Remove dupplicate printed information
* TranslatorSmtLib: Remove dupplicate operands in the generated formulas
* ConstantFolderSimplifier:
  * Add new operators (BitVecZeroExtand / BitVecExtract)
  * Add no-trivial operators into the operations dict
      -> users can known that these operators are handled
* ArithmeticSimplifier:
  * Use of a recursive approach to visit expression (faster)
  * Fix the missing calls to no-trivial operators for constant folding
  * Add new operator (ArraySelect)

* Fix issues on corner cases (tests/travis_test.sh ok)

ArithmeticSimplifier: recursive approach less interesting in travis_test.sh
Use a stack approach, which includes bugfixes of the original

* - Call direclty constant_folder() in ArithmeticSimplifier
- Remove None bindings in ConstantFolderSimplifier.operators
- Move the fixed point thing direclty to the Visitors.visit (new parameter: use_fixed_point)

* - Calling constant_folder only if all operands are constants
- Add None default attribute to getattr
This commit is contained in:
Feist Josselin 2017-05-26 22:28:22 +03:00 committed by feliam
parent 0b710dd86a
commit aeca64285a
2 changed files with 103 additions and 52 deletions

View File

@ -351,7 +351,18 @@ class Z3Solver(Solver):
while self._check() == 'sat':
value = self._getvalue(var)
result.append(value)
self._assert( var != value )
# Reset the solver to avoid the incremental mode
# Triggered with two consecutive calls to check-sat
# Yet, if the number of solution is large, sending back
# the whole formula is more expensive
if len(result) < 50:
self._reset(temp_cs.related_to(var) )
for value in result:
self._assert( var != value )
else:
self._assert(var != value)
if len(result) >= maxcnt:
if silent:
# do not throw an exception if set to silent

View File

@ -60,7 +60,18 @@ class Visitor(object):
return value
return expression
def visit(self, node):
def visit(self, node, use_fixed_point=False):
'''
The entry point of the visitor.
The exploration algorithm is a DFS post-order traversal
The implementation used two stacks instead of a recursion
The final result is store in self.result
:param node: Node to explore
:type node: Expression
:param use_fixed_point: if True, it runs _methods until a fixed point is found
:type use_fixed_point: Bool
'''
cache = self._cache
visited = set()
@ -74,7 +85,15 @@ class Visitor(object):
elif isinstance(node, Operation):
if node in visited:
operands = [self.pop() for _ in xrange(len(node.operands))]
value = self._method(node, *operands)
if use_fixed_point:
new_node = self._rebuild(node, operands)
value = self._method(new_node, *operands)
while value is not new_node:
new_node = value
if isinstance(new_node, Operation):
value = self._method(new_node, *new_node.operands)
else:
value = self._method(node, *operands)
visited.remove(node)
self.push(value)
cache[node] = value
@ -85,6 +104,15 @@ class Visitor(object):
else:
self.push(self._method(node))
@staticmethod
def _rebuild(expression, operands):
if isinstance(expression, Operation):
import copy
aux = copy.copy(expression)
aux._operands = operands
return aux
return type(expression)(*operands, taint=expression.taint)
class GetDeclarations(Visitor):
''' Simple visitor to collect all variables in an expression or set of
expressions
@ -136,8 +164,29 @@ class PrettyPrinter(Visitor):
self.output += '\n'
def visit(self, expression):
'''
Overload Visitor.visit because:
- We need a pre-order traversal
- We use a recursion as it makes eaiser to keep track of the indentation
'''
self._method(expression)
def _method(self, expression, *args):
'''
Overload Visitor._method because we want to stop to iterate over the
visit_ functions as soon as a valide visit_ function is found
'''
assert expression.__class__.__mro__[-1] is object
for cls in expression.__class__.__mro__:
sort = cls.__name__
methodname = 'visit_%s' % sort
method = getattr(self, methodname, None)
if method is not None:
method(expression, *args)
return
return
def visit_Operation(self, expression, *operands):
self._print(expression.__class__.__name__, expression)
self.indent += 2
@ -147,6 +196,7 @@ class PrettyPrinter(Visitor):
else:
self._print('...')
self.indent -= 2
return ''
def visit_BitVecExtract(self, expression):
self._print(expression.__class__.__name__+'{%d:%d}'%(expression.begining,expression.end), expression)
@ -157,12 +207,15 @@ class PrettyPrinter(Visitor):
else:
self._print('...')
self.indent -= 2
return ''
def visit_Constant(self, expression):
self._print(expression.value)
return ''
def visit_Variable(self, expression):
self._print(expression.name)
return ''
@property
def result(self):
@ -205,6 +258,19 @@ class ConstantFolderSimplifier(Visitor):
result |= o.value
return BitVecConstant(expression.size, result, taint=expression.taint)
def visit_BitVecZeroExtend(self, expression, *operands):
if all( isinstance(o, Constant) for o in operands):
return BitVecConstant(expression.size, operands[0].value, taint=expression.taint)
def visit_BitVecExtract(self, expression, *operands):
if all( isinstance(o, Constant) for o in expression.operands):
value=expression.operands[0].value
begining = expression.begining
end = expression.end
value = value >> begining
mask = 2**(end - begining +1) - 1
value = value & mask
return BitVecConstant(expression.size, value, taint=expression.taint)
def visit_Operation(self, expression, *operands):
''' constant folding, if all operands of an expression are a Constant do the math '''
@ -232,30 +298,6 @@ class ArithmeticSimplifier(Visitor):
def __init__(self, parent=None, **kw):
super(ArithmeticSimplifier, self).__init__(**kw)
def _method(self, expression, *operands):
value = super(ArithmeticSimplifier, self)._method(expression, *operands)
#while value is not expression:
# expression = value
# if isinstance(expression, Operation):
# print "A", expression.operands
# operands = [self._method(op, *op.operands) for op in expression.operands]
# print "B", operands
# value = super(ArithmeticSimplifier, self)._method(expression, *operands)
#something changed recursively visit the new node.
if expression is not value:
self.visit(value)
value = self.pop()
#if value is not expression
# if isinstance(value, Operation):
# for i in xrange(len(value.operands)):
# self.visit(op)
# new_operands = reversed([self.pop() for _ in range(len(value.operands))])
# value = super(ArithmeticSimplifier, self)._method(value, *new_operands)
return value
@staticmethod
def _same_constant(a,b):
return isinstance(a, Constant) and\
@ -267,32 +309,14 @@ class ArithmeticSimplifier(Visitor):
arity = len(operands)
return any( operands[i] is not expression.operands[i] for i in range(arity))
@staticmethod
def _rebuild(expression, operands):
if isinstance(expression, Operation):
import copy
aux = copy.copy(expression)
aux._operands = operands
return aux
return type(expression)(*operands, taint=expression.taint)
def visit_Operation(self, expression, *operands):
''' constant folding, if all operands of an expression are a Constant do the math '''
if all( isinstance(o, Constant) for o in operands) :
if type(expression) in ConstantFolderSimplifier.operations:
operation = ConstantFolderSimplifier.operations[type(expression)]
value = operation(*(x.value for x in operands))
if isinstance(expression, BitVec):
return BitVecConstant(expression.size, value, taint=expression.taint)
else:
isinstance(expression, Bool)
return BoolConstant(value, taint=expression.taint)
else:
if self._changed(expression, operands):
expression = self._rebuild(expression, operands)
expression = constant_folder(expression)
if self._changed(expression, operands):
expression = self._rebuild(expression, operands)
return expression
def visit_BitVecZeroExtend(self, expression, *operands):
if self._changed(expression, operands):
return BitVecZeroExtend(expression.size, *operands, taint=expression.taint)
@ -418,6 +442,25 @@ class ArithmeticSimplifier(Visitor):
elif right.value >= right.size:
return left
def visit_ArraySelect(self, expression, *operands):
''' ArraySelect (ArrayStore((ArrayStore(x0,v0) ...),xn, vn), x0)
-> v0
'''
arr = expression.array
index = expression.index
if isinstance(index, BitVecConstant):
index = index.value
prev_arr = arr
while isinstance(arr, ArrayStore):
prev_arr = arr
index_store = arr.index
if isinstance(index_store, BitVecConstant):
if index_store.value == index:
return arr.byte
arr = arr.array
else:
break
def visit_Expression(self, expression, *operands):
assert len(operands) == 0
@ -426,7 +469,7 @@ class ArithmeticSimplifier(Visitor):
def arithmetic_simplifier(expression):
simp = ArithmeticSimplifier()
simp.visit(expression)
simp.visit(expression, use_fixed_point=True)
return simp.result
class TranslatorSmtlib(Visitor):
@ -526,9 +569,6 @@ class TranslatorSmtlib(Visitor):
elif isinstance(expression, BitVecExtract):
operation = operation % (expression.end, expression.begining)
for x in zip(expression.operands, operands):
self._add_binding(*x)
operands = map(lambda x: self._add_binding(*x), zip(expression.operands, operands))
smtlib = '(%s %s)' % (operation, ' '.join(operands))
return smtlib