Optimizations and bug fixes in smt formulas (#249)
* Remove the use of the incremental mode in get_all_values
* Improvement and bug fixes in visitor.py
* PretyPrinter: Remove dupplicate printed information
* TranslatorSmtLib: Remove dupplicate operands in the generated formulas
* ConstantFolderSimplifier:
* Add new operators (BitVecZeroExtand / BitVecExtract)
* Add no-trivial operators into the operations dict
-> users can known that these operators are handled
* ArithmeticSimplifier:
* Use of a recursive approach to visit expression (faster)
* Fix the missing calls to no-trivial operators for constant folding
* Add new operator (ArraySelect)
* Fix issues on corner cases (tests/travis_test.sh ok)
ArithmeticSimplifier: recursive approach less interesting in travis_test.sh
Use a stack approach, which includes bugfixes of the original
* - Call direclty constant_folder() in ArithmeticSimplifier
- Remove None bindings in ConstantFolderSimplifier.operators
- Move the fixed point thing direclty to the Visitors.visit (new parameter: use_fixed_point)
* - Calling constant_folder only if all operands are constants
- Add None default attribute to getattr
This commit is contained in:
parent
0b710dd86a
commit
aeca64285a
@ -351,7 +351,18 @@ class Z3Solver(Solver):
|
||||
while self._check() == 'sat':
|
||||
value = self._getvalue(var)
|
||||
result.append(value)
|
||||
self._assert( var != value )
|
||||
|
||||
# Reset the solver to avoid the incremental mode
|
||||
# Triggered with two consecutive calls to check-sat
|
||||
# Yet, if the number of solution is large, sending back
|
||||
# the whole formula is more expensive
|
||||
if len(result) < 50:
|
||||
self._reset(temp_cs.related_to(var) )
|
||||
for value in result:
|
||||
self._assert( var != value )
|
||||
else:
|
||||
self._assert(var != value)
|
||||
|
||||
if len(result) >= maxcnt:
|
||||
if silent:
|
||||
# do not throw an exception if set to silent
|
||||
|
||||
@ -60,7 +60,18 @@ class Visitor(object):
|
||||
return value
|
||||
return expression
|
||||
|
||||
def visit(self, node):
|
||||
def visit(self, node, use_fixed_point=False):
|
||||
'''
|
||||
The entry point of the visitor.
|
||||
The exploration algorithm is a DFS post-order traversal
|
||||
The implementation used two stacks instead of a recursion
|
||||
The final result is store in self.result
|
||||
|
||||
:param node: Node to explore
|
||||
:type node: Expression
|
||||
:param use_fixed_point: if True, it runs _methods until a fixed point is found
|
||||
:type use_fixed_point: Bool
|
||||
'''
|
||||
cache = self._cache
|
||||
|
||||
visited = set()
|
||||
@ -74,7 +85,15 @@ class Visitor(object):
|
||||
elif isinstance(node, Operation):
|
||||
if node in visited:
|
||||
operands = [self.pop() for _ in xrange(len(node.operands))]
|
||||
value = self._method(node, *operands)
|
||||
if use_fixed_point:
|
||||
new_node = self._rebuild(node, operands)
|
||||
value = self._method(new_node, *operands)
|
||||
while value is not new_node:
|
||||
new_node = value
|
||||
if isinstance(new_node, Operation):
|
||||
value = self._method(new_node, *new_node.operands)
|
||||
else:
|
||||
value = self._method(node, *operands)
|
||||
visited.remove(node)
|
||||
self.push(value)
|
||||
cache[node] = value
|
||||
@ -85,6 +104,15 @@ class Visitor(object):
|
||||
else:
|
||||
self.push(self._method(node))
|
||||
|
||||
@staticmethod
|
||||
def _rebuild(expression, operands):
|
||||
if isinstance(expression, Operation):
|
||||
import copy
|
||||
aux = copy.copy(expression)
|
||||
aux._operands = operands
|
||||
return aux
|
||||
return type(expression)(*operands, taint=expression.taint)
|
||||
|
||||
class GetDeclarations(Visitor):
|
||||
''' Simple visitor to collect all variables in an expression or set of
|
||||
expressions
|
||||
@ -136,8 +164,29 @@ class PrettyPrinter(Visitor):
|
||||
self.output += '\n'
|
||||
|
||||
def visit(self, expression):
|
||||
'''
|
||||
Overload Visitor.visit because:
|
||||
- We need a pre-order traversal
|
||||
- We use a recursion as it makes eaiser to keep track of the indentation
|
||||
|
||||
'''
|
||||
self._method(expression)
|
||||
|
||||
def _method(self, expression, *args):
|
||||
'''
|
||||
Overload Visitor._method because we want to stop to iterate over the
|
||||
visit_ functions as soon as a valide visit_ function is found
|
||||
'''
|
||||
assert expression.__class__.__mro__[-1] is object
|
||||
for cls in expression.__class__.__mro__:
|
||||
sort = cls.__name__
|
||||
methodname = 'visit_%s' % sort
|
||||
method = getattr(self, methodname, None)
|
||||
if method is not None:
|
||||
method(expression, *args)
|
||||
return
|
||||
return
|
||||
|
||||
def visit_Operation(self, expression, *operands):
|
||||
self._print(expression.__class__.__name__, expression)
|
||||
self.indent += 2
|
||||
@ -147,6 +196,7 @@ class PrettyPrinter(Visitor):
|
||||
else:
|
||||
self._print('...')
|
||||
self.indent -= 2
|
||||
return ''
|
||||
|
||||
def visit_BitVecExtract(self, expression):
|
||||
self._print(expression.__class__.__name__+'{%d:%d}'%(expression.begining,expression.end), expression)
|
||||
@ -157,12 +207,15 @@ class PrettyPrinter(Visitor):
|
||||
else:
|
||||
self._print('...')
|
||||
self.indent -= 2
|
||||
return ''
|
||||
|
||||
def visit_Constant(self, expression):
|
||||
self._print(expression.value)
|
||||
return ''
|
||||
|
||||
def visit_Variable(self, expression):
|
||||
self._print(expression.name)
|
||||
return ''
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
@ -205,6 +258,19 @@ class ConstantFolderSimplifier(Visitor):
|
||||
result |= o.value
|
||||
return BitVecConstant(expression.size, result, taint=expression.taint)
|
||||
|
||||
def visit_BitVecZeroExtend(self, expression, *operands):
|
||||
if all( isinstance(o, Constant) for o in operands):
|
||||
return BitVecConstant(expression.size, operands[0].value, taint=expression.taint)
|
||||
|
||||
def visit_BitVecExtract(self, expression, *operands):
|
||||
if all( isinstance(o, Constant) for o in expression.operands):
|
||||
value=expression.operands[0].value
|
||||
begining = expression.begining
|
||||
end = expression.end
|
||||
value = value >> begining
|
||||
mask = 2**(end - begining +1) - 1
|
||||
value = value & mask
|
||||
return BitVecConstant(expression.size, value, taint=expression.taint)
|
||||
|
||||
def visit_Operation(self, expression, *operands):
|
||||
''' constant folding, if all operands of an expression are a Constant do the math '''
|
||||
@ -232,30 +298,6 @@ class ArithmeticSimplifier(Visitor):
|
||||
def __init__(self, parent=None, **kw):
|
||||
super(ArithmeticSimplifier, self).__init__(**kw)
|
||||
|
||||
def _method(self, expression, *operands):
|
||||
value = super(ArithmeticSimplifier, self)._method(expression, *operands)
|
||||
#while value is not expression:
|
||||
# expression = value
|
||||
# if isinstance(expression, Operation):
|
||||
# print "A", expression.operands
|
||||
# operands = [self._method(op, *op.operands) for op in expression.operands]
|
||||
# print "B", operands
|
||||
# value = super(ArithmeticSimplifier, self)._method(expression, *operands)
|
||||
#something changed recursively visit the new node.
|
||||
if expression is not value:
|
||||
self.visit(value)
|
||||
value = self.pop()
|
||||
|
||||
|
||||
#if value is not expression
|
||||
# if isinstance(value, Operation):
|
||||
# for i in xrange(len(value.operands)):
|
||||
# self.visit(op)
|
||||
# new_operands = reversed([self.pop() for _ in range(len(value.operands))])
|
||||
# value = super(ArithmeticSimplifier, self)._method(value, *new_operands)
|
||||
return value
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _same_constant(a,b):
|
||||
return isinstance(a, Constant) and\
|
||||
@ -267,32 +309,14 @@ class ArithmeticSimplifier(Visitor):
|
||||
arity = len(operands)
|
||||
return any( operands[i] is not expression.operands[i] for i in range(arity))
|
||||
|
||||
@staticmethod
|
||||
def _rebuild(expression, operands):
|
||||
if isinstance(expression, Operation):
|
||||
import copy
|
||||
aux = copy.copy(expression)
|
||||
aux._operands = operands
|
||||
return aux
|
||||
return type(expression)(*operands, taint=expression.taint)
|
||||
|
||||
def visit_Operation(self, expression, *operands):
|
||||
''' constant folding, if all operands of an expression are a Constant do the math '''
|
||||
if all( isinstance(o, Constant) for o in operands) :
|
||||
if type(expression) in ConstantFolderSimplifier.operations:
|
||||
operation = ConstantFolderSimplifier.operations[type(expression)]
|
||||
value = operation(*(x.value for x in operands))
|
||||
if isinstance(expression, BitVec):
|
||||
return BitVecConstant(expression.size, value, taint=expression.taint)
|
||||
else:
|
||||
isinstance(expression, Bool)
|
||||
return BoolConstant(value, taint=expression.taint)
|
||||
else:
|
||||
if self._changed(expression, operands):
|
||||
expression = self._rebuild(expression, operands)
|
||||
expression = constant_folder(expression)
|
||||
if self._changed(expression, operands):
|
||||
expression = self._rebuild(expression, operands)
|
||||
return expression
|
||||
|
||||
|
||||
def visit_BitVecZeroExtend(self, expression, *operands):
|
||||
if self._changed(expression, operands):
|
||||
return BitVecZeroExtend(expression.size, *operands, taint=expression.taint)
|
||||
@ -418,6 +442,25 @@ class ArithmeticSimplifier(Visitor):
|
||||
elif right.value >= right.size:
|
||||
return left
|
||||
|
||||
def visit_ArraySelect(self, expression, *operands):
|
||||
''' ArraySelect (ArrayStore((ArrayStore(x0,v0) ...),xn, vn), x0)
|
||||
-> v0
|
||||
'''
|
||||
arr = expression.array
|
||||
index = expression.index
|
||||
if isinstance(index, BitVecConstant):
|
||||
index = index.value
|
||||
prev_arr = arr
|
||||
while isinstance(arr, ArrayStore):
|
||||
prev_arr = arr
|
||||
index_store = arr.index
|
||||
if isinstance(index_store, BitVecConstant):
|
||||
if index_store.value == index:
|
||||
return arr.byte
|
||||
arr = arr.array
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def visit_Expression(self, expression, *operands):
|
||||
assert len(operands) == 0
|
||||
@ -426,7 +469,7 @@ class ArithmeticSimplifier(Visitor):
|
||||
|
||||
def arithmetic_simplifier(expression):
|
||||
simp = ArithmeticSimplifier()
|
||||
simp.visit(expression)
|
||||
simp.visit(expression, use_fixed_point=True)
|
||||
return simp.result
|
||||
|
||||
class TranslatorSmtlib(Visitor):
|
||||
@ -526,9 +569,6 @@ class TranslatorSmtlib(Visitor):
|
||||
elif isinstance(expression, BitVecExtract):
|
||||
operation = operation % (expression.end, expression.begining)
|
||||
|
||||
for x in zip(expression.operands, operands):
|
||||
self._add_binding(*x)
|
||||
|
||||
operands = map(lambda x: self._add_binding(*x), zip(expression.operands, operands))
|
||||
smtlib = '(%s %s)' % (operation, ' '.join(operands))
|
||||
return smtlib
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user