# Source code for sympy.solvers.pde

"""
Analytical methods for solving Partial Differential Equations
Currently implemented methods:
- separation of variables - pde_separate

"""

from sympy import Eq, Equality
from sympy.simplify import simplify
from sympy.core.compatibility import reduce
from sympy.utilities.iterables import has_dups

import operator

[docs]def pde_separate(eq, fun, sep, strategy='mul'):
"""Separate variables in partial differential equation either by additive
or multiplicative separation approach. It tries to rewrite an equation so
that one of the specified variables occurs on a different side of the
equation than the others.

:param eq: Partial differential equation

:param fun: Original function F(x, y, z)

:param sep: List of separated functions [X(x), u(y, z)]

:param strategy: Separation strategy. You can choose between additive
separation ('add') and multiplicative separation ('mul') which is
default.

Examples
========

>>> from sympy import E, Eq, Function, pde_separate, Derivative as D
>>> from sympy.abc import x, t
>>> u, X, T = map(Function, 'uXT')

>>> eq = Eq(D(u(x, t), x), E**(u(x, t))*D(u(x, t), t))
>>> pde_separate(eq, u(x, t), [X(x), T(t)], strategy='add')
[exp(-X(x))*Derivative(X(x), x), exp(T(t))*Derivative(T(t), t)]

>>> eq = Eq(D(u(x, t), x, 2), D(u(x, t), t, 2))
>>> pde_separate(eq, u(x, t), [X(x), T(t)], strategy='mul')
[Derivative(X(x), x, x)/X(x), Derivative(T(t), t, t)/T(t)]

========
"""

elif strategy == 'mul':
else:
assert ValueError('Unknown strategy: %s' % strategy)

if isinstance(eq, Equality):
if eq.rhs != 0:
return pde_separate(Eq(eq.lhs - eq.rhs), fun, sep, strategy)
assert eq.rhs == 0

# Handle arguments
orig_args = list(fun.args)
subs_args = []
for s in sep:
for j in range(0, len(s.args)):
subs_args.append(s.args[j])

else:
functions = reduce(operator.mul, sep)

# Check whether variables match
if len(subs_args) != len(orig_args):
raise ValueError("Variable counts do not match")
# Check for duplicate arguments like  [X(x), u(x, y)]
if has_dups(subs_args):
raise ValueError("Duplicate substitution arguments detected")
# Check whether the variables match
if set(orig_args) != set(subs_args):
raise ValueError("Arguments do not match")

# Substitute original function with separated...
result = eq.lhs.subs(fun, functions).doit()

# Divide by terms when doing multiplicative separation
eq = 0
for i in result.args:
eq += i/functions
result = eq

svar = subs_args[0]
dvar = subs_args[1:]
return _separate(result, svar, dvar)

"""
Helper function for searching additive separable solutions.

Consider an equation of two independent variables x, y and a dependent
variable w, we look for the product of two functions depending on different
arguments:

w(x, y, z) = X(x) + y(y, z)

Examples
========

>>> from sympy import E, Eq, Function, pde_separate_add, Derivative as D
>>> from sympy.abc import x, t
>>> u, X, T = map(Function, 'uXT')

>>> eq = Eq(D(u(x, t), x), E**(u(x, t))*D(u(x, t), t))
>>> pde_separate_add(eq, u(x, t), [X(x), T(t)])
[exp(-X(x))*Derivative(X(x), x), exp(T(t))*Derivative(T(t), t)]

"""

[docs]def pde_separate_mul(eq, fun, sep):
"""
Helper function for searching multiplicative separable solutions.

Consider an equation of two independent variables x, y and a dependent
variable w, we look for the product of two functions depending on different
arguments:

w(x, y, z) = X(x)*u(y, z)

Examples
========

>>> from sympy import Function, Eq, pde_separate_mul, Derivative as D
>>> from sympy.abc import x, y
>>> u, X, Y = map(Function, 'uXY')

>>> eq = Eq(D(u(x, y), x, 2), D(u(x, y), y, 2))
>>> pde_separate_mul(eq, u(x, y), [X(x), Y(y)])
[Derivative(X(x), x, x)/X(x), Derivative(Y(y), y, y)/Y(y)]

"""
return pde_separate(eq, fun, sep, strategy='mul')

def _separate(eq, dep, others):
"""Separate expression into two parts based on dependencies of variables."""

# FIRST PASS
# Extract derivatives depending our separable variable...
terms = set()
for term in eq.args:
if term.is_Mul:
for i in term.args:
if i.is_Derivative and not i.has(*others):
continue
elif term.is_Derivative and not term.has(*others):
# Find the factor that we need to divide by
div = set()
for term in terms:
ext, sep = term.expand().as_independent(dep)
# Failed?
if sep.has(*others):
return None
# FIXME: Find lcm() of all the divisors and divide with it, instead of
# current hack :(
if len(div) > 0:
final = 0
for term in eq.args:
eqn = 0
for i in div:
eqn += term / i
final += simplify(eqn)
eq = final

# SECOND PASS - separate the derivatives
div = set()
lhs = rhs = 0
for term in eq.args:
# Check, whether we have already term with independent variable...
if not term.has(*others):
lhs += term
continue
# ...otherwise, try to separate
temp, sep = term.expand().as_independent(dep)
# Failed?
if sep.has(*others):
return None
# Extract the divisors