/

# Source code for sympy.functions.elementary.piecewise

from sympy.core import Basic, S, Function, diff, Tuple
from sympy.core.relational import Equality, Relational
from sympy.core.symbol import Dummy
from sympy.functions.elementary.miscellaneous import Max, Min
from sympy.logic.boolalg import And, Boolean, Or
from sympy.utilities.misc import default_sort_key

[docs]class ExprCondPair(Tuple):
"""Represents an expression, condition pair."""

true_sentinel = Dummy('True')

def __new__(cls, expr, cond):
if cond is True:
cond = ExprCondPair.true_sentinel
return Tuple.__new__(cls, expr, cond)

@property
[docs]    def expr(self):
"""
Returns the expression of this pair.
"""
return self.args[0]

@property
[docs]    def cond(self):
"""
Returns the condition of this pair.
"""
if self.args[1] == ExprCondPair.true_sentinel:
return True
return self.args[1]

@property
[docs]    def free_symbols(self):
"""
Return the free symbols of this pair.
"""
# Overload Basic.free_symbols because self.args[1] may contain non-Basic
result = self.expr.free_symbols
if hasattr(self.cond, 'free_symbols'):
result |= self.cond.free_symbols
return result

@property
def is_commutative(self):
return self.expr.is_commutative

def __iter__(self):
yield self.expr
yield self.cond

[docs]class Piecewise(Function):
"""
Represents a piecewise function.

Usage:

Piecewise( (expr,cond), (expr,cond), ... )
- Each argument is a 2-tuple defining a expression and condition
- The conds are evaluated in turn returning the first that is True.
If any of the evaluated conds are not determined explicitly False,
e.g. x < 1, the function is returned in symbolic form.
- If the function is evaluated at a place where all conditions are False,
a ValueError exception will be raised.
- Pairs where the cond is explicitly False, will be removed.

Examples
========

>>> from sympy import Piecewise, log
>>> from sympy.abc import x
>>> f = x**2
>>> g = log(x)
>>> p = Piecewise( (0, x<-1), (f, x<=1), (g, True))
>>> p.subs(x,1)
1
>>> p.subs(x,5)
log(5)

========

piecewise_fold
"""

nargs = None
is_Piecewise = True

def __new__(cls, *args, **options):
# (Try to) sympify args first
newargs = []
for ec in args:
pair = ExprCondPair(*ec)
cond = pair.cond
if cond is False:
continue
if not isinstance(cond, (bool, Relational, Boolean)):
raise TypeError(
"Cond %s is of type %s, but must be a Relational," \
" Boolean, or a built-in bool." % (cond, type(cond)))
newargs.append(pair)
if cond is True:
break

if options.pop('evaluate', True):
r = cls.eval(*newargs)
else:
r = None

if r is None:
return Basic.__new__(cls, *newargs, **options)
else:
return r

@classmethod
def eval(cls, *args):
# Check for situations where we can evaluate the Piecewise object.
# 1) Hit an unevaluable cond (e.g. x<1) -> keep object
# 2) Hit a true condition -> return that expr
# 3) Remove false conditions, if no conditions left -> raise ValueError
all_conds_evaled = True    # Do all conds eval to a bool?
piecewise_again = False    # Should we pass args to Piecewise again?
non_false_ecpairs = []
or1 = Or(*[cond for (_, cond) in args if cond is not True])
for expr, cond in args:
# Check here if expr is a Piecewise and collapse if one of
# the conds in expr matches cond. This allows the collapsing
# of Piecewise((Piecewise(x,x<0),x<0)) to Piecewise((x,x<0)).
# This is important when using piecewise_fold to simplify
# multiple Piecewise instances having the same conds.
# Eventually, this code should be able to collapse Piecewise's
# having different intervals, but this will probably require
# using the new assumptions.
if isinstance(expr, Piecewise):
or2 = Or(*[c for (_, c) in expr.args if c is not True])
for e, c in expr.args:
# Don't collapse if cond is "True" as this leads to
# incorrect simplifications with nested Piecewises.
if c == cond and (or1 == or2 or cond is not True):
expr = e
piecewise_again = True
cond_eval = cls.__eval_cond(cond)
if cond_eval is None:
all_conds_evaled = False
elif cond_eval:
if all_conds_evaled:
return expr
if len(non_false_ecpairs) != 0 and non_false_ecpairs[-1].expr == expr:
non_false_ecpairs[-1] = ExprCondPair(expr, Or(cond, non_false_ecpairs[-1].cond))
else:
non_false_ecpairs.append( ExprCondPair(expr, cond) )
if len(non_false_ecpairs) != len(args) or piecewise_again:
return Piecewise(*non_false_ecpairs)

