Source code for sympy.printing.lambdarepr

from __future__ import print_function, division
from distutils.version import LooseVersion as V

from .str import StrPrinter
from .pycode import (
    PythonCodePrinter,
    MpmathPrinter,  # MpmathPrinter is imported for backward compatibility
    NumPyPrinter  # NumPyPrinter is imported for backward compatibility
)
from sympy.external import import_module
from sympy.utilities import default_sort_key

tensorflow = import_module('tensorflow')
if tensorflow and V(tensorflow.__version__) < '1.0':
    tensorflow_piecewise = "select"
else:
    tensorflow_piecewise = "where"

[docs]class LambdaPrinter(PythonCodePrinter): """ This printer converts expressions into strings that can be used by lambdify. """ printmethod = "_lambdacode" def _print_And(self, expr): result = ['('] for arg in sorted(expr.args, key=default_sort_key): result.extend(['(', self._print(arg), ')']) result.append(' and ') result = result[:-1] result.append(')') return ''.join(result) def _print_Or(self, expr): result = ['('] for arg in sorted(expr.args, key=default_sort_key): result.extend(['(', self._print(arg), ')']) result.append(' or ') result = result[:-1] result.append(')') return ''.join(result) def _print_Not(self, expr): result = ['(', 'not (', self._print(expr.args[0]), '))'] return ''.join(result) def _print_BooleanTrue(self, expr): return "True" def _print_BooleanFalse(self, expr): return "False" def _print_ITE(self, expr): result = [ '((', self._print(expr.args[1]), ') if (', self._print(expr.args[0]), ') else (', self._print(expr.args[2]), '))' ] return ''.join(result) def _print_NumberSymbol(self, expr): return str(expr)
class TensorflowPrinter(LambdaPrinter): """ Tensorflow printer which handles vectorized piecewise functions, logical operators, max/min, and relational operators. """ printmethod = "_tensorflowcode" def _print_And(self, expr): "Logical And printer" # We have to override LambdaPrinter because it uses Python 'and' keyword. # If LambdaPrinter didn't define it, we could use StrPrinter's # version of the function and add 'logical_and' to TENSORFLOW_TRANSLATIONS. return '{0}({1})'.format('logical_and', ','.join(self._print(i) for i in expr.args)) def _print_Or(self, expr): "Logical Or printer" # We have to override LambdaPrinter because it uses Python 'or' keyword. # If LambdaPrinter didn't define it, we could use StrPrinter's # version of the function and add 'logical_or' to TENSORFLOW_TRANSLATIONS. return '{0}({1})'.format('logical_or', ','.join(self._print(i) for i in expr.args)) def _print_Not(self, expr): "Logical Not printer" # We have to override LambdaPrinter because it uses Python 'not' keyword. # If LambdaPrinter didn't define it, we would still have to define our # own because StrPrinter doesn't define it. return '{0}({1})'.format('logical_not', ','.join(self._print(i) for i in expr.args)) def _print_Min(self, expr, **kwargs): from sympy import Min if len(expr.args) == 1: return self._print(expr.args[0], **kwargs) return 'minimum({0}, {1})'.format( self._print(expr.args[0], **kwargs), self._print(Min(*expr.args[1:]), **kwargs)) def _print_Max(self, expr, **kwargs): from sympy import Max if len(expr.args) == 1: return self._print(expr.args[0], **kwargs) return 'maximum({0}, {1})'.format( self._print(expr.args[0], **kwargs), self._print(Max(*expr.args[1:]), **kwargs)) def _print_Piecewise(self, expr, **kwargs): from sympy import Piecewise e, cond = expr.args[0].args if len(expr.args) == 1: return '{0}({1}, {2}, {3})'.format( tensorflow_piecewise, self._print(cond, **kwargs), self._print(e, **kwargs), 0) return '{0}({1}, {2}, {3})'.format( tensorflow_piecewise, self._print(cond, **kwargs), self._print(e, **kwargs), self._print(Piecewise(*expr.args[1:]), **kwargs)) def _print_Relational(self, expr): "Relational printer for Equality and Unequality" op = { '==' :'equal', '!=' :'not_equal', '<' :'less', '<=' :'less_equal', '>' :'greater', '>=' :'greater_equal', } if expr.rel_op in op: lhs = self._print(expr.lhs) rhs = self._print(expr.rhs) return '{op}({lhs}, {rhs})'.format(op=op[expr.rel_op], lhs=lhs, rhs=rhs) return super(TensorflowPrinter, self)._print_Relational(expr) # numexpr works by altering the string passed to numexpr.evaluate # rather than by populating a namespace. Thus a special printer... class NumExprPrinter(LambdaPrinter): # key, value pairs correspond to sympy name and numexpr name # functions not appearing in this dict will raise a TypeError printmethod = "_numexprcode" _numexpr_functions = { 'sin' : 'sin', 'cos' : 'cos', 'tan' : 'tan', 'asin': 'arcsin', 'acos': 'arccos', 'atan': 'arctan', 'atan2' : 'arctan2', 'sinh' : 'sinh', 'cosh' : 'cosh', 'tanh' : 'tanh', 'asinh': 'arcsinh', 'acosh': 'arccosh', 'atanh': 'arctanh', 'ln' : 'log', 'log': 'log', 'exp': 'exp', 'sqrt' : 'sqrt', 'Abs' : 'abs', 'conjugate' : 'conj', 'im' : 'imag', 're' : 'real', 'where' : 'where', 'complex' : 'complex', 'contains' : 'contains', } def _print_ImaginaryUnit(self, expr): return '1j' def _print_seq(self, seq, delimiter=', '): # simplified _print_seq taken from pretty.py s = [self._print(item) for item in seq] if s: return delimiter.join(s) else: return "" def _print_Function(self, e): func_name = e.func.__name__ nstr = self._numexpr_functions.get(func_name, None) if nstr is None: # check for implemented_function if hasattr(e, '_imp_'): return "(%s)" % self._print(e._imp_(*e.args)) else: raise TypeError("numexpr does not support function '%s'" % func_name) return "%s(%s)" % (nstr, self._print_seq(e.args)) def blacklisted(self, expr): raise TypeError("numexpr cannot be used with %s" % expr.__class__.__name__) # blacklist all Matrix printing _print_SparseMatrix = \ _print_MutableSparseMatrix = \ _print_ImmutableSparseMatrix = \ _print_Matrix = \ _print_DenseMatrix = \ _print_MutableDenseMatrix = \ _print_ImmutableMatrix = \ _print_ImmutableDenseMatrix = \ blacklisted # blacklist some python expressions _print_list = \ _print_tuple = \ _print_Tuple = \ _print_dict = \ _print_Dict = \ blacklisted def doprint(self, expr): lstr = super(NumExprPrinter, self).doprint(expr) return "evaluate('%s', truediv=True)" % lstr for k in NumExprPrinter._numexpr_functions: setattr(NumExprPrinter, '_print_%s' % k, NumExprPrinter._print_Function)
[docs]def lambdarepr(expr, **settings): """ Returns a string usable for lambdifying. """ return LambdaPrinter(settings).doprint(expr)