|
import math |
|
from sympy.sets.sets import Interval |
|
from sympy.calculus.singularities import is_increasing, is_decreasing |
|
from sympy.codegen.rewriting import Optimization |
|
from sympy.core.function import UndefinedFunction |
|
|
|
""" |
|
This module collects classes useful for approximate rewriting of expressions. |
|
This can be beneficial when generating numeric code for which performance is |
|
of greater importance than precision (e.g. for preconditioners used in iterative |
|
methods). |
|
""" |
|
|
|
class SumApprox(Optimization): |
|
""" |
|
Approximates sum by neglecting small terms. |
|
|
|
Explanation |
|
=========== |
|
|
|
If terms are expressions which can be determined to be monotonic, then |
|
bounds for those expressions are added. |
|
|
|
Parameters |
|
========== |
|
|
|
bounds : dict |
|
Mapping expressions to length 2 tuple of bounds (low, high). |
|
reltol : number |
|
Threshold for when to ignore a term. Taken relative to the largest |
|
lower bound among bounds. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import exp |
|
>>> from sympy.abc import x, y, z |
|
>>> from sympy.codegen.rewriting import optimize |
|
>>> from sympy.codegen.approximations import SumApprox |
|
>>> bounds = {x: (-1, 1), y: (1000, 2000), z: (-10, 3)} |
|
>>> sum_approx3 = SumApprox(bounds, reltol=1e-3) |
|
>>> sum_approx2 = SumApprox(bounds, reltol=1e-2) |
|
>>> sum_approx1 = SumApprox(bounds, reltol=1e-1) |
|
>>> expr = 3*(x + y + exp(z)) |
|
>>> optimize(expr, [sum_approx3]) |
|
3*(x + y + exp(z)) |
|
>>> optimize(expr, [sum_approx2]) |
|
3*y + 3*exp(z) |
|
>>> optimize(expr, [sum_approx1]) |
|
3*y |
|
|
|
""" |
|
|
|
def __init__(self, bounds, reltol, **kwargs): |
|
super().__init__(**kwargs) |
|
self.bounds = bounds |
|
self.reltol = reltol |
|
|
|
def __call__(self, expr): |
|
return expr.factor().replace(self.query, lambda arg: self.value(arg)) |
|
|
|
def query(self, expr): |
|
return expr.is_Add |
|
|
|
def value(self, add): |
|
for term in add.args: |
|
if term.is_number or term in self.bounds or len(term.free_symbols) != 1: |
|
continue |
|
fs, = term.free_symbols |
|
if fs not in self.bounds: |
|
continue |
|
intrvl = Interval(*self.bounds[fs]) |
|
if is_increasing(term, intrvl, fs): |
|
self.bounds[term] = ( |
|
term.subs({fs: self.bounds[fs][0]}), |
|
term.subs({fs: self.bounds[fs][1]}) |
|
) |
|
elif is_decreasing(term, intrvl, fs): |
|
self.bounds[term] = ( |
|
term.subs({fs: self.bounds[fs][1]}), |
|
term.subs({fs: self.bounds[fs][0]}) |
|
) |
|
else: |
|
return add |
|
|
|
if all(term.is_number or term in self.bounds for term in add.args): |
|
bounds = [(term, term) if term.is_number else self.bounds[term] for term in add.args] |
|
largest_abs_guarantee = 0 |
|
for lo, hi in bounds: |
|
if lo <= 0 <= hi: |
|
continue |
|
largest_abs_guarantee = max(largest_abs_guarantee, |
|
min(abs(lo), abs(hi))) |
|
new_terms = [] |
|
for term, (lo, hi) in zip(add.args, bounds): |
|
if max(abs(lo), abs(hi)) >= largest_abs_guarantee*self.reltol: |
|
new_terms.append(term) |
|
return add.func(*new_terms) |
|
else: |
|
return add |
|
|
|
|
|
class SeriesApprox(Optimization): |
|
""" Approximates functions by expanding them as a series. |
|
|
|
Parameters |
|
========== |
|
|
|
bounds : dict |
|
Mapping expressions to length 2 tuple of bounds (low, high). |
|
reltol : number |
|
Threshold for when to ignore a term. Taken relative to the largest |
|
lower bound among bounds. |
|
max_order : int |
|
Largest order to include in series expansion |
|
n_point_checks : int (even) |
|
The validity of an expansion (with respect to reltol) is checked at |
|
discrete points (linearly spaced over the bounds of the variable). The |
|
number of points used in this numerical check is given by this number. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import sin, pi |
|
>>> from sympy.abc import x, y |
|
>>> from sympy.codegen.rewriting import optimize |
|
>>> from sympy.codegen.approximations import SeriesApprox |
|
>>> bounds = {x: (-.1, .1), y: (pi-1, pi+1)} |
|
>>> series_approx2 = SeriesApprox(bounds, reltol=1e-2) |
|
>>> series_approx3 = SeriesApprox(bounds, reltol=1e-3) |
|
>>> series_approx8 = SeriesApprox(bounds, reltol=1e-8) |
|
>>> expr = sin(x)*sin(y) |
|
>>> optimize(expr, [series_approx2]) |
|
x*(-y + (y - pi)**3/6 + pi) |
|
>>> optimize(expr, [series_approx3]) |
|
(-x**3/6 + x)*sin(y) |
|
>>> optimize(expr, [series_approx8]) |
|
sin(x)*sin(y) |
|
|
|
""" |
|
def __init__(self, bounds, reltol, max_order=4, n_point_checks=4, **kwargs): |
|
super().__init__(**kwargs) |
|
self.bounds = bounds |
|
self.reltol = reltol |
|
self.max_order = max_order |
|
if n_point_checks % 2 == 1: |
|
raise ValueError("Checking the solution at expansion point is not helpful") |
|
self.n_point_checks = n_point_checks |
|
self._prec = math.ceil(-math.log10(self.reltol)) |
|
|
|
def __call__(self, expr): |
|
return expr.factor().replace(self.query, lambda arg: self.value(arg)) |
|
|
|
def query(self, expr): |
|
return (expr.is_Function and not isinstance(expr, UndefinedFunction) |
|
and len(expr.args) == 1) |
|
|
|
def value(self, fexpr): |
|
free_symbols = fexpr.free_symbols |
|
if len(free_symbols) != 1: |
|
return fexpr |
|
symb, = free_symbols |
|
if symb not in self.bounds: |
|
return fexpr |
|
lo, hi = self.bounds[symb] |
|
x0 = (lo + hi)/2 |
|
cheapest = None |
|
for n in range(self.max_order+1, 0, -1): |
|
fseri = fexpr.series(symb, x0=x0, n=n).removeO() |
|
n_ok = True |
|
for idx in range(self.n_point_checks): |
|
x = lo + idx*(hi - lo)/(self.n_point_checks - 1) |
|
val = fseri.xreplace({symb: x}) |
|
ref = fexpr.xreplace({symb: x}) |
|
if abs((1 - val/ref).evalf(self._prec)) > self.reltol: |
|
n_ok = False |
|
break |
|
|
|
if n_ok: |
|
cheapest = fseri |
|
else: |
|
break |
|
|
|
if cheapest is None: |
|
return fexpr |
|
else: |
|
return cheapest |
|
|