|
""" |
|
Additional AST nodes for operations on matrices. The nodes in this module |
|
are meant to represent optimization of matrix expressions within codegen's |
|
target languages that cannot be represented by SymPy expressions. |
|
|
|
As an example, we can use :meth:`sympy.codegen.rewriting.optimize` and the |
|
``matin_opt`` optimization provided in :mod:`sympy.codegen.rewriting` to |
|
transform matrix multiplication under certain assumptions: |
|
|
|
>>> from sympy import symbols, MatrixSymbol |
|
>>> n = symbols('n', integer=True) |
|
>>> A = MatrixSymbol('A', n, n) |
|
>>> x = MatrixSymbol('x', n, 1) |
|
>>> expr = A**(-1) * x |
|
>>> from sympy import assuming, Q |
|
>>> from sympy.codegen.rewriting import matinv_opt, optimize |
|
>>> with assuming(Q.fullrank(A)): |
|
... optimize(expr, [matinv_opt]) |
|
MatrixSolve(A, vector=x) |
|
""" |
|
|
|
from .ast import Token |
|
from sympy.matrices import MatrixExpr |
|
from sympy.core.sympify import sympify |
|
|
|
|
|
class MatrixSolve(Token, MatrixExpr): |
|
"""Represents an operation to solve a linear matrix equation. |
|
|
|
Parameters |
|
========== |
|
|
|
matrix : MatrixSymbol |
|
|
|
Matrix representing the coefficients of variables in the linear |
|
equation. This matrix must be square and full-rank (i.e. all columns must |
|
be linearly independent) for the solving operation to be valid. |
|
|
|
vector : MatrixSymbol |
|
|
|
One-column matrix representing the solutions to the equations |
|
represented in ``matrix``. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import symbols, MatrixSymbol |
|
>>> from sympy.codegen.matrix_nodes import MatrixSolve |
|
>>> n = symbols('n', integer=True) |
|
>>> A = MatrixSymbol('A', n, n) |
|
>>> x = MatrixSymbol('x', n, 1) |
|
>>> from sympy.printing.numpy import NumPyPrinter |
|
>>> NumPyPrinter().doprint(MatrixSolve(A, x)) |
|
'numpy.linalg.solve(A, x)' |
|
>>> from sympy import octave_code |
|
>>> octave_code(MatrixSolve(A, x)) |
|
'A \\\\ x' |
|
|
|
""" |
|
__slots__ = _fields = ('matrix', 'vector') |
|
|
|
_construct_matrix = staticmethod(sympify) |
|
_construct_vector = staticmethod(sympify) |
|
|
|
@property |
|
def shape(self): |
|
return self.vector.shape |
|
|
|
def _eval_derivative(self, x): |
|
A, b = self.matrix, self.vector |
|
return MatrixSolve(A, b.diff(x) - A.diff(x) * MatrixSolve(A, b)) |
|
|