|
from functools import reduce |
|
import operator |
|
|
|
from sympy.core import Basic, sympify |
|
from sympy.core.add import add, Add, _could_extract_minus_sign |
|
from sympy.core.sorting import default_sort_key |
|
from sympy.functions import adjoint |
|
from sympy.matrices.matrixbase import MatrixBase |
|
from sympy.matrices.expressions.transpose import transpose |
|
from sympy.strategies import (rm_id, unpack, flatten, sort, condition, |
|
exhaust, do_one, glom) |
|
from sympy.matrices.expressions.matexpr import MatrixExpr |
|
from sympy.matrices.expressions.special import ZeroMatrix, GenericZeroMatrix |
|
from sympy.matrices.expressions._shape import validate_matadd_integer as validate |
|
from sympy.utilities.iterables import sift |
|
from sympy.utilities.exceptions import sympy_deprecation_warning |
|
|
|
|
|
class MatAdd(MatrixExpr, Add): |
|
"""A Sum of Matrix Expressions |
|
|
|
MatAdd inherits from and operates like SymPy Add |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import MatAdd, MatrixSymbol |
|
>>> A = MatrixSymbol('A', 5, 5) |
|
>>> B = MatrixSymbol('B', 5, 5) |
|
>>> C = MatrixSymbol('C', 5, 5) |
|
>>> MatAdd(A, B, C) |
|
A + B + C |
|
""" |
|
is_MatAdd = True |
|
|
|
identity = GenericZeroMatrix() |
|
|
|
def __new__(cls, *args, evaluate=False, check=None, _sympify=True): |
|
if not args: |
|
return cls.identity |
|
|
|
|
|
|
|
args = list(filter(lambda i: cls.identity != i, args)) |
|
if _sympify: |
|
args = list(map(sympify, args)) |
|
|
|
if not all(isinstance(arg, MatrixExpr) for arg in args): |
|
raise TypeError("Mix of Matrix and Scalar symbols") |
|
|
|
obj = Basic.__new__(cls, *args) |
|
|
|
if check is not None: |
|
sympy_deprecation_warning( |
|
"Passing check to MatAdd is deprecated and the check argument will be removed in a future version.", |
|
deprecated_since_version="1.11", |
|
active_deprecations_target='remove-check-argument-from-matrix-operations') |
|
|
|
if check is not False: |
|
validate(*args) |
|
|
|
if evaluate: |
|
obj = cls._evaluate(obj) |
|
|
|
return obj |
|
|
|
@classmethod |
|
def _evaluate(cls, expr): |
|
return canonicalize(expr) |
|
|
|
@property |
|
def shape(self): |
|
return self.args[0].shape |
|
|
|
def could_extract_minus_sign(self): |
|
return _could_extract_minus_sign(self) |
|
|
|
def expand(self, **kwargs): |
|
expanded = super(MatAdd, self).expand(**kwargs) |
|
return self._evaluate(expanded) |
|
|
|
def _entry(self, i, j, **kwargs): |
|
return Add(*[arg._entry(i, j, **kwargs) for arg in self.args]) |
|
|
|
def _eval_transpose(self): |
|
return MatAdd(*[transpose(arg) for arg in self.args]).doit() |
|
|
|
def _eval_adjoint(self): |
|
return MatAdd(*[adjoint(arg) for arg in self.args]).doit() |
|
|
|
def _eval_trace(self): |
|
from .trace import trace |
|
return Add(*[trace(arg) for arg in self.args]).doit() |
|
|
|
def doit(self, **hints): |
|
deep = hints.get('deep', True) |
|
if deep: |
|
args = [arg.doit(**hints) for arg in self.args] |
|
else: |
|
args = self.args |
|
return canonicalize(MatAdd(*args)) |
|
|
|
def _eval_derivative_matrix_lines(self, x): |
|
add_lines = [arg._eval_derivative_matrix_lines(x) for arg in self.args] |
|
return [j for i in add_lines for j in i] |
|
|
|
add.register_handlerclass((Add, MatAdd), MatAdd) |
|
|
|
|
|
factor_of = lambda arg: arg.as_coeff_mmul()[0] |
|
matrix_of = lambda arg: unpack(arg.as_coeff_mmul()[1]) |
|
def combine(cnt, mat): |
|
if cnt == 1: |
|
return mat |
|
else: |
|
return cnt * mat |
|
|
|
|
|
def merge_explicit(matadd): |
|
""" Merge explicit MatrixBase arguments |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import MatrixSymbol, eye, Matrix, MatAdd, pprint |
|
>>> from sympy.matrices.expressions.matadd import merge_explicit |
|
>>> A = MatrixSymbol('A', 2, 2) |
|
>>> B = eye(2) |
|
>>> C = Matrix([[1, 2], [3, 4]]) |
|
>>> X = MatAdd(A, B, C) |
|
>>> pprint(X) |
|
[1 0] [1 2] |
|
A + [ ] + [ ] |
|
[0 1] [3 4] |
|
>>> pprint(merge_explicit(X)) |
|
[2 2] |
|
A + [ ] |
|
[3 5] |
|
""" |
|
groups = sift(matadd.args, lambda arg: isinstance(arg, MatrixBase)) |
|
if len(groups[True]) > 1: |
|
return MatAdd(*(groups[False] + [reduce(operator.add, groups[True])])) |
|
else: |
|
return matadd |
|
|
|
|
|
rules = (rm_id(lambda x: x == 0 or isinstance(x, ZeroMatrix)), |
|
unpack, |
|
flatten, |
|
glom(matrix_of, factor_of, combine), |
|
merge_explicit, |
|
sort(default_sort_key)) |
|
|
|
canonicalize = exhaust(condition(lambda x: isinstance(x, MatAdd), |
|
do_one(*rules))) |
|
|