|
from sympy.core.basic import Basic |
|
from sympy.printing import pprint |
|
|
|
import random |
|
|
|
def interactive_traversal(expr): |
|
"""Traverse a tree asking a user which branch to choose. """ |
|
|
|
RED, BRED = '\033[0;31m', '\033[1;31m' |
|
GREEN, BGREEN = '\033[0;32m', '\033[1;32m' |
|
YELLOW, BYELLOW = '\033[0;33m', '\033[1;33m' |
|
BLUE, BBLUE = '\033[0;34m', '\033[1;34m' |
|
MAGENTA, BMAGENTA = '\033[0;35m', '\033[1;35m' |
|
CYAN, BCYAN = '\033[0;36m', '\033[1;36m' |
|
END = '\033[0m' |
|
|
|
def cprint(*args): |
|
print("".join(map(str, args)) + END) |
|
|
|
def _interactive_traversal(expr, stage): |
|
if stage > 0: |
|
print() |
|
|
|
cprint("Current expression (stage ", BYELLOW, stage, END, "):") |
|
print(BCYAN) |
|
pprint(expr) |
|
print(END) |
|
|
|
if isinstance(expr, Basic): |
|
if expr.is_Add: |
|
args = expr.as_ordered_terms() |
|
elif expr.is_Mul: |
|
args = expr.as_ordered_factors() |
|
else: |
|
args = expr.args |
|
elif hasattr(expr, "__iter__"): |
|
args = list(expr) |
|
else: |
|
return expr |
|
|
|
n_args = len(args) |
|
|
|
if not n_args: |
|
return expr |
|
|
|
for i, arg in enumerate(args): |
|
cprint(GREEN, "[", BGREEN, i, GREEN, "] ", BLUE, type(arg), END) |
|
pprint(arg) |
|
print() |
|
|
|
if n_args == 1: |
|
choices = '0' |
|
else: |
|
choices = '0-%d' % (n_args - 1) |
|
|
|
try: |
|
choice = input("Your choice [%s,f,l,r,d,?]: " % choices) |
|
except EOFError: |
|
result = expr |
|
print() |
|
else: |
|
if choice == '?': |
|
cprint(RED, "%s - select subexpression with the given index" % |
|
choices) |
|
cprint(RED, "f - select the first subexpression") |
|
cprint(RED, "l - select the last subexpression") |
|
cprint(RED, "r - select a random subexpression") |
|
cprint(RED, "d - done\n") |
|
|
|
result = _interactive_traversal(expr, stage) |
|
elif choice in ('d', ''): |
|
result = expr |
|
elif choice == 'f': |
|
result = _interactive_traversal(args[0], stage + 1) |
|
elif choice == 'l': |
|
result = _interactive_traversal(args[-1], stage + 1) |
|
elif choice == 'r': |
|
result = _interactive_traversal(random.choice(args), stage + 1) |
|
else: |
|
try: |
|
choice = int(choice) |
|
except ValueError: |
|
cprint(BRED, |
|
"Choice must be a number in %s range\n" % choices) |
|
result = _interactive_traversal(expr, stage) |
|
else: |
|
if choice < 0 or choice >= n_args: |
|
cprint(BRED, "Choice must be in %s range\n" % choices) |
|
result = _interactive_traversal(expr, stage) |
|
else: |
|
result = _interactive_traversal(args[choice], stage + 1) |
|
|
|
return result |
|
|
|
return _interactive_traversal(expr, 0) |
|
|