Source code for sympy.simplify.cse_main

""" Tools for doing common subexpression elimination.
"""
import bisect
import difflib

from sympy import Basic, Mul, Add
from sympy.utilities.iterables import preorder_traversal, numbered_symbols

import cse_opts

# (preprocessor, postprocessor) pairs which are commonly useful. They should
# each take a sympy expression and return a possibly transformed expression.
# When used in the function ``cse()``, the target expressions will be transformed
# by each of the preprocessor functions in order. After the common
# subexpressions are eliminated, each resulting expression will have the
# postprocessor functions transform them in *reverse* order in order to undo the
# transformation if necessary. This allows the algorithm to operate on
# a representation of the expressions that allows for more optimization
# opportunities.
# ``None`` can be used to specify no transformation for either the preprocessor or
# postprocessor.

cse_optimizations = list(cse_opts.default_optimizations)

def preprocess_for_cse(expr, optimizations):
    """ Preprocess an expression to optimize for common subexpression
    elimination.

    Parameters
    ----------
    expr : sympy expression
        The target expression to optimize.
    optimizations : list of (callable, callable) pairs
        The (preprocessor, postprocessor) pairs.

    Returns
    -------
    expr : sympy expression
        The transformed expression.
    """
    for pre, post in optimizations:
        if pre is not None:
            expr = pre(expr)
    return expr

def postprocess_for_cse(expr, optimizations):
    """ Postprocess an expression after common subexpression elimination to
    return the expression to canonical sympy form.

    Parameters
    ----------
    expr : sympy expression
        The target expression to transform.
    optimizations : list of (callable, callable) pairs, optional
        The (preprocessor, postprocessor) pairs.  The postprocessors will be
        applied in reversed order to undo the effects of the preprocessors
        correctly.

    Returns
    -------
    expr : sympy expression
        The transformed expression.
    """
    if optimizations is None:
        optimizations = cse_optimizations
    for pre, post in reversed(optimizations):
        if post is not None:
            expr = post(expr)
    return expr

[docs]def cse(exprs, symbols=None, optimizations=None): """ Perform common subexpression elimination on an expression. Parameters: exprs : list of sympy expressions, or a single sympy expression The expressions to reduce. symbols : infinite iterator yielding unique Symbols The symbols used to label the common subexpressions which are pulled out. The ``numbered_symbols`` generator is useful. The default is a stream of symbols of the form "x0", "x1", etc. This must be an infinite iterator. optimizations : list of (callable, callable) pairs, optional The (preprocessor, postprocessor) pairs. If not provided, ``sympy.simplify.cse.cse_optimizations`` is used. Returns: replacements : list of (Symbol, expression) pairs All of the common subexpressions that were replaced. Subexpressions earlier in this list might show up in subexpressions later in this list. reduced_exprs : list of sympy expressions The reduced expressions with all of the replacements above. """ if symbols is None: symbols = numbered_symbols() else: # In case we get passed an iterable with an __iter__ method instead of # an actual iterator. symbols = iter(symbols) seen_subexp = set() muls = set() adds = set() to_eliminate = [] to_eliminate_ops_count = [] if optimizations is None: # Pull out the default here just in case there are some weird # manipulations of the module-level list in some other thread. optimizations = list(cse_optimizations) # Handle the case if just one expression was passed. if isinstance(exprs, Basic): exprs = [exprs] # Preprocess the expressions to give us better optimization opportunities. exprs = [preprocess_for_cse(e, optimizations) for e in exprs] # Find all of the repeated subexpressions. def insert(subtree): '''This helper will insert the subtree into to_eliminate while maintaining the ordering by op count and will skip the insertion if subtree is already present.''' ops_count = subtree.count_ops() index_to_insert = bisect.bisect(to_eliminate_ops_count, ops_count) # all i up to this index have op count <= the current op count # so check that subtree is not yet present from this index down # (if necessary) to zero. for i in xrange(index_to_insert - 1, -1, -1): if to_eliminate_ops_count[i] == ops_count and \ subtree == to_eliminate[i]: return # already have it to_eliminate_ops_count.insert(index_to_insert, ops_count) to_eliminate.insert(index_to_insert, subtree) for expr in exprs: pt = preorder_traversal(expr) for subtree in pt: if subtree.is_Atom: # Exclude atoms, since there is no point in renaming them. continue if subtree in seen_subexp: insert(subtree) pt.skip() continue if subtree.is_Mul: muls.add(subtree) elif subtree.is_Add: adds.add(subtree) seen_subexp.add(subtree) # process adds - any adds that weren't repeated might contain # subpatterns that are repeated, e.g. x+y+z and x+y have x+y in common adds = [set(a.args) for a in adds] for i in xrange(len(adds)): for j in xrange(i + 1, len(adds)): com = adds[i].intersection(adds[j]) if len(com) > 1: insert(Add(*com)) # remove this set of symbols so it doesn't appear again adds[i] = adds[i].difference(com) adds[j] = adds[j].difference(com) for k in xrange(j + 1, len(adds)): if not com.difference(adds[k]): adds[k] = adds[k].difference(com) # process muls - any muls that weren't repeated might contain # subpatterns that are repeated, e.g. x*y*z and x*y have x*y in common # use SequenceMatcher on the nc part to find the longest common expression # in common between the two nc parts sm = difflib.SequenceMatcher() muls = [a.args_cnc() for a in muls] for i in xrange(len(muls)): if muls[i][1]: sm.set_seq1(muls[i][1]) for j in xrange(i + 1, len(muls)): # the commutative part in common ccom = muls[i][0].intersection(muls[j][0]) # the non-commutative part in common if muls[i][1] and muls[j][1]: # see if there is any chance of an nc match ncom = set(muls[i][1]).intersection(set(muls[j][1])) if len(ccom) + len(ncom) < 2: continue # now work harder to find the match sm.set_seq2(muls[j][1]) i1, _, n = sm.find_longest_match(0, len(muls[i][1]), 0, len(muls[j][1])) ncom = muls[i][1][i1:i1 + n] else: ncom = [] com = list(ccom) + ncom if len(com) < 2: continue insert(Mul(*com)) # remove ccom from all if there was no ncom; to update the nc part # would require finding the subexpr and then replacing it with a # dummy to keep bounding nc symbols from being identified as a # subexpr, e.g. removing B*C from A*B*C*D might allow A*D to be # identified as a subexpr which would not be right. if not ncom: muls[i][0] = muls[i][0].difference(ccom) for k in xrange(j, len(muls)): if not ccom.difference(muls[k][0]): muls[k][0] = muls[k][0].difference(ccom) # Substitute symbols for all of the repeated subexpressions. replacements = [] reduced_exprs = list(exprs) for i, subtree in enumerate(to_eliminate): sym = symbols.next() replacements.append((sym, subtree)) # Make the substitution in all of the target expressions. for j, expr in enumerate(reduced_exprs): reduced_exprs[j] = expr.subs(subtree, sym) # Make the substitution in all of the subsequent substitutions. for j in range(i+1, len(to_eliminate)): to_eliminate[j] = to_eliminate[j].subs(subtree, sym) # Postprocess the expressions to return the expressions to canonical form. for i, (sym, subtree) in enumerate(replacements): subtree = postprocess_for_cse(subtree, optimizations) replacements[i] = (sym, subtree) reduced_exprs = [postprocess_for_cse(e, optimizations) for e in reduced_exprs] return replacements, reduced_exprs