Source code for sympy.functions.elementary.piecewise

from sympy.core import Basic, S, Function, diff, Tuple
from sympy.core.relational import Equality, Relational
from sympy.core.symbol import Dummy
from sympy.functions.elementary.miscellaneous import Max, Min
from sympy.logic.boolalg import And, Boolean, Or
from sympy.utilities.misc import default_sort_key

[docs]class ExprCondPair(Tuple): """Represents an expression, condition pair.""" true_sentinel = Dummy('True') def __new__(cls, expr, cond): if cond is True: cond = ExprCondPair.true_sentinel return Tuple.__new__(cls, expr, cond) @property
[docs] def expr(self): """ Returns the expression of this pair. """ return self.args[0]
[docs] def cond(self): """ Returns the condition of this pair. """ if self.args[1] == ExprCondPair.true_sentinel: return True return self.args[1]
[docs] def free_symbols(self): """ Return the free symbols of this pair. """ # Overload Basic.free_symbols because self.args[1] may contain non-Basic result = self.expr.free_symbols if hasattr(self.cond, 'free_symbols'): result |= self.cond.free_symbols return result
@property def is_commutative(self): return self.expr.is_commutative def __iter__(self): yield self.expr yield self.cond
[docs]class Piecewise(Function): """ Represents a piecewise function. Usage: Piecewise( (expr,cond), (expr,cond), ... ) - Each argument is a 2-tuple defining a expression and condition - The conds are evaluated in turn returning the first that is True. If any of the evaluated conds are not determined explicitly False, e.g. x < 1, the function is returned in symbolic form. - If the function is evaluated at a place where all conditions are False, a ValueError exception will be raised. - Pairs where the cond is explicitly False, will be removed. Examples ======== >>> from sympy import Piecewise, log >>> from sympy.abc import x >>> f = x**2 >>> g = log(x) >>> p = Piecewise( (0, x<-1), (f, x<=1), (g, True)) >>> p.subs(x,1) 1 >>> p.subs(x,5) log(5) See Also ======== piecewise_fold """ nargs = None is_Piecewise = True def __new__(cls, *args, **options): # (Try to) sympify args first newargs = [] for ec in args: pair = ExprCondPair(*ec) cond = pair.cond if cond is False: continue if not isinstance(cond, (bool, Relational, Boolean)): raise TypeError( "Cond %s is of type %s, but must be a Relational," \ " Boolean, or a built-in bool." % (cond, type(cond))) newargs.append(pair) if cond is True: break if options.pop('evaluate', True): r = cls.eval(*newargs) else: r = None if r is None: return Basic.__new__(cls, *newargs, **options) else: return r @classmethod def eval(cls, *args): # Check for situations where we can evaluate the Piecewise object. # 1) Hit an unevaluable cond (e.g. x<1) -> keep object # 2) Hit a true condition -> return that expr # 3) Remove false conditions, if no conditions left -> raise ValueError all_conds_evaled = True # Do all conds eval to a bool? piecewise_again = False # Should we pass args to Piecewise again? non_false_ecpairs = [] or1 = Or(*[cond for (_, cond) in args if cond is not True]) for expr, cond in args: # Check here if expr is a Piecewise and collapse if one of # the conds in expr matches cond. This allows the collapsing # of Piecewise((Piecewise(x,x<0),x<0)) to Piecewise((x,x<0)). # This is important when using piecewise_fold to simplify # multiple Piecewise instances having the same conds. # Eventually, this code should be able to collapse Piecewise's # having different intervals, but this will probably require # using the new assumptions. if isinstance(expr, Piecewise): or2 = Or(*[c for (_, c) in expr.args if c is not True]) for e, c in expr.args: # Don't collapse if cond is "True" as this leads to # incorrect simplifications with nested Piecewises. if c == cond and (or1 == or2 or cond is not True): expr = e piecewise_again = True cond_eval = cls.__eval_cond(cond) if cond_eval is None: all_conds_evaled = False elif cond_eval: if all_conds_evaled: return expr if len(non_false_ecpairs) != 0 and non_false_ecpairs[-1].expr == expr: non_false_ecpairs[-1] = ExprCondPair(expr, Or(cond, non_false_ecpairs[-1].cond)) else: non_false_ecpairs.append( ExprCondPair(expr, cond) ) if len(non_false_ecpairs) != len(args) or piecewise_again: return Piecewise(*non_false_ecpairs) return None
[docs] def doit(self, **hints): """ Evaluate this piecewise function. """ newargs = [] for e, c in self.args: if hints.get('deep', True): if isinstance(e, Basic): e = e.doit(**hints) if isinstance(c, Basic): c = c.doit(**hints) newargs.append((e, c)) return Piecewise(*newargs)
def _eval_as_leading_term(self, x): for e, c in self.args: if c is True or c.subs(x, 0) is True: return e.as_leading_term(x) def _eval_conjugate(self): from sympy.functions.elementary.complexes import conjugate return Piecewise(*[(conjugate(e), c) for e, c in self.args]) def _eval_derivative(self, x): return Piecewise(*[(diff(e, x), c) for e, c in self.args]) def _eval_evalf(self, prec): return Piecewise(*[(e.evalf(prec), c) for e, c in self.args]) def _eval_integral(self, x): from sympy.integrals import integrate return Piecewise(*[(integrate(e, x), c) for e, c in self.args]) def _eval_interval(self, sym, a, b): """Evaluates the function along the sym in a given interval ab""" # FIXME: Currently complex intervals are not supported. A possible # replacement algorithm, discussed in issue 2128, can be found in the # following papers; # http://portal.acm.org/citation.cfm?id=281649 # http://citeseerx.ist.psu.edu/viewdoc/download?doi= mul = 1 if (a == b) is True: return S.Zero elif (a > b) is True: a, b, mul = b, a, -1 elif (a <= b) is not True: newargs = [] for e, c in self.args: intervals = self._sort_expr_cond(sym, S.NegativeInfinity, S.Infinity, c) values = [] for lower, upper in intervals: if (a < lower) is True: mid = lower rep = b val = e.subs(sym, b) - e.subs(sym, mid) val += self._eval_interval(sym, a, mid) elif (a > upper) is True: mid = upper rep = b val = e.subs(sym, b) - e.subs(sym, mid) val += self._eval_interval(sym, a, mid) elif (a >= lower) is True and (a <= upper) is True: rep = b val = e.subs(sym, b) - e.subs(sym, a) elif (b < lower) is True: mid = lower rep = a val = e.subs(sym, mid) - e.subs(sym, a) val += self._eval_interval(sym, mid, b) elif (b > upper) is True: mid = upper rep = a val = e.subs(sym, mid) - e.subs(sym, a) val += self._eval_interval(sym, mid, b) elif ((b >= lower) is True) and ((b <= upper) is True): rep = a val = e.subs(sym, b) - e.subs(sym, a) else: raise NotImplementedError( """The evaluation of a Piecewise interval when both the lower and the upper limit are symbolic is not yet implemented.""") values.append(val) if len(set(values)) == 1 : try: c = c.subs(sym, rep) except AttributeError: pass e = values[0] newargs.append((e, c)) else: for i in range(len(values)): newargs.append((values[i], (c is True and i == len(values) - 1) or And(rep >= intervals[i][0], rep <= intervals[i][1]))) return Piecewise(*newargs) # Determine what intervals the expr,cond pairs affect. int_expr = self._sort_expr_cond(sym, a, b) # Finally run through the intervals and sum the evaluation. ret_fun = 0 for int_a, int_b, expr in int_expr: ret_fun += expr._eval_interval(sym, Max(a, int_a), Min(b, int_b)) return mul * ret_fun def _sort_expr_cond(self, sym, a, b, targetcond = None): """Determine what intervals the expr, cond pairs affect. 1) If cond is True, then log it as default 1.1) Currently if cond can't be evaluated, throw NotImplementedError. 2) For each inequality, if previous cond defines part of the interval update the new conds interval. - eg x < 1, x < 3 -> [oo,1],[1,3] instead of [oo,1],[oo,3] 3) Sort the intervals to make it easier to find correct exprs Under normal use, we return the expr,cond pairs in increasing order along the real axis corresponding to the symbol sym. If targetcond is given, we return a list of (lowerbound, upperbound) pairs for this condition.""" default = None int_expr = [] expr_cond = [] or_cond = False or_intervals = [] for expr, cond in self.args: if isinstance(cond, Or): for cond2 in sorted(cond.args, key=default_sort_key): expr_cond.append((expr, cond2)) else: expr_cond.append((expr, cond)) if cond is True: break for expr, cond in expr_cond: if cond is True: default = expr break elif isinstance(cond, Equality): continue elif isinstance(cond, And): lower = S.NegativeInfinity upper = S.Infinity for cond2 in cond.args: if cond2.lts.has(sym): upper = Min(cond2.gts, upper) elif cond2.gts.has(sym): lower = Max(cond2.lts, lower) else: lower, upper = cond.lts, cond.gts # part 1: initialize with givens if cond.lts.has(sym): # part 1a: expand the side ... lower = S.NegativeInfinity # e.g. x <= 0 ---> -oo <= 0 elif cond.gts.has(sym): # part 1a: ... that can be expanded upper = S.Infinity # e.g. x >= 0 ---> oo >= 0 else: raise NotImplementedError( "Unable to handle interval evaluation of expression.") # part 1b: Reduce (-)infinity to what was passed in. lower, upper = Max(a, lower), Min(b, upper) for n in range(len(int_expr)): # Part 2: remove any interval overlap. For any conflicts, the # iterval already there wins, and the incoming interval updates # its bounds accordingly. if self.__eval_cond(lower < int_expr[n][1]) and \ self.__eval_cond(lower >= int_expr[n][0]): lower = int_expr[n][1] elif len(int_expr[n][1].free_symbols) and \ self.__eval_cond(lower >= int_expr[n][0]): if self.__eval_cond(lower == int_expr[n][0]): lower = int_expr[n][1] else: int_expr[n][1] = Min(lower, int_expr[n][1]) elif len(int_expr[n][1].free_symbols) and \ lower < int_expr[n][0] is not True: upper = Min(upper, int_expr[n][0]) elif self.__eval_cond(upper > int_expr[n][0]) and \ self.__eval_cond(upper <= int_expr[n][1]): upper = int_expr[n][0] elif len(int_expr[n][0].free_symbols) and \ self.__eval_cond(upper < int_expr[n][1]): int_expr[n][0] = Max(upper, int_expr[n][0]) if self.__eval_cond(lower >= upper) is not True: # Is it still an interval? int_expr.append([lower, upper, expr]) if cond is targetcond: return [(lower, upper)] elif isinstance(targetcond, Or) and cond in targetcond.args: or_cond = Or(or_cond, cond) or_intervals.append((lower, upper)) if or_cond == targetcond: or_intervals.sort(key=lambda x:x[0]) return or_intervals int_expr.sort(key=lambda x: x[1].sort_key() if x[1].is_number else S.NegativeInfinity.sort_key()) int_expr.sort(key=lambda x: x[0].sort_key() if x[0].is_number else S.Infinity.sort_key()) from sympy.functions.elementary.miscellaneous import MinMaxBase for n in range(len(int_expr)): if len(int_expr[n][0].free_symbols) or len(int_expr[n][1].free_symbols): if isinstance(int_expr[n][1], Min) or int_expr[n][1] == b: newval = Min(*int_expr[n][:-1]) if n > 0 and int_expr[n][0] == int_expr[n-1][1]: int_expr[n-1][1] = newval int_expr[n][0] = newval else: newval = Max(*int_expr[n][:-1]) if n < len(int_expr) - 1 and int_expr[n][1] == int_expr[n+1][0]: int_expr[n+1][0] = newval int_expr[n][1] = newval # Add holes to list of intervals if there is a default value, # otherwise raise a ValueError. holes = [] curr_low = a for int_a, int_b, expr in int_expr: if (curr_low < int_a) is True: holes.append([curr_low, Min(b, int_a), default]) elif (curr_low >= int_a) is not True: holes.append([curr_low, Min(b, int_a), default]) curr_low = Min(b, int_b) if (curr_low < b) is True: holes.append([Min(b, curr_low), b, default]) elif (curr_low >= b) is not True: holes.append([Min(b, curr_low), b, default]) if holes and default is not None: int_expr.extend(holes) if targetcond is True: return [(h[0], h[1]) for h in holes] elif holes and default == None: raise ValueError("Called interval evaluation over piecewise " \ "function on undefined intervals %s" % \ ", ".join([str((h[0], h[1])) for h in holes])) return int_expr def _eval_power(self, s): return Piecewise(*[(e**s, c) for e, c in self.args]) def _eval_subs(self, old, new): """ Piecewise conditions may contain bool which are not of Basic type. """ args = list(self.args) for i, (e, c) in enumerate(args): e = e._subs(old, new) if isinstance(c, bool): pass elif isinstance(c, Basic): c = c._subs(old, new) args[i] = e, c return Piecewise(*args) def _eval_nseries(self, x, n, logx): args = [(ec.expr._eval_nseries(x, n, logx), ec.cond) for ec in self.args] return self.func(*args) def _eval_template_is_attr(self, is_attr, when_multiple=None): b = None for expr, _ in self.args: a = getattr(expr, is_attr) if a is None: return None if b is None: b = a elif b is not a: return when_multiple return b _eval_is_bounded = lambda self: self._eval_template_is_attr('is_bounded', when_multiple=False) _eval_is_complex = lambda self: self._eval_template_is_attr('is_complex') _eval_is_even = lambda self: self._eval_template_is_attr('is_even') _eval_is_imaginary = lambda self: self._eval_template_is_attr('is_imaginary') _eval_is_integer = lambda self: self._eval_template_is_attr('is_integer') _eval_is_irrational = lambda self: self._eval_template_is_attr('is_irrational') _eval_is_negative = lambda self: self._eval_template_is_attr('is_negative') _eval_is_nonnegative = lambda self: self._eval_template_is_attr('is_nonnegative') _eval_is_nonpositive = lambda self: self._eval_template_is_attr('is_nonpositive') _eval_is_nonzero = lambda self: self._eval_template_is_attr('is_nonzero', when_multiple=True) _eval_is_odd = lambda self: self._eval_template_is_attr('is_odd') _eval_is_polar = lambda self: self._eval_template_is_attr('is_polar') _eval_is_positive = lambda self: self._eval_template_is_attr('is_positive') _eval_is_real = lambda self: self._eval_template_is_attr('is_real') _eval_is_zero = lambda self: self._eval_template_is_attr('is_zero', when_multiple=False) @classmethod def __eval_cond(cls, cond): """Return the truth value of the condition.""" if cond is True: return True return None
[docs]def piecewise_fold(expr): """ Takes an expression containing a piecewise function and returns the expression in piecewise form. Examples ======== >>> from sympy import Piecewise, piecewise_fold, sympify as S >>> from sympy.abc import x >>> p = Piecewise((x, x < 1), (1, S(1) <= x)) >>> piecewise_fold(x*p) Piecewise((x**2, x < 1), (x, 1 <= x)) See Also ======== Piecewise """ if not isinstance(expr, Basic) or not expr.has(Piecewise): return expr new_args = list(map(piecewise_fold, expr.args)) if expr.func is ExprCondPair: return ExprCondPair(*new_args) piecewise_args = [] for n, arg in enumerate(new_args): if arg.func is Piecewise: piecewise_args.append(n) if len(piecewise_args) > 0: n = piecewise_args[0] new_args = [(expr.func(*(new_args[:n] + [e] + new_args[n+1:])), c) \ for e, c in new_args[n].args] if len(piecewise_args) > 1: return piecewise_fold(Piecewise(*new_args)) return Piecewise(*new_args)