|
"""Module for differentiation using CSE.""" |
|
|
|
from sympy import cse, Matrix, Derivative, MatrixBase |
|
from sympy.utilities.iterables import iterable |
|
|
|
|
|
def _remove_cse_from_derivative(replacements, reduced_expressions): |
|
""" |
|
This function is designed to postprocess the output of a common subexpression |
|
elimination (CSE) operation. Specifically, it removes any CSE replacement |
|
symbols from the arguments of ``Derivative`` terms in the expression. This |
|
is necessary to ensure that the forward Jacobian function correctly handles |
|
derivative terms. |
|
|
|
Parameters |
|
========== |
|
|
|
replacements : list of (Symbol, expression) pairs |
|
Replacement symbols and relative common subexpressions that have been |
|
replaced during a CSE operation. |
|
|
|
reduced_expressions : list of SymPy expressions |
|
The reduced expressions with all the replacements from the |
|
replacements list above. |
|
|
|
Returns |
|
======= |
|
|
|
processed_replacements : list of (Symbol, expression) pairs |
|
Processed replacement list, in the same format of the |
|
``replacements`` input list. |
|
|
|
processed_reduced : list of SymPy expressions |
|
Processed reduced list, in the same format of the |
|
``reduced_expressions`` input list. |
|
""" |
|
|
|
def traverse(node, repl_dict): |
|
if isinstance(node, Derivative): |
|
return replace_all(node, repl_dict) |
|
if not node.args: |
|
return node |
|
new_args = [traverse(arg, repl_dict) for arg in node.args] |
|
return node.func(*new_args) |
|
|
|
def replace_all(node, repl_dict): |
|
result = node |
|
while True: |
|
free_symbols = result.free_symbols |
|
symbols_dict = {k: repl_dict[k] for k in free_symbols if k in repl_dict} |
|
if not symbols_dict: |
|
break |
|
result = result.xreplace(symbols_dict) |
|
return result |
|
|
|
repl_dict = dict(replacements) |
|
processed_replacements = [ |
|
(rep_sym, traverse(sub_exp, repl_dict)) |
|
for rep_sym, sub_exp in replacements |
|
] |
|
processed_reduced = [ |
|
red_exp.__class__([traverse(exp, repl_dict) for exp in red_exp]) |
|
for red_exp in reduced_expressions |
|
] |
|
|
|
return processed_replacements, processed_reduced |
|
|
|
|
|
def _forward_jacobian_cse(replacements, reduced_expr, wrt): |
|
""" |
|
Core function to compute the Jacobian of an input Matrix of expressions |
|
through forward accumulation. Takes directly the output of a CSE operation |
|
(replacements and reduced_expr), and an iterable of variables (wrt) with |
|
respect to which to differentiate the reduced expression and returns the |
|
reduced Jacobian matrix and the ``replacements`` list. |
|
|
|
The function also returns a list of precomputed free symbols for each |
|
subexpression, which are useful in the substitution process. |
|
|
|
Parameters |
|
========== |
|
|
|
replacements : list of (Symbol, expression) pairs |
|
Replacement symbols and relative common subexpressions that have been |
|
replaced during a CSE operation. |
|
|
|
reduced_expr : list of SymPy expressions |
|
The reduced expressions with all the replacements from the |
|
replacements list above. |
|
|
|
wrt : iterable |
|
Iterable of expressions with respect to which to compute the |
|
Jacobian matrix. |
|
|
|
Returns |
|
======= |
|
|
|
replacements : list of (Symbol, expression) pairs |
|
Replacement symbols and relative common subexpressions that have been |
|
replaced during a CSE operation. Compared to the input replacement list, |
|
the output one doesn't contain replacement symbols inside |
|
``Derivative``'s arguments. |
|
|
|
jacobian : list of SymPy expressions |
|
The list only contains one element, which is the Jacobian matrix with |
|
elements in reduced form (replacement symbols are present). |
|
|
|
precomputed_fs: list |
|
List of sets, which store the free symbols present in each sub-expression. |
|
Useful in the substitution process. |
|
""" |
|
|
|
if not isinstance(reduced_expr[0], MatrixBase): |
|
raise TypeError("``expr`` must be of matrix type") |
|
|
|
if not (reduced_expr[0].shape[0] == 1 or reduced_expr[0].shape[1] == 1): |
|
raise TypeError("``expr`` must be a row or a column matrix") |
|
|
|
if not iterable(wrt): |
|
raise TypeError("``wrt`` must be an iterable of variables") |
|
|
|
elif not isinstance(wrt, MatrixBase): |
|
wrt = Matrix(wrt) |
|
|
|
if not (wrt.shape[0] == 1 or wrt.shape[1] == 1): |
|
raise TypeError("``wrt`` must be a row or a column matrix") |
|
|
|
replacements, reduced_expr = _remove_cse_from_derivative(replacements, reduced_expr) |
|
|
|
if replacements: |
|
rep_sym, sub_expr = map(Matrix, zip(*replacements)) |
|
else: |
|
rep_sym, sub_expr = Matrix([]), Matrix([]) |
|
|
|
l_sub, l_wrt, l_red = len(sub_expr), len(wrt), len(reduced_expr[0]) |
|
|
|
f1 = reduced_expr[0].__class__.from_dok(l_red, l_wrt, |
|
{ |
|
(i, j): diff_value |
|
for i, r in enumerate(reduced_expr[0]) |
|
for j, w in enumerate(wrt) |
|
if (diff_value := r.diff(w)) != 0 |
|
}, |
|
) |
|
|
|
if not replacements: |
|
return [], [f1], [] |
|
|
|
f2 = Matrix.from_dok(l_red, l_sub, |
|
{ |
|
(i, j): diff_value |
|
for i, (r, fs) in enumerate([(r, r.free_symbols) for r in reduced_expr[0]]) |
|
for j, s in enumerate(rep_sym) |
|
if s in fs and (diff_value := r.diff(s)) != 0 |
|
}, |
|
) |
|
|
|
rep_sym_set = set(rep_sym) |
|
precomputed_fs = [s.free_symbols & rep_sym_set for s in sub_expr ] |
|
|
|
c_matrix = Matrix.from_dok(1, l_wrt, |
|
{(0, j): diff_value for j, w in enumerate(wrt) |
|
if (diff_value := sub_expr[0].diff(w)) != 0}) |
|
|
|
for i in range(1, l_sub): |
|
|
|
bi_matrix = Matrix.from_dok(1, i, |
|
{(0, j): diff_value for j in range(i + 1) |
|
if rep_sym[j] in precomputed_fs[i] |
|
and (diff_value := sub_expr[i].diff(rep_sym[j])) != 0}) |
|
|
|
ai_matrix = Matrix.from_dok(1, l_wrt, |
|
{(0, j): diff_value for j, w in enumerate(wrt) |
|
if (diff_value := sub_expr[i].diff(w)) != 0}) |
|
|
|
if bi_matrix._rep.nnz(): |
|
ci_matrix = bi_matrix.multiply(c_matrix).add(ai_matrix) |
|
c_matrix = Matrix.vstack(c_matrix, ci_matrix) |
|
else: |
|
c_matrix = Matrix.vstack(c_matrix, ai_matrix) |
|
|
|
jacobian = f2.multiply(c_matrix).add(f1) |
|
jacobian = [reduced_expr[0].__class__(jacobian)] |
|
|
|
return replacements, jacobian, precomputed_fs |
|
|
|
|
|
def _forward_jacobian_norm_in_cse_out(expr, wrt): |
|
""" |
|
Function to compute the Jacobian of an input Matrix of expressions through |
|
forward accumulation. Takes a sympy Matrix of expressions (expr) as input |
|
and an iterable of variables (wrt) with respect to which to compute the |
|
Jacobian matrix. The matrix is returned in reduced form (containing |
|
replacement symbols) along with the ``replacements`` list. |
|
|
|
The function also returns a list of precomputed free symbols for each |
|
subexpression, which are useful in the substitution process. |
|
|
|
Parameters |
|
========== |
|
|
|
expr : Matrix |
|
The vector to be differentiated. |
|
|
|
wrt : iterable |
|
The vector with respect to which to perform the differentiation. |
|
Can be a matrix or an iterable of variables. |
|
|
|
Returns |
|
======= |
|
|
|
replacements : list of (Symbol, expression) pairs |
|
Replacement symbols and relative common subexpressions that have been |
|
replaced during a CSE operation. The output replacement list doesn't |
|
contain replacement symbols inside ``Derivative``'s arguments. |
|
|
|
jacobian : list of SymPy expressions |
|
The list only contains one element, which is the Jacobian matrix with |
|
elements in reduced form (replacement symbols are present). |
|
|
|
precomputed_fs: list |
|
List of sets, which store the free symbols present in each |
|
sub-expression. Useful in the substitution process. |
|
""" |
|
|
|
replacements, reduced_expr = cse(expr) |
|
replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt) |
|
|
|
return replacements, jacobian, precomputed_fs |
|
|
|
|
|
def _forward_jacobian(expr, wrt): |
|
""" |
|
Function to compute the Jacobian of an input Matrix of expressions through |
|
forward accumulation. Takes a sympy Matrix of expressions (expr) as input |
|
and an iterable of variables (wrt) with respect to which to compute the |
|
Jacobian matrix. |
|
|
|
Explanation |
|
=========== |
|
|
|
Expressions often contain repeated subexpressions. Using a tree structure, |
|
these subexpressions are duplicated and differentiated multiple times, |
|
leading to inefficiency. |
|
|
|
Instead, if a data structure called a directed acyclic graph (DAG) is used |
|
then each of these repeated subexpressions will only exist a single time. |
|
This function uses a combination of representing the expression as a DAG and |
|
a forward accumulation algorithm (repeated application of the chain rule |
|
symbolically) to more efficiently calculate the Jacobian matrix of a target |
|
expression ``expr`` with respect to an expression or set of expressions |
|
``wrt``. |
|
|
|
Note that this function is intended to improve performance when |
|
differentiating large expressions that contain many common subexpressions. |
|
For small and simple expressions it is likely less performant than using |
|
SymPy's standard differentiation functions and methods. |
|
|
|
Parameters |
|
========== |
|
|
|
expr : Matrix |
|
The vector to be differentiated. |
|
|
|
wrt : iterable |
|
The vector with respect to which to do the differentiation. |
|
Can be a matrix or an iterable of variables. |
|
|
|
See Also |
|
======== |
|
|
|
Direct Acyclic Graph : https://en.wikipedia.org/wiki/Directed_acyclic_graph |
|
""" |
|
|
|
replacements, reduced_expr = cse(expr) |
|
|
|
if replacements: |
|
rep_sym, _ = map(Matrix, zip(*replacements)) |
|
else: |
|
rep_sym = Matrix([]) |
|
|
|
replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt) |
|
|
|
if not replacements: return jacobian[0] |
|
|
|
sub_rep = dict(replacements) |
|
for i, ik in enumerate(precomputed_fs): |
|
sub_dict = {j: sub_rep[j] for j in ik} |
|
sub_rep[rep_sym[i]] = sub_rep[rep_sym[i]].xreplace(sub_dict) |
|
|
|
return jacobian[0].xreplace(sub_rep) |
|
|