# Source code for sympy.parsing.mathematica

from __future__ import print_function, division

from itertools import product
import re
from sympy import sympify

'''Users can add their own translation dictionary
# Example
In [1]: mathematica('Log3[9]', {'Log3[x]':'log(x,3)'})
Out[1]: 2
In [2]: mathematica('F[7,5,3]', {'F[*x]':'Max(*x)*Min(*x)'})
Out[2]: 21
variable-length argument needs '*' character '''

return sympify(parser.parse(s))

def _deco(cls):
cls._initialize_class()
return cls

@_deco
class MathematicaParser(object):
'''An instance of this class converts a string of a basic Mathematica
expression to SymPy style. Output is string type.'''

# left: Mathematica, right: SymPy
CORRESPONDENCES = {
'Sqrt[x]': 'sqrt(x)',
'Exp[x]': 'exp(x)',
'Log[x]': 'log(x)',
'Log[x,y]': 'log(y,x)',
'Log2[x]': 'log(x,2)',
'Log10[x]': 'log(x,10)',
'Mod[x,y]': 'Mod(x,y)',
'Max[*x]': 'Max(*x)',
'Min[*x]': 'Min(*x)',
}

# trigonometric, e.t.c.
for arc, tri, h in product(('', 'Arc'), (
'Sin', 'Cos', 'Tan', 'Cot', 'Sec', 'Csc'), ('', 'h')):
fm = arc + tri + h + '[x]'
if arc:  # arc func
fs = 'a' + tri.lower() + h + '(x)'
else:    # non-arc func
fs = tri.lower() + h + '(x)'
CORRESPONDENCES.update({fm: fs})

REPLACEMENTS = {
' ': '',
'^': '**',
'{': '[',
'}': ']',
}

RULES = {
# a single whitespace to '*'
'whitespace': (
re.compile(r'''
(?<=[a-zA-Z\d])     # a letter or a number
\                   # a whitespace
(?=[a-zA-Z\d])      # a letter or a number
''', re.VERBOSE),
'*'),

re.compile(r'''
(?<=[])\d])         # ], ) or a number
# ''
(?=[(a-zA-Z])       # ( or a single letter
''', re.VERBOSE),
'*'),

# add omitted '*' character (variable letter preceding)
re.compile(r'''
(?<=[a-zA-Z])       # a letter
\(                  # ( as a character
(?=.)               # any characters
''', re.VERBOSE),
'*('),

# convert 'Pi' to 'pi'
'Pi': (
re.compile(r'''
(?:
\A|(?<=[^a-zA-Z])
)
Pi                  # 'Pi' is 3.14159... in Mathematica
(?=[^a-zA-Z])
''', re.VERBOSE),
'pi'),
}

# Mathematica function name pattern
FM_PATTERN = re.compile(r'''
(?:
\A|(?<=[^a-zA-Z])   # at the top or a non-letter
)
[A-Z][a-zA-Z\d]*    # Function
(?=\[)              # [ as a character
''', re.VERBOSE)

# list or matrix pattern (for future usage)
ARG_MTRX_PATTERN = re.compile(r'''
\{.*\}
''', re.VERBOSE)

# regex string for function argument pattern
ARGS_PATTERN_TEMPLATE = r'''
(?:
\A|(?<=[^a-zA-Z])
)
{arguments}         # model argument like x, y,...
(?=[^a-zA-Z])
'''

# will contain transformed CORRESPONDENCES dictionary
TRANSLATIONS = {}

# cache for a raw users' translation dictionary
cache_original = {}

# cache for a compiled users' translation dictionary
cache_compiled = {}

@classmethod
def _initialize_class(cls):
# get a transformed CORRESPONDENCES dictionary
d = cls._compile_dictionary(cls.CORRESPONDENCES)
cls.TRANSLATIONS.update(d)

self.translations = {}

# update with TRANSLATIONS (class constant)
self.translations.update(self.TRANSLATIONS)

# check the latest added translations
raise ValueError('The argument must be dict type')

# get a transformed additional_translations dictionary

# update cache
self.__class__.cache_compiled = d

# merge user's own translations
self.translations.update(self.__class__.cache_compiled)

@classmethod
def _compile_dictionary(cls, dic):
# for return
d = {}

for fm, fs in dic.items():
# check function form
cls._check_input(fm)
cls._check_input(fs)

# uncover '*' hiding behind a whitespace
fm = cls._apply_rules(fm, 'whitespace')
fs = cls._apply_rules(fs, 'whitespace')

# remove whitespace(s)
fm = cls._replace(fm, ' ')
fs = cls._replace(fs, ' ')

# search Mathematica function name
m = cls.FM_PATTERN.search(fm)

# if no-hit
if m is None:
err = "'{f}' function form is invalid.".format(f=fm)
raise ValueError(err)

# get Mathematica function name like 'Log'
fm_name = m.group()

# get arguments of Mathematica function
args, end = cls._get_args(m)

