"""Module for differentiation using CSE.""" from sympy import cse, Matrix, Derivative, MatrixBase from sympy.utilities.iterables import iterable def _remove_cse_from_derivative(replacements, reduced_expressions): """ This function is designed to postprocess the output of a common subexpression elimination (CSE) operation. Specifically, it removes any CSE replacement symbols from the arguments of ``Derivative`` terms in the expression. This is necessary to ensure that the forward Jacobian function correctly handles derivative terms. Parameters ========== replacements : list of (Symbol, expression) pairs Replacement symbols and relative common subexpressions that have been replaced during a CSE operation. reduced_expressions : list of SymPy expressions The reduced expressions with all the replacements from the replacements list above. Returns ======= processed_replacements : list of (Symbol, expression) pairs Processed replacement list, in the same format of the ``replacements`` input list. processed_reduced : list of SymPy expressions Processed reduced list, in the same format of the ``reduced_expressions`` input list. """ def traverse(node, repl_dict): if isinstance(node, Derivative): return replace_all(node, repl_dict) if not node.args: return node new_args = [traverse(arg, repl_dict) for arg in node.args] return node.func(*new_args) def replace_all(node, repl_dict): result = node while True: free_symbols = result.free_symbols symbols_dict = {k: repl_dict[k] for k in free_symbols if k in repl_dict} if not symbols_dict: break result = result.xreplace(symbols_dict) return result repl_dict = dict(replacements) processed_replacements = [ (rep_sym, traverse(sub_exp, repl_dict)) for rep_sym, sub_exp in replacements ] processed_reduced = [ red_exp.__class__([traverse(exp, repl_dict) for exp in red_exp]) for red_exp in reduced_expressions ] return processed_replacements, processed_reduced def _forward_jacobian_cse(replacements, reduced_expr, wrt): """ Core function to compute the Jacobian of an input Matrix of expressions through forward accumulation. Takes directly the output of a CSE operation (replacements and reduced_expr), and an iterable of variables (wrt) with respect to which to differentiate the reduced expression and returns the reduced Jacobian matrix and the ``replacements`` list. The function also returns a list of precomputed free symbols for each subexpression, which are useful in the substitution process. Parameters ========== replacements : list of (Symbol, expression) pairs Replacement symbols and relative common subexpressions that have been replaced during a CSE operation. reduced_expr : list of SymPy expressions The reduced expressions with all the replacements from the replacements list above. wrt : iterable Iterable of expressions with respect to which to compute the Jacobian matrix. Returns ======= replacements : list of (Symbol, expression) pairs Replacement symbols and relative common subexpressions that have been replaced during a CSE operation. Compared to the input replacement list, the output one doesn't contain replacement symbols inside ``Derivative``'s arguments. jacobian : list of SymPy expressions The list only contains one element, which is the Jacobian matrix with elements in reduced form (replacement symbols are present). precomputed_fs: list List of sets, which store the free symbols present in each sub-expression. Useful in the substitution process. """ if not isinstance(reduced_expr[0], MatrixBase): raise TypeError("``expr`` must be of matrix type") if not (reduced_expr[0].shape[0] == 1 or reduced_expr[0].shape[1] == 1): raise TypeError("``expr`` must be a row or a column matrix") if not iterable(wrt): raise TypeError("``wrt`` must be an iterable of variables") elif not isinstance(wrt, MatrixBase): wrt = Matrix(wrt) if not (wrt.shape[0] == 1 or wrt.shape[1] == 1): raise TypeError("``wrt`` must be a row or a column matrix") replacements, reduced_expr = _remove_cse_from_derivative(replacements, reduced_expr) if replacements: rep_sym, sub_expr = map(Matrix, zip(*replacements)) else: rep_sym, sub_expr = Matrix([]), Matrix([]) l_sub, l_wrt, l_red = len(sub_expr), len(wrt), len(reduced_expr[0]) f1 = reduced_expr[0].__class__.from_dok(l_red, l_wrt, { (i, j): diff_value for i, r in enumerate(reduced_expr[0]) for j, w in enumerate(wrt) if (diff_value := r.diff(w)) != 0 }, ) if not replacements: return [], [f1], [] f2 = Matrix.from_dok(l_red, l_sub, { (i, j): diff_value for i, (r, fs) in enumerate([(r, r.free_symbols) for r in reduced_expr[0]]) for j, s in enumerate(rep_sym) if s in fs and (diff_value := r.diff(s)) != 0 }, ) rep_sym_set = set(rep_sym) precomputed_fs = [s.free_symbols & rep_sym_set for s in sub_expr ] c_matrix = Matrix.from_dok(1, l_wrt, {(0, j): diff_value for j, w in enumerate(wrt) if (diff_value := sub_expr[0].diff(w)) != 0}) for i in range(1, l_sub): bi_matrix = Matrix.from_dok(1, i, {(0, j): diff_value for j in range(i + 1) if rep_sym[j] in precomputed_fs[i] and (diff_value := sub_expr[i].diff(rep_sym[j])) != 0}) ai_matrix = Matrix.from_dok(1, l_wrt, {(0, j): diff_value for j, w in enumerate(wrt) if (diff_value := sub_expr[i].diff(w)) != 0}) if bi_matrix._rep.nnz(): ci_matrix = bi_matrix.multiply(c_matrix).add(ai_matrix) c_matrix = Matrix.vstack(c_matrix, ci_matrix) else: c_matrix = Matrix.vstack(c_matrix, ai_matrix) jacobian = f2.multiply(c_matrix).add(f1) jacobian = [reduced_expr[0].__class__(jacobian)] return replacements, jacobian, precomputed_fs def _forward_jacobian_norm_in_cse_out(expr, wrt): """ Function to compute the Jacobian of an input Matrix of expressions through forward accumulation. Takes a sympy Matrix of expressions (expr) as input and an iterable of variables (wrt) with respect to which to compute the Jacobian matrix. The matrix is returned in reduced form (containing replacement symbols) along with the ``replacements`` list. The function also returns a list of precomputed free symbols for each subexpression, which are useful in the substitution process. Parameters ========== expr : Matrix The vector to be differentiated. wrt : iterable The vector with respect to which to perform the differentiation. Can be a matrix or an iterable of variables. Returns ======= replacements : list of (Symbol, expression) pairs Replacement symbols and relative common subexpressions that have been replaced during a CSE operation. The output replacement list doesn't contain replacement symbols inside ``Derivative``'s arguments. jacobian : list of SymPy expressions The list only contains one element, which is the Jacobian matrix with elements in reduced form (replacement symbols are present). precomputed_fs: list List of sets, which store the free symbols present in each sub-expression. Useful in the substitution process. """ replacements, reduced_expr = cse(expr) replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt) return replacements, jacobian, precomputed_fs def _forward_jacobian(expr, wrt): """ Function to compute the Jacobian of an input Matrix of expressions through forward accumulation. Takes a sympy Matrix of expressions (expr) as input and an iterable of variables (wrt) with respect to which to compute the Jacobian matrix. Explanation =========== Expressions often contain repeated subexpressions. Using a tree structure, these subexpressions are duplicated and differentiated multiple times, leading to inefficiency. Instead, if a data structure called a directed acyclic graph (DAG) is used then each of these repeated subexpressions will only exist a single time. This function uses a combination of representing the expression as a DAG and a forward accumulation algorithm (repeated application of the chain rule symbolically) to more efficiently calculate the Jacobian matrix of a target expression ``expr`` with respect to an expression or set of expressions ``wrt``. Note that this function is intended to improve performance when differentiating large expressions that contain many common subexpressions. For small and simple expressions it is likely less performant than using SymPy's standard differentiation functions and methods. Parameters ========== expr : Matrix The vector to be differentiated. wrt : iterable The vector with respect to which to do the differentiation. Can be a matrix or an iterable of variables. See Also ======== Direct Acyclic Graph : https://en.wikipedia.org/wiki/Directed_acyclic_graph """ replacements, reduced_expr = cse(expr) if replacements: rep_sym, _ = map(Matrix, zip(*replacements)) else: rep_sym = Matrix([]) replacements, jacobian, precomputed_fs = _forward_jacobian_cse(replacements, reduced_expr, wrt) if not replacements: return jacobian[0] sub_rep = dict(replacements) for i, ik in enumerate(precomputed_fs): sub_dict = {j: sub_rep[j] for j in ik} sub_rep[rep_sym[i]] = sub_rep[rep_sym[i]].xreplace(sub_dict) return jacobian[0].xreplace(sub_rep)