jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
"""Module for differentiation using CSE."""
from sympy import cse, Matrix, Derivative, MatrixBase
from sympy.utilities.iterables import iterable
def _remove_cse_from_derivative(replacements, reduced_expressions):
"""
This function is designed to postprocess the output of a common subexpression
elimination (CSE) operation. Specifically, it removes any CSE replacement
symbols from the arguments of ``Derivative`` terms in the expression. This
is necessary to ensure that the forward Jacobian function correctly handles
derivative terms.
Parameters
==========
replacements : list of (Symbol, expression) pairs
Replacement symbols and relative common subexpressions that have been
replaced during a CSE operation.
reduced_expressions : list of SymPy expressions
The reduced expressions with all the replacements from the
replacements list above.
Returns
=======
processed_replacements : list of (Symbol, expression) pairs
Processed replacement list, in the same format of the
``replacements`` input list.
processed_reduced : list of SymPy expressions
Processed reduced list, in the same format of the
``reduced_expressions`` input list.
"""
def traverse(node, repl_dict):
if isinstance(node, Derivative):
return replace_all(node, repl_dict)
if not node.args:
return node
new_args = [traverse(arg, repl_dict) for arg in node.args]
return node.func(*new_args)
def replace_all(node, repl_dict):
result = node
while True:
free_symbols = result.free_symbols
symbols_dict = {k: repl_dict[k] for k in free_symbols if k in repl_dict}
if not symbols_dict:
break
result = result.xreplace(symbols_dict)
return result
repl_dict = dict(replacements)
processed_replacements = [
(rep_sym, traverse(sub_exp, repl_dict))
for rep_sym, sub_exp in replacements
]
processed_reduced = [
red_exp.__class__([traverse(exp, repl_dict) for exp in red_exp])
for red_exp in reduced_expressions
]
return processed_replacements, processed_reduced
def _forward_jacobian_cse(replacements, reduced_expr, wrt):
"""
Core function to compute the Jacobian of an input Matrix of expressions
through forward accumulation. Takes directly the output of a CSE operation
(replacements and reduced_expr), and an iterable of variables (wrt) with
respect to which to differentiate the reduced expression and returns the
reduced Jacobian matrix and the ``replacements`` list.
The function also returns a list of precomputed free symbols for each
subexpression, which are useful in the substitution process.
Parameters
==========
replacements : list of (Symbol, expression) pairs
Replacement symbols and relative common subexpressions that have been
replaced during a CSE operation.
reduced_expr : list of SymPy expressions
The reduced expressions with all the replacements from the
replacements list above.
wrt : iterable
Iterable of expressions with respect to which to compute the
Jacobian matrix.
Returns
=======
replacements : list of (Symbol, expression) pairs
Replacement symbols and relative common subexpressions that have been
replaced during a CSE operation. Compared to the input replacement list,
the output one doesn't contain replacement symbols inside
``Derivative``'s arguments.
jacobian : list of SymPy expressions
The list only contains one element, which is the Jacobian matrix with
elements in reduced form (replacement symbols are present).
precomputed_fs: list
List of sets, which store the free symbols present in each sub-expression.
Useful in the substitution process.
"""
if not isinstance(reduced_expr[0], MatrixBase):
raise TypeError("``expr`` must be of matrix type")
if not (reduced_expr[0].shape[0] == 1 or reduced_expr[0].shape[1] == 1):
raise TypeError("``expr`` must be a row or a column matrix")
if not iterable(wrt):
raise TypeError("``wrt`` must be an iterable of variables")
elif not isinstance(wrt, MatrixBase):
wrt = Matrix(wrt)
if not (wrt.shape[0] == 1 or wrt.shape[1] == 1):
raise TypeError("``wrt`` must be a row or a column matrix")
replacements, reduced_expr = _remove_cse_from_derivative(replacements, reduced_expr)
if replacements:
rep_sym, sub_expr = map(Matrix, zip(*replacements))
else:
rep_sym, sub_expr = Matrix([]), Matrix([])
l_sub, l_wrt, l_red = len(sub_expr), len(wrt), len(reduced_expr[0])
f1 = reduced_expr[0].__class__.from_dok(l_red, l_wrt,
{
(i, j): diff_value
for i, r in enumerate(reduced_expr[0])
for j, w in enumerate(wrt)
if (diff_value := r.diff(w)) != 0
},
)
if not replacements:
return [], [f1], []
f2 = Matrix.from_dok(l_red, l_sub,
{
(i, j): diff_value
for i, (r, fs) in enumerate([(r, r.free_symbols) for r in reduced_expr[0]])
for j, s in enumerate(rep_sym)
if s in fs and (diff_value := r.diff(s)) != 0
},
)
rep_sym_set = set(rep_sym)
precomputed_fs = [s.free_symbols & rep_sym_set for s in sub_expr ]
c_matrix = Matrix.from_dok(1, l_wrt,
{(0, j): diff_value for j, w in enumerate(wrt)
if (diff_value := sub_expr[0].diff(w)) != 0})
for i in range(1, l_sub):
bi_matrix = Matrix.from_dok(1, i,
{(0, j): diff_value for j in range(i + 1)
if rep_sym[j] in precomputed_fs[i]
and (diff_value := sub_expr[i].diff(rep_sym[j])) != 0})
ai_matrix = Matrix.from_dok(1, l_wrt,
{(0, j): diff_value for j, w in enumerate(wrt)
if (diff_value := sub_expr[i].diff(w)) != 0})
if bi_matrix._rep.nnz():
ci_matrix = bi_matrix.multiply(c_matrix).add(ai_matrix)
c_matrix = Matrix.vstack(c_matrix, ci_matrix)
else:
c_matrix = Matrix.vstack(c_matrix, ai_matrix)
jacobian = f2.multiply(c_matrix).add(f1)
jacobian = [reduced_expr[0].__class__(jacobian)]
return replacements, jacobian, precomputed_fs
def _forward_jacobian_norm_in_cse_out(expr, wrt):
"""
Function to compute the Jacobian of an input Matrix of expressions through
forward accumulation. Takes a sympy Matrix of expressions (expr) as input
and an iterable of variables (wrt) with respect to which to compute the
Jacobian matrix. The matrix is returned in reduced form (containing
replacement symbols) along with the ``replacements`` list.
The function also returns a list of precomputed free symbols for each
subexpression, which are useful in the substitution process.
Parameters
==========
expr : Matrix
The vector to be differentiated.
wrt : iterable
The vector with respect to which to perform the differentiation.
Can be a matrix or an iterable of variables.
Returns
=======
replacements : list of (Symbol, expression) pairs
Replacement symbols and relative common subexpressions that have been
replaced during a CSE operation. The output replacement list doesn't
contain replacement symbols inside ``Derivative``'s arguments.
jacobian : list of SymPy expressions
The list only contains one element, which is the Jacobian matrix with
elements in reduced form (replacement symbols are present).
precomputed_fs: list
List of sets, which store the free symbols present in each
sub-expression. Useful in the substitution process.
"""
replacements, reduced_expr = cse(expr)
replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt)
return replacements, jacobian, precomputed_fs
def _forward_jacobian(expr, wrt):
"""
Function to compute the Jacobian of an input Matrix of expressions through
forward accumulation. Takes a sympy Matrix of expressions (expr) as input
and an iterable of variables (wrt) with respect to which to compute the
Jacobian matrix.
Explanation
===========
Expressions often contain repeated subexpressions. Using a tree structure,
these subexpressions are duplicated and differentiated multiple times,
leading to inefficiency.
Instead, if a data structure called a directed acyclic graph (DAG) is used
then each of these repeated subexpressions will only exist a single time.
This function uses a combination of representing the expression as a DAG and
a forward accumulation algorithm (repeated application of the chain rule
symbolically) to more efficiently calculate the Jacobian matrix of a target
expression ``expr`` with respect to an expression or set of expressions
``wrt``.
Note that this function is intended to improve performance when
differentiating large expressions that contain many common subexpressions.
For small and simple expressions it is likely less performant than using
SymPy's standard differentiation functions and methods.
Parameters
==========
expr : Matrix
The vector to be differentiated.
wrt : iterable
The vector with respect to which to do the differentiation.
Can be a matrix or an iterable of variables.
See Also
========
Direct Acyclic Graph : https://en.wikipedia.org/wiki/Directed_acyclic_graph
"""
replacements, reduced_expr = cse(expr)
if replacements:
rep_sym, _ = map(Matrix, zip(*replacements))
else:
rep_sym = Matrix([])
replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt)
if not replacements: return jacobian[0]
sub_rep = dict(replacements)
for i, ik in enumerate(precomputed_fs):
sub_dict = {j: sub_rep[j] for j in ik}
sub_rep[rep_sym[i]] = sub_rep[rep_sym[i]].xreplace(sub_dict)
return jacobian[0].xreplace(sub_rep)