Source code for sympy.matrices.expressions.matmul

from __future__ import print_function, division

from sympy import Number
from sympy.core import Mul, Basic, sympify, Add
from sympy.core.compatibility import range
from sympy.functions import adjoint
from sympy.matrices.expressions.transpose import transpose
from sympy.strategies import (rm_id, unpack, typed, flatten, exhaust,
        do_one, new)
from sympy.matrices.expressions.matexpr import (MatrixExpr, ShapeError,
        Identity, ZeroMatrix)
from sympy.matrices.matrices import MatrixBase


[docs]class MatMul(MatrixExpr): """ A product of matrix expressions Examples ======== >>> from sympy import MatMul, MatrixSymbol >>> A = MatrixSymbol('A', 5, 4) >>> B = MatrixSymbol('B', 4, 3) >>> C = MatrixSymbol('C', 3, 6) >>> MatMul(A, B, C) A*B*C """ is_MatMul = True def __new__(cls, *args, **kwargs): check = kwargs.get('check', True) args = list(map(sympify, args)) obj = Basic.__new__(cls, *args) factor, matrices = obj.as_coeff_matrices() if check: validate(*matrices) if not matrices: return factor return obj @property def shape(self): matrices = [arg for arg in self.args if arg.is_Matrix] return (matrices[0].rows, matrices[-1].cols) def _entry(self, i, j, expand=True): coeff, matrices = self.as_coeff_matrices() if len(matrices) == 1: # situation like 2*X, matmul is just X return coeff * matrices[0][i, j] head, tail = matrices[0], matrices[1:] if len(tail) == 0: raise ValueError("lenth of tail cannot be 0") X = head Y = MatMul(*tail) from sympy.core.symbol import Dummy from sympy.concrete.summations import Sum from sympy.matrices import ImmutableMatrix k = Dummy('k', integer=True) if X.has(ImmutableMatrix) or Y.has(ImmutableMatrix): return coeff*Add(*[X[i, k]*Y[k, j] for k in range(X.cols)]) result = Sum(coeff*X[i, k]*Y[k, j], (k, 0, X.cols - 1)) if not X.cols.is_number: # Don't waste time in result.doit() if the sum bounds are symbolic expand = False return result.doit() if expand else result def as_coeff_matrices(self): scalars = [x for x in self.args if not x.is_Matrix] matrices = [x for x in self.args if x.is_Matrix] coeff = Mul(*scalars) return coeff, matrices def as_coeff_mmul(self): coeff, matrices = self.as_coeff_matrices() return coeff, MatMul(*matrices) def _eval_transpose(self): return MatMul(*[transpose(arg) for arg in self.args[::-1]]).doit() def _eval_adjoint(self): return MatMul(*[adjoint(arg) for arg in self.args[::-1]]).doit() def _eval_trace(self): factor, mmul = self.as_coeff_mmul() if factor != 1: from .trace import trace return factor * trace(mmul.doit()) else: raise NotImplementedError("Can't simplify any further") def _eval_determinant(self): from sympy.matrices.expressions.determinant import Determinant factor, matrices = self.as_coeff_matrices() square_matrices = only_squares(*matrices) return factor**self.rows * Mul(*list(map(Determinant, square_matrices))) def _eval_inverse(self): try: return MatMul(*[ arg.inverse() if isinstance(arg, MatrixExpr) else arg**-1 for arg in self.args[::-1]]).doit() except ShapeError: from sympy.matrices.expressions.inverse import Inverse return Inverse(self) def doit(self, **kwargs): deep = kwargs.get('deep', True) if deep: args = [arg.doit(**kwargs) for arg in self.args] else: args = self.args return canonicalize(MatMul(*args)) # Needed for partial compatibility with Mul def args_cnc(self, **kwargs): coeff, matrices = self.as_coeff_matrices() # I don't know how coeff could have noncommutative factors, but this # handles it. coeff_c, coeff_nc = coeff.args_cnc(**kwargs) return coeff_c, coeff_nc + matrices
def validate(*matrices): """ Checks for valid shapes for args of MatMul """ for i in range(len(matrices)-1): A, B = matrices[i:i+2] if A.cols != B.rows: raise ShapeError("Matrices %s and %s are not aligned"%(A, B)) # Rules def newmul(*args): if args[0] == 1: args = args[1:] return new(MatMul, *args) def any_zeros(mul): if any([arg.is_zero or (arg.is_Matrix and arg.is_ZeroMatrix) for arg in mul.args]): matrices = [arg for arg in mul.args if arg.is_Matrix] return ZeroMatrix(matrices[0].rows, matrices[-1].cols) return mul def merge_explicit(matmul): """ Merge explicit MatrixBase arguments >>> from sympy import MatrixSymbol, eye, Matrix, MatMul, pprint >>> from sympy.matrices.expressions.matmul import merge_explicit >>> A = MatrixSymbol('A', 2, 2) >>> B = Matrix([[1, 1], [1, 1]]) >>> C = Matrix([[1, 2], [3, 4]]) >>> X = MatMul(A, B, C) >>> pprint(X) [1 1] [1 2] A*[ ]*[ ] [1 1] [3 4] >>> pprint(merge_explicit(X)) [4 6] A*[ ] [4 6] >>> X = MatMul(B, A, C) >>> pprint(X) [1 1] [1 2] [ ]*A*[ ] [1 1] [3 4] >>> pprint(merge_explicit(X)) [1 1] [1 2] [ ]*A*[ ] [1 1] [3 4] """ if not any(isinstance(arg, MatrixBase) for arg in matmul.args): return matmul newargs = [] last = matmul.args[0] for arg in matmul.args[1:]: if isinstance(arg, (MatrixBase, Number)) and isinstance(last, (MatrixBase, Number)): last = last * arg else: newargs.append(last) last = arg newargs.append(last) return MatMul(*newargs) def xxinv(mul): """ Y * X * X.I -> Y """ factor, matrices = mul.as_coeff_matrices() for i, (X, Y) in enumerate(zip(matrices[:-1], matrices[1:])): try: if X.is_square and Y.is_square and X == Y.inverse(): I = Identity(X.rows) return newmul(factor, *(matrices[:i] + [I] + matrices[i+2:])) except ValueError: # Y might not be invertible pass return mul def remove_ids(mul): """ Remove Identities from a MatMul This is a modified version of sympy.strategies.rm_id. This is necesssary because MatMul may contain both MatrixExprs and Exprs as args. See Also -------- sympy.strategies.rm_id """ # Separate Exprs from MatrixExprs in args factor, mmul = mul.as_coeff_mmul() # Apply standard rm_id for MatMuls result = rm_id(lambda x: x.is_Identity is True)(mmul) if result != mmul: return newmul(factor, *result.args) # Recombine and return else: return mul def factor_in_front(mul): factor, matrices = mul.as_coeff_matrices() if factor != 1: return newmul(factor, *matrices) return mul rules = (any_zeros, remove_ids, xxinv, unpack, rm_id(lambda x: x == 1), merge_explicit, factor_in_front, flatten) canonicalize = exhaust(typed({MatMul: do_one(*rules)})) def only_squares(*matrices): """ factor matrices only if they are square """ if matrices[0].rows != matrices[-1].cols: raise RuntimeError("Invalid matrices being multiplied") out = [] start = 0 for i, M in enumerate(matrices): if M.cols == matrices[start].rows: out.append(MatMul(*matrices[start:i+1]).doit()) start = i+1 return out from sympy.assumptions.ask import ask, Q from sympy.assumptions.refine import handlers_dict def refine_MatMul(expr, assumptions): """ >>> from sympy import MatrixSymbol, Q, assuming, refine >>> X = MatrixSymbol('X', 2, 2) >>> expr = X * X.T >>> print(expr) X*X.T >>> with assuming(Q.orthogonal(X)): ... print(refine(expr)) I """ newargs = [] exprargs = [] for args in expr.args: if args.is_Matrix: exprargs.append(args) else: newargs.append(args) last = exprargs[0] for arg in exprargs[1:]: if arg == last.T and ask(Q.orthogonal(arg), assumptions): last = Identity(arg.shape[0]) elif arg == last.conjugate() and ask(Q.unitary(arg), assumptions): last = Identity(arg.shape[0]) else: newargs.append(last) last = arg newargs.append(last) return MatMul(*newargs) handlers_dict['MatMul'] = refine_MatMul