|
from sympy.core.containers import Tuple |
|
from sympy.core.basic import Basic |
|
from sympy.core.expr import Expr |
|
from sympy.core.function import AppliedUndef |
|
from sympy.core.relational import Relational |
|
from sympy.core.symbol import Dummy |
|
from sympy.core.sympify import sympify |
|
from sympy.logic.boolalg import BooleanFunction |
|
from sympy.sets.fancysets import ImageSet |
|
from sympy.sets.sets import FiniteSet |
|
from sympy.tensor.indexed import Indexed |
|
|
|
|
|
def _get_free_symbols(exprs): |
|
"""Returns the free symbols of a symbolic expression. |
|
|
|
If the expression contains any of these elements, assume that they are |
|
the "free symbols" of the expression: |
|
|
|
* indexed objects |
|
* applied undefined function (useful for sympy.physics.mechanics module) |
|
""" |
|
if not isinstance(exprs, (list, tuple, set)): |
|
exprs = [exprs] |
|
if all(callable(e) for e in exprs): |
|
return set() |
|
|
|
free = set().union(*[e.atoms(Indexed) for e in exprs]) |
|
free = free.union(*[e.atoms(AppliedUndef) for e in exprs]) |
|
return free or set().union(*[e.free_symbols for e in exprs]) |
|
|
|
|
|
def extract_solution(set_sol, n=10): |
|
"""Extract numerical solutions from a set solution (computed by solveset, |
|
linsolve, nonlinsolve). Often, it is not trivial do get something useful |
|
out of them. |
|
|
|
Parameters |
|
========== |
|
|
|
n : int, optional |
|
In order to replace ImageSet with FiniteSet, an iterator is created |
|
for each ImageSet contained in `set_sol`, starting from 0 up to `n`. |
|
Default value: 10. |
|
""" |
|
images = set_sol.find(ImageSet) |
|
for im in images: |
|
it = iter(im) |
|
s = FiniteSet(*[next(it) for n in range(0, n)]) |
|
set_sol = set_sol.subs(im, s) |
|
return set_sol |
|
|
|
|
|
def _plot_sympify(args): |
|
"""This function recursively loop over the arguments passed to the plot |
|
functions: the sympify function will be applied to all arguments except |
|
those of type string/dict. |
|
|
|
Generally, users can provide the following arguments to a plot function: |
|
|
|
expr, range1 [tuple, opt], ..., label [str, opt], rendering_kw [dict, opt] |
|
|
|
`expr, range1, ...` can be sympified, whereas `label, rendering_kw` can't. |
|
In particular, whenever a special character like $, {, }, ... is used in |
|
the `label`, sympify will raise an error. |
|
""" |
|
if isinstance(args, Expr): |
|
return args |
|
|
|
args = list(args) |
|
for i, a in enumerate(args): |
|
if isinstance(a, (list, tuple)): |
|
args[i] = Tuple(*_plot_sympify(a), sympify=False) |
|
elif not (isinstance(a, (str, dict)) or callable(a) |
|
|
|
|
|
|
|
|
|
or ((a.__class__.__name__ == "Vector") and not isinstance(a, Basic)) |
|
): |
|
args[i] = sympify(a) |
|
return args |
|
|
|
|
|
def _create_ranges(exprs, ranges, npar, label="", params=None): |
|
"""This function does two things: |
|
|
|
1. Check if the number of free symbols is in agreement with the type of |
|
plot chosen. For example, plot() requires 1 free symbol; |
|
plot3d() requires 2 free symbols. |
|
2. Sometime users create plots without providing ranges for the variables. |
|
Here we create the necessary ranges. |
|
|
|
Parameters |
|
========== |
|
|
|
exprs : iterable |
|
The expressions from which to extract the free symbols |
|
ranges : iterable |
|
The limiting ranges provided by the user |
|
npar : int |
|
The number of free symbols required by the plot functions. |
|
For example, |
|
npar=1 for plot, npar=2 for plot3d, ... |
|
params : dict |
|
A dictionary mapping symbols to parameters for interactive plot. |
|
""" |
|
get_default_range = lambda symbol: Tuple(symbol, -10, 10) |
|
|
|
free_symbols = _get_free_symbols(exprs) |
|
if params is not None: |
|
free_symbols = free_symbols.difference(params.keys()) |
|
|
|
if len(free_symbols) > npar: |
|
raise ValueError( |
|
"Too many free symbols.\n" |
|
+ "Expected {} free symbols.\n".format(npar) |
|
+ "Received {}: {}".format(len(free_symbols), free_symbols) |
|
) |
|
|
|
if len(ranges) > npar: |
|
raise ValueError( |
|
"Too many ranges. Received %s, expected %s" % (len(ranges), npar)) |
|
|
|
|
|
rfs = set().union([r[0] for r in ranges]) |
|
if len(rfs) != len(ranges): |
|
raise ValueError("Multiple ranges with the same symbol") |
|
|
|
if len(ranges) < npar: |
|
symbols = free_symbols.difference(rfs) |
|
if symbols != set(): |
|
|
|
for s in symbols: |
|
ranges.append(get_default_range(s)) |
|
|
|
for i in range(npar - len(ranges)): |
|
ranges.append(get_default_range(Dummy())) |
|
|
|
if len(free_symbols) == npar: |
|
|
|
|
|
|
|
rfs = set().union([r[0] for r in ranges]) |
|
if len(free_symbols.difference(rfs)) > 0: |
|
raise ValueError( |
|
"Incompatible free symbols of the expressions with " |
|
"the ranges.\n" |
|
+ "Free symbols in the expressions: {}\n".format(free_symbols) |
|
+ "Free symbols in the ranges: {}".format(rfs) |
|
) |
|
return ranges |
|
|
|
|
|
def _is_range(r): |
|
"""A range is defined as (symbol, start, end). start and end should |
|
be numbers. |
|
""" |
|
|
|
return ( |
|
isinstance(r, Tuple) |
|
and (len(r) == 3) |
|
and (not isinstance(r.args[1], str)) and r.args[1].is_number |
|
and (not isinstance(r.args[2], str)) and r.args[2].is_number |
|
) |
|
|
|
|
|
def _unpack_args(*args): |
|
"""Given a list/tuple of arguments previously processed by _plot_sympify() |
|
and/or _check_arguments(), separates and returns its components: |
|
expressions, ranges, label and rendering keywords. |
|
|
|
Examples |
|
======== |
|
|
|
>>> from sympy import cos, sin, symbols |
|
>>> from sympy.plotting.utils import _plot_sympify, _unpack_args |
|
>>> x, y = symbols('x, y') |
|
>>> args = (sin(x), (x, -10, 10), "f1") |
|
>>> args = _plot_sympify(args) |
|
>>> _unpack_args(*args) |
|
([sin(x)], [(x, -10, 10)], 'f1', None) |
|
|
|
>>> args = (sin(x**2 + y**2), (x, -2, 2), (y, -3, 3), "f2") |
|
>>> args = _plot_sympify(args) |
|
>>> _unpack_args(*args) |
|
([sin(x**2 + y**2)], [(x, -2, 2), (y, -3, 3)], 'f2', None) |
|
|
|
>>> args = (sin(x + y), cos(x - y), x + y, (x, -2, 2), (y, -3, 3), "f3") |
|
>>> args = _plot_sympify(args) |
|
>>> _unpack_args(*args) |
|
([sin(x + y), cos(x - y), x + y], [(x, -2, 2), (y, -3, 3)], 'f3', None) |
|
""" |
|
ranges = [t for t in args if _is_range(t)] |
|
labels = [t for t in args if isinstance(t, str)] |
|
label = None if not labels else labels[0] |
|
rendering_kw = [t for t in args if isinstance(t, dict)] |
|
rendering_kw = None if not rendering_kw else rendering_kw[0] |
|
|
|
|
|
results = [not (_is_range(a) or isinstance(a, (str, dict)) or (a is None)) for a in args] |
|
exprs = [a for a, b in zip(args, results) if b] |
|
return exprs, ranges, label, rendering_kw |
|
|
|
|
|
def _check_arguments(args, nexpr, npar, **kwargs): |
|
"""Checks the arguments and converts into tuples of the |
|
form (exprs, ranges, label, rendering_kw). |
|
|
|
Parameters |
|
========== |
|
|
|
args |
|
The arguments provided to the plot functions |
|
nexpr |
|
The number of sub-expression forming an expression to be plotted. |
|
For example: |
|
nexpr=1 for plot. |
|
nexpr=2 for plot_parametric: a curve is represented by a tuple of two |
|
elements. |
|
nexpr=1 for plot3d. |
|
nexpr=3 for plot3d_parametric_line: a curve is represented by a tuple |
|
of three elements. |
|
npar |
|
The number of free symbols required by the plot functions. For example, |
|
npar=1 for plot, npar=2 for plot3d, ... |
|
**kwargs : |
|
keyword arguments passed to the plotting function. It will be used to |
|
verify if ``params`` has ben provided. |
|
|
|
Examples |
|
======== |
|
|
|
.. plot:: |
|
:context: reset |
|
:format: doctest |
|
:include-source: True |
|
|
|
>>> from sympy import cos, sin, symbols |
|
>>> from sympy.plotting.plot import _check_arguments |
|
>>> x = symbols('x') |
|
>>> _check_arguments([cos(x), sin(x)], 2, 1) |
|
[(cos(x), sin(x), (x, -10, 10), None, None)] |
|
|
|
>>> _check_arguments([cos(x), sin(x), "test"], 2, 1) |
|
[(cos(x), sin(x), (x, -10, 10), 'test', None)] |
|
|
|
>>> _check_arguments([cos(x), sin(x), "test", {"a": 0, "b": 1}], 2, 1) |
|
[(cos(x), sin(x), (x, -10, 10), 'test', {'a': 0, 'b': 1})] |
|
|
|
>>> _check_arguments([x, x**2], 1, 1) |
|
[(x, (x, -10, 10), None, None), (x**2, (x, -10, 10), None, None)] |
|
""" |
|
if not args: |
|
return [] |
|
output = [] |
|
params = kwargs.get("params", None) |
|
|
|
if all(isinstance(a, (Expr, Relational, BooleanFunction)) for a in args[:nexpr]): |
|
|
|
|
|
|
|
|
|
exprs, ranges, label, rendering_kw = _unpack_args(*args) |
|
free_symbols = set().union(*[e.free_symbols for e in exprs]) |
|
ranges = _create_ranges(exprs, ranges, npar, label, params) |
|
|
|
if nexpr > 1: |
|
|
|
|
|
if len(exprs) == nexpr: |
|
exprs = (tuple(exprs),) |
|
for expr in exprs: |
|
|
|
|
|
is_expr = isinstance(expr, (Expr, Relational, BooleanFunction)) |
|
e = (expr,) if is_expr else expr |
|
output.append((*e, *ranges, label, rendering_kw)) |
|
|
|
else: |
|
|
|
|
|
|
|
|
|
_, ranges, labels, rendering_kw = _unpack_args(*args) |
|
labels = [labels] if labels else [] |
|
|
|
|
|
n = (len(ranges) + len(labels) + |
|
(len(rendering_kw) if rendering_kw is not None else 0)) |
|
new_args = args[:-n] if n > 0 else args |
|
|
|
|
|
|
|
|
|
if not isinstance(new_args[0], (list, tuple, Tuple)): |
|
new_args = [new_args] |
|
|
|
|
|
|
|
for arg in new_args: |
|
|
|
l = [a for a in arg if isinstance(a, str)] |
|
if not l: |
|
l = labels |
|
r = [a for a in arg if _is_range(a)] |
|
if not r: |
|
r = ranges.copy() |
|
rend_kw = [a for a in arg if isinstance(a, dict)] |
|
rend_kw = rendering_kw if len(rend_kw) == 0 else rend_kw[0] |
|
|
|
|
|
|
|
arg = [arg[i] for i in range(nexpr)] |
|
free_symbols = set() |
|
if all(not callable(a) for a in arg): |
|
free_symbols = free_symbols.union(*[a.free_symbols for a in arg]) |
|
if len(r) != npar: |
|
r = _create_ranges(arg, r, npar, "", params) |
|
|
|
label = None if not l else l[0] |
|
output.append((*arg, *r, label, rend_kw)) |
|
return output |
|
|