|
from __future__ import annotations |
|
from functools import wraps |
|
|
|
from sympy.core import S, Integer, Basic, Mul, Add |
|
from sympy.core.assumptions import check_assumptions |
|
from sympy.core.decorators import call_highest_priority |
|
from sympy.core.expr import Expr, ExprBuilder |
|
from sympy.core.logic import FuzzyBool |
|
from sympy.core.symbol import Str, Dummy, symbols, Symbol |
|
from sympy.core.sympify import SympifyError, _sympify |
|
from sympy.external.gmpy import SYMPY_INTS |
|
from sympy.functions import conjugate, adjoint |
|
from sympy.functions.special.tensor_functions import KroneckerDelta |
|
from sympy.matrices.exceptions import NonSquareMatrixError |
|
from sympy.matrices.kind import MatrixKind |
|
from sympy.matrices.matrixbase import MatrixBase |
|
from sympy.multipledispatch import dispatch |
|
from sympy.utilities.misc import filldedent |
|
|
|
|
|
def _sympifyit(arg, retval=None): |
|
|
|
def deco(func): |
|
@wraps(func) |
|
def __sympifyit_wrapper(a, b): |
|
try: |
|
b = _sympify(b) |
|
return func(a, b) |
|
except SympifyError: |
|
return retval |
|
|
|
return __sympifyit_wrapper |
|
|
|
return deco |
|
|
|
|
|
class MatrixExpr(Expr): |
|
"""Superclass for Matrix Expressions |
|
|
|
MatrixExprs represent abstract matrices, linear transformations represented |
|
within a particular basis. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import MatrixSymbol |
|
>>> A = MatrixSymbol('A', 3, 3) |
|
>>> y = MatrixSymbol('y', 3, 1) |
|
>>> x = (A.T*A).I * A * y |
|
|
|
See Also |
|
======== |
|
|
|
MatrixSymbol, MatAdd, MatMul, Transpose, Inverse |
|
""" |
|
__slots__: tuple[str, ...] = () |
|
|
|
|
|
|
|
|
|
_iterable = False |
|
|
|
_op_priority = 11.0 |
|
|
|
is_Matrix: bool = True |
|
is_MatrixExpr: bool = True |
|
is_Identity: FuzzyBool = None |
|
is_Inverse = False |
|
is_Transpose = False |
|
is_ZeroMatrix = False |
|
is_MatAdd = False |
|
is_MatMul = False |
|
|
|
is_commutative = False |
|
is_number = False |
|
is_symbol = False |
|
is_scalar = False |
|
|
|
kind: MatrixKind = MatrixKind() |
|
|
|
def __new__(cls, *args, **kwargs): |
|
args = map(_sympify, args) |
|
return Basic.__new__(cls, *args, **kwargs) |
|
|
|
|
|
|
|
@property |
|
def shape(self) -> tuple[Expr, Expr]: |
|
raise NotImplementedError |
|
|
|
@property |
|
def _add_handler(self): |
|
return MatAdd |
|
|
|
@property |
|
def _mul_handler(self): |
|
return MatMul |
|
|
|
def __neg__(self): |
|
return MatMul(S.NegativeOne, self).doit() |
|
|
|
def __abs__(self): |
|
raise NotImplementedError |
|
|
|
@_sympifyit('other', NotImplemented) |
|
@call_highest_priority('__radd__') |
|
def __add__(self, other): |
|
return MatAdd(self, other).doit() |
|
|
|
@_sympifyit('other', NotImplemented) |
|
@call_highest_priority('__add__') |
|
def __radd__(self, other): |
|
return MatAdd(other, self).doit() |
|
|
|
@_sympifyit('other', NotImplemented) |
|
@call_highest_priority('__rsub__') |
|
def __sub__(self, other): |
|
return MatAdd(self, -other).doit() |
|
|
|
@_sympifyit('other', NotImplemented) |
|
@call_highest_priority('__sub__') |
|
def __rsub__(self, other): |
|
return MatAdd(other, -self).doit() |
|
|
|
@_sympifyit('other', NotImplemented) |
|
@call_highest_priority('__rmul__') |
|
def __mul__(self, other): |
|
return MatMul(self, other).doit() |
|
|
|
@_sympifyit('other', NotImplemented) |
|
@call_highest_priority('__rmul__') |
|
def __matmul__(self, other): |
|
return MatMul(self, other).doit() |
|
|
|
@_sympifyit('other', NotImplemented) |
|
@call_highest_priority('__mul__') |
|
def __rmul__(self, other): |
|
return MatMul(other, self).doit() |
|
|
|
@_sympifyit('other', NotImplemented) |
|
@call_highest_priority('__mul__') |
|
def __rmatmul__(self, other): |
|
return MatMul(other, self).doit() |
|
|
|
@_sympifyit('other', NotImplemented) |
|
@call_highest_priority('__rpow__') |
|
def __pow__(self, other): |
|
return MatPow(self, other).doit() |
|
|
|
@_sympifyit('other', NotImplemented) |
|
@call_highest_priority('__pow__') |
|
def __rpow__(self, other): |
|
raise NotImplementedError("Matrix Power not defined") |
|
|
|
@_sympifyit('other', NotImplemented) |
|
@call_highest_priority('__rtruediv__') |
|
def __truediv__(self, other): |
|
return self * other**S.NegativeOne |
|
|
|
@_sympifyit('other', NotImplemented) |
|
@call_highest_priority('__truediv__') |
|
def __rtruediv__(self, other): |
|
raise NotImplementedError() |
|
|
|
|
|
@property |
|
def rows(self): |
|
return self.shape[0] |
|
|
|
@property |
|
def cols(self): |
|
return self.shape[1] |
|
|
|
@property |
|
def is_square(self) -> bool | None: |
|
rows, cols = self.shape |
|
if isinstance(rows, Integer) and isinstance(cols, Integer): |
|
return rows == cols |
|
if rows == cols: |
|
return True |
|
return None |
|
|
|
def _eval_conjugate(self): |
|
from sympy.matrices.expressions.adjoint import Adjoint |
|
return Adjoint(Transpose(self)) |
|
|
|
def as_real_imag(self, deep=True, **hints): |
|
return self._eval_as_real_imag() |
|
|
|
def _eval_as_real_imag(self): |
|
real = S.Half * (self + self._eval_conjugate()) |
|
im = (self - self._eval_conjugate())/(2*S.ImaginaryUnit) |
|
return (real, im) |
|
|
|
def _eval_inverse(self): |
|
return Inverse(self) |
|
|
|
def _eval_determinant(self): |
|
return Determinant(self) |
|
|
|
def _eval_transpose(self): |
|
return Transpose(self) |
|
|
|
def _eval_trace(self): |
|
return None |
|
|
|
def _eval_power(self, exp): |
|
""" |
|
Override this in sub-classes to implement simplification of powers. The cases where the exponent |
|
is -1, 0, 1 are already covered in MatPow.doit(), so implementations can exclude these cases. |
|
""" |
|
return MatPow(self, exp) |
|
|
|
def _eval_simplify(self, **kwargs): |
|
if self.is_Atom: |
|
return self |
|
else: |
|
from sympy.simplify import simplify |
|
return self.func(*[simplify(x, **kwargs) for x in self.args]) |
|
|
|
def _eval_adjoint(self): |
|
from sympy.matrices.expressions.adjoint import Adjoint |
|
return Adjoint(self) |
|
|
|
def _eval_derivative_n_times(self, x, n): |
|
return Basic._eval_derivative_n_times(self, x, n) |
|
|
|
def _eval_derivative(self, x): |
|
|
|
if self.has(x): |
|
|
|
return super()._eval_derivative(x) |
|
else: |
|
return ZeroMatrix(*self.shape) |
|
|
|
@classmethod |
|
def _check_dim(cls, dim): |
|
"""Helper function to check invalid matrix dimensions""" |
|
ok = not dim.is_Float and check_assumptions( |
|
dim, integer=True, nonnegative=True) |
|
if ok is False: |
|
raise ValueError( |
|
"The dimension specification {} should be " |
|
"a nonnegative integer.".format(dim)) |
|
|
|
|
|
def _entry(self, i, j, **kwargs): |
|
raise NotImplementedError( |
|
"Indexing not implemented for %s" % self.__class__.__name__) |
|
|
|
def adjoint(self): |
|
return adjoint(self) |
|
|
|
def as_coeff_Mul(self, rational=False): |
|
"""Efficiently extract the coefficient of a product.""" |
|
return S.One, self |
|
|
|
def conjugate(self): |
|
return conjugate(self) |
|
|
|
def transpose(self): |
|
from sympy.matrices.expressions.transpose import transpose |
|
return transpose(self) |
|
|
|
@property |
|
def T(self): |
|
'''Matrix transposition''' |
|
return self.transpose() |
|
|
|
def inverse(self): |
|
if self.is_square is False: |
|
raise NonSquareMatrixError('Inverse of non-square matrix') |
|
return self._eval_inverse() |
|
|
|
def inv(self): |
|
return self.inverse() |
|
|
|
def det(self): |
|
from sympy.matrices.expressions.determinant import det |
|
return det(self) |
|
|
|
@property |
|
def I(self): |
|
return self.inverse() |
|
|
|
def valid_index(self, i, j): |
|
def is_valid(idx): |
|
return isinstance(idx, (int, Integer, Symbol, Expr)) |
|
return (is_valid(i) and is_valid(j) and |
|
(self.rows is None or |
|
(i >= -self.rows) != False and (i < self.rows) != False) and |
|
(j >= -self.cols) != False and (j < self.cols) != False) |
|
|
|
def __getitem__(self, key): |
|
if not isinstance(key, tuple) and isinstance(key, slice): |
|
from sympy.matrices.expressions.slice import MatrixSlice |
|
return MatrixSlice(self, key, (0, None, 1)) |
|
if isinstance(key, tuple) and len(key) == 2: |
|
i, j = key |
|
if isinstance(i, slice) or isinstance(j, slice): |
|
from sympy.matrices.expressions.slice import MatrixSlice |
|
return MatrixSlice(self, i, j) |
|
i, j = _sympify(i), _sympify(j) |
|
if self.valid_index(i, j) != False: |
|
return self._entry(i, j) |
|
else: |
|
raise IndexError("Invalid indices (%s, %s)" % (i, j)) |
|
elif isinstance(key, (SYMPY_INTS, Integer)): |
|
|
|
rows, cols = self.shape |
|
|
|
if not isinstance(cols, Integer): |
|
raise IndexError(filldedent(''' |
|
Single indexing is only supported when the number |
|
of columns is known.''')) |
|
key = _sympify(key) |
|
i = key // cols |
|
j = key % cols |
|
if self.valid_index(i, j) != False: |
|
return self._entry(i, j) |
|
else: |
|
raise IndexError("Invalid index %s" % key) |
|
elif isinstance(key, (Symbol, Expr)): |
|
raise IndexError(filldedent(''' |
|
Only integers may be used when addressing the matrix |
|
with a single index.''')) |
|
raise IndexError("Invalid index, wanted %s[i,j]" % self) |
|
|
|
def _is_shape_symbolic(self) -> bool: |
|
return (not isinstance(self.rows, (SYMPY_INTS, Integer)) |
|
or not isinstance(self.cols, (SYMPY_INTS, Integer))) |
|
|
|
def as_explicit(self): |
|
""" |
|
Returns a dense Matrix with elements represented explicitly |
|
|
|
Returns an object of type ImmutableDenseMatrix. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import Identity |
|
>>> I = Identity(3) |
|
>>> I |
|
I |
|
>>> I.as_explicit() |
|
Matrix([ |
|
[1, 0, 0], |
|
[0, 1, 0], |
|
[0, 0, 1]]) |
|
|
|
See Also |
|
======== |
|
as_mutable: returns mutable Matrix type |
|
|
|
""" |
|
if self._is_shape_symbolic(): |
|
raise ValueError( |
|
'Matrix with symbolic shape ' |
|
'cannot be represented explicitly.') |
|
from sympy.matrices.immutable import ImmutableDenseMatrix |
|
return ImmutableDenseMatrix([[self[i, j] |
|
for j in range(self.cols)] |
|
for i in range(self.rows)]) |
|
|
|
def as_mutable(self): |
|
""" |
|
Returns a dense, mutable matrix with elements represented explicitly |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import Identity |
|
>>> I = Identity(3) |
|
>>> I |
|
I |
|
>>> I.shape |
|
(3, 3) |
|
>>> I.as_mutable() |
|
Matrix([ |
|
[1, 0, 0], |
|
[0, 1, 0], |
|
[0, 0, 1]]) |
|
|
|
See Also |
|
======== |
|
as_explicit: returns ImmutableDenseMatrix |
|
""" |
|
return self.as_explicit().as_mutable() |
|
|
|
def __array__(self, dtype=object, copy=None): |
|
if copy is not None and not copy: |
|
raise TypeError("Cannot implement copy=False when converting Matrix to ndarray") |
|
from numpy import empty |
|
a = empty(self.shape, dtype=object) |
|
for i in range(self.rows): |
|
for j in range(self.cols): |
|
a[i, j] = self[i, j] |
|
return a |
|
|
|
def equals(self, other): |
|
""" |
|
Test elementwise equality between matrices, potentially of different |
|
types |
|
|
|
>>> from sympy import Identity, eye |
|
>>> Identity(3).equals(eye(3)) |
|
True |
|
""" |
|
return self.as_explicit().equals(other) |
|
|
|
def canonicalize(self): |
|
return self |
|
|
|
def as_coeff_mmul(self): |
|
return S.One, MatMul(self) |
|
|
|
@staticmethod |
|
def from_index_summation(expr, first_index=None, last_index=None, dimensions=None): |
|
r""" |
|
Parse expression of matrices with explicitly summed indices into a |
|
matrix expression without indices, if possible. |
|
|
|
This transformation expressed in mathematical notation: |
|
|
|
`\sum_{j=0}^{N-1} A_{i,j} B_{j,k} \Longrightarrow \mathbf{A}\cdot \mathbf{B}` |
|
|
|
Optional parameter ``first_index``: specify which free index to use as |
|
the index starting the expression. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import MatrixSymbol, MatrixExpr, Sum |
|
>>> from sympy.abc import i, j, k, l, N |
|
>>> A = MatrixSymbol("A", N, N) |
|
>>> B = MatrixSymbol("B", N, N) |
|
>>> expr = Sum(A[i, j]*B[j, k], (j, 0, N-1)) |
|
>>> MatrixExpr.from_index_summation(expr) |
|
A*B |
|
|
|
Transposition is detected: |
|
|
|
>>> expr = Sum(A[j, i]*B[j, k], (j, 0, N-1)) |
|
>>> MatrixExpr.from_index_summation(expr) |
|
A.T*B |
|
|
|
Detect the trace: |
|
|
|
>>> expr = Sum(A[i, i], (i, 0, N-1)) |
|
>>> MatrixExpr.from_index_summation(expr) |
|
Trace(A) |
|
|
|
More complicated expressions: |
|
|
|
>>> expr = Sum(A[i, j]*B[k, j]*A[l, k], (j, 0, N-1), (k, 0, N-1)) |
|
>>> MatrixExpr.from_index_summation(expr) |
|
A*B.T*A.T |
|
""" |
|
from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array |
|
from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix |
|
first_indices = [] |
|
if first_index is not None: |
|
first_indices.append(first_index) |
|
if last_index is not None: |
|
first_indices.append(last_index) |
|
arr = convert_indexed_to_array(expr, first_indices=first_indices) |
|
return convert_array_to_matrix(arr) |
|
|
|
def applyfunc(self, func): |
|
from .applyfunc import ElementwiseApplyFunction |
|
return ElementwiseApplyFunction(func, self) |
|
|
|
|
|
@dispatch(MatrixExpr, Expr) |
|
def _eval_is_eq(lhs, rhs): |
|
return False |
|
|
|
@dispatch(MatrixExpr, MatrixExpr) |
|
def _eval_is_eq(lhs, rhs): |
|
if lhs.shape != rhs.shape: |
|
return False |
|
if (lhs - rhs).is_ZeroMatrix: |
|
return True |
|
|
|
def get_postprocessor(cls): |
|
def _postprocessor(expr): |
|
|
|
mat_class = {Mul: MatMul, Add: MatAdd}[cls] |
|
nonmatrices = [] |
|
matrices = [] |
|
for term in expr.args: |
|
if isinstance(term, MatrixExpr): |
|
matrices.append(term) |
|
else: |
|
nonmatrices.append(term) |
|
|
|
if not matrices: |
|
return cls._from_args(nonmatrices) |
|
|
|
if nonmatrices: |
|
if cls == Mul: |
|
for i in range(len(matrices)): |
|
if not matrices[i].is_MatrixExpr: |
|
|
|
|
|
|
|
matrices[i] = matrices[i].__mul__(cls._from_args(nonmatrices)) |
|
nonmatrices = [] |
|
break |
|
|
|
else: |
|
|
|
|
|
|
|
|
|
return cls._from_args(nonmatrices + [mat_class(*matrices).doit(deep=False)]) |
|
|
|
if mat_class == MatAdd: |
|
return mat_class(*matrices).doit(deep=False) |
|
return mat_class(cls._from_args(nonmatrices), *matrices).doit(deep=False) |
|
return _postprocessor |
|
|
|
|
|
Basic._constructor_postprocessor_mapping[MatrixExpr] = { |
|
"Mul": [get_postprocessor(Mul)], |
|
"Add": [get_postprocessor(Add)], |
|
} |
|
|
|
|
|
def _matrix_derivative(expr, x, old_algorithm=False): |
|
|
|
if isinstance(expr, MatrixBase) or isinstance(x, MatrixBase): |
|
|
|
old_algorithm = True |
|
|
|
if old_algorithm: |
|
return _matrix_derivative_old_algorithm(expr, x) |
|
|
|
from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array |
|
from sympy.tensor.array.expressions.arrayexpr_derivatives import array_derive |
|
from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix |
|
|
|
array_expr = convert_matrix_to_array(expr) |
|
diff_array_expr = array_derive(array_expr, x) |
|
diff_matrix_expr = convert_array_to_matrix(diff_array_expr) |
|
return diff_matrix_expr |
|
|
|
|
|
def _matrix_derivative_old_algorithm(expr, x): |
|
from sympy.tensor.array.array_derivatives import ArrayDerivative |
|
lines = expr._eval_derivative_matrix_lines(x) |
|
|
|
parts = [i.build() for i in lines] |
|
|
|
from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix |
|
|
|
parts = [[convert_array_to_matrix(j) for j in i] for i in parts] |
|
|
|
def _get_shape(elem): |
|
if isinstance(elem, MatrixExpr): |
|
return elem.shape |
|
return 1, 1 |
|
|
|
def get_rank(parts): |
|
return sum(j not in (1, None) for i in parts for j in _get_shape(i)) |
|
|
|
ranks = [get_rank(i) for i in parts] |
|
rank = ranks[0] |
|
|
|
def contract_one_dims(parts): |
|
if len(parts) == 1: |
|
return parts[0] |
|
else: |
|
p1, p2 = parts[:2] |
|
if p2.is_Matrix: |
|
p2 = p2.T |
|
if p1 == Identity(1): |
|
pbase = p2 |
|
elif p2 == Identity(1): |
|
pbase = p1 |
|
else: |
|
pbase = p1*p2 |
|
if len(parts) == 2: |
|
return pbase |
|
else: |
|
if pbase.is_Matrix: |
|
raise ValueError("") |
|
return pbase*Mul.fromiter(parts[2:]) |
|
|
|
if rank <= 2: |
|
return Add.fromiter([contract_one_dims(i) for i in parts]) |
|
|
|
return ArrayDerivative(expr, x) |
|
|
|
|
|
class MatrixElement(Expr): |
|
parent = property(lambda self: self.args[0]) |
|
i = property(lambda self: self.args[1]) |
|
j = property(lambda self: self.args[2]) |
|
_diff_wrt = True |
|
is_symbol = True |
|
is_commutative = True |
|
|
|
def __new__(cls, name, n, m): |
|
n, m = map(_sympify, (n, m)) |
|
if isinstance(name, str): |
|
name = Symbol(name) |
|
else: |
|
if isinstance(name, MatrixBase): |
|
if n.is_Integer and m.is_Integer: |
|
return name[n, m] |
|
name = _sympify(name) |
|
else: |
|
name = _sympify(name) |
|
if not isinstance(name.kind, MatrixKind): |
|
raise TypeError("First argument of MatrixElement should be a matrix") |
|
if not getattr(name, 'valid_index', lambda n, m: True)(n, m): |
|
raise IndexError('indices out of range') |
|
obj = Expr.__new__(cls, name, n, m) |
|
return obj |
|
|
|
@property |
|
def symbol(self): |
|
return self.args[0] |
|
|
|
def doit(self, **hints): |
|
deep = hints.get('deep', True) |
|
if deep: |
|
args = [arg.doit(**hints) for arg in self.args] |
|
else: |
|
args = self.args |
|
return args[0][args[1], args[2]] |
|
|
|
@property |
|
def indices(self): |
|
return self.args[1:] |
|
|
|
def _eval_derivative(self, v): |
|
|
|
if not isinstance(v, MatrixElement): |
|
return self.parent.diff(v)[self.i, self.j] |
|
|
|
M = self.args[0] |
|
|
|
m, n = self.parent.shape |
|
|
|
if M == v.args[0]: |
|
return KroneckerDelta(self.args[1], v.args[1], (0, m-1)) * \ |
|
KroneckerDelta(self.args[2], v.args[2], (0, n-1)) |
|
|
|
if isinstance(M, Inverse): |
|
from sympy.concrete.summations import Sum |
|
i, j = self.args[1:] |
|
i1, i2 = symbols("z1, z2", cls=Dummy) |
|
Y = M.args[0] |
|
r1, r2 = Y.shape |
|
return -Sum(M[i, i1]*Y[i1, i2].diff(v)*M[i2, j], (i1, 0, r1-1), (i2, 0, r2-1)) |
|
|
|
if self.has(v.args[0]): |
|
return None |
|
|
|
return S.Zero |
|
|
|
|
|
class MatrixSymbol(MatrixExpr): |
|
"""Symbolic representation of a Matrix object |
|
|
|
Creates a SymPy Symbol to represent a Matrix. This matrix has a shape and |
|
can be included in Matrix Expressions |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import MatrixSymbol, Identity |
|
>>> A = MatrixSymbol('A', 3, 4) # A 3 by 4 Matrix |
|
>>> B = MatrixSymbol('B', 4, 3) # A 4 by 3 Matrix |
|
>>> A.shape |
|
(3, 4) |
|
>>> 2*A*B + Identity(3) |
|
I + 2*A*B |
|
""" |
|
is_commutative = False |
|
is_symbol = True |
|
_diff_wrt = True |
|
|
|
def __new__(cls, name, n, m): |
|
n, m = _sympify(n), _sympify(m) |
|
|
|
cls._check_dim(m) |
|
cls._check_dim(n) |
|
|
|
if isinstance(name, str): |
|
name = Str(name) |
|
obj = Basic.__new__(cls, name, n, m) |
|
return obj |
|
|
|
@property |
|
def shape(self): |
|
return self.args[1], self.args[2] |
|
|
|
@property |
|
def name(self): |
|
return self.args[0].name |
|
|
|
def _entry(self, i, j, **kwargs): |
|
return MatrixElement(self, i, j) |
|
|
|
@property |
|
def free_symbols(self): |
|
return {self} |
|
|
|
def _eval_simplify(self, **kwargs): |
|
return self |
|
|
|
def _eval_derivative(self, x): |
|
|
|
return ZeroMatrix(self.shape[0], self.shape[1]) |
|
|
|
def _eval_derivative_matrix_lines(self, x): |
|
if self != x: |
|
first = ZeroMatrix(x.shape[0], self.shape[0]) if self.shape[0] != 1 else S.Zero |
|
second = ZeroMatrix(x.shape[1], self.shape[1]) if self.shape[1] != 1 else S.Zero |
|
return [_LeftRightArgs( |
|
[first, second], |
|
)] |
|
else: |
|
first = Identity(self.shape[0]) if self.shape[0] != 1 else S.One |
|
second = Identity(self.shape[1]) if self.shape[1] != 1 else S.One |
|
return [_LeftRightArgs( |
|
[first, second], |
|
)] |
|
|
|
|
|
def matrix_symbols(expr): |
|
return [sym for sym in expr.free_symbols if sym.is_Matrix] |
|
|
|
|
|
class _LeftRightArgs: |
|
r""" |
|
Helper class to compute matrix derivatives. |
|
|
|
The logic: when an expression is derived by a matrix `X_{mn}`, two lines of |
|
matrix multiplications are created: the one contracted to `m` (first line), |
|
and the one contracted to `n` (second line). |
|
|
|
Transposition flips the side by which new matrices are connected to the |
|
lines. |
|
|
|
The trace connects the end of the two lines. |
|
""" |
|
|
|
def __init__(self, lines, higher=S.One): |
|
self._lines = list(lines) |
|
self._first_pointer_parent = self._lines |
|
self._first_pointer_index = 0 |
|
self._first_line_index = 0 |
|
self._second_pointer_parent = self._lines |
|
self._second_pointer_index = 1 |
|
self._second_line_index = 1 |
|
self.higher = higher |
|
|
|
@property |
|
def first_pointer(self): |
|
return self._first_pointer_parent[self._first_pointer_index] |
|
|
|
@first_pointer.setter |
|
def first_pointer(self, value): |
|
self._first_pointer_parent[self._first_pointer_index] = value |
|
|
|
@property |
|
def second_pointer(self): |
|
return self._second_pointer_parent[self._second_pointer_index] |
|
|
|
@second_pointer.setter |
|
def second_pointer(self, value): |
|
self._second_pointer_parent[self._second_pointer_index] = value |
|
|
|
def __repr__(self): |
|
built = [self._build(i) for i in self._lines] |
|
return "_LeftRightArgs(lines=%s, higher=%s)" % ( |
|
built, |
|
self.higher, |
|
) |
|
|
|
def transpose(self): |
|
self._first_pointer_parent, self._second_pointer_parent = self._second_pointer_parent, self._first_pointer_parent |
|
self._first_pointer_index, self._second_pointer_index = self._second_pointer_index, self._first_pointer_index |
|
self._first_line_index, self._second_line_index = self._second_line_index, self._first_line_index |
|
return self |
|
|
|
@staticmethod |
|
def _build(expr): |
|
if isinstance(expr, ExprBuilder): |
|
return expr.build() |
|
if isinstance(expr, list): |
|
if len(expr) == 1: |
|
return expr[0] |
|
else: |
|
return expr[0](*[_LeftRightArgs._build(i) for i in expr[1]]) |
|
else: |
|
return expr |
|
|
|
def build(self): |
|
data = [self._build(i) for i in self._lines] |
|
if self.higher != 1: |
|
data += [self._build(self.higher)] |
|
data = list(data) |
|
return data |
|
|
|
def matrix_form(self): |
|
if self.first != 1 and self.higher != 1: |
|
raise ValueError("higher dimensional array cannot be represented") |
|
|
|
def _get_shape(elem): |
|
if isinstance(elem, MatrixExpr): |
|
return elem.shape |
|
return (None, None) |
|
|
|
if _get_shape(self.first)[1] != _get_shape(self.second)[1]: |
|
|
|
|
|
if _get_shape(self.second) == (1, 1): |
|
return self.first*self.second[0, 0] |
|
if _get_shape(self.first) == (1, 1): |
|
return self.first[1, 1]*self.second.T |
|
raise ValueError("incompatible shapes") |
|
if self.first != 1: |
|
return self.first*self.second.T |
|
else: |
|
return self.higher |
|
|
|
def rank(self): |
|
""" |
|
Number of dimensions different from trivial (warning: not related to |
|
matrix rank). |
|
""" |
|
rank = 0 |
|
if self.first != 1: |
|
rank += sum(i != 1 for i in self.first.shape) |
|
if self.second != 1: |
|
rank += sum(i != 1 for i in self.second.shape) |
|
if self.higher != 1: |
|
rank += 2 |
|
return rank |
|
|
|
def _multiply_pointer(self, pointer, other): |
|
from ...tensor.array.expressions.array_expressions import ArrayTensorProduct |
|
from ...tensor.array.expressions.array_expressions import ArrayContraction |
|
|
|
subexpr = ExprBuilder( |
|
ArrayContraction, |
|
[ |
|
ExprBuilder( |
|
ArrayTensorProduct, |
|
[ |
|
pointer, |
|
other |
|
] |
|
), |
|
(1, 2) |
|
], |
|
validator=ArrayContraction._validate |
|
) |
|
|
|
return subexpr |
|
|
|
def append_first(self, other): |
|
self.first_pointer *= other |
|
|
|
def append_second(self, other): |
|
self.second_pointer *= other |
|
|
|
|
|
def _make_matrix(x): |
|
from sympy.matrices.immutable import ImmutableDenseMatrix |
|
if isinstance(x, MatrixExpr): |
|
return x |
|
return ImmutableDenseMatrix([[x]]) |
|
|
|
|
|
from .matmul import MatMul |
|
from .matadd import MatAdd |
|
from .matpow import MatPow |
|
from .transpose import Transpose |
|
from .inverse import Inverse |
|
from .special import ZeroMatrix, Identity |
|
from .determinant import Determinant |
|
|