|
from __future__ import annotations |
|
|
|
import ast |
|
import builtins |
|
import inspect |
|
import json |
|
import textwrap |
|
from collections.abc import Callable |
|
|
|
|
|
class TranspilerError(Exception): |
|
"""Exception raised when transpilation fails or encounters ambiguous syntax.""" |
|
|
|
def __init__( |
|
self, |
|
issues: list[tuple[int, str, str]] | None = None, |
|
message: str | None = None, |
|
): |
|
self.issues = issues or [] |
|
if message: |
|
super().__init__(message) |
|
else: |
|
issue_count = len(self.issues) |
|
issues_text = ( |
|
f"{issue_count} issue{'s' if issue_count != 1 else ''} found:\n\n" |
|
) |
|
for line_no, message, code in self.issues: |
|
issues_text += f"* Line {line_no}: {message}\n>> {code}\n" |
|
super().__init__(issues_text) |
|
|
|
|
|
class PythonToJSVisitor(ast.NodeVisitor): |
|
def __init__(self): |
|
self.js_lines = [] |
|
self.indent_level = 0 |
|
self.declared_vars = set() |
|
self.issues: list[tuple[int, str, str]] = [] |
|
self.source_lines: list[str] = [] |
|
self.var_types: dict[str, type | None] = {} |
|
|
|
|
|
self.builtin_transforms = { |
|
"range": self.transform_range, |
|
"len": self.transform_len, |
|
} |
|
|
|
def add_issue(self, node: ast.AST, message: str) -> None: |
|
"""Add a transpilation issue with source code context.""" |
|
if hasattr(node, "lineno"): |
|
line_no = node.lineno |
|
line_text = self.source_lines[line_no - 1].strip() |
|
self.issues.append((line_no, message, line_text)) |
|
|
|
def visit(self, node): |
|
try: |
|
return super().visit(node) |
|
except TranspilerError as e: |
|
if e.issues: |
|
self.issues.extend(e.issues) |
|
return "" |
|
|
|
def generic_visit(self, node): |
|
self.add_issue(node, f"Unsupported syntax: {type(node).__name__}") |
|
return "" |
|
|
|
def indent(self) -> str: |
|
return " " * self.indent_level |
|
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef): |
|
|
|
params = [] |
|
for arg in node.args.args: |
|
params.append(arg.arg) |
|
if arg.annotation and isinstance(arg.annotation, ast.Name): |
|
type_name = arg.annotation.id |
|
type_obj = globals().get(type_name, getattr(builtins, type_name, None)) |
|
if type_obj is not None: |
|
self.var_types[arg.arg] = type_obj |
|
|
|
header = f"function {node.name}({', '.join(params)}) " + "{" |
|
self.js_lines.append(header) |
|
self.indent_level += 1 |
|
|
|
for stmt in node.body: |
|
self.visit(stmt) |
|
self.indent_level -= 1 |
|
self.js_lines.append("}") |
|
|
|
|
|
def visit_Attribute(self, node: ast.Attribute): |
|
|
|
value = self.visit(node.value) |
|
return f"{value}.{node.attr}" |
|
|
|
|
|
def visit_Return(self, node: ast.Return): |
|
ret_val = "" if node.value is None else self.visit(node.value) |
|
self.js_lines.append(f"{self.indent()}return {ret_val};") |
|
|
|
|
|
def visit_Expr(self, node: ast.Expr): |
|
expr = self.visit(node.value) |
|
self.js_lines.append(f"{self.indent()}{expr};") |
|
|
|
|
|
def visit_Assign(self, node: ast.Assign): |
|
if len(node.targets) != 1: |
|
raise TranspilerError("Multiple assignment targets are not supported yet.") |
|
target_node = node.targets[0] |
|
target = self.visit(target_node) |
|
value = self.visit(node.value) |
|
|
|
if ( |
|
isinstance(target_node, ast.Name) |
|
and target_node.id not in self.declared_vars |
|
): |
|
self.declared_vars.add(target_node.id) |
|
expr_type = self.get_expr_type(node.value) |
|
if expr_type is not None: |
|
self.var_types[target_node.id] = expr_type |
|
self.js_lines.append(f"{self.indent()}let {target} = {value};") |
|
else: |
|
self.js_lines.append(f"{self.indent()}{target} = {value};") |
|
|
|
|
|
def check_type_safety(self, node: ast.AST, *exprs: ast.AST, context: str) -> None: |
|
""" |
|
Check if an operation is type-safe. |
|
Raises TranspilerError if types are ambiguous or incompatible. |
|
""" |
|
types = [self.get_expr_type(expr) for expr in exprs] |
|
|
|
|
|
if any(t is None for t in types): |
|
self.add_issue( |
|
node, |
|
f"Ambiguous operation: Cannot determine types for {context}. " |
|
"Operation behavior may differ in JavaScript based on types.", |
|
) |
|
raise TranspilerError() |
|
|
|
|
|
if len(types) > 1 and not all(t == types[0] for t in types): |
|
type_names = [t.__name__ for t in types] |
|
self.add_issue( |
|
node, |
|
f"Ambiguous operation: Mixed types ({', '.join(type_names)}) in {context}. " |
|
"Behavior may differ in JavaScript.", |
|
) |
|
raise TranspilerError() |
|
|
|
def visit_BinOp(self, node: ast.BinOp): |
|
left = self.visit(node.left) |
|
right = self.visit(node.right) |
|
op = self.visit(node.op) |
|
|
|
self.check_type_safety( |
|
node, node.left, node.right, context=f"'{left} {op} {right}'" |
|
) |
|
return f"({left} {op} {right})" |
|
|
|
def visit_Add(self, node: ast.Add): |
|
return "+" |
|
|
|
def visit_Sub(self, node: ast.Sub): |
|
return "-" |
|
|
|
def visit_Mult(self, node: ast.Mult): |
|
return "*" |
|
|
|
def visit_Div(self, node: ast.Div): |
|
return "/" |
|
|
|
|
|
def visit_Compare(self, node: ast.Compare): |
|
if len(node.ops) != 1 or len(node.comparators) != 1: |
|
raise TranspilerError("Only single comparisons are supported") |
|
op = node.ops[0] |
|
left = self.visit(node.left) |
|
right = self.visit(node.comparators[0]) |
|
|
|
if isinstance(op, ast.In): |
|
|
|
return f"{right}.includes({left})" |
|
elif isinstance(op, ast.NotIn): |
|
|
|
return f"!{right}.includes({left})" |
|
else: |
|
|
|
self.check_type_safety( |
|
node, |
|
node.left, |
|
node.comparators[0], |
|
context=f"comparison {left} {self.visit(op)} {right}", |
|
) |
|
op_str = self.visit(op) |
|
return f"({left} {op_str} {right})" |
|
|
|
def visit_Gt(self, node: ast.Gt): |
|
return ">" |
|
|
|
def visit_Lt(self, node: ast.Lt): |
|
return "<" |
|
|
|
def visit_GtE(self, node: ast.GtE): |
|
return ">=" |
|
|
|
def visit_LtE(self, node: ast.LtE): |
|
return "<=" |
|
|
|
def visit_Eq(self, node: ast.Eq): |
|
return "===" |
|
|
|
def visit_NotEq(self, node: ast.NotEq): |
|
return "!==" |
|
|
|
|
|
def visit_In(self, node: ast.In): |
|
return "in" |
|
|
|
def visit_NotIn(self, node: ast.NotIn): |
|
return "not in" |
|
|
|
|
|
def visit_If(self, node: ast.If): |
|
test = self.visit(node.test) |
|
self.js_lines.append(f"{self.indent()}if ({test}) " + "{") |
|
self.indent_level += 1 |
|
for stmt in node.body: |
|
self.visit(stmt) |
|
self.indent_level -= 1 |
|
self.js_lines.append(f"{self.indent()}" + "}") |
|
|
|
|
|
current = node |
|
while ( |
|
current.orelse |
|
and len(current.orelse) == 1 |
|
and isinstance(current.orelse[0], ast.If) |
|
): |
|
current = current.orelse[0] |
|
test = self.visit(current.test) |
|
self.js_lines.append(f"{self.indent()}else if ({test}) " + "{") |
|
self.indent_level += 1 |
|
for stmt in current.body: |
|
self.visit(stmt) |
|
self.indent_level -= 1 |
|
self.js_lines.append(f"{self.indent()}" + "}") |
|
|
|
|
|
if current.orelse: |
|
self.js_lines.append(f"{self.indent()}else " + "{") |
|
self.indent_level += 1 |
|
for stmt in current.orelse: |
|
self.visit(stmt) |
|
self.indent_level -= 1 |
|
self.js_lines.append(f"{self.indent()}" + "}") |
|
|
|
|
|
def transform_range(self, node: ast.Call) -> str: |
|
"""Transform Python's range() to an equivalent JavaScript array expression.""" |
|
args = [self.visit(arg) for arg in node.args] |
|
for arg in node.args: |
|
self.check_type_safety(arg, arg, context="range() argument") |
|
if len(args) == 1: |
|
return f"Array.from({{length: {args[0]}}}, (_, i) => i)" |
|
elif len(args) == 2: |
|
return f"Array.from({{length: {args[1]} - {args[0]}}}, (_, i) => i + {args[0]})" |
|
elif len(args) == 3: |
|
raise TranspilerError("range() with step argument is not supported yet") |
|
else: |
|
raise TranspilerError("Invalid number of arguments for range()") |
|
|
|
def transform_len(self, node: ast.Call) -> str: |
|
"""Transform Python's len() to the equivalent JavaScript property access.""" |
|
if len(node.args) != 1: |
|
raise TranspilerError("len() takes exactly one argument") |
|
arg_code = self.visit(node.args[0]) |
|
t = self.get_expr_type(node.args[0]) |
|
|
|
if t is dict: |
|
return f"Object.keys({arg_code}).length" |
|
else: |
|
|
|
return f"{arg_code}.length" |
|
|
|
|
|
def _handle_gradio_component_updates(self, node: ast.Call): |
|
"""Handle Gradio component calls and return JSON representation.""" |
|
kwargs = {} |
|
for kw in node.keywords: |
|
if isinstance(kw.value, ast.Constant) and kw.value.value is None: |
|
|
|
|
|
kwargs[kw.arg] = None |
|
continue |
|
value = self.visit(kw.value) |
|
try: |
|
kwargs[kw.arg] = ast.literal_eval(value) |
|
except Exception: |
|
kwargs[kw.arg] = value |
|
kwargs["__type__"] = "update" |
|
return json.dumps(kwargs) |
|
|
|
def visit_Call(self, node: ast.Call): |
|
try: |
|
import gradio |
|
|
|
has_gradio = True |
|
except ImportError: |
|
has_gradio = False |
|
|
|
|
|
if isinstance(node.func, ast.Name): |
|
if node.func.id in self.builtin_transforms: |
|
return self.builtin_transforms[node.func.id](node) |
|
|
|
|
|
if has_gradio: |
|
try: |
|
|
|
if node.func.id == "update": |
|
return self._handle_gradio_component_updates(node) |
|
|
|
component_class = getattr(gradio, node.func.id, None) |
|
if component_class and issubclass( |
|
component_class, gradio.blocks.Block |
|
): |
|
return self._handle_gradio_component_updates(node) |
|
except Exception: |
|
pass |
|
|
|
for arg in node.args: |
|
self.check_type_safety( |
|
arg, arg, context=f"argument in {node.func.id}() call" |
|
) |
|
self.add_issue(node, f'Unsupported function "{node.func.id}()"') |
|
return "" |
|
|
|
|
|
if isinstance(node.func, ast.Attribute) and has_gradio: |
|
try: |
|
|
|
if isinstance(node.func.value, ast.Name) and node.func.value.id in { |
|
"gradio", |
|
"gr", |
|
}: |
|
|
|
if node.func.attr == "update": |
|
return self._handle_gradio_component_updates(node) |
|
|
|
component_class = getattr(gradio, node.func.attr, None) |
|
if component_class and issubclass( |
|
component_class, gradio.blocks.Block |
|
): |
|
return self._handle_gradio_component_updates(node) |
|
except Exception: |
|
pass |
|
|
|
|
|
func = self.visit(node.func) |
|
args = [self.visit(arg) for arg in node.args] |
|
|
|
if isinstance(node.func, ast.Attribute): |
|
self.check_type_safety( |
|
node.func, node.func.value, context=f"object in method call {func}" |
|
) |
|
for arg in node.args: |
|
self.check_type_safety( |
|
arg, arg, context=f"argument in method call {func}" |
|
) |
|
|
|
return f"{func}({', '.join(args)})" |
|
|
|
|
|
def visit_Name(self, node: ast.Name): |
|
return node.id |
|
|
|
|
|
def visit_Constant(self, node: ast.Constant): |
|
if node.value is None: |
|
return "null" |
|
return repr(node.value) |
|
|
|
|
|
def visit_For(self, node: ast.For): |
|
target = self.visit(node.target) |
|
iter_expr = self.visit(node.iter) |
|
|
|
|
|
if ( |
|
isinstance(node.iter, ast.Call) |
|
and isinstance(node.iter.func, ast.Name) |
|
and node.iter.func.id == "range" |
|
): |
|
if isinstance(node.target, ast.Name): |
|
self.var_types[node.target.id] = int |
|
|
|
self.js_lines.append(f"{self.indent()}for (let {target} of {iter_expr}) " + "{") |
|
self.indent_level += 1 |
|
for stmt in node.body: |
|
self.visit(stmt) |
|
self.indent_level -= 1 |
|
self.js_lines.append(f"{self.indent()}" + "}") |
|
|
|
|
|
def visit_While(self, node: ast.While): |
|
test = self.visit(node.test) |
|
self.js_lines.append(f"{self.indent()}while ({test}) " + "{") |
|
self.indent_level += 1 |
|
for stmt in node.body: |
|
self.visit(stmt) |
|
self.indent_level -= 1 |
|
self.js_lines.append(f"{self.indent()}" + "}") |
|
|
|
|
|
def visit_List(self, node: ast.List): |
|
elements = [self.visit(elt) for elt in node.elts] |
|
return f"[{', '.join(elements)}]" |
|
|
|
|
|
def visit_Tuple(self, node: ast.Tuple): |
|
elements = [self.visit(elt) for elt in node.elts] |
|
return f"[{', '.join(elements)}]" |
|
|
|
|
|
def visit_ListComp(self, node: ast.ListComp): |
|
""" |
|
Transform a Python list comprehension into a combination of filter and map calls. |
|
For example: |
|
[x * 2 for x in arr if x > 10] |
|
becomes: |
|
arr.filter(x => x > 10).map(x => x * 2) |
|
""" |
|
if len(node.generators) != 1: |
|
self.add_issue( |
|
node, "Only single generator list comprehensions are supported" |
|
) |
|
raise TranspilerError() |
|
gen = node.generators[0] |
|
iter_js = self.visit(gen.iter) |
|
target_js = self.visit(gen.target) |
|
elt_js = self.visit(node.elt) |
|
if gen.ifs: |
|
|
|
conditions = " && ".join(self.visit(if_node) for if_node in gen.ifs) |
|
result = f"{iter_js}.filter({target_js} => {conditions})" |
|
|
|
if not (isinstance(node.elt, ast.Name) and node.elt.id == gen.target.id): |
|
result += f".map({target_js} => {elt_js})" |
|
else: |
|
result = f"{iter_js}.map({target_js} => {elt_js})" |
|
return result |
|
|
|
|
|
def visit_Subscript(self, node: ast.Subscript): |
|
value = self.visit(node.value) |
|
slice_value = self.visit(node.slice) |
|
return f"{value}[{slice_value}]" |
|
|
|
|
|
def visit_AugAssign(self, node: ast.AugAssign): |
|
target = self.visit(node.target) |
|
op = self.visit(node.op).strip() |
|
value = self.visit(node.value) |
|
self.js_lines.append(f"{self.indent()}{target} {op}= {value};") |
|
|
|
|
|
def visit_BoolOp(self, node: ast.BoolOp): |
|
op = self.visit(node.op) |
|
values = [self.visit(value) for value in node.values] |
|
|
|
return f"({f' {op} '.join(values)})" |
|
|
|
def visit_And(self, node: ast.And): |
|
return "&&" |
|
|
|
def visit_Or(self, node: ast.Or): |
|
return "||" |
|
|
|
|
|
def visit_Dict(self, node: ast.Dict): |
|
pairs = [] |
|
for key, value in zip(node.keys, node.values): |
|
if key is None: |
|
continue |
|
key_js = self.visit(key) |
|
value_js = self.visit(value) |
|
pairs.append(f"{key_js}: {value_js}") |
|
return f"{{{', '.join(pairs)}}}" |
|
|
|
def get_expr_type(self, node: ast.AST) -> type | None: |
|
"""Determine the type of an expression if possible.""" |
|
if isinstance(node, ast.Constant): |
|
return type(node.value) |
|
elif isinstance(node, ast.Name): |
|
|
|
if node.id in self.var_types: |
|
return self.var_types[node.id] |
|
return None |
|
elif isinstance(node, ast.BinOp): |
|
left_type = self.get_expr_type(node.left) |
|
right_type = self.get_expr_type(node.right) |
|
if left_type == right_type and left_type is not None: |
|
return left_type |
|
return None |
|
elif isinstance(node, ast.Call): |
|
|
|
return None |
|
elif isinstance(node, ast.List): |
|
return list |
|
elif isinstance(node, ast.Dict): |
|
return dict |
|
elif isinstance(node, ast.Tuple): |
|
return tuple |
|
return None |
|
|
|
|
|
def transpile(fn: Callable, validate: bool = False) -> str: |
|
""" |
|
Transpiles a Python function to JavaScript and returns the JavaScript code as a string. |
|
|
|
Parameters: |
|
fn: The Python function to transpile. |
|
validate: If True, the function will be validated to ensure it takes no arguments & only returns gradio component property updates. This is used when Groovy is used inside Gradio and `gradio` must be installed to use this. |
|
|
|
Returns: |
|
The JavaScript code as a string. |
|
|
|
Raises: |
|
TranspilerError: If the function cannot be transpiled or if the transpiled function is not valid. |
|
""" |
|
if validate: |
|
sig = inspect.signature(fn) |
|
if sig.parameters: |
|
param_names = list(sig.parameters.keys()) |
|
raise TranspilerError( |
|
message=f"Function must take no arguments for client-side use, but got: {param_names}" |
|
) |
|
|
|
try: |
|
source = inspect.getsource(fn) |
|
source = textwrap.dedent(source) |
|
except Exception as e: |
|
raise TranspilerError( |
|
message="Could not retrieve source code from the function." |
|
) from e |
|
|
|
try: |
|
tree = ast.parse(source) |
|
except SyntaxError as e: |
|
raise TranspilerError(message="Could not parse function source.") from e |
|
|
|
if validate: |
|
try: |
|
import gradio |
|
except ImportError: |
|
raise TranspilerError(message="Gradio must be installed for validation.") |
|
|
|
func_node = None |
|
for node in ast.walk(tree): |
|
if isinstance(node, ast.FunctionDef) and node.name == fn.__name__: |
|
func_node = node |
|
break |
|
|
|
if func_node: |
|
return_nodes = [] |
|
for node in ast.walk(func_node): |
|
if isinstance(node, ast.Return) and node.value is not None: |
|
return_nodes.append(node) |
|
|
|
if not return_nodes: |
|
raise TranspilerError( |
|
message="Function must return Gradio component updates, but no return statement found." |
|
) |
|
|
|
for return_node in return_nodes: |
|
if not _is_valid_gradio_return(return_node.value): |
|
line_no = return_node.lineno |
|
line_text = source.splitlines()[line_no - 1].strip() |
|
raise TranspilerError( |
|
message=f"Function must only return Gradio component updates. Invalid return at line {line_no}: {line_text}" |
|
) |
|
|
|
func_node = None |
|
for node in ast.walk(tree): |
|
if isinstance(node, (ast.FunctionDef, ast.Lambda)): |
|
func_node = node |
|
break |
|
|
|
if func_node is None: |
|
raise TranspilerError( |
|
message="No function or lambda definition found in the provided source." |
|
) |
|
|
|
visitor = PythonToJSVisitor() |
|
visitor.source_lines = source.splitlines() |
|
|
|
if isinstance(func_node, ast.Lambda): |
|
args = [arg.arg for arg in func_node.args.args] |
|
visitor.js_lines.append(f"function ({', '.join(args)}) " + "{") |
|
visitor.indent_level += 1 |
|
visitor.js_lines.append( |
|
f"{visitor.indent()}return {visitor.visit(func_node.body)};" |
|
) |
|
visitor.indent_level -= 1 |
|
visitor.js_lines.append("}") |
|
else: |
|
visitor.visit(func_node) |
|
|
|
if visitor.issues: |
|
raise TranspilerError(issues=visitor.issues) |
|
|
|
return "\n".join(visitor.js_lines) |
|
|
|
|
|
def _is_valid_gradio_return(node: ast.AST) -> bool: |
|
""" |
|
Check if a return value is a valid Gradio component or collection of components. |
|
|
|
Args: |
|
node: The AST node representing the return value |
|
|
|
Returns: |
|
bool: True if the return value is valid, False otherwise |
|
""" |
|
|
|
if isinstance(node, ast.Call): |
|
if isinstance(node.func, ast.Attribute) and isinstance( |
|
node.func.value, ast.Name |
|
): |
|
if node.func.value.id in {"gr", "gradio"}: |
|
try: |
|
import gradio |
|
|
|
if node.func.attr == "update": |
|
return True |
|
|
|
component_class = getattr(gradio, node.func.attr, None) |
|
if component_class and issubclass( |
|
component_class, gradio.blocks.Block |
|
): |
|
if node.args: |
|
return False |
|
for kw in node.keywords: |
|
if kw.arg == "value": |
|
return False |
|
return True |
|
except (ImportError, AttributeError): |
|
pass |
|
return False |
|
elif isinstance(node.func, ast.Name): |
|
try: |
|
import gradio |
|
|
|
if node.func.id == "update": |
|
return True |
|
|
|
component_class = getattr(gradio, node.func.id, None) |
|
if component_class and issubclass(component_class, gradio.blocks.Block): |
|
if node.args: |
|
return False |
|
for kw in node.keywords: |
|
if kw.arg == "value": |
|
return False |
|
return True |
|
except (ImportError, AttributeError): |
|
pass |
|
return False |
|
|
|
elif isinstance(node, (ast.Tuple, ast.List)): |
|
if not node.elts: |
|
return False |
|
return all(_is_valid_gradio_return(elt) for elt in node.elts) |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
import gradio as gr |
|
|
|
def filter_rows_by_term(): |
|
return gr.update(selected=2, visible=True, info=None) |
|
|
|
js_code = transpile(filter_rows_by_term, validate=True) |
|
print(js_code) |
|
|