return None

[docs]    def doit(self, **hints):
"""
Evaluate this piecewise function.
"""
newargs = []
for e, c in self.args:
if hints.get('deep', True):
if isinstance(e, Basic):
e = e.doit(**hints)
if isinstance(c, Basic):
c = c.doit(**hints)
newargs.append((e, c))
return Piecewise(*newargs)

for e, c in self.args:
if c is True or c.subs(x, 0) is True:

def _eval_conjugate(self):
from sympy.functions.elementary.complexes import conjugate
return Piecewise(*[(conjugate(e), c) for e, c in self.args])

def _eval_derivative(self, x):
return Piecewise(*[(diff(e, x), c) for e, c in self.args])

def _eval_evalf(self, prec):
return Piecewise(*[(e.evalf(prec), c) for e, c in self.args])

def _eval_integral(self, x):
from sympy.integrals import integrate
return Piecewise(*[(integrate(e, x), c) for e, c in self.args])

def _eval_interval(self, sym, a, b):
"""Evaluates the function along the sym in a given interval ab"""
# FIXME: Currently complex intervals are not supported.  A possible
# replacement algorithm, discussed in issue 2128, can be found in the
# following papers;
#     http://portal.acm.org/citation.cfm?id=281649
mul = 1
if (a == b) is True:
return S.Zero
elif (a > b) is True:
a, b, mul = b, a, -1
elif (a <= b) is not True:
newargs = []
for e, c in self.args:
intervals = self._sort_expr_cond(sym, S.NegativeInfinity, S.Infinity, c)
values = []
for lower, upper in intervals:
if (a < lower) is True:
mid = lower
rep = b
val = e.subs(sym, b) - e.subs(sym, mid)
val += self._eval_interval(sym, a, mid)
elif (a > upper) is True:
mid = upper
rep = b
val = e.subs(sym, b) - e.subs(sym, mid)
val += self._eval_interval(sym, a, mid)
elif (a >= lower) is True and (a <= upper) is True:
rep = b
val = e.subs(sym, b) - e.subs(sym, a)
elif (b < lower) is True:
mid = lower
rep = a
val = e.subs(sym, mid) - e.subs(sym, a)
val += self._eval_interval(sym, mid, b)
elif (b > upper) is True:
mid = upper
rep = a
val = e.subs(sym, mid) - e.subs(sym, a)
val += self._eval_interval(sym, mid, b)
elif ((b >= lower) is True) and ((b <= upper) is True):
rep = a
val = e.subs(sym, b) - e.subs(sym, a)
else:
raise NotImplementedError(
"""The evaluation of a Piecewise interval when both the lower
and the upper limit are symbolic is not yet implemented.""")
values.append(val)
if len(set(values)) == 1 :
try:
c = c.subs(sym, rep)
except AttributeError:
pass
e = values[0]
newargs.append((e, c))
else:
for i in range(len(values)):
newargs.append((values[i], (c is True and i == len(values) - 1) or
And(rep >= intervals[i][0], rep <= intervals[i][1])))
return Piecewise(*newargs)

# Determine what intervals the expr,cond pairs affect.
int_expr = self._sort_expr_cond(sym, a, b)

# Finally run through the intervals and sum the evaluation.
ret_fun = 0
for int_a, int_b, expr in int_expr:
ret_fun += expr._eval_interval(sym, Max(a, int_a), Min(b, int_b))
return mul * ret_fun

def _sort_expr_cond(self, sym, a, b, targetcond = None):
"""Determine what intervals the expr, cond pairs affect.

1) If cond is True, then log it as default
1.1) Currently if cond can't be evaluated, throw NotImplementedError.
2) For each inequality, if previous cond defines part of the interval
update the new conds interval.
-  eg x < 1, x < 3 -> [oo,1],[1,3] instead of [oo,1],[oo,3]
3) Sort the intervals to make it easier to find correct exprs

Under normal use, we return the expr,cond pairs in increasing order
along the real axis corresponding to the symbol sym.  If targetcond
is given, we return a list of (lowerbound, upperbound) pairs for
this condition."""
default = None
int_expr = []
expr_cond = []
or_cond = False
or_intervals = []
for expr, cond in self.args:
if isinstance(cond, Or):
for cond2 in sorted(cond.args, key=default_sort_key):
expr_cond.append((expr, cond2))
else:
expr_cond.append((expr, cond))
if cond is True:
break
for expr, cond in expr_cond:
if cond is True:
default = expr
break
elif isinstance(cond, Equality):
continue
elif isinstance(cond, And):
lower = S.NegativeInfinity
upper = S.Infinity
for cond2 in cond.args:
if cond2.lts.has(sym):
upper = Min(cond2.gts, upper)
elif cond2.gts.has(sym):
lower = Max(cond2.lts, lower)
else:
lower, upper = cond.lts, cond.gts # part 1: initialize with givens
if cond.lts.has(sym):     # part 1a: expand the side ...
lower = S.NegativeInfinity   # e.g. x <= 0 ---> -oo <= 0
elif cond.gts.has(sym):   # part 1a: ... that can be expanded
upper = S.Infinity           # e.g. x >= 0 --->  oo >= 0
else:
raise NotImplementedError(
"Unable to handle interval evaluation of expression.")

# part 1b: Reduce (-)infinity to what was passed in.
lower, upper = Max(a, lower), Min(b, upper)

for n in xrange(len(int_expr)):
# Part 2: remove any interval overlap.  For any conflicts, the
# its bounds accordingly.
if self.__eval_cond(lower < int_expr[n][1]) and \
self.__eval_cond(lower >= int_expr[n][0]):
lower = int_expr[n][1]
elif len(int_expr[n][1].free_symbols) and \
self.__eval_cond(lower >= int_expr[n][0]):
if self.__eval_cond(lower == int_expr[n][0]):
lower = int_expr[n][1]
else:
int_expr[n][1] = Min(lower, int_expr[n][1])
elif len(int_expr[n][1].free_symbols) and \
lower < int_expr[n][0] is not True:
upper = Min(upper, int_expr[n][0])
elif self.__eval_cond(upper > int_expr[n][0]) and \
self.__eval_cond(upper <= int_expr[n][1]):
upper = int_expr[n][0]
elif len(int_expr[n][0].free_symbols) and \
self.__eval_cond(upper < int_expr[n][1]):
int_expr[n][0] = Max(upper, int_expr[n][0])

if self.__eval_cond(lower >= upper) is not True:  # Is it still an interval?
int_expr.append([lower, upper, expr])
if cond is targetcond:
return [(lower, upper)]
elif isinstance(targetcond, Or) and cond in targetcond.args:
or_cond = Or(or_cond, cond)
or_intervals.append((lower, upper))
if or_cond == targetcond:
or_intervals.sort(key=lambda x:x[0])
return or_intervals

int_expr.sort(key=lambda x: x[1].sort_key() if x[1].is_number else S.NegativeInfinity.sort_key())
int_expr.sort(key=lambda x: x[0].sort_key() if x[0].is_number else S.Infinity.sort_key())
from sympy.functions.elementary.miscellaneous import MinMaxBase
for n in xrange(len(int_expr)):
if len(int_expr[n][0].free_symbols) or len(int_expr[n][1].free_symbols):
if isinstance(int_expr[n][1], Min) or int_expr[n][1] == b:
newval = Min(*int_expr[n][:-1])
if n > 0 and int_expr[n][0] == int_expr[n-1][1]:
int_expr[n-1][1] = newval
int_expr[n][0] = newval
else:
newval = Max(*int_expr[n][:-1])
if n < len(int_expr) - 1 and int_expr[n][1] == int_expr[n+1][0]:
int_expr[n+1][0] = newval
int_expr[n][1] = newval

# Add holes to list of intervals if there is a default value,
# otherwise raise a ValueError.
holes = []
curr_low = a
for int_a, int_b, expr in int_expr:
if (curr_low < int_a) is True:
holes.append([curr_low, Min(b, int_a), default])
elif (curr_low >= int_a) is not True:
holes.append([curr_low, Min(b, int_a), default])
curr_low = Min(b, int_b)
if (curr_low < b) is True:
holes.append([Min(b, curr_low), b, default])
elif (curr_low >= b) is not True:
holes.append([Min(b, curr_low), b, default])

if holes and default is not None:
int_expr.extend(holes)
if targetcond is True:
return [(h[0], h[1]) for h in holes]
elif holes and default == None:
raise ValueError("Called interval evaluation over piecewise " \
"function on undefined intervals %s" % \
", ".join([str((h[0], h[1])) for h in holes]))

return int_expr

def _eval_power(self, s):
return Piecewise(*[(e**s, c) for e, c in self.args])

def _eval_subs(self, old, new):
"""
Piecewise conditions may contain bool which are not of Basic type.
"""
args = list(self.args)
for i, (e, c) in enumerate(args):
e = e._subs(old, new)

if isinstance(c, bool):
pass
elif isinstance(c, Basic):
c = c._subs(old, new)

args[i] = e, c

return Piecewise(*args)

def _eval_nseries(self, x, n, logx):
args = map(lambda ec: (ec.expr._eval_nseries(x, n, logx), ec.cond), \
self.args)
return self.func(*args)

def _eval_template_is_attr(self, is_attr, when_multiple=None):
b = None
for expr, _ in self.args:
a = getattr(expr, is_attr)
if a is None:
return None
if b is None:
b = a
elif b is not a:
return when_multiple
return b

_eval_is_bounded = lambda self: self._eval_template_is_attr('is_bounded', when_multiple=False)
_eval_is_complex = lambda self: self._eval_template_is_attr('is_complex')
_eval_is_even = lambda self: self._eval_template_is_attr('is_even')
_eval_is_imaginary = lambda self: self._eval_template_is_attr('is_imaginary')
_eval_is_integer = lambda self: self._eval_template_is_attr('is_integer')
_eval_is_irrational = lambda self: self._eval_template_is_attr('is_irrational')
_eval_is_negative = lambda self: self._eval_template_is_attr('is_negative')
_eval_is_nonnegative = lambda self: self._eval_template_is_attr('is_nonnegative')
_eval_is_nonpositive = lambda self: self._eval_template_is_attr('is_nonpositive')
_eval_is_nonzero = lambda self: self._eval_template_is_attr('is_nonzero', when_multiple=True)
_eval_is_odd = lambda self: self._eval_template_is_attr('is_odd')
_eval_is_polar = lambda self: self._eval_template_is_attr('is_polar')
_eval_is_positive = lambda self: self._eval_template_is_attr('is_positive')
_eval_is_real = lambda self: self._eval_template_is_attr('is_real')
_eval_is_zero = lambda self: self._eval_template_is_attr('is_zero', when_multiple=False)

@classmethod
def __eval_cond(cls, cond):
"""Return the truth value of the condition."""
if cond is True:
return True
return None

[docs]def piecewise_fold(expr):
"""
Takes an expression containing a piecewise function and returns the
expression in piecewise form.

Examples
========

>>> from sympy import Piecewise, piecewise_fold, sympify as S
>>> from sympy.abc import x
>>> p = Piecewise((x, x < 1), (1, S(1) <= x))
>>> piecewise_fold(x*p)
Piecewise((x**2, x < 1), (x, 1 <= x))

========

Piecewise
"""
if not isinstance(expr, Basic) or not expr.has(Piecewise):
return expr
new_args = map(piecewise_fold, expr.args)
if expr.func is ExprCondPair:
return ExprCondPair(*new_args)
piecewise_args = []
for n, arg in enumerate(new_args):
if arg.func is Piecewise:
piecewise_args.append(n)
if len(piecewise_args) > 0:
n = piecewise_args[0]
new_args = [(expr.func(*(new_args[:n] + [e] + new_args[n+1:])), c) \
for e, c in new_args[n].args]
if len(piecewise_args) > 1:
return piecewise_fold(Piecewise(*new_args))
return Piecewise(*new_args)