# Source code for sympy.simplify.combsimp

from __future__ import print_function, division

from sympy.core import Mul, Pow
from sympy.core.basic import preorder_traversal
from sympy.core.function import count_ops
from sympy.functions.combinatorial.factorials import (binomial,
CombinatorialFunction, factorial)
from sympy.functions import gamma
from sympy.simplify.gammasimp import gammasimp, _gammasimp

from sympy.utilities.timeutils import timethis

[docs]@timethis('combsimp')
def combsimp(expr):
r"""
Simplify combinatorial expressions.

This function takes as input an expression containing factorials,
binomials, Pochhammer symbol and other "combinatorial" functions,
and tries to minimize the number of those functions and reduce
the size of their arguments.

The algorithm works by rewriting all combinatorial functions as
gamma functions and applying gammasimp() except simplification
steps that may make an integer argument non-integer. See docstring
of gammasimp for more information.

Then it rewrites expression in terms of factorials and binomials by
rewriting gammas as factorials and converting (a+b)!/a!b! into
binomials.

If expression has gamma functions or combinatorial functions
with non-integer argument, it is automatically passed to gammasimp.

Examples
========

>>> from sympy.simplify import combsimp
>>> from sympy import factorial, binomial, symbols
>>> n, k = symbols('n k', integer = True)

>>> combsimp(factorial(n)/factorial(n - 3))
n*(n - 2)*(n - 1)
>>> combsimp(binomial(n+1, k+1)/binomial(n, k))
(n + 1)/(k + 1)

"""

expr = expr.rewrite(gamma)
if any(isinstance(node, gamma) and not node.args[0].is_integer
for node in preorder_traversal(expr)):
return gammasimp(expr);

expr = _gammasimp(expr, as_comb = True)
expr = _gamma_as_comb(expr)
return expr

def _gamma_as_comb(expr):
"""
Helper function for combsimp.

Rewrites expression in terms of factorials and binomials
"""

expr = expr.rewrite(factorial)

from .simplify import bottom_up

def f(rv):
if not rv.is_Mul:
return rv
rvd = rv.as_powers_dict()
nd_fact_args = [[], []] # numerator, denominator

for k in rvd:
if isinstance(k, factorial) and rvd[k].is_Integer:
if rvd[k].is_positive:
nd_fact_args[0].extend([k.args[0]]*rvd[k])
else:
nd_fact_args[1].extend([k.args[0]]*-rvd[k])
rvd[k] = 0
if not nd_fact_args[0] or not nd_fact_args[1]:
return rv

hit = False
for m in range(2):
i = 0
while i < len(nd_fact_args[m]):
ai = nd_fact_args[m][i]
for j in range(i + 1, len(nd_fact_args[m])):
aj = nd_fact_args[m][j]

sum = ai + aj
if sum in nd_fact_args[1 - m]:
hit = True

nd_fact_args[1 - m].remove(sum)
del nd_fact_args[m][j]
del nd_fact_args[m][i]

rvd[binomial(sum, ai if count_ops(ai) <
count_ops(aj) else aj)] += (
-1 if m == 0 else 1)
break
else:
i += 1

if hit:
return Mul(*([k**rvd[k] for k in rvd] + [factorial(k)
for k in nd_fact_args[0]]))/Mul(*[factorial(k)
for k in nd_fact_args[1]])
return rv

return bottom_up(expr, f)