# function side check. (e.g.) '2*Func[x]' is invalid.
if m.start() != 0 or end != len(fm):
err = "'{f}' function form is invalid.".format(f=fm)
raise ValueError(err)

# check the last argument's 1st character
if args[-1][0] == '*':
key_arg = '*'
else:
key_arg = len(args)

key = (fm_name, key_arg)

# convert '*x' to '\\*x' for regex
re_args = [x if x[0] != '*' else '\\' + x for x in args]

# for regex. Example: (?:(x|y|z))
xyz = '(?:(' + '|'.join(re_args) + '))'

# string for regex compile
patStr = cls.ARGS_PATTERN_TEMPLATE.format(arguments=xyz)

pat = re.compile(patStr, re.VERBOSE)

# update dictionary
d[key] = {}
d[key]['fs'] = fs  # SymPy function template
d[key]['args'] = args  # args are ['x', 'y'] for example
d[key]['pat'] = pat

return d

def _convert_function(self, s):
'''Parse Mathematica function to SymPy one'''

# compiled regex object
pat = self.FM_PATTERN

scanned = ''                # converted string
cur = 0                     # position cursor
while True:
m = pat.search(s)

if m is None:
# append the rest of string
scanned += s
break

# get Mathematica function name
fm = m.group()

# get arguments, and the end position of fm function
args, end = self._get_args(m)

# the start position of fm function
bgn = m.start()

# convert Mathematica function to SymPy one
s = self._convert_one_function(s, fm, args, bgn, end)

# update cursor
cur = bgn

# append converted part
scanned += s[:cur]

# shrink s
s = s[cur:]

return scanned

def _convert_one_function(self, s, fm, args, bgn, end):
# no variable-length argument
if (fm, len(args)) in self.translations:
key = (fm, len(args))

# x, y,... model arguments
x_args = self.translations[key]['args']

# make CORRESPONDENCES between model arguments and actual ones
d = {k: v for k, v in zip(x_args, args)}

# with variable-length argument
elif (fm, '*') in self.translations:
key = (fm, '*')

# x, y,..*args (model arguments)
x_args = self.translations[key]['args']

# make CORRESPONDENCES between model arguments and actual ones
d = {}
for i, x in enumerate(x_args):
if x[0] == '*':
d[x] = ','.join(args[i:])
break
d[x] = args[i]

# out of self.translations
else:
err = "'{f}' is out of the whitelist.".format(f=fm)
raise ValueError(err)

# template string of converted function
template = self.translations[key]['fs']

# regex pattern for x_args
pat = self.translations[key]['pat']

scanned = ''
cur = 0
while True:
m = pat.search(template)

if m is None:
scanned += template
break

# get model argument
x = m.group()

# get a start position of the model argument
xbgn = m.start()

# add the corresponding actual argument
scanned += template[:xbgn] + d[x]

# update cursor to the end of the model argument
cur = m.end()

# shrink template
template = template[cur:]

# update to swapped string
s = s[:bgn] + scanned + s[end:]

return s

@classmethod
def _get_args(cls, m):
'''Get arguments of a Mathematica function'''

s = m.string                # whole string
anc = m.end() + 1           # pointing the first letter of arguments
square, curly = [], []      # stack for brakets
args = []

# current cursor
cur = anc
for i, c in enumerate(s[anc:], anc):
# extract one argument
if c == ',' and (not square) and (not curly):
cur = i + 1                 # move cursor

# handle list or matrix (for future usage)
if c == '{':
curly.append(c)
elif c == '}':
curly.pop()

# seek corresponding ']' with skipping irrevant ones
if c == '[':
square.append(c)
elif c == ']':
if square:
square.pop()
else:   # empty stack
args.append(s[cur:i])
break

# the next position to ']' bracket (the function end)
func_end = i + 1

return args, func_end

@classmethod
def _replace(cls, s, bef):
aft = cls.REPLACEMENTS[bef]
s = s.replace(bef, aft)
return s

@classmethod
def _apply_rules(cls, s, bef):
pat, aft = cls.RULES[bef]
return pat.sub(aft, s)

@classmethod
def _check_input(cls, s):
for bracket in (('[', ']'), ('{', '}'), ('(', ')')):
if s.count(bracket[0]) != s.count(bracket[1]):
err = "'{f}' function form is invalid.".format(f=s)
raise ValueError(err)

if '{' in s:
err = "Currently list is not supported.".format(f=s)
raise ValueError(err)

def parse(self, s):
# input check
self._check_input(s)

# uncover '*' hiding behind a whitespace
s = self._apply_rules(s, 'whitespace')

# remove whitespace(s)
s = self._replace(s, ' ')

# translate function
s = self._convert_function(s)

# '^' to '**'
s = self._replace(s, '^')

# 'Pi' to 'pi'
s = self._apply_rules(s, 'Pi')

# '{', '}' to '[', ']', respectively
#        s = cls._replace(s, '{')   # currently list is not taken into account
#        s = cls._replace(s, '}')

return s