|
"""Integration method that emulates by-hand techniques. |
|
|
|
This module also provides functionality to get the steps used to evaluate a |
|
particular integral, in the ``integral_steps`` function. This will return |
|
nested ``Rule`` s representing the integration rules used. |
|
|
|
Each ``Rule`` class represents a (maybe parametrized) integration rule, e.g. |
|
``SinRule`` for integrating ``sin(x)`` and ``ReciprocalSqrtQuadraticRule`` |
|
for integrating ``1/sqrt(a+b*x+c*x**2)``. The ``eval`` method returns the |
|
integration result. |
|
|
|
The ``manualintegrate`` function computes the integral by calling ``eval`` |
|
on the rule returned by ``integral_steps``. |
|
|
|
The integrator can be extended with new heuristics and evaluation |
|
techniques. To do so, extend the ``Rule`` class, implement ``eval`` method, |
|
then write a function that accepts an ``IntegralInfo`` object and returns |
|
either a ``Rule`` instance or ``None``. If the new technique requires a new |
|
match, add the key and call to the antiderivative function to integral_steps. |
|
To enable simple substitutions, add the match to find_substitutions. |
|
|
|
""" |
|
|
|
from __future__ import annotations |
|
from typing import NamedTuple, Type, Callable, Sequence |
|
from abc import ABC, abstractmethod |
|
from dataclasses import dataclass |
|
from collections import defaultdict |
|
from collections.abc import Mapping |
|
|
|
from sympy.core.add import Add |
|
from sympy.core.cache import cacheit |
|
from sympy.core.containers import Dict |
|
from sympy.core.expr import Expr |
|
from sympy.core.function import Derivative |
|
from sympy.core.logic import fuzzy_not |
|
from sympy.core.mul import Mul |
|
from sympy.core.numbers import Integer, Number, E |
|
from sympy.core.power import Pow |
|
from sympy.core.relational import Eq, Ne, Boolean |
|
from sympy.core.singleton import S |
|
from sympy.core.symbol import Dummy, Symbol, Wild |
|
from sympy.functions.elementary.complexes import Abs |
|
from sympy.functions.elementary.exponential import exp, log |
|
from sympy.functions.elementary.hyperbolic import (HyperbolicFunction, csch, |
|
cosh, coth, sech, sinh, tanh, asinh) |
|
from sympy.functions.elementary.miscellaneous import sqrt |
|
from sympy.functions.elementary.piecewise import Piecewise |
|
from sympy.functions.elementary.trigonometric import (TrigonometricFunction, |
|
cos, sin, tan, cot, csc, sec, acos, asin, atan, acot, acsc, asec) |
|
from sympy.functions.special.delta_functions import Heaviside, DiracDelta |
|
from sympy.functions.special.error_functions import (erf, erfi, fresnelc, |
|
fresnels, Ci, Chi, Si, Shi, Ei, li) |
|
from sympy.functions.special.gamma_functions import uppergamma |
|
from sympy.functions.special.elliptic_integrals import elliptic_e, elliptic_f |
|
from sympy.functions.special.polynomials import (chebyshevt, chebyshevu, |
|
legendre, hermite, laguerre, assoc_laguerre, gegenbauer, jacobi, |
|
OrthogonalPolynomial) |
|
from sympy.functions.special.zeta_functions import polylog |
|
from .integrals import Integral |
|
from sympy.logic.boolalg import And |
|
from sympy.ntheory.factor_ import primefactors |
|
from sympy.polys.polytools import degree, lcm_list, gcd_list, Poly |
|
from sympy.simplify.radsimp import fraction |
|
from sympy.simplify.simplify import simplify |
|
from sympy.solvers.solvers import solve |
|
from sympy.strategies.core import switch, do_one, null_safe, condition |
|
from sympy.utilities.iterables import iterable |
|
from sympy.utilities.misc import debug |
|
|
|
|
|
@dataclass |
|
class Rule(ABC): |
|
integrand: Expr |
|
variable: Symbol |
|
|
|
@abstractmethod |
|
def eval(self) -> Expr: |
|
pass |
|
|
|
@abstractmethod |
|
def contains_dont_know(self) -> bool: |
|
pass |
|
|
|
|
|
@dataclass |
|
class AtomicRule(Rule, ABC): |
|
"""A simple rule that does not depend on other rules""" |
|
def contains_dont_know(self) -> bool: |
|
return False |
|
|
|
|
|
@dataclass |
|
class ConstantRule(AtomicRule): |
|
"""integrate(a, x) -> a*x""" |
|
def eval(self) -> Expr: |
|
return self.integrand * self.variable |
|
|
|
|
|
@dataclass |
|
class ConstantTimesRule(Rule): |
|
"""integrate(a*f(x), x) -> a*integrate(f(x), x)""" |
|
constant: Expr |
|
other: Expr |
|
substep: Rule |
|
|
|
def eval(self) -> Expr: |
|
return self.constant * self.substep.eval() |
|
|
|
def contains_dont_know(self) -> bool: |
|
return self.substep.contains_dont_know() |
|
|
|
|
|
@dataclass |
|
class PowerRule(AtomicRule): |
|
"""integrate(x**a, x)""" |
|
base: Expr |
|
exp: Expr |
|
|
|
def eval(self) -> Expr: |
|
return Piecewise( |
|
((self.base**(self.exp + 1))/(self.exp + 1), Ne(self.exp, -1)), |
|
(log(self.base), True), |
|
) |
|
|
|
|
|
@dataclass |
|
class NestedPowRule(AtomicRule): |
|
"""integrate((x**a)**b, x)""" |
|
base: Expr |
|
exp: Expr |
|
|
|
def eval(self) -> Expr: |
|
m = self.base * self.integrand |
|
return Piecewise((m / (self.exp + 1), Ne(self.exp, -1)), |
|
(m * log(self.base), True)) |
|
|
|
|
|
@dataclass |
|
class AddRule(Rule): |
|
"""integrate(f(x) + g(x), x) -> integrate(f(x), x) + integrate(g(x), x)""" |
|
substeps: list[Rule] |
|
|
|
def eval(self) -> Expr: |
|
return Add(*(substep.eval() for substep in self.substeps)) |
|
|
|
def contains_dont_know(self) -> bool: |
|
return any(substep.contains_dont_know() for substep in self.substeps) |
|
|
|
|
|
@dataclass |
|
class URule(Rule): |
|
"""integrate(f(g(x))*g'(x), x) -> integrate(f(u), u), u = g(x)""" |
|
u_var: Symbol |
|
u_func: Expr |
|
substep: Rule |
|
|
|
def eval(self) -> Expr: |
|
result = self.substep.eval() |
|
if self.u_func.is_Pow: |
|
base, exp_ = self.u_func.as_base_exp() |
|
if exp_ == -1: |
|
|
|
result = result.subs(log(self.u_var), -log(base)) |
|
return result.subs(self.u_var, self.u_func) |
|
|
|
def contains_dont_know(self) -> bool: |
|
return self.substep.contains_dont_know() |
|
|
|
|
|
@dataclass |
|
class PartsRule(Rule): |
|
"""integrate(u(x)*v'(x), x) -> u(x)*v(x) - integrate(u'(x)*v(x), x)""" |
|
u: Symbol |
|
dv: Expr |
|
v_step: Rule |
|
second_step: Rule | None |
|
|
|
def eval(self) -> Expr: |
|
assert self.second_step is not None |
|
v = self.v_step.eval() |
|
return self.u * v - self.second_step.eval() |
|
|
|
def contains_dont_know(self) -> bool: |
|
return self.v_step.contains_dont_know() or ( |
|
self.second_step is not None and self.second_step.contains_dont_know()) |
|
|
|
|
|
@dataclass |
|
class CyclicPartsRule(Rule): |
|
"""Apply PartsRule multiple times to integrate exp(x)*sin(x)""" |
|
parts_rules: list[PartsRule] |
|
coefficient: Expr |
|
|
|
def eval(self) -> Expr: |
|
result = [] |
|
sign = 1 |
|
for rule in self.parts_rules: |
|
result.append(sign * rule.u * rule.v_step.eval()) |
|
sign *= -1 |
|
return Add(*result) / (1 - self.coefficient) |
|
|
|
def contains_dont_know(self) -> bool: |
|
return any(substep.contains_dont_know() for substep in self.parts_rules) |
|
|
|
|
|
@dataclass |
|
class TrigRule(AtomicRule, ABC): |
|
pass |
|
|
|
|
|
@dataclass |
|
class SinRule(TrigRule): |
|
"""integrate(sin(x), x) -> -cos(x)""" |
|
def eval(self) -> Expr: |
|
return -cos(self.variable) |
|
|
|
|
|
@dataclass |
|
class CosRule(TrigRule): |
|
"""integrate(cos(x), x) -> sin(x)""" |
|
def eval(self) -> Expr: |
|
return sin(self.variable) |
|
|
|
|
|
@dataclass |
|
class SecTanRule(TrigRule): |
|
"""integrate(sec(x)*tan(x), x) -> sec(x)""" |
|
def eval(self) -> Expr: |
|
return sec(self.variable) |
|
|
|
|
|
@dataclass |
|
class CscCotRule(TrigRule): |
|
"""integrate(csc(x)*cot(x), x) -> -csc(x)""" |
|
def eval(self) -> Expr: |
|
return -csc(self.variable) |
|
|
|
|
|
@dataclass |
|
class Sec2Rule(TrigRule): |
|
"""integrate(sec(x)**2, x) -> tan(x)""" |
|
def eval(self) -> Expr: |
|
return tan(self.variable) |
|
|
|
|
|
@dataclass |
|
class Csc2Rule(TrigRule): |
|
"""integrate(csc(x)**2, x) -> -cot(x)""" |
|
def eval(self) -> Expr: |
|
return -cot(self.variable) |
|
|
|
|
|
@dataclass |
|
class HyperbolicRule(AtomicRule, ABC): |
|
pass |
|
|
|
|
|
@dataclass |
|
class SinhRule(HyperbolicRule): |
|
"""integrate(sinh(x), x) -> cosh(x)""" |
|
def eval(self) -> Expr: |
|
return cosh(self.variable) |
|
|
|
|
|
@dataclass |
|
class CoshRule(HyperbolicRule): |
|
"""integrate(cosh(x), x) -> sinh(x)""" |
|
def eval(self): |
|
return sinh(self.variable) |
|
|
|
|
|
@dataclass |
|
class ExpRule(AtomicRule): |
|
"""integrate(a**x, x) -> a**x/ln(a)""" |
|
base: Expr |
|
exp: Expr |
|
|
|
def eval(self) -> Expr: |
|
return self.integrand / log(self.base) |
|
|
|
|
|
@dataclass |
|
class ReciprocalRule(AtomicRule): |
|
"""integrate(1/x, x) -> ln(x)""" |
|
base: Expr |
|
|
|
def eval(self) -> Expr: |
|
return log(self.base) |
|
|
|
|
|
@dataclass |
|
class ArcsinRule(AtomicRule): |
|
"""integrate(1/sqrt(1-x**2), x) -> asin(x)""" |
|
def eval(self) -> Expr: |
|
return asin(self.variable) |
|
|
|
|
|
@dataclass |
|
class ArcsinhRule(AtomicRule): |
|
"""integrate(1/sqrt(1+x**2), x) -> asin(x)""" |
|
def eval(self) -> Expr: |
|
return asinh(self.variable) |
|
|
|
|
|
@dataclass |
|
class ReciprocalSqrtQuadraticRule(AtomicRule): |
|
"""integrate(1/sqrt(a+b*x+c*x**2), x) -> log(2*sqrt(c)*sqrt(a+b*x+c*x**2)+b+2*c*x)/sqrt(c)""" |
|
a: Expr |
|
b: Expr |
|
c: Expr |
|
|
|
def eval(self) -> Expr: |
|
a, b, c, x = self.a, self.b, self.c, self.variable |
|
return log(2*sqrt(c)*sqrt(a+b*x+c*x**2)+b+2*c*x)/sqrt(c) |
|
|
|
|
|
@dataclass |
|
class SqrtQuadraticDenomRule(AtomicRule): |
|
"""integrate(poly(x)/sqrt(a+b*x+c*x**2), x)""" |
|
a: Expr |
|
b: Expr |
|
c: Expr |
|
coeffs: list[Expr] |
|
|
|
def eval(self) -> Expr: |
|
a, b, c, coeffs, x = self.a, self.b, self.c, self.coeffs.copy(), self.variable |
|
|
|
|
|
|
|
|
|
|
|
|
|
result_coeffs = [] |
|
coeffs = coeffs.copy() |
|
for i in range(len(coeffs)-2): |
|
n = len(coeffs)-1-i |
|
coeff = coeffs[i]/(c*n) |
|
result_coeffs.append(coeff) |
|
coeffs[i+1] -= (2*n-1)*b/2*coeff |
|
coeffs[i+2] -= (n-1)*a*coeff |
|
d, e = coeffs[-1], coeffs[-2] |
|
s = sqrt(a+b*x+c*x**2) |
|
constant = d-b*e/(2*c) |
|
if constant == 0: |
|
I0 = 0 |
|
else: |
|
step = inverse_trig_rule(IntegralInfo(1/s, x), degenerate=False) |
|
I0 = constant*step.eval() |
|
return Add(*(result_coeffs[i]*x**(len(coeffs)-2-i) |
|
for i in range(len(result_coeffs))), e/c)*s + I0 |
|
|
|
|
|
@dataclass |
|
class SqrtQuadraticRule(AtomicRule): |
|
"""integrate(sqrt(a+b*x+c*x**2), x)""" |
|
a: Expr |
|
b: Expr |
|
c: Expr |
|
|
|
def eval(self) -> Expr: |
|
step = sqrt_quadratic_rule(IntegralInfo(self.integrand, self.variable), degenerate=False) |
|
return step.eval() |
|
|
|
|
|
@dataclass |
|
class AlternativeRule(Rule): |
|
"""Multiple ways to do integration.""" |
|
alternatives: list[Rule] |
|
|
|
def eval(self) -> Expr: |
|
return self.alternatives[0].eval() |
|
|
|
def contains_dont_know(self) -> bool: |
|
return any(substep.contains_dont_know() for substep in self.alternatives) |
|
|
|
|
|
@dataclass |
|
class DontKnowRule(Rule): |
|
"""Leave the integral as is.""" |
|
def eval(self) -> Expr: |
|
return Integral(self.integrand, self.variable) |
|
|
|
def contains_dont_know(self) -> bool: |
|
return True |
|
|
|
|
|
@dataclass |
|
class DerivativeRule(AtomicRule): |
|
"""integrate(f'(x), x) -> f(x)""" |
|
def eval(self) -> Expr: |
|
assert isinstance(self.integrand, Derivative) |
|
variable_count = list(self.integrand.variable_count) |
|
for i, (var, count) in enumerate(variable_count): |
|
if var == self.variable: |
|
variable_count[i] = (var, count - 1) |
|
break |
|
return Derivative(self.integrand.expr, *variable_count) |
|
|
|
|
|
@dataclass |
|
class RewriteRule(Rule): |
|
"""Rewrite integrand to another form that is easier to handle.""" |
|
rewritten: Expr |
|
substep: Rule |
|
|
|
def eval(self) -> Expr: |
|
return self.substep.eval() |
|
|
|
def contains_dont_know(self) -> bool: |
|
return self.substep.contains_dont_know() |
|
|
|
|
|
@dataclass |
|
class CompleteSquareRule(RewriteRule): |
|
"""Rewrite a+b*x+c*x**2 to a-b**2/(4*c) + c*(x+b/(2*c))**2""" |
|
pass |
|
|
|
|
|
@dataclass |
|
class PiecewiseRule(Rule): |
|
subfunctions: Sequence[tuple[Rule, bool | Boolean]] |
|
|
|
def eval(self) -> Expr: |
|
return Piecewise(*[(substep.eval(), cond) |
|
for substep, cond in self.subfunctions]) |
|
|
|
def contains_dont_know(self) -> bool: |
|
return any(substep.contains_dont_know() for substep, _ in self.subfunctions) |
|
|
|
|
|
@dataclass |
|
class HeavisideRule(Rule): |
|
harg: Expr |
|
ibnd: Expr |
|
substep: Rule |
|
|
|
def eval(self) -> Expr: |
|
|
|
|
|
|
|
|
|
result = self.substep.eval() |
|
return Heaviside(self.harg) * (result - result.subs(self.variable, self.ibnd)) |
|
|
|
def contains_dont_know(self) -> bool: |
|
return self.substep.contains_dont_know() |
|
|
|
|
|
@dataclass |
|
class DiracDeltaRule(AtomicRule): |
|
n: Expr |
|
a: Expr |
|
b: Expr |
|
|
|
def eval(self) -> Expr: |
|
n, a, b, x = self.n, self.a, self.b, self.variable |
|
if n == 0: |
|
return Heaviside(a+b*x)/b |
|
return DiracDelta(a+b*x, n-1)/b |
|
|
|
|
|
@dataclass |
|
class TrigSubstitutionRule(Rule): |
|
theta: Expr |
|
func: Expr |
|
rewritten: Expr |
|
substep: Rule |
|
restriction: bool | Boolean |
|
|
|
def eval(self) -> Expr: |
|
theta, func, x = self.theta, self.func, self.variable |
|
func = func.subs(sec(theta), 1/cos(theta)) |
|
func = func.subs(csc(theta), 1/sin(theta)) |
|
func = func.subs(cot(theta), 1/tan(theta)) |
|
|
|
trig_function = list(func.find(TrigonometricFunction)) |
|
assert len(trig_function) == 1 |
|
trig_function = trig_function[0] |
|
relation = solve(x - func, trig_function) |
|
assert len(relation) == 1 |
|
numer, denom = fraction(relation[0]) |
|
|
|
if isinstance(trig_function, sin): |
|
opposite = numer |
|
hypotenuse = denom |
|
adjacent = sqrt(denom**2 - numer**2) |
|
inverse = asin(relation[0]) |
|
elif isinstance(trig_function, cos): |
|
adjacent = numer |
|
hypotenuse = denom |
|
opposite = sqrt(denom**2 - numer**2) |
|
inverse = acos(relation[0]) |
|
else: |
|
opposite = numer |
|
adjacent = denom |
|
hypotenuse = sqrt(denom**2 + numer**2) |
|
inverse = atan(relation[0]) |
|
|
|
substitution = [ |
|
(sin(theta), opposite/hypotenuse), |
|
(cos(theta), adjacent/hypotenuse), |
|
(tan(theta), opposite/adjacent), |
|
(theta, inverse) |
|
] |
|
return Piecewise( |
|
(self.substep.eval().subs(substitution).trigsimp(), self.restriction) |
|
) |
|
|
|
def contains_dont_know(self) -> bool: |
|
return self.substep.contains_dont_know() |
|
|
|
|
|
@dataclass |
|
class ArctanRule(AtomicRule): |
|
"""integrate(a/(b*x**2+c), x) -> a/b / sqrt(c/b) * atan(x/sqrt(c/b))""" |
|
a: Expr |
|
b: Expr |
|
c: Expr |
|
|
|
def eval(self) -> Expr: |
|
a, b, c, x = self.a, self.b, self.c, self.variable |
|
return a/b / sqrt(c/b) * atan(x/sqrt(c/b)) |
|
|
|
|
|
@dataclass |
|
class OrthogonalPolyRule(AtomicRule, ABC): |
|
n: Expr |
|
|
|
|
|
@dataclass |
|
class JacobiRule(OrthogonalPolyRule): |
|
a: Expr |
|
b: Expr |
|
|
|
def eval(self) -> Expr: |
|
n, a, b, x = self.n, self.a, self.b, self.variable |
|
return Piecewise( |
|
(2*jacobi(n + 1, a - 1, b - 1, x)/(n + a + b), Ne(n + a + b, 0)), |
|
(x, Eq(n, 0)), |
|
((a + b + 2)*x**2/4 + (a - b)*x/2, Eq(n, 1))) |
|
|
|
|
|
@dataclass |
|
class GegenbauerRule(OrthogonalPolyRule): |
|
a: Expr |
|
|
|
def eval(self) -> Expr: |
|
n, a, x = self.n, self.a, self.variable |
|
return Piecewise( |
|
(gegenbauer(n + 1, a - 1, x)/(2*(a - 1)), Ne(a, 1)), |
|
(chebyshevt(n + 1, x)/(n + 1), Ne(n, -1)), |
|
(S.Zero, True)) |
|
|
|
|
|
@dataclass |
|
class ChebyshevTRule(OrthogonalPolyRule): |
|
def eval(self) -> Expr: |
|
n, x = self.n, self.variable |
|
return Piecewise( |
|
((chebyshevt(n + 1, x)/(n + 1) - |
|
chebyshevt(n - 1, x)/(n - 1))/2, Ne(Abs(n), 1)), |
|
(x**2/2, True)) |
|
|
|
|
|
@dataclass |
|
class ChebyshevURule(OrthogonalPolyRule): |
|
def eval(self) -> Expr: |
|
n, x = self.n, self.variable |
|
return Piecewise( |
|
(chebyshevt(n + 1, x)/(n + 1), Ne(n, -1)), |
|
(S.Zero, True)) |
|
|
|
|
|
@dataclass |
|
class LegendreRule(OrthogonalPolyRule): |
|
def eval(self) -> Expr: |
|
n, x = self.n, self.variable |
|
return(legendre(n + 1, x) - legendre(n - 1, x))/(2*n + 1) |
|
|
|
|
|
@dataclass |
|
class HermiteRule(OrthogonalPolyRule): |
|
def eval(self) -> Expr: |
|
n, x = self.n, self.variable |
|
return hermite(n + 1, x)/(2*(n + 1)) |
|
|
|
|
|
@dataclass |
|
class LaguerreRule(OrthogonalPolyRule): |
|
def eval(self) -> Expr: |
|
n, x = self.n, self.variable |
|
return laguerre(n, x) - laguerre(n + 1, x) |
|
|
|
|
|
@dataclass |
|
class AssocLaguerreRule(OrthogonalPolyRule): |
|
a: Expr |
|
|
|
def eval(self) -> Expr: |
|
return -assoc_laguerre(self.n + 1, self.a - 1, self.variable) |
|
|
|
|
|
@dataclass |
|
class IRule(AtomicRule, ABC): |
|
a: Expr |
|
b: Expr |
|
|
|
|
|
@dataclass |
|
class CiRule(IRule): |
|
def eval(self) -> Expr: |
|
a, b, x = self.a, self.b, self.variable |
|
return cos(b)*Ci(a*x) - sin(b)*Si(a*x) |
|
|
|
|
|
@dataclass |
|
class ChiRule(IRule): |
|
def eval(self) -> Expr: |
|
a, b, x = self.a, self.b, self.variable |
|
return cosh(b)*Chi(a*x) + sinh(b)*Shi(a*x) |
|
|
|
|
|
@dataclass |
|
class EiRule(IRule): |
|
def eval(self) -> Expr: |
|
a, b, x = self.a, self.b, self.variable |
|
return exp(b)*Ei(a*x) |
|
|
|
|
|
@dataclass |
|
class SiRule(IRule): |
|
def eval(self) -> Expr: |
|
a, b, x = self.a, self.b, self.variable |
|
return sin(b)*Ci(a*x) + cos(b)*Si(a*x) |
|
|
|
|
|
@dataclass |
|
class ShiRule(IRule): |
|
def eval(self) -> Expr: |
|
a, b, x = self.a, self.b, self.variable |
|
return sinh(b)*Chi(a*x) + cosh(b)*Shi(a*x) |
|
|
|
|
|
@dataclass |
|
class LiRule(IRule): |
|
def eval(self) -> Expr: |
|
a, b, x = self.a, self.b, self.variable |
|
return li(a*x + b)/a |
|
|
|
|
|
@dataclass |
|
class ErfRule(AtomicRule): |
|
a: Expr |
|
b: Expr |
|
c: Expr |
|
|
|
def eval(self) -> Expr: |
|
a, b, c, x = self.a, self.b, self.c, self.variable |
|
if a.is_extended_real: |
|
return Piecewise( |
|
(sqrt(S.Pi)/sqrt(-a)/2 * exp(c - b**2/(4*a)) * |
|
erf((-2*a*x - b)/(2*sqrt(-a))), a < 0), |
|
(sqrt(S.Pi)/sqrt(a)/2 * exp(c - b**2/(4*a)) * |
|
erfi((2*a*x + b)/(2*sqrt(a))), True)) |
|
return sqrt(S.Pi)/sqrt(a)/2 * exp(c - b**2/(4*a)) * \ |
|
erfi((2*a*x + b)/(2*sqrt(a))) |
|
|
|
|
|
@dataclass |
|
class FresnelCRule(AtomicRule): |
|
a: Expr |
|
b: Expr |
|
c: Expr |
|
|
|
def eval(self) -> Expr: |
|
a, b, c, x = self.a, self.b, self.c, self.variable |
|
return sqrt(S.Pi)/sqrt(2*a) * ( |
|
cos(b**2/(4*a) - c)*fresnelc((2*a*x + b)/sqrt(2*a*S.Pi)) + |
|
sin(b**2/(4*a) - c)*fresnels((2*a*x + b)/sqrt(2*a*S.Pi))) |
|
|
|
|
|
@dataclass |
|
class FresnelSRule(AtomicRule): |
|
a: Expr |
|
b: Expr |
|
c: Expr |
|
|
|
def eval(self) -> Expr: |
|
a, b, c, x = self.a, self.b, self.c, self.variable |
|
return sqrt(S.Pi)/sqrt(2*a) * ( |
|
cos(b**2/(4*a) - c)*fresnels((2*a*x + b)/sqrt(2*a*S.Pi)) - |
|
sin(b**2/(4*a) - c)*fresnelc((2*a*x + b)/sqrt(2*a*S.Pi))) |
|
|
|
|
|
@dataclass |
|
class PolylogRule(AtomicRule): |
|
a: Expr |
|
b: Expr |
|
|
|
def eval(self) -> Expr: |
|
return polylog(self.b + 1, self.a * self.variable) |
|
|
|
|
|
@dataclass |
|
class UpperGammaRule(AtomicRule): |
|
a: Expr |
|
e: Expr |
|
|
|
def eval(self) -> Expr: |
|
a, e, x = self.a, self.e, self.variable |
|
return x**e * (-a*x)**(-e) * uppergamma(e + 1, -a*x)/a |
|
|
|
|
|
@dataclass |
|
class EllipticFRule(AtomicRule): |
|
a: Expr |
|
d: Expr |
|
|
|
def eval(self) -> Expr: |
|
return elliptic_f(self.variable, self.d/self.a)/sqrt(self.a) |
|
|
|
|
|
@dataclass |
|
class EllipticERule(AtomicRule): |
|
a: Expr |
|
d: Expr |
|
|
|
def eval(self) -> Expr: |
|
return elliptic_e(self.variable, self.d/self.a)*sqrt(self.a) |
|
|
|
|
|
class IntegralInfo(NamedTuple): |
|
integrand: Expr |
|
symbol: Symbol |
|
|
|
|
|
def manual_diff(f, symbol): |
|
"""Derivative of f in form expected by find_substitutions |
|
|
|
SymPy's derivatives for some trig functions (like cot) are not in a form |
|
that works well with finding substitutions; this replaces the |
|
derivatives for those particular forms with something that works better. |
|
|
|
""" |
|
if f.args: |
|
arg = f.args[0] |
|
if isinstance(f, tan): |
|
return arg.diff(symbol) * sec(arg)**2 |
|
elif isinstance(f, cot): |
|
return -arg.diff(symbol) * csc(arg)**2 |
|
elif isinstance(f, sec): |
|
return arg.diff(symbol) * sec(arg) * tan(arg) |
|
elif isinstance(f, csc): |
|
return -arg.diff(symbol) * csc(arg) * cot(arg) |
|
elif isinstance(f, Add): |
|
return sum(manual_diff(arg, symbol) for arg in f.args) |
|
elif isinstance(f, Mul): |
|
if len(f.args) == 2 and isinstance(f.args[0], Number): |
|
return f.args[0] * manual_diff(f.args[1], symbol) |
|
return f.diff(symbol) |
|
|
|
def manual_subs(expr, *args): |
|
""" |
|
A wrapper for `expr.subs(*args)` with additional logic for substitution |
|
of invertible functions. |
|
""" |
|
if len(args) == 1: |
|
sequence = args[0] |
|
if isinstance(sequence, (Dict, Mapping)): |
|
sequence = sequence.items() |
|
elif not iterable(sequence): |
|
raise ValueError("Expected an iterable of (old, new) pairs") |
|
elif len(args) == 2: |
|
sequence = [args] |
|
else: |
|
raise ValueError("subs accepts either 1 or 2 arguments") |
|
|
|
new_subs = [] |
|
for old, new in sequence: |
|
if isinstance(old, log): |
|
|
|
|
|
|
|
|
|
x0 = old.args[0] |
|
expr = expr.replace(lambda x: x.is_Pow and x.base == x0, |
|
lambda x: exp(x.exp*new)) |
|
new_subs.append((x0, exp(new))) |
|
|
|
return expr.subs(list(sequence) + new_subs) |
|
|
|
|
|
|
|
|
|
inverse_trig_functions = (atan, asin, acos, acot, acsc, asec) |
|
|
|
|
|
def find_substitutions(integrand, symbol, u_var): |
|
results = [] |
|
|
|
def test_subterm(u, u_diff): |
|
if u_diff == 0: |
|
return False |
|
substituted = integrand / u_diff |
|
debug("substituted: {}, u: {}, u_var: {}".format(substituted, u, u_var)) |
|
substituted = manual_subs(substituted, u, u_var).cancel() |
|
|
|
if substituted.has_free(symbol): |
|
return False |
|
|
|
if integrand.is_rational_function(symbol) and substituted.is_rational_function(u_var): |
|
deg_before = max(degree(t, symbol) for t in integrand.as_numer_denom()) |
|
deg_after = max(degree(t, u_var) for t in substituted.as_numer_denom()) |
|
if deg_after > deg_before: |
|
return False |
|
return substituted.as_independent(u_var, as_Add=False) |
|
|
|
def exp_subterms(term: Expr): |
|
linear_coeffs = [] |
|
terms = [] |
|
n = Wild('n', properties=[lambda n: n.is_Integer]) |
|
for exp_ in term.find(exp): |
|
arg = exp_.args[0] |
|
if symbol not in arg.free_symbols: |
|
continue |
|
match = arg.match(n*symbol) |
|
if match: |
|
linear_coeffs.append(match[n]) |
|
else: |
|
terms.append(exp_) |
|
if linear_coeffs: |
|
terms.append(exp(gcd_list(linear_coeffs)*symbol)) |
|
return terms |
|
|
|
def possible_subterms(term): |
|
if isinstance(term, (TrigonometricFunction, HyperbolicFunction, |
|
*inverse_trig_functions, |
|
exp, log, Heaviside)): |
|
return [term.args[0]] |
|
elif isinstance(term, (chebyshevt, chebyshevu, |
|
legendre, hermite, laguerre)): |
|
return [term.args[1]] |
|
elif isinstance(term, (gegenbauer, assoc_laguerre)): |
|
return [term.args[2]] |
|
elif isinstance(term, jacobi): |
|
return [term.args[3]] |
|
elif isinstance(term, Mul): |
|
r = [] |
|
for u in term.args: |
|
r.append(u) |
|
r.extend(possible_subterms(u)) |
|
return r |
|
elif isinstance(term, Pow): |
|
r = [arg for arg in term.args if arg.has(symbol)] |
|
if term.exp.is_Integer: |
|
r.extend([term.base**d for d in primefactors(term.exp) |
|
if 1 < d < abs(term.args[1])]) |
|
if term.base.is_Add: |
|
r.extend([t for t in possible_subterms(term.base) |
|
if t.is_Pow]) |
|
return r |
|
elif isinstance(term, Add): |
|
r = [] |
|
for arg in term.args: |
|
r.append(arg) |
|
r.extend(possible_subterms(arg)) |
|
return r |
|
return [] |
|
|
|
for u in list(dict.fromkeys(possible_subterms(integrand) + exp_subterms(integrand))): |
|
if u == symbol: |
|
continue |
|
u_diff = manual_diff(u, symbol) |
|
new_integrand = test_subterm(u, u_diff) |
|
if new_integrand is not False: |
|
constant, new_integrand = new_integrand |
|
if new_integrand == integrand.subs(symbol, u_var): |
|
continue |
|
substitution = (u, constant, new_integrand) |
|
if substitution not in results: |
|
results.append(substitution) |
|
|
|
return results |
|
|
|
def rewriter(condition, rewrite): |
|
"""Strategy that rewrites an integrand.""" |
|
def _rewriter(integral): |
|
integrand, symbol = integral |
|
debug("Integral: {} is rewritten with {} on symbol: {}".format(integrand, rewrite, symbol)) |
|
if condition(*integral): |
|
rewritten = rewrite(*integral) |
|
if rewritten != integrand: |
|
substep = integral_steps(rewritten, symbol) |
|
if not isinstance(substep, DontKnowRule) and substep: |
|
return RewriteRule(integrand, symbol, rewritten, substep) |
|
return _rewriter |
|
|
|
def proxy_rewriter(condition, rewrite): |
|
"""Strategy that rewrites an integrand based on some other criteria.""" |
|
def _proxy_rewriter(criteria): |
|
criteria, integral = criteria |
|
integrand, symbol = integral |
|
debug("Integral: {} is rewritten with {} on symbol: {} and criteria: {}".format(integrand, rewrite, symbol, criteria)) |
|
args = criteria + list(integral) |
|
if condition(*args): |
|
rewritten = rewrite(*args) |
|
if rewritten != integrand: |
|
return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol)) |
|
return _proxy_rewriter |
|
|
|
def multiplexer(conditions): |
|
"""Apply the rule that matches the condition, else None""" |
|
def multiplexer_rl(expr): |
|
for key, rule in conditions.items(): |
|
if key(expr): |
|
return rule(expr) |
|
return multiplexer_rl |
|
|
|
def alternatives(*rules): |
|
"""Strategy that makes an AlternativeRule out of multiple possible results.""" |
|
def _alternatives(integral): |
|
alts = [] |
|
count = 0 |
|
debug("List of Alternative Rules") |
|
for rule in rules: |
|
count = count + 1 |
|
debug("Rule {}: {}".format(count, rule)) |
|
|
|
result = rule(integral) |
|
if (result and not isinstance(result, DontKnowRule) and |
|
result != integral and result not in alts): |
|
alts.append(result) |
|
if len(alts) == 1: |
|
return alts[0] |
|
elif alts: |
|
doable = [rule for rule in alts if not rule.contains_dont_know()] |
|
if doable: |
|
return AlternativeRule(*integral, doable) |
|
else: |
|
return AlternativeRule(*integral, alts) |
|
return _alternatives |
|
|
|
def constant_rule(integral): |
|
return ConstantRule(*integral) |
|
|
|
def power_rule(integral): |
|
integrand, symbol = integral |
|
base, expt = integrand.as_base_exp() |
|
|
|
if symbol not in expt.free_symbols and isinstance(base, Symbol): |
|
if simplify(expt + 1) == 0: |
|
return ReciprocalRule(integrand, symbol, base) |
|
return PowerRule(integrand, symbol, base, expt) |
|
elif symbol not in base.free_symbols and isinstance(expt, Symbol): |
|
rule = ExpRule(integrand, symbol, base, expt) |
|
|
|
if fuzzy_not(log(base).is_zero): |
|
return rule |
|
elif log(base).is_zero: |
|
return ConstantRule(1, symbol) |
|
|
|
return PiecewiseRule(integrand, symbol, [ |
|
(rule, Ne(log(base), 0)), |
|
(ConstantRule(1, symbol), True) |
|
]) |
|
|
|
def exp_rule(integral): |
|
integrand, symbol = integral |
|
if isinstance(integrand.args[0], Symbol): |
|
return ExpRule(integrand, symbol, E, integrand.args[0]) |
|
|
|
|
|
def orthogonal_poly_rule(integral): |
|
orthogonal_poly_classes = { |
|
jacobi: JacobiRule, |
|
gegenbauer: GegenbauerRule, |
|
chebyshevt: ChebyshevTRule, |
|
chebyshevu: ChebyshevURule, |
|
legendre: LegendreRule, |
|
hermite: HermiteRule, |
|
laguerre: LaguerreRule, |
|
assoc_laguerre: AssocLaguerreRule |
|
} |
|
orthogonal_poly_var_index = { |
|
jacobi: 3, |
|
gegenbauer: 2, |
|
assoc_laguerre: 2 |
|
} |
|
integrand, symbol = integral |
|
for klass in orthogonal_poly_classes: |
|
if isinstance(integrand, klass): |
|
var_index = orthogonal_poly_var_index.get(klass, 1) |
|
if (integrand.args[var_index] is symbol and not |
|
any(v.has(symbol) for v in integrand.args[:var_index])): |
|
return orthogonal_poly_classes[klass](integrand, symbol, *integrand.args[:var_index]) |
|
|
|
|
|
_special_function_patterns: list[tuple[Type, Expr, Callable | None, tuple]] = [] |
|
_wilds = [] |
|
_symbol = Dummy('x') |
|
|
|
|
|
def special_function_rule(integral): |
|
integrand, symbol = integral |
|
if not _special_function_patterns: |
|
a = Wild('a', exclude=[_symbol], properties=[lambda x: not x.is_zero]) |
|
b = Wild('b', exclude=[_symbol]) |
|
c = Wild('c', exclude=[_symbol]) |
|
d = Wild('d', exclude=[_symbol], properties=[lambda x: not x.is_zero]) |
|
e = Wild('e', exclude=[_symbol], properties=[ |
|
lambda x: not (x.is_nonnegative and x.is_integer)]) |
|
_wilds.extend((a, b, c, d, e)) |
|
|
|
|
|
|
|
linear_pattern = a*_symbol + b |
|
quadratic_pattern = a*_symbol**2 + b*_symbol + c |
|
_special_function_patterns.extend(( |
|
(Mul, exp(linear_pattern, evaluate=False)/_symbol, None, EiRule), |
|
(Mul, cos(linear_pattern, evaluate=False)/_symbol, None, CiRule), |
|
(Mul, cosh(linear_pattern, evaluate=False)/_symbol, None, ChiRule), |
|
(Mul, sin(linear_pattern, evaluate=False)/_symbol, None, SiRule), |
|
(Mul, sinh(linear_pattern, evaluate=False)/_symbol, None, ShiRule), |
|
(Pow, 1/log(linear_pattern, evaluate=False), None, LiRule), |
|
(exp, exp(quadratic_pattern, evaluate=False), None, ErfRule), |
|
(sin, sin(quadratic_pattern, evaluate=False), None, FresnelSRule), |
|
(cos, cos(quadratic_pattern, evaluate=False), None, FresnelCRule), |
|
(Mul, _symbol**e*exp(a*_symbol, evaluate=False), None, UpperGammaRule), |
|
(Mul, polylog(b, a*_symbol, evaluate=False)/_symbol, None, PolylogRule), |
|
(Pow, 1/sqrt(a - d*sin(_symbol, evaluate=False)**2), |
|
lambda a, d: a != d, EllipticFRule), |
|
(Pow, sqrt(a - d*sin(_symbol, evaluate=False)**2), |
|
lambda a, d: a != d, EllipticERule), |
|
)) |
|
_integrand = integrand.subs(symbol, _symbol) |
|
for type_, pattern, constraint, rule in _special_function_patterns: |
|
if isinstance(_integrand, type_): |
|
match = _integrand.match(pattern) |
|
if match: |
|
wild_vals = tuple(match.get(w) for w in _wilds |
|
if match.get(w) is not None) |
|
if constraint is None or constraint(*wild_vals): |
|
return rule(integrand, symbol, *wild_vals) |
|
|
|
|
|
def _add_degenerate_step(generic_cond, generic_step: Rule, degenerate_step: Rule | None) -> Rule: |
|
if degenerate_step is None: |
|
return generic_step |
|
if isinstance(generic_step, PiecewiseRule): |
|
subfunctions = [(substep, (cond & generic_cond).simplify()) |
|
for substep, cond in generic_step.subfunctions] |
|
else: |
|
subfunctions = [(generic_step, generic_cond)] |
|
if isinstance(degenerate_step, PiecewiseRule): |
|
subfunctions += degenerate_step.subfunctions |
|
else: |
|
subfunctions.append((degenerate_step, S.true)) |
|
return PiecewiseRule(generic_step.integrand, generic_step.variable, subfunctions) |
|
|
|
|
|
def nested_pow_rule(integral: IntegralInfo): |
|
|
|
integrand, x = integral |
|
|
|
a_ = Wild('a', exclude=[x]) |
|
b_ = Wild('b', exclude=[x, 0]) |
|
pattern = a_+b_*x |
|
generic_cond = S.true |
|
|
|
class NoMatch(Exception): |
|
pass |
|
|
|
def _get_base_exp(expr: Expr) -> tuple[Expr, Expr]: |
|
if not expr.has_free(x): |
|
return S.One, S.Zero |
|
if expr.is_Mul: |
|
_, terms = expr.as_coeff_mul() |
|
if not terms: |
|
return S.One, S.Zero |
|
results = [_get_base_exp(term) for term in terms] |
|
bases = {b for b, _ in results} |
|
bases.discard(S.One) |
|
if len(bases) == 1: |
|
return bases.pop(), Add(*(e for _, e in results)) |
|
raise NoMatch |
|
if expr.is_Pow: |
|
b, e = expr.base, expr.exp |
|
if e.has_free(x): |
|
raise NoMatch |
|
base_, sub_exp = _get_base_exp(b) |
|
return base_, sub_exp * e |
|
match = expr.match(pattern) |
|
if match: |
|
a, b = match[a_], match[b_] |
|
base_ = x + a/b |
|
nonlocal generic_cond |
|
generic_cond = Ne(b, 0) |
|
return base_, S.One |
|
raise NoMatch |
|
|
|
try: |
|
base, exp_ = _get_base_exp(integrand) |
|
except NoMatch: |
|
return |
|
if generic_cond is S.true: |
|
degenerate_step = None |
|
else: |
|
|
|
degenerate_step = ConstantRule(integrand.subs(x, 0), x) |
|
generic_step = NestedPowRule(integrand, x, base, exp_) |
|
return _add_degenerate_step(generic_cond, generic_step, degenerate_step) |
|
|
|
|
|
def inverse_trig_rule(integral: IntegralInfo, degenerate=True): |
|
""" |
|
Set degenerate=False on recursive call where coefficient of quadratic term |
|
is assumed non-zero. |
|
""" |
|
integrand, symbol = integral |
|
base, exp = integrand.as_base_exp() |
|
a = Wild('a', exclude=[symbol]) |
|
b = Wild('b', exclude=[symbol]) |
|
c = Wild('c', exclude=[symbol, 0]) |
|
match = base.match(a + b*symbol + c*symbol**2) |
|
|
|
if not match: |
|
return |
|
|
|
def make_inverse_trig(RuleClass, a, sign_a, c, sign_c, h) -> Rule: |
|
u_var = Dummy("u") |
|
rewritten = 1/sqrt(sign_a*a + sign_c*c*(symbol-h)**2) |
|
quadratic_base = sqrt(c/a)*(symbol-h) |
|
constant = 1/sqrt(c) |
|
u_func = None |
|
if quadratic_base is not symbol: |
|
u_func = quadratic_base |
|
quadratic_base = u_var |
|
standard_form = 1/sqrt(sign_a + sign_c*quadratic_base**2) |
|
substep = RuleClass(standard_form, quadratic_base) |
|
if constant != 1: |
|
substep = ConstantTimesRule(constant*standard_form, symbol, constant, standard_form, substep) |
|
if u_func is not None: |
|
substep = URule(rewritten, symbol, u_var, u_func, substep) |
|
if h != 0: |
|
substep = CompleteSquareRule(integrand, symbol, rewritten, substep) |
|
return substep |
|
|
|
a, b, c = [match.get(i, S.Zero) for i in (a, b, c)] |
|
generic_cond = Ne(c, 0) |
|
if not degenerate or generic_cond is S.true: |
|
degenerate_step = None |
|
elif b.is_zero: |
|
degenerate_step = ConstantRule(a ** exp, symbol) |
|
else: |
|
degenerate_step = sqrt_linear_rule(IntegralInfo((a + b * symbol) ** exp, symbol)) |
|
|
|
if simplify(2*exp + 1) == 0: |
|
h, k = -b/(2*c), a - b**2/(4*c) |
|
non_square_cond = Ne(k, 0) |
|
square_step = None |
|
if non_square_cond is not S.true: |
|
square_step = NestedPowRule(1/sqrt(c*(symbol-h)**2), symbol, symbol-h, S.NegativeOne) |
|
if non_square_cond is S.false: |
|
return square_step |
|
generic_step = ReciprocalSqrtQuadraticRule(integrand, symbol, a, b, c) |
|
step = _add_degenerate_step(non_square_cond, generic_step, square_step) |
|
if k.is_real and c.is_real: |
|
|
|
rules = [] |
|
for args, cond in ( |
|
((ArcsinRule, k, 1, -c, -1, h), And(k > 0, c < 0)), |
|
((ArcsinhRule, k, 1, c, 1, h), And(k > 0, c > 0)), |
|
): |
|
if cond is S.true: |
|
return make_inverse_trig(*args) |
|
if cond is not S.false: |
|
rules.append((make_inverse_trig(*args), cond)) |
|
if rules: |
|
if not k.is_positive: |
|
rules.append((generic_step, S.true)) |
|
step = PiecewiseRule(integrand, symbol, rules) |
|
else: |
|
step = generic_step |
|
return _add_degenerate_step(generic_cond, step, degenerate_step) |
|
if exp == S.Half: |
|
step = SqrtQuadraticRule(integrand, symbol, a, b, c) |
|
return _add_degenerate_step(generic_cond, step, degenerate_step) |
|
|
|
|
|
def add_rule(integral): |
|
integrand, symbol = integral |
|
results = [integral_steps(g, symbol) |
|
for g in integrand.as_ordered_terms()] |
|
return None if None in results else AddRule(integrand, symbol, results) |
|
|
|
|
|
def mul_rule(integral: IntegralInfo): |
|
integrand, symbol = integral |
|
|
|
|
|
coeff, f = integrand.as_independent(symbol) |
|
if coeff != 1: |
|
next_step = integral_steps(f, symbol) |
|
if next_step is not None: |
|
return ConstantTimesRule(integrand, symbol, coeff, f, next_step) |
|
|
|
|
|
def _parts_rule(integrand, symbol) -> tuple[Expr, Expr, Expr, Expr, Rule] | None: |
|
|
|
|
|
def pull_out_algebraic(integrand): |
|
integrand = integrand.cancel().together() |
|
|
|
algebraic = ([] if isinstance(integrand, Piecewise) or not integrand.is_Mul |
|
else [arg for arg in integrand.args if arg.is_algebraic_expr(symbol)]) |
|
if algebraic: |
|
u = Mul(*algebraic) |
|
dv = (integrand / u).cancel() |
|
return u, dv |
|
|
|
def pull_out_u(*functions) -> Callable[[Expr], tuple[Expr, Expr] | None]: |
|
def pull_out_u_rl(integrand: Expr) -> tuple[Expr, Expr] | None: |
|
if any(integrand.has(f) for f in functions): |
|
args = [arg for arg in integrand.args |
|
if any(isinstance(arg, cls) for cls in functions)] |
|
if args: |
|
u = Mul(*args) |
|
dv = integrand / u |
|
return u, dv |
|
return None |
|
|
|
return pull_out_u_rl |
|
|
|
liate_rules = [pull_out_u(log), pull_out_u(*inverse_trig_functions), |
|
pull_out_algebraic, pull_out_u(sin, cos), |
|
pull_out_u(exp)] |
|
|
|
|
|
dummy = Dummy("temporary") |
|
|
|
if isinstance(integrand, (log, *inverse_trig_functions)): |
|
integrand = dummy * integrand |
|
|
|
for index, rule in enumerate(liate_rules): |
|
result = rule(integrand) |
|
|
|
if result: |
|
u, dv = result |
|
|
|
|
|
if symbol not in u.free_symbols and not u.has(dummy): |
|
return None |
|
|
|
u = u.subs(dummy, 1) |
|
dv = dv.subs(dummy, 1) |
|
|
|
|
|
if rule == pull_out_algebraic and not u.is_polynomial(symbol): |
|
return None |
|
|
|
if isinstance(u, log): |
|
rec_dv = 1/dv |
|
if (rec_dv.is_polynomial(symbol) and |
|
degree(rec_dv, symbol) == 1): |
|
return None |
|
|
|
|
|
if rule == pull_out_algebraic: |
|
if dv.is_Derivative or dv.has(TrigonometricFunction) or \ |
|
isinstance(dv, OrthogonalPolynomial): |
|
v_step = integral_steps(dv, symbol) |
|
if v_step.contains_dont_know(): |
|
return None |
|
else: |
|
du = u.diff(symbol) |
|
v = v_step.eval() |
|
return u, dv, v, du, v_step |
|
|
|
|
|
accept = False |
|
if index < 2: |
|
accept = True |
|
elif (rule == pull_out_algebraic and dv.args and |
|
all(isinstance(a, (sin, cos, exp)) |
|
for a in dv.args)): |
|
accept = True |
|
else: |
|
for lrule in liate_rules[index + 1:]: |
|
r = lrule(integrand) |
|
if r and r[0].subs(dummy, 1).equals(dv): |
|
accept = True |
|
break |
|
|
|
if accept: |
|
du = u.diff(symbol) |
|
v_step = integral_steps(simplify(dv), symbol) |
|
if not v_step.contains_dont_know(): |
|
v = v_step.eval() |
|
return u, dv, v, du, v_step |
|
return None |
|
|
|
|
|
def parts_rule(integral): |
|
integrand, symbol = integral |
|
constant, integrand = integrand.as_coeff_Mul() |
|
|
|
result = _parts_rule(integrand, symbol) |
|
|
|
steps = [] |
|
if result: |
|
u, dv, v, du, v_step = result |
|
debug("u : {}, dv : {}, v : {}, du : {}, v_step: {}".format(u, dv, v, du, v_step)) |
|
steps.append(result) |
|
|
|
if isinstance(v, Integral): |
|
return |
|
|
|
|
|
if isinstance(u, (sin, cos, exp, sinh, cosh)): |
|
cachekey = u.xreplace({symbol: _cache_dummy}) |
|
if _parts_u_cache[cachekey] > 2: |
|
return |
|
_parts_u_cache[cachekey] += 1 |
|
|
|
|
|
for _ in range(4): |
|
debug("Cyclic integration {} with v: {}, du: {}, integrand: {}".format(_, v, du, integrand)) |
|
coefficient = ((v * du) / integrand).cancel() |
|
if coefficient == 1: |
|
break |
|
if symbol not in coefficient.free_symbols: |
|
rule = CyclicPartsRule(integrand, symbol, |
|
[PartsRule(None, None, u, dv, v_step, None) |
|
for (u, dv, v, du, v_step) in steps], |
|
(-1) ** len(steps) * coefficient) |
|
if (constant != 1) and rule: |
|
rule = ConstantTimesRule(constant * integrand, symbol, constant, integrand, rule) |
|
return rule |
|
|
|
|
|
next_constant, next_integrand = (v * du).as_coeff_Mul() |
|
result = _parts_rule(next_integrand, symbol) |
|
|
|
if result: |
|
u, dv, v, du, v_step = result |
|
u *= next_constant |
|
du *= next_constant |
|
steps.append((u, dv, v, du, v_step)) |
|
else: |
|
break |
|
|
|
def make_second_step(steps, integrand): |
|
if steps: |
|
u, dv, v, du, v_step = steps[0] |
|
return PartsRule(integrand, symbol, u, dv, v_step, make_second_step(steps[1:], v * du)) |
|
return integral_steps(integrand, symbol) |
|
|
|
if steps: |
|
u, dv, v, du, v_step = steps[0] |
|
rule = PartsRule(integrand, symbol, u, dv, v_step, make_second_step(steps[1:], v * du)) |
|
if (constant != 1) and rule: |
|
rule = ConstantTimesRule(constant * integrand, symbol, constant, integrand, rule) |
|
return rule |
|
|
|
|
|
def trig_rule(integral): |
|
integrand, symbol = integral |
|
if integrand == sin(symbol): |
|
return SinRule(integrand, symbol) |
|
if integrand == cos(symbol): |
|
return CosRule(integrand, symbol) |
|
if integrand == sec(symbol)**2: |
|
return Sec2Rule(integrand, symbol) |
|
if integrand == csc(symbol)**2: |
|
return Csc2Rule(integrand, symbol) |
|
|
|
if isinstance(integrand, tan): |
|
rewritten = sin(*integrand.args) / cos(*integrand.args) |
|
elif isinstance(integrand, cot): |
|
rewritten = cos(*integrand.args) / sin(*integrand.args) |
|
elif isinstance(integrand, sec): |
|
arg = integrand.args[0] |
|
rewritten = ((sec(arg)**2 + tan(arg) * sec(arg)) / |
|
(sec(arg) + tan(arg))) |
|
elif isinstance(integrand, csc): |
|
arg = integrand.args[0] |
|
rewritten = ((csc(arg)**2 + cot(arg) * csc(arg)) / |
|
(csc(arg) + cot(arg))) |
|
else: |
|
return |
|
|
|
return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol)) |
|
|
|
def trig_product_rule(integral: IntegralInfo): |
|
integrand, symbol = integral |
|
if integrand == sec(symbol) * tan(symbol): |
|
return SecTanRule(integrand, symbol) |
|
if integrand == csc(symbol) * cot(symbol): |
|
return CscCotRule(integrand, symbol) |
|
|
|
|
|
def quadratic_denom_rule(integral): |
|
integrand, symbol = integral |
|
a = Wild('a', exclude=[symbol]) |
|
b = Wild('b', exclude=[symbol]) |
|
c = Wild('c', exclude=[symbol]) |
|
|
|
match = integrand.match(a / (b * symbol ** 2 + c)) |
|
|
|
if match: |
|
a, b, c = match[a], match[b], match[c] |
|
general_rule = ArctanRule(integrand, symbol, a, b, c) |
|
if b.is_extended_real and c.is_extended_real: |
|
positive_cond = c/b > 0 |
|
if positive_cond is S.true: |
|
return general_rule |
|
coeff = a/(2*sqrt(-c)*sqrt(b)) |
|
constant = sqrt(-c/b) |
|
r1 = 1/(symbol-constant) |
|
r2 = 1/(symbol+constant) |
|
log_steps = [ReciprocalRule(r1, symbol, symbol-constant), |
|
ConstantTimesRule(-r2, symbol, -1, r2, ReciprocalRule(r2, symbol, symbol+constant))] |
|
rewritten = sub = r1 - r2 |
|
negative_step = AddRule(sub, symbol, log_steps) |
|
if coeff != 1: |
|
rewritten = Mul(coeff, sub, evaluate=False) |
|
negative_step = ConstantTimesRule(rewritten, symbol, coeff, sub, negative_step) |
|
negative_step = RewriteRule(integrand, symbol, rewritten, negative_step) |
|
if positive_cond is S.false: |
|
return negative_step |
|
return PiecewiseRule(integrand, symbol, [(general_rule, positive_cond), (negative_step, S.true)]) |
|
|
|
power = PowerRule(integrand, symbol, symbol, -2) |
|
if b != 1: |
|
power = ConstantTimesRule(integrand, symbol, 1/b, symbol**-2, power) |
|
|
|
return PiecewiseRule(integrand, symbol, [(general_rule, Ne(c, 0)), (power, True)]) |
|
|
|
d = Wild('d', exclude=[symbol]) |
|
match2 = integrand.match(a / (b * symbol ** 2 + c * symbol + d)) |
|
if match2: |
|
b, c = match2[b], match2[c] |
|
if b.is_zero: |
|
return |
|
u = Dummy('u') |
|
u_func = symbol + c/(2*b) |
|
integrand2 = integrand.subs(symbol, u - c / (2*b)) |
|
next_step = integral_steps(integrand2, u) |
|
if next_step: |
|
return URule(integrand2, symbol, u, u_func, next_step) |
|
else: |
|
return |
|
e = Wild('e', exclude=[symbol]) |
|
match3 = integrand.match((a* symbol + b) / (c * symbol ** 2 + d * symbol + e)) |
|
if match3: |
|
a, b, c, d, e = match3[a], match3[b], match3[c], match3[d], match3[e] |
|
if c.is_zero: |
|
return |
|
denominator = c * symbol**2 + d * symbol + e |
|
const = a/(2*c) |
|
numer1 = (2*c*symbol+d) |
|
numer2 = - const*d + b |
|
u = Dummy('u') |
|
step1 = URule(integrand, symbol, |
|
u, denominator, integral_steps(u**(-1), u)) |
|
if const != 1: |
|
step1 = ConstantTimesRule(const*numer1/denominator, symbol, |
|
const, numer1/denominator, step1) |
|
if numer2.is_zero: |
|
return step1 |
|
step2 = integral_steps(numer2/denominator, symbol) |
|
substeps = AddRule(integrand, symbol, [step1, step2]) |
|
rewriten = const*numer1/denominator+numer2/denominator |
|
return RewriteRule(integrand, symbol, rewriten, substeps) |
|
|
|
return |
|
|
|
|
|
def sqrt_linear_rule(integral: IntegralInfo): |
|
""" |
|
Substitute common (a+b*x)**(1/n) |
|
""" |
|
integrand, x = integral |
|
a = Wild('a', exclude=[x]) |
|
b = Wild('b', exclude=[x, 0]) |
|
a0 = b0 = 0 |
|
bases, qs, bs = [], [], [] |
|
for pow_ in integrand.find(Pow): |
|
base, exp_ = pow_.base, pow_.exp |
|
if exp_.is_Integer or x not in base.free_symbols: |
|
continue |
|
if not exp_.is_Rational: |
|
return |
|
match = base.match(a+b*x) |
|
if not match: |
|
continue |
|
a1, b1 = match[a], match[b] |
|
if a0*b1 != a1*b0 or not (b0/b1).is_nonnegative: |
|
return |
|
if b0 == 0 or (b0/b1 > 1) is S.true: |
|
a0, b0 = a1, b1 |
|
bases.append(base) |
|
bs.append(b1) |
|
qs.append(exp_.q) |
|
if b0 == 0: |
|
return |
|
q0: Integer = lcm_list(qs) |
|
u_x = (a0 + b0*x)**(1/q0) |
|
u = Dummy("u") |
|
substituted = integrand.subs({base**(S.One/q): (b/b0)**(S.One/q)*u**(q0/q) |
|
for base, b, q in zip(bases, bs, qs)}).subs(x, (u**q0-a0)/b0) |
|
substep = integral_steps(substituted*u**(q0-1)*q0/b0, u) |
|
if not substep.contains_dont_know(): |
|
step: Rule = URule(integrand, x, u, u_x, substep) |
|
generic_cond = Ne(b0, 0) |
|
if generic_cond is not S.true: |
|
simplified = integrand.subs(dict.fromkeys(bs, 0)) |
|
degenerate_step = integral_steps(simplified, x) |
|
step = PiecewiseRule(integrand, x, [(step, generic_cond), (degenerate_step, S.true)]) |
|
return step |
|
|
|
|
|
def sqrt_quadratic_rule(integral: IntegralInfo, degenerate=True): |
|
integrand, x = integral |
|
a = Wild('a', exclude=[x]) |
|
b = Wild('b', exclude=[x]) |
|
c = Wild('c', exclude=[x, 0]) |
|
f = Wild('f') |
|
n = Wild('n', properties=[lambda n: n.is_Integer and n.is_odd]) |
|
match = integrand.match(f*sqrt(a+b*x+c*x**2)**n) |
|
if not match: |
|
return |
|
a, b, c, f, n = match[a], match[b], match[c], match[f], match[n] |
|
f_poly = f.as_poly(x) |
|
if f_poly is None: |
|
return |
|
|
|
generic_cond = Ne(c, 0) |
|
if not degenerate or generic_cond is S.true: |
|
degenerate_step = None |
|
elif b.is_zero: |
|
degenerate_step = integral_steps(f*sqrt(a)**n, x) |
|
else: |
|
degenerate_step = sqrt_linear_rule(IntegralInfo(f*sqrt(a+b*x)**n, x)) |
|
|
|
def sqrt_quadratic_denom_rule(numer_poly: Poly, integrand: Expr): |
|
denom = sqrt(a+b*x+c*x**2) |
|
deg = numer_poly.degree() |
|
if deg <= 1: |
|
|
|
e, d = numer_poly.all_coeffs() if deg == 1 else (S.Zero, numer_poly.as_expr()) |
|
|
|
A = e/(2*c) |
|
B = d-A*b |
|
pre_substitute = (2*c*x+b)/denom |
|
constant_step: Rule | None = None |
|
linear_step: Rule | None = None |
|
if A != 0: |
|
u = Dummy("u") |
|
pow_rule = PowerRule(1/sqrt(u), u, u, -S.Half) |
|
linear_step = URule(pre_substitute, x, u, a+b*x+c*x**2, pow_rule) |
|
if A != 1: |
|
linear_step = ConstantTimesRule(A*pre_substitute, x, A, pre_substitute, linear_step) |
|
if B != 0: |
|
constant_step = inverse_trig_rule(IntegralInfo(1/denom, x), degenerate=False) |
|
if B != 1: |
|
constant_step = ConstantTimesRule(B/denom, x, B, 1/denom, constant_step) |
|
if linear_step and constant_step: |
|
add = Add(A*pre_substitute, B/denom, evaluate=False) |
|
step: Rule | None = RewriteRule(integrand, x, add, AddRule(add, x, [linear_step, constant_step])) |
|
else: |
|
step = linear_step or constant_step |
|
else: |
|
coeffs = numer_poly.all_coeffs() |
|
step = SqrtQuadraticDenomRule(integrand, x, a, b, c, coeffs) |
|
return step |
|
|
|
if n > 0: |
|
numer_poly = f_poly * (a+b*x+c*x**2)**((n+1)/2) |
|
rewritten = numer_poly.as_expr()/sqrt(a+b*x+c*x**2) |
|
substep = sqrt_quadratic_denom_rule(numer_poly, rewritten) |
|
generic_step = RewriteRule(integrand, x, rewritten, substep) |
|
elif n == -1: |
|
generic_step = sqrt_quadratic_denom_rule(f_poly, integrand) |
|
else: |
|
return |
|
return _add_degenerate_step(generic_cond, generic_step, degenerate_step) |
|
|
|
|
|
def hyperbolic_rule(integral: tuple[Expr, Symbol]): |
|
integrand, symbol = integral |
|
if isinstance(integrand, HyperbolicFunction) and integrand.args[0] == symbol: |
|
if integrand.func == sinh: |
|
return SinhRule(integrand, symbol) |
|
if integrand.func == cosh: |
|
return CoshRule(integrand, symbol) |
|
u = Dummy('u') |
|
if integrand.func == tanh: |
|
rewritten = sinh(symbol)/cosh(symbol) |
|
return RewriteRule(integrand, symbol, rewritten, |
|
URule(rewritten, symbol, u, cosh(symbol), ReciprocalRule(1/u, u, u))) |
|
if integrand.func == coth: |
|
rewritten = cosh(symbol)/sinh(symbol) |
|
return RewriteRule(integrand, symbol, rewritten, |
|
URule(rewritten, symbol, u, sinh(symbol), ReciprocalRule(1/u, u, u))) |
|
else: |
|
rewritten = integrand.rewrite(tanh) |
|
if integrand.func == sech: |
|
return RewriteRule(integrand, symbol, rewritten, |
|
URule(rewritten, symbol, u, tanh(symbol/2), |
|
ArctanRule(2/(u**2 + 1), u, S(2), S.One, S.One))) |
|
if integrand.func == csch: |
|
return RewriteRule(integrand, symbol, rewritten, |
|
URule(rewritten, symbol, u, tanh(symbol/2), |
|
ReciprocalRule(1/u, u, u))) |
|
|
|
@cacheit |
|
def make_wilds(symbol): |
|
a = Wild('a', exclude=[symbol]) |
|
b = Wild('b', exclude=[symbol]) |
|
m = Wild('m', exclude=[symbol], properties=[lambda n: isinstance(n, Integer)]) |
|
n = Wild('n', exclude=[symbol], properties=[lambda n: isinstance(n, Integer)]) |
|
|
|
return a, b, m, n |
|
|
|
@cacheit |
|
def sincos_pattern(symbol): |
|
a, b, m, n = make_wilds(symbol) |
|
pattern = sin(a*symbol)**m * cos(b*symbol)**n |
|
|
|
return pattern, a, b, m, n |
|
|
|
@cacheit |
|
def tansec_pattern(symbol): |
|
a, b, m, n = make_wilds(symbol) |
|
pattern = tan(a*symbol)**m * sec(b*symbol)**n |
|
|
|
return pattern, a, b, m, n |
|
|
|
@cacheit |
|
def cotcsc_pattern(symbol): |
|
a, b, m, n = make_wilds(symbol) |
|
pattern = cot(a*symbol)**m * csc(b*symbol)**n |
|
|
|
return pattern, a, b, m, n |
|
|
|
@cacheit |
|
def heaviside_pattern(symbol): |
|
m = Wild('m', exclude=[symbol]) |
|
b = Wild('b', exclude=[symbol]) |
|
g = Wild('g') |
|
pattern = Heaviside(m*symbol + b) * g |
|
|
|
return pattern, m, b, g |
|
|
|
def uncurry(func): |
|
def uncurry_rl(args): |
|
return func(*args) |
|
return uncurry_rl |
|
|
|
def trig_rewriter(rewrite): |
|
def trig_rewriter_rl(args): |
|
a, b, m, n, integrand, symbol = args |
|
rewritten = rewrite(a, b, m, n, integrand, symbol) |
|
if rewritten != integrand: |
|
return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol)) |
|
return trig_rewriter_rl |
|
|
|
sincos_botheven_condition = uncurry( |
|
lambda a, b, m, n, i, s: m.is_even and n.is_even and |
|
m.is_nonnegative and n.is_nonnegative) |
|
|
|
sincos_botheven = trig_rewriter( |
|
lambda a, b, m, n, i, symbol: ( (((1 - cos(2*a*symbol)) / 2) ** (m / 2)) * |
|
(((1 + cos(2*b*symbol)) / 2) ** (n / 2)) )) |
|
|
|
sincos_sinodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd and m >= 3) |
|
|
|
sincos_sinodd = trig_rewriter( |
|
lambda a, b, m, n, i, symbol: ( (1 - cos(a*symbol)**2)**((m - 1) / 2) * |
|
sin(a*symbol) * |
|
cos(b*symbol) ** n)) |
|
|
|
sincos_cosodd_condition = uncurry(lambda a, b, m, n, i, s: n.is_odd and n >= 3) |
|
|
|
sincos_cosodd = trig_rewriter( |
|
lambda a, b, m, n, i, symbol: ( (1 - sin(b*symbol)**2)**((n - 1) / 2) * |
|
cos(b*symbol) * |
|
sin(a*symbol) ** m)) |
|
|
|
tansec_seceven_condition = uncurry(lambda a, b, m, n, i, s: n.is_even and n >= 4) |
|
tansec_seceven = trig_rewriter( |
|
lambda a, b, m, n, i, symbol: ( (1 + tan(b*symbol)**2) ** (n/2 - 1) * |
|
sec(b*symbol)**2 * |
|
tan(a*symbol) ** m )) |
|
|
|
tansec_tanodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd) |
|
tansec_tanodd = trig_rewriter( |
|
lambda a, b, m, n, i, symbol: ( (sec(a*symbol)**2 - 1) ** ((m - 1) / 2) * |
|
tan(a*symbol) * |
|
sec(b*symbol) ** n )) |
|
|
|
tan_tansquared_condition = uncurry(lambda a, b, m, n, i, s: m == 2 and n == 0) |
|
tan_tansquared = trig_rewriter( |
|
lambda a, b, m, n, i, symbol: ( sec(a*symbol)**2 - 1)) |
|
|
|
cotcsc_csceven_condition = uncurry(lambda a, b, m, n, i, s: n.is_even and n >= 4) |
|
cotcsc_csceven = trig_rewriter( |
|
lambda a, b, m, n, i, symbol: ( (1 + cot(b*symbol)**2) ** (n/2 - 1) * |
|
csc(b*symbol)**2 * |
|
cot(a*symbol) ** m )) |
|
|
|
cotcsc_cotodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd) |
|
cotcsc_cotodd = trig_rewriter( |
|
lambda a, b, m, n, i, symbol: ( (csc(a*symbol)**2 - 1) ** ((m - 1) / 2) * |
|
cot(a*symbol) * |
|
csc(b*symbol) ** n )) |
|
|
|
def trig_sincos_rule(integral): |
|
integrand, symbol = integral |
|
|
|
if any(integrand.has(f) for f in (sin, cos)): |
|
pattern, a, b, m, n = sincos_pattern(symbol) |
|
match = integrand.match(pattern) |
|
if not match: |
|
return |
|
|
|
return multiplexer({ |
|
sincos_botheven_condition: sincos_botheven, |
|
sincos_sinodd_condition: sincos_sinodd, |
|
sincos_cosodd_condition: sincos_cosodd |
|
})(tuple( |
|
[match.get(i, S.Zero) for i in (a, b, m, n)] + |
|
[integrand, symbol])) |
|
|
|
def trig_tansec_rule(integral): |
|
integrand, symbol = integral |
|
|
|
integrand = integrand.subs({ |
|
1 / cos(symbol): sec(symbol) |
|
}) |
|
|
|
if any(integrand.has(f) for f in (tan, sec)): |
|
pattern, a, b, m, n = tansec_pattern(symbol) |
|
match = integrand.match(pattern) |
|
if not match: |
|
return |
|
|
|
return multiplexer({ |
|
tansec_tanodd_condition: tansec_tanodd, |
|
tansec_seceven_condition: tansec_seceven, |
|
tan_tansquared_condition: tan_tansquared |
|
})(tuple( |
|
[match.get(i, S.Zero) for i in (a, b, m, n)] + |
|
[integrand, symbol])) |
|
|
|
def trig_cotcsc_rule(integral): |
|
integrand, symbol = integral |
|
integrand = integrand.subs({ |
|
1 / sin(symbol): csc(symbol), |
|
1 / tan(symbol): cot(symbol), |
|
cos(symbol) / tan(symbol): cot(symbol) |
|
}) |
|
|
|
if any(integrand.has(f) for f in (cot, csc)): |
|
pattern, a, b, m, n = cotcsc_pattern(symbol) |
|
match = integrand.match(pattern) |
|
if not match: |
|
return |
|
|
|
return multiplexer({ |
|
cotcsc_cotodd_condition: cotcsc_cotodd, |
|
cotcsc_csceven_condition: cotcsc_csceven |
|
})(tuple( |
|
[match.get(i, S.Zero) for i in (a, b, m, n)] + |
|
[integrand, symbol])) |
|
|
|
def trig_sindouble_rule(integral): |
|
integrand, symbol = integral |
|
a = Wild('a', exclude=[sin(2*symbol)]) |
|
match = integrand.match(sin(2*symbol)*a) |
|
if match: |
|
sin_double = 2*sin(symbol)*cos(symbol)/sin(2*symbol) |
|
return integral_steps(integrand * sin_double, symbol) |
|
|
|
def trig_powers_products_rule(integral): |
|
return do_one(null_safe(trig_sincos_rule), |
|
null_safe(trig_tansec_rule), |
|
null_safe(trig_cotcsc_rule), |
|
null_safe(trig_sindouble_rule))(integral) |
|
|
|
def trig_substitution_rule(integral): |
|
integrand, symbol = integral |
|
A = Wild('a', exclude=[0, symbol]) |
|
B = Wild('b', exclude=[0, symbol]) |
|
theta = Dummy("theta") |
|
target_pattern = A + B*symbol**2 |
|
|
|
matches = integrand.find(target_pattern) |
|
for expr in matches: |
|
match = expr.match(target_pattern) |
|
a = match.get(A, S.Zero) |
|
b = match.get(B, S.Zero) |
|
|
|
a_positive = ((a.is_number and a > 0) or a.is_positive) |
|
b_positive = ((b.is_number and b > 0) or b.is_positive) |
|
a_negative = ((a.is_number and a < 0) or a.is_negative) |
|
b_negative = ((b.is_number and b < 0) or b.is_negative) |
|
x_func = None |
|
if a_positive and b_positive: |
|
|
|
x_func = (sqrt(a)/sqrt(b)) * tan(theta) |
|
|
|
|
|
|
|
restriction = True |
|
elif a_positive and b_negative: |
|
|
|
constant = sqrt(a)/sqrt(-b) |
|
x_func = constant * sin(theta) |
|
restriction = And(symbol > -constant, symbol < constant) |
|
elif a_negative and b_positive: |
|
|
|
constant = sqrt(-a)/sqrt(b) |
|
x_func = constant * sec(theta) |
|
restriction = And(symbol > -constant, symbol < constant) |
|
if x_func: |
|
|
|
|
|
substitutions = {} |
|
for f in [sin, cos, tan, |
|
sec, csc, cot]: |
|
substitutions[sqrt(f(theta)**2)] = f(theta) |
|
substitutions[sqrt(f(theta)**(-2))] = 1/f(theta) |
|
|
|
replaced = integrand.subs(symbol, x_func).trigsimp() |
|
replaced = manual_subs(replaced, substitutions) |
|
if not replaced.has(symbol): |
|
replaced *= manual_diff(x_func, theta) |
|
replaced = replaced.trigsimp() |
|
secants = replaced.find(1/cos(theta)) |
|
if secants: |
|
replaced = replaced.xreplace({ |
|
1/cos(theta): sec(theta) |
|
}) |
|
|
|
substep = integral_steps(replaced, theta) |
|
if not substep.contains_dont_know(): |
|
return TrigSubstitutionRule(integrand, symbol, |
|
theta, x_func, replaced, substep, restriction) |
|
|
|
def heaviside_rule(integral): |
|
integrand, symbol = integral |
|
pattern, m, b, g = heaviside_pattern(symbol) |
|
match = integrand.match(pattern) |
|
if match and 0 != match[g]: |
|
|
|
substep = integral_steps(match[g], symbol) |
|
m, b = match[m], match[b] |
|
return HeavisideRule(integrand, symbol, m*symbol + b, -b/m, substep) |
|
|
|
|
|
def dirac_delta_rule(integral: IntegralInfo): |
|
integrand, x = integral |
|
if len(integrand.args) == 1: |
|
n = S.Zero |
|
else: |
|
n = integrand.args[1] |
|
if not n.is_Integer or n < 0: |
|
return |
|
a, b = Wild('a', exclude=[x]), Wild('b', exclude=[x, 0]) |
|
match = integrand.args[0].match(a+b*x) |
|
if not match: |
|
return |
|
a, b = match[a], match[b] |
|
generic_cond = Ne(b, 0) |
|
if generic_cond is S.true: |
|
degenerate_step = None |
|
else: |
|
degenerate_step = ConstantRule(DiracDelta(a, n), x) |
|
generic_step = DiracDeltaRule(integrand, x, n, a, b) |
|
return _add_degenerate_step(generic_cond, generic_step, degenerate_step) |
|
|
|
|
|
def substitution_rule(integral): |
|
integrand, symbol = integral |
|
|
|
u_var = Dummy("u") |
|
substitutions = find_substitutions(integrand, symbol, u_var) |
|
count = 0 |
|
if substitutions: |
|
debug("List of Substitution Rules") |
|
ways = [] |
|
for u_func, c, substituted in substitutions: |
|
subrule = integral_steps(substituted, u_var) |
|
count = count + 1 |
|
debug("Rule {}: {}".format(count, subrule)) |
|
|
|
if subrule.contains_dont_know(): |
|
continue |
|
|
|
if simplify(c - 1) != 0: |
|
_, denom = c.as_numer_denom() |
|
if subrule: |
|
subrule = ConstantTimesRule(c * substituted, u_var, c, substituted, subrule) |
|
|
|
if denom.free_symbols: |
|
piecewise = [] |
|
could_be_zero = [] |
|
|
|
if isinstance(denom, Mul): |
|
could_be_zero = denom.args |
|
else: |
|
could_be_zero.append(denom) |
|
|
|
for expr in could_be_zero: |
|
if not fuzzy_not(expr.is_zero): |
|
substep = integral_steps(manual_subs(integrand, expr, 0), symbol) |
|
|
|
if substep: |
|
piecewise.append(( |
|
substep, |
|
Eq(expr, 0) |
|
)) |
|
piecewise.append((subrule, True)) |
|
subrule = PiecewiseRule(substituted, symbol, piecewise) |
|
|
|
ways.append(URule(integrand, symbol, u_var, u_func, subrule)) |
|
|
|
if len(ways) > 1: |
|
return AlternativeRule(integrand, symbol, ways) |
|
elif ways: |
|
return ways[0] |
|
|
|
|
|
partial_fractions_rule = rewriter( |
|
lambda integrand, symbol: integrand.is_rational_function(), |
|
lambda integrand, symbol: integrand.apart(symbol)) |
|
|
|
cancel_rule = rewriter( |
|
|
|
|
|
lambda integrand, symbol: True, |
|
lambda integrand, symbol: integrand.cancel()) |
|
|
|
distribute_expand_rule = rewriter( |
|
lambda integrand, symbol: ( |
|
isinstance(integrand, (Pow, Mul)) or all(arg.is_Pow or arg.is_polynomial(symbol) for arg in integrand.args)), |
|
lambda integrand, symbol: integrand.expand()) |
|
|
|
trig_expand_rule = rewriter( |
|
|
|
lambda integrand, symbol: ( |
|
len({a.args[0] for a in integrand.atoms(TrigonometricFunction)}) > 1), |
|
lambda integrand, symbol: integrand.expand(trig=True)) |
|
|
|
def derivative_rule(integral): |
|
integrand = integral[0] |
|
diff_variables = integrand.variables |
|
undifferentiated_function = integrand.expr |
|
integrand_variables = undifferentiated_function.free_symbols |
|
|
|
if integral.symbol in integrand_variables: |
|
if integral.symbol in diff_variables: |
|
return DerivativeRule(*integral) |
|
else: |
|
return DontKnowRule(integrand, integral.symbol) |
|
else: |
|
return ConstantRule(*integral) |
|
|
|
def rewrites_rule(integral): |
|
integrand, symbol = integral |
|
|
|
if integrand.match(1/cos(symbol)): |
|
rewritten = integrand.subs(1/cos(symbol), sec(symbol)) |
|
return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol)) |
|
|
|
def fallback_rule(integral): |
|
return DontKnowRule(*integral) |
|
|
|
|
|
|
|
|
|
_integral_cache: dict[Expr, Expr | None] = {} |
|
_parts_u_cache: dict[Expr, int] = defaultdict(int) |
|
_cache_dummy = Dummy("z") |
|
|
|
def integral_steps(integrand, symbol, **options): |
|
"""Returns the steps needed to compute an integral. |
|
|
|
Explanation |
|
=========== |
|
|
|
This function attempts to mirror what a student would do by hand as |
|
closely as possible. |
|
|
|
SymPy Gamma uses this to provide a step-by-step explanation of an |
|
integral. The code it uses to format the results of this function can be |
|
found at |
|
https://github.com/sympy/sympy_gamma/blob/master/app/logic/intsteps.py. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import exp, sin |
|
>>> from sympy.integrals.manualintegrate import integral_steps |
|
>>> from sympy.abc import x |
|
>>> print(repr(integral_steps(exp(x) / (1 + exp(2 * x)), x))) \ |
|
# doctest: +NORMALIZE_WHITESPACE |
|
URule(integrand=exp(x)/(exp(2*x) + 1), variable=x, u_var=_u, u_func=exp(x), |
|
substep=ArctanRule(integrand=1/(_u**2 + 1), variable=_u, a=1, b=1, c=1)) |
|
>>> print(repr(integral_steps(sin(x), x))) \ |
|
# doctest: +NORMALIZE_WHITESPACE |
|
SinRule(integrand=sin(x), variable=x) |
|
>>> print(repr(integral_steps((x**2 + 3)**2, x))) \ |
|
# doctest: +NORMALIZE_WHITESPACE |
|
RewriteRule(integrand=(x**2 + 3)**2, variable=x, rewritten=x**4 + 6*x**2 + 9, |
|
substep=AddRule(integrand=x**4 + 6*x**2 + 9, variable=x, |
|
substeps=[PowerRule(integrand=x**4, variable=x, base=x, exp=4), |
|
ConstantTimesRule(integrand=6*x**2, variable=x, constant=6, other=x**2, |
|
substep=PowerRule(integrand=x**2, variable=x, base=x, exp=2)), |
|
ConstantRule(integrand=9, variable=x)])) |
|
|
|
Returns |
|
======= |
|
|
|
rule : Rule |
|
The first step; most rules have substeps that must also be |
|
considered. These substeps can be evaluated using ``manualintegrate`` |
|
to obtain a result. |
|
|
|
""" |
|
cachekey = integrand.xreplace({symbol: _cache_dummy}) |
|
if cachekey in _integral_cache: |
|
if _integral_cache[cachekey] is None: |
|
|
|
return DontKnowRule(integrand, symbol) |
|
else: |
|
|
|
|
|
return (_integral_cache[cachekey].xreplace(_cache_dummy, symbol), |
|
symbol) |
|
else: |
|
_integral_cache[cachekey] = None |
|
|
|
integral = IntegralInfo(integrand, symbol) |
|
|
|
def key(integral): |
|
integrand = integral.integrand |
|
|
|
if symbol not in integrand.free_symbols: |
|
return Number |
|
for cls in (Symbol, TrigonometricFunction, OrthogonalPolynomial): |
|
if isinstance(integrand, cls): |
|
return cls |
|
return type(integrand) |
|
|
|
def integral_is_subclass(*klasses): |
|
def _integral_is_subclass(integral): |
|
k = key(integral) |
|
return k and issubclass(k, klasses) |
|
return _integral_is_subclass |
|
|
|
result = do_one( |
|
null_safe(special_function_rule), |
|
null_safe(switch(key, { |
|
Pow: do_one(null_safe(power_rule), null_safe(inverse_trig_rule), |
|
null_safe(sqrt_linear_rule), |
|
null_safe(quadratic_denom_rule)), |
|
Symbol: power_rule, |
|
exp: exp_rule, |
|
Add: add_rule, |
|
Mul: do_one(null_safe(mul_rule), null_safe(trig_product_rule), |
|
null_safe(heaviside_rule), null_safe(quadratic_denom_rule), |
|
null_safe(sqrt_linear_rule), |
|
null_safe(sqrt_quadratic_rule)), |
|
Derivative: derivative_rule, |
|
TrigonometricFunction: trig_rule, |
|
Heaviside: heaviside_rule, |
|
DiracDelta: dirac_delta_rule, |
|
OrthogonalPolynomial: orthogonal_poly_rule, |
|
Number: constant_rule |
|
})), |
|
do_one( |
|
null_safe(trig_rule), |
|
null_safe(hyperbolic_rule), |
|
null_safe(alternatives( |
|
rewrites_rule, |
|
substitution_rule, |
|
condition( |
|
integral_is_subclass(Mul, Pow), |
|
partial_fractions_rule), |
|
condition( |
|
integral_is_subclass(Mul, Pow), |
|
cancel_rule), |
|
condition( |
|
integral_is_subclass(Mul, log, |
|
*inverse_trig_functions), |
|
parts_rule), |
|
condition( |
|
integral_is_subclass(Mul, Pow), |
|
distribute_expand_rule), |
|
trig_powers_products_rule, |
|
trig_expand_rule |
|
)), |
|
null_safe(condition(integral_is_subclass(Mul, Pow), nested_pow_rule)), |
|
null_safe(trig_substitution_rule) |
|
), |
|
fallback_rule)(integral) |
|
del _integral_cache[cachekey] |
|
return result |
|
|
|
|
|
def manualintegrate(f, var): |
|
"""manualintegrate(f, var) |
|
|
|
Explanation |
|
=========== |
|
|
|
Compute indefinite integral of a single variable using an algorithm that |
|
resembles what a student would do by hand. |
|
|
|
Unlike :func:`~.integrate`, var can only be a single symbol. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import sin, cos, tan, exp, log, integrate |
|
>>> from sympy.integrals.manualintegrate import manualintegrate |
|
>>> from sympy.abc import x |
|
>>> manualintegrate(1 / x, x) |
|
log(x) |
|
>>> integrate(1/x) |
|
log(x) |
|
>>> manualintegrate(log(x), x) |
|
x*log(x) - x |
|
>>> integrate(log(x)) |
|
x*log(x) - x |
|
>>> manualintegrate(exp(x) / (1 + exp(2 * x)), x) |
|
atan(exp(x)) |
|
>>> integrate(exp(x) / (1 + exp(2 * x))) |
|
RootSum(4*_z**2 + 1, Lambda(_i, _i*log(2*_i + exp(x)))) |
|
>>> manualintegrate(cos(x)**4 * sin(x), x) |
|
-cos(x)**5/5 |
|
>>> integrate(cos(x)**4 * sin(x), x) |
|
-cos(x)**5/5 |
|
>>> manualintegrate(cos(x)**4 * sin(x)**3, x) |
|
cos(x)**7/7 - cos(x)**5/5 |
|
>>> integrate(cos(x)**4 * sin(x)**3, x) |
|
cos(x)**7/7 - cos(x)**5/5 |
|
>>> manualintegrate(tan(x), x) |
|
-log(cos(x)) |
|
>>> integrate(tan(x), x) |
|
-log(cos(x)) |
|
|
|
See Also |
|
======== |
|
|
|
sympy.integrals.integrals.integrate |
|
sympy.integrals.integrals.Integral.doit |
|
sympy.integrals.integrals.Integral |
|
""" |
|
result = integral_steps(f, var).eval() |
|
|
|
_parts_u_cache.clear() |
|
|
|
if isinstance(result, Piecewise) and len(result.args) == 2: |
|
cond = result.args[0][1] |
|
if isinstance(cond, Eq) and result.args[1][1] == True: |
|
result = result.func( |
|
(result.args[1][0], Ne(*cond.args)), |
|
(result.args[0][0], True)) |
|
return result |
|
|