|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import glob |
|
import importlib |
|
import os |
|
import re |
|
from abc import ABC, abstractmethod |
|
from collections import Counter, defaultdict, deque |
|
from typing import Dict, Optional, Set, Union |
|
|
|
import libcst as cst |
|
from check_copies import run_ruff |
|
from create_dependency_mapping import find_priority_list |
|
from libcst import ClassDef, CSTVisitor |
|
from libcst import matchers as m |
|
from libcst.metadata import MetadataWrapper, ParentNodeProvider, PositionProvider, ScopeProvider |
|
|
|
from transformers import logging |
|
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 |
|
# This file was automatically generated from {relative_path}. |
|
# Do NOT edit this file manually as any edits will be overwritten by the generation of |
|
# the file from the modular. If any change should be done, please apply the change to the |
|
# {short_name} file directly. One of our CI enforces this. |
|
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 |
|
""" |
|
|
|
|
|
def get_module_source_from_name(module_name: str) -> str: |
|
|
|
spec = importlib.util.find_spec(module_name) |
|
if spec is None or spec.origin is None: |
|
raise ValueError(f"Cannot open file associated with {module_name} module.") |
|
|
|
with open(spec.origin, "r", encoding="utf-8") as file: |
|
source_code = file.read() |
|
return source_code |
|
|
|
|
|
def preserve_case_replace(text, patterns: dict, default_name: str): |
|
|
|
regex_pattern = "|".join(re.escape(key) for key in patterns.keys()) |
|
compiled_regex = re.compile(f"(?<![a-z0-9])({regex_pattern})(.|$)", re.IGNORECASE | re.DOTALL) |
|
|
|
def replace(match): |
|
matched_pattern = match.group(1) |
|
next_char = match.group(2) |
|
new_pattern = patterns.get(matched_pattern, default_name) |
|
|
|
|
|
|
|
if len(patterns) == 2 and matched_pattern.isupper(): |
|
if not next_char.isalpha(): |
|
|
|
new_pattern = patterns[matched_pattern.lower()].upper() |
|
|
|
return new_pattern + next_char |
|
|
|
return compiled_regex.sub(replace, text) |
|
|
|
|
|
def get_cased_name(lowercase_name: str) -> str: |
|
"""From a model name in lowercase in the format `my_model`, return the cased name in the format `MyModel`.""" |
|
alt_lowercase_name = lowercase_name.replace("_", "-") |
|
if lowercase_name in CONFIG_MAPPING_NAMES: |
|
return CONFIG_MAPPING_NAMES[lowercase_name].replace("Config", "") |
|
elif alt_lowercase_name in CONFIG_MAPPING_NAMES: |
|
return CONFIG_MAPPING_NAMES[alt_lowercase_name].replace("Config", "") |
|
else: |
|
return "".join(x.title() for x in lowercase_name.split("_")) |
|
|
|
|
|
def get_lowercase_name(cased_name: str) -> str: |
|
"""From a model name in Camelcase in the format `MyModel`, return the lowercase name in the format `my_model`.""" |
|
inverse_mapping = {value: key for key, value in CONFIG_MAPPING_NAMES.items()} |
|
if cased_name + "Config" in inverse_mapping: |
|
return inverse_mapping[cased_name + "Config"] |
|
else: |
|
return "_".join([s.lower() for s in re.findall(r"[A-Z][^A-Z]*", cased_name)]) |
|
|
|
|
|
class ReplaceNameTransformer(m.MatcherDecoratableTransformer): |
|
"""A transformer that replaces `old_name` with `new_name` in comments, string and any references. |
|
It should take into account name like `MyNewModel`, or `my_new_model`. Without using the AUTO_MAPPING. |
|
Supported renaming patterns: |
|
- llama -> my_new_model and my_new_model -> llama |
|
- Llama -> MyNewModel and MyNewModel -> Llama |
|
- LLAMA -> MY_NEW_MODEL and MY_NEW_MODEL -> LLAMA |
|
- LLaMa -> MyNewModel abd MyNewModel -> Llama |
|
""" |
|
|
|
def __init__(self, old_name: str, new_name: str, original_new_model_name: str = "", only_doc: bool = False): |
|
super().__init__() |
|
old_name = old_name.replace("-", "_") |
|
new_name = new_name.replace("-", "_") |
|
self.old_name = old_name |
|
self.new_name = new_name |
|
self.cased_new_name = get_cased_name(self.new_name) |
|
self.cased_old_name = get_cased_name(self.old_name) |
|
self.patterns = { |
|
old_name: new_name, |
|
old_name.upper(): new_name.upper(), |
|
|
|
self.cased_old_name: self.cased_new_name, |
|
} |
|
|
|
self.original_new_model_name = original_new_model_name |
|
self.only_doc = only_doc |
|
|
|
def _replace_name(self, original_node, updated_node): |
|
if re.findall(r"# Copied from", updated_node.value): |
|
return cst.RemoveFromParent() |
|
update = preserve_case_replace(updated_node.value, self.patterns, self.cased_new_name) |
|
return updated_node.with_changes(value=update) |
|
|
|
@m.leave(m.SimpleString() | m.Comment()) |
|
def replace_name(self, original_node, updated_node): |
|
return self._replace_name(original_node, updated_node) |
|
|
|
def leave_Name(self, original_node, updated_node): |
|
if not self.only_doc: |
|
return self._replace_name(original_node, updated_node) |
|
return updated_node |
|
|
|
def leave_ImportFrom(self, original_node, updated_node): |
|
"""The imports from other file types (configuration, processing etc) should use original model name.""" |
|
if self.original_new_model_name != self.new_name and m.matches(updated_node.module, m.Name()): |
|
patterns = "|".join(ALL_FILE_TYPES) |
|
regex = rf"({patterns})_{self.new_name}" |
|
new_source = re.sub( |
|
regex, lambda m: f"{m.group(1)}_{self.original_new_model_name}", updated_node.module.value |
|
) |
|
updated_node = updated_node.with_changes(module=updated_node.module.with_changes(value=new_source)) |
|
return updated_node |
|
|
|
|
|
DOCSTRING_NODE = m.SimpleStatementLine( |
|
body=[ |
|
m.Expr( |
|
value=m.SimpleString( |
|
|
|
value=m.MatchIfTrue(lambda value: re.search(r"\"\"\"[\s\S]*\"\"\"", value) is not None) |
|
) |
|
) |
|
] |
|
) |
|
|
|
|
|
def SUPER_CALL_NODE(func_name): |
|
return m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")), attr=m.Name(func_name))) |
|
|
|
|
|
def is_call_to_super(node, func_name): |
|
return m.matches( |
|
node, m.SimpleStatementLine(body=[m.Return(SUPER_CALL_NODE(func_name)) | m.Expr(SUPER_CALL_NODE(func_name))]) |
|
) |
|
|
|
|
|
def get_full_attribute_name(node: Union[cst.Attribute, cst.Name]) -> Optional[str]: |
|
"""Get the full name of an Attribute or Name node (e.g. `"nn.Module"` for an Attribute representing it). If the |
|
successive value of an Attribute are not Name nodes, return `None`.""" |
|
if m.matches(node, m.Name()): |
|
return node.value |
|
elif m.matches(node, m.Attribute()): |
|
if not m.matches(node.attr, m.Name()): |
|
return None |
|
name = node.attr.value |
|
new_node = node.value |
|
while m.matches(new_node, m.Attribute()): |
|
if not m.matches(new_node.attr, m.Name()): |
|
return None |
|
name = new_node.attr.value + "." + name |
|
new_node = new_node.value |
|
if not m.matches(new_node, m.Name()): |
|
return None |
|
return new_node.value + "." + name |
|
return None |
|
|
|
|
|
|
|
class ReplaceMethodCallTransformer(cst.CSTTransformer): |
|
def __init__(self, all_bases: Set[str]): |
|
self.all_bases = all_bases |
|
|
|
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.CSTNode: |
|
|
|
if ( |
|
m.matches(original_node.value, m.Name() | m.Attribute()) |
|
and get_full_attribute_name(original_node.value) in self.all_bases |
|
and m.matches(original_node.attr, m.Name()) |
|
): |
|
|
|
return updated_node.with_changes( |
|
value=cst.Call(cst.Name("super")), |
|
) |
|
|
|
elif ( |
|
m.matches(original_node.value, m.Call()) |
|
and m.matches(original_node.value.func, m.Name() | m.Attribute()) |
|
and get_full_attribute_name(original_node.value.func) in self.all_bases |
|
and m.matches(original_node.attr, m.Name()) |
|
): |
|
|
|
return updated_node.with_changes(value=cst.Call(cst.Name("super"))) |
|
return updated_node |
|
|
|
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode: |
|
|
|
if m.matches(original_node.func, m.Attribute()) and ( |
|
|
|
( |
|
m.matches(original_node.func.value, m.Call()) |
|
and m.matches(original_node.func.value.func, m.Name() | m.Attribute()) |
|
and get_full_attribute_name(original_node.func.value.func) in self.all_bases |
|
) |
|
or |
|
|
|
( |
|
m.matches(original_node.func.value, m.Name() | m.Attribute()) |
|
and get_full_attribute_name(original_node.func.value) in self.all_bases |
|
) |
|
): |
|
|
|
if len(original_node.args) > 0 and m.matches(original_node.args[0].value, m.Name("self")): |
|
|
|
new_args = updated_node.args[1:] |
|
else: |
|
new_args = updated_node.args |
|
|
|
return updated_node.with_changes(args=new_args) |
|
return updated_node |
|
|
|
|
|
def get_docstring_indent(docstring): |
|
|
|
match = re.search(r'(?:"""|\'\'\'|```)\n(\s+)', docstring) |
|
if match: |
|
|
|
return len(match.group(1)) |
|
return 0 |
|
|
|
|
|
def is_full_docstring(original_docstring: str, new_docstring: str, original_level: int) -> bool: |
|
"""Check if `new_docstring` is a full docstring, or if it is only part of a docstring that should then |
|
be merged with the existing old one. |
|
""" |
|
|
|
new_docstring = new_docstring.split('"""', 1)[1] |
|
|
|
if re.search(r"\n\s*Args:\n", new_docstring): |
|
return True |
|
elif re.search(r"\n\s*Args:\n", original_docstring): |
|
return False |
|
|
|
param_pattern = re.compile( |
|
|
|
rf"^\s{{0,{original_level}}}(\w+)\s*\(\s*([^, \)]*)(\s*.*?)\s*\)\s*:\s*((?:(?!\n^\s{{0,{original_level}}}\w+\s*\().)*)", |
|
re.DOTALL | re.MULTILINE, |
|
) |
|
match_object = param_pattern.search(new_docstring) |
|
if match_object is not None: |
|
return True |
|
|
|
|
|
match_object = re.search(r"\n([^\S\n]*)Returns:\n", new_docstring) |
|
if match_object is not None: |
|
full_indent = match_object.group(1) |
|
striped_doc = new_docstring.strip("\n") |
|
if striped_doc.startswith(full_indent + " " * 4) or striped_doc.startswith(full_indent + "\t"): |
|
return True |
|
return False |
|
|
|
|
|
def merge_docstrings(original_docstring, updated_docstring): |
|
original_level = get_docstring_indent(original_docstring) |
|
if not is_full_docstring(original_docstring, updated_docstring, original_level): |
|
|
|
parts = original_docstring.split("```") |
|
if "```" in updated_docstring and len(parts) > 1: |
|
updated_docstring = updated_docstring.lstrip('r"') |
|
new_parts = updated_docstring.split("```") |
|
if len(new_parts) != 3: |
|
raise ValueError("There should only be one example, and it should have opening and closing '```'") |
|
parts[1] = new_parts[1] |
|
updated_docstring = "".join( |
|
[ |
|
f"\n{original_level * ' '}```", |
|
parts[1], |
|
"```", |
|
parts[2], |
|
] |
|
) |
|
docstring_opening, original_start_docstring = parts[0].rstrip(" \n").split('"""')[:2] |
|
new_start_docstring = new_parts[0].rstrip(" \n") |
|
docstring_opening += '"""' |
|
if new_start_docstring.startswith(original_start_docstring): |
|
updated_docstring = new_start_docstring + "\n" + updated_docstring |
|
elif original_start_docstring.endswith(new_start_docstring): |
|
updated_docstring = original_start_docstring + "\n" + updated_docstring |
|
else: |
|
updated_docstring = original_start_docstring + "\n" + new_start_docstring + "\n" + updated_docstring |
|
updated_docstring = docstring_opening + updated_docstring |
|
elif updated_docstring not in original_docstring: |
|
|
|
if re.search(r"\n\s*.*\(.*\)\:\n\s*\w", updated_docstring): |
|
updated_docstring = updated_docstring.replace("\n ", "\n ") |
|
updated_docstring = original_docstring.rstrip('"') + "\n" + updated_docstring.lstrip('r"\n') |
|
return updated_docstring |
|
|
|
|
|
class SuperTransformer(cst.CSTTransformer): |
|
METADATA_DEPENDENCIES = (ParentNodeProvider,) |
|
|
|
def __init__(self, python_module: cst.Module, original_methods, updated_methods, all_bases=None): |
|
self.python_module = python_module |
|
self.original_methods = original_methods |
|
self.updated_methods = updated_methods |
|
self.all_assign_target = {} |
|
self.deleted_targets = {} |
|
self.all_bases = all_bases or [] |
|
self.transformer = ReplaceMethodCallTransformer(set(self.all_bases)) |
|
|
|
def update_body(self, existing_body, new_statements): |
|
""" |
|
Helper method to update the body by removing duplicates before adding new statements. |
|
`existing_body` is the body of the original method, the parent class |
|
`new_statements` are the additional statements |
|
""" |
|
deduplicated_new_body = [] |
|
existing_nodes = set() |
|
for node in new_statements: |
|
if m.matches(node, m.SimpleStatementLine(body=[m.Assign()])): |
|
target = self.python_module.code_for_node(node.body[0].targets[0].target) |
|
self.all_assign_target[target] = node |
|
if m.matches(node, m.SimpleStatementLine(body=[m.Del()])): |
|
target = self.python_module.code_for_node(node.body[0].target) |
|
self.deleted_targets[target] = node |
|
|
|
for stmt in existing_body: |
|
if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])): |
|
target = self.python_module.code_for_node(stmt.body[0].targets[0].target) |
|
if target in self.deleted_targets: |
|
continue |
|
if target in self.all_assign_target: |
|
stmt = self.all_assign_target[target] |
|
|
|
elif m.matches(stmt, DOCSTRING_NODE): |
|
continue |
|
comment_less_code = re.sub(r"#.*", "", self.python_module.code_for_node(stmt)).strip() |
|
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() |
|
deduplicated_new_body.append(stmt) |
|
existing_nodes.add(comment_less_code) |
|
|
|
for node in new_statements: |
|
code = self.python_module.code_for_node(node) |
|
comment_less_code = re.sub(r"#.*", "", code).strip() |
|
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() |
|
if node not in deduplicated_new_body and comment_less_code not in existing_nodes: |
|
if not m.matches(node, m.SimpleStatementLine(body=[m.Del()])): |
|
deduplicated_new_body.append(node) |
|
existing_nodes.add(comment_less_code) |
|
|
|
deduplicated_new_body = self._fix_post_init_location(deduplicated_new_body) |
|
|
|
return deduplicated_new_body |
|
|
|
def _fix_post_init_location(self, new_body: list[cst.CSTNode]): |
|
"""Fix the location of the `post_init()` in the new body, if we added statements after the call to |
|
`super()` (it needs to be the very last statement called)""" |
|
|
|
for i, node in enumerate(new_body): |
|
code = self.python_module.code_for_node(node) |
|
comment_less_code = re.sub(r"#.*", "", code).strip() |
|
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() |
|
if "self.post_init(" in comment_less_code and i < len(new_body) - 1: |
|
|
|
new_body.pop(i) |
|
new_body.append(node) |
|
break |
|
return new_body |
|
|
|
def _fix_init_location(self, new_body): |
|
"""Fix the location of the `super().__init__()` in the new body, if we had new statements before it.""" |
|
start_index = 0 |
|
for i, node in enumerate(new_body): |
|
if m.matches(node, DOCSTRING_NODE) and i == start_index: |
|
start_index += 1 |
|
continue |
|
code = self.python_module.code_for_node(node) |
|
comment_less_code = re.sub(r"#.*", "", code).strip() |
|
comment_less_code = re.sub(r"\ *\n", "\n", comment_less_code).strip() |
|
if "super().__init__" in comment_less_code and i > start_index: |
|
|
|
node = new_body.pop(i) |
|
new_body = new_body[:start_index] + [node] + new_body[start_index:] |
|
break |
|
return new_body |
|
|
|
def replace_super_calls(self, node: cst.IndentedBlock, func_name: str) -> cst.CSTNode: |
|
"""Updates the body of the input `node`'s `func_name` function by replacing calls |
|
to super().func_name() with the source code of the parent class' `func_name`. |
|
It keeps everything that is defined before `super().func_name()`. |
|
""" |
|
self.has_docstring = False |
|
parent_has_docstring = False |
|
if func_name in self.original_methods: |
|
parent_has_docstring = m.matches(self.original_methods[func_name].body.body[0], DOCSTRING_NODE) |
|
new_body = [] |
|
has_super_call = False |
|
|
|
for i, expr in enumerate(node.body): |
|
if is_call_to_super(expr, func_name): |
|
has_super_call = True |
|
new_body.extend(self.update_body(self.original_methods[func_name].body.body, node.body[i + 1 :])) |
|
new_body = self._fix_init_location(new_body) |
|
else: |
|
expr = expr.visit(self.transformer) |
|
if m.matches(expr, DOCSTRING_NODE): |
|
self.has_docstring = True |
|
if parent_has_docstring: |
|
original_docstring = self.original_methods[func_name].body.body[0].body[0].value.value |
|
updated_docstring = expr.body[0].value.value |
|
merged_doc = merge_docstrings(original_docstring, updated_docstring) |
|
new_node = [expr.with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))])] |
|
else: |
|
new_node = [expr] |
|
new_body.extend(new_node) |
|
elif not m.matches(expr, m.SimpleStatementLine(body=[m.Del()])) and not has_super_call: |
|
new_body.append(expr) |
|
if not self.has_docstring and parent_has_docstring: |
|
new_body = [self.original_methods[func_name].body.body[0]] + new_body |
|
return node.with_changes(body=new_body) |
|
|
|
def leave_FunctionDef(self, original_node: cst.Call, updated_node: cst.Call) -> cst.CSTNode: |
|
if updated_node.name.value in self.updated_methods: |
|
name = updated_node.name.value |
|
new_body = self.replace_super_calls(updated_node.body, name) |
|
return updated_node.with_changes(body=new_body, params=updated_node.params) |
|
return updated_node |
|
|
|
def leave_Return(self, original_node: cst.Return, updated_node: cst.Return) -> cst.CSTNode: |
|
""" "When a return statement is reached, it is replaced with the unrolled super code""" |
|
if m.matches(updated_node.value, m.Call(func=m.Attribute(attr=m.Name("super")))): |
|
func_def = self.get_metadata(ParentNodeProvider, original_node) |
|
if m.matched(func_def, m.FunctionDef()) and func_def.name.value in self.original_methods: |
|
updated_return_value = updated_node.value.with_changes( |
|
args=[ |
|
cst.Arg( |
|
value=cst.Call(func=cst.Name("super"), args=[cst.Arg(value=cst.Name(func_def.name.value))]) |
|
) |
|
] |
|
) |
|
return updated_node.with_changes(value=updated_return_value) |
|
return updated_node |
|
|
|
|
|
def find_all_dependencies( |
|
dependency_mapping: Dict[str, set], |
|
start_entity: Optional[str] = None, |
|
initial_dependencies: Optional[set] = None, |
|
initial_checked_dependencies: Optional[set] = None, |
|
return_parent: bool = False, |
|
) -> Union[list, set]: |
|
"""Return all the dependencies of the given `start_entity` or `initial_dependencies`. This is basically some kind of |
|
BFS traversal algorithm. It can either start from `start_entity`, or `initial_dependencies`. |
|
|
|
Args: |
|
dependency_mapping (`Dict[str, set]`): |
|
A mapping from entities (usually function/assignment names), to immediate dependencies. That is, for function names, |
|
a mapping {"foo": {"bar", "test"}} would indicate that functions `bar` and `test` are immediately called |
|
in `foo`'s definition. |
|
start_entity (str | None, *optional*): |
|
A key of `dependency_mapping`, indicating from which entity to start the search. |
|
initial_dependencies (set | None, *optional*): |
|
If `start_entity` is not provided, this can be used as an alternative. In this case, the search will continue |
|
from all the entities in `initial_dependencies`, if they are in `dependency_mapping`. |
|
initial_checked_dependencies (set | None, *optional*): |
|
If provided, entities already present in `initial_checked_dependencies` will not be part of the returned dependencies. |
|
return_parent (bool, *optional*): |
|
If `True`, will return a list consisting of tuples (dependency, parent) instead of a simple set of dependencies. Note |
|
that the order of the items in the list reflects the traversal order. Thus, no parent can ever appear before childs. |
|
Returns: |
|
A set of all the dependencies, or a list of tuples `(dependency, parent)` if `return_parent=True`. |
|
|
|
Example: |
|
Given the following structure in the `modular_xxx.py` file: |
|
``` |
|
def foo1(): |
|
pass |
|
|
|
def foo2(): |
|
pass |
|
|
|
def bar(): |
|
foo1() |
|
|
|
def foobar(): |
|
bar() |
|
foo2() |
|
|
|
class MyLayer(SomeOtherModelLayer): |
|
def forward(...): |
|
foobar() |
|
``` |
|
and the `dependency_mapping` created when visiting the `modular_xxx.py` file, we get: |
|
``` |
|
dependency_mapping = {'bar': {'foo1'}, 'foobar': {'bar', 'foo2'}} |
|
find_all_dependencies(dependency_mapping, start_entity='foobar', return_parent=True) |
|
>>> [('bar', 'foobar'), ('foo2', 'foobar'), ('foo1', 'bar')] |
|
``` |
|
That is, all the functions needed (and potentially their immediate parent) so that the function to be added |
|
in MyLayer (`foobar`) can work correctly. |
|
""" |
|
if initial_dependencies is None and start_entity is not None: |
|
initial_dependencies = dependency_mapping[start_entity] |
|
if initial_checked_dependencies is None: |
|
initial_checked_dependencies = set() |
|
|
|
dependency_queue = deque(initial_dependencies) |
|
all_dependencies = set() |
|
all_dependencies_with_parent = [] |
|
checked_dependencies = set(initial_checked_dependencies) |
|
parents = dict.fromkeys(initial_dependencies, start_entity) |
|
while len(dependency_queue) > 0: |
|
|
|
current = dependency_queue.popleft() |
|
if current not in checked_dependencies: |
|
|
|
all_dependencies.add(current) |
|
all_dependencies_with_parent += [(current, parents[current])] |
|
if current in dependency_mapping.keys(): |
|
|
|
dependency_queue.extend(dependency_mapping[current]) |
|
parents.update(dict.fromkeys(dependency_mapping[current], current)) |
|
|
|
checked_dependencies.add(current) |
|
|
|
if not return_parent: |
|
return all_dependencies |
|
|
|
return all_dependencies_with_parent |
|
|
|
|
|
|
|
ASSIGNMENTS_REGEX_TO_KEEP = [r"_CHECKPOINT", r"_EXPECTED", r"_FOR_DOC", r"_HIDDEN_STATES_START_POSITION"] |
|
|
|
|
|
ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE = [r"_DOCSTRING"] |
|
|
|
|
|
class ClassDependencyMapper(CSTVisitor): |
|
"""A visitor which is designed to analyze a single class node to get all its dependencies that are shared with the set of |
|
`global_names`. |
|
""" |
|
|
|
def __init__( |
|
self, class_name: str, global_names: set[str], objects_imported_from_modeling: Optional[set[str]] = None |
|
): |
|
super().__init__() |
|
self.class_name = class_name |
|
self.dependencies = set() |
|
self.global_names = global_names |
|
self.objects_imported_from_modeling = ( |
|
set() if objects_imported_from_modeling is None else objects_imported_from_modeling |
|
) |
|
|
|
def visit_Name(self, node): |
|
if ( |
|
node.value != self.class_name |
|
and node.value in self.global_names |
|
and node.value not in self.objects_imported_from_modeling |
|
): |
|
self.dependencies.add(node.value) |
|
|
|
|
|
def dependencies_for_class_node(node: cst.ClassDef, global_names: set[str]) -> set: |
|
"""Create immediate dependencies for a class node based on the `global_names`.""" |
|
temp_module = cst.Module(body=[node]) |
|
visitor = ClassDependencyMapper(node.name.value, global_names) |
|
temp_module.visit(visitor) |
|
return visitor.dependencies |
|
|
|
|
|
def augmented_dependencies_for_class_node( |
|
node: cst.ClassDef, mapper: "ModuleMapper", objects_imported_from_modeling: Optional[set[str]] = None |
|
) -> set: |
|
"""Create augmented dependencies for a class node based on a `mapper`. |
|
Augmented dependencies means immediate dependencies + recursive function and assignments dependencies. |
|
""" |
|
temp_module = cst.Module(body=[node]) |
|
visitor = ClassDependencyMapper(node.name.value, set(mapper.global_nodes.keys()), objects_imported_from_modeling) |
|
temp_module.visit(visitor) |
|
return mapper.augment_dependencies(visitor.dependencies) |
|
|
|
|
|
|
|
ALL_FILE_TYPES = ( |
|
"modeling", |
|
"configuration", |
|
"tokenization", |
|
"processing", |
|
"image_processing", |
|
"video_processing", |
|
"feature_extractor", |
|
) |
|
|
|
|
|
class ModuleMapper(CSTVisitor, ABC): |
|
"""An abstract visitor class which analyses a module, creating a mapping of dependencies for classes, functions and assignments. |
|
Class dependencies are computed with `compute_class_dependencies()`, while function and assignment dependencies are stored in |
|
`self.object_recursive_dependency_mapping` (can be computed by `_compute_recursive_object_dependencies()`). |
|
It defines common visiting patterns (i.e. common visit_xxx/leave_xxx functions) between the modular file and the |
|
modeling files that will be visited. |
|
""" |
|
|
|
METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider) |
|
|
|
def __init__(self, python_module: cst.Module): |
|
|
|
self.python_module: cst.Module = python_module |
|
self.classes: Dict[str, cst.ClassDef] = {} |
|
self.imports = [] |
|
self.functions: Dict[str, cst.FunctionDef] = {} |
|
self.object_dependency_mapping = defaultdict(set) |
|
self.assignments: Dict[str, cst.SimpleStatementLine] = {} |
|
self.current_function = None |
|
self.current_class = None |
|
self.current_assignment = None |
|
|
|
self.objects_imported_from_modeling = set() |
|
|
|
self.match_patterns = "|".join(ALL_FILE_TYPES) |
|
|
|
|
|
def visit_ImportFrom(self, node): |
|
"""This keeps track of objects imported from neighbor modeling files (e.g. in `modeling_xxx.py, we have |
|
`from .configuration_xxx import Config`, then `Config` should be recorded as it is not a dependency that needs |
|
to be added (because it will be part of the imports)""" |
|
import_module = self.python_module.code_for_node(node.module) |
|
import_statement = "." * len(node.relative) + import_module |
|
if re.search(rf"^\.({self.match_patterns})_.*", import_statement): |
|
for imported_object in node.names: |
|
|
|
if imported_object.evaluated_alias is not None: |
|
self.objects_imported_from_modeling.add(imported_object.evaluated_alias) |
|
else: |
|
self.objects_imported_from_modeling.add(imported_object.evaluated_name) |
|
|
|
def visit_SimpleStatementLine(self, node): |
|
""" |
|
Global Assigns like `GEMMA_INPUT_DOCSTRING = 'THIS IS THE INPUT'` and all import statements |
|
are extracted and saved in their corresponding dict. They are then used when updating dependency mappings. |
|
""" |
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) |
|
simple_top_level_assign_structure = m.SimpleStatementLine( |
|
body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] |
|
) |
|
if m.matches(parent_node, m.Module()): |
|
if m.matches(node, simple_top_level_assign_structure): |
|
left_hand_side = node.body[0].targets[0].target.value |
|
self.current_assignment = left_hand_side |
|
self.assignments[left_hand_side] = node |
|
elif m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): |
|
self.imports.append(node) |
|
|
|
def leave_SimpleStatementLine(self, node): |
|
|
|
|
|
self.current_assignment = None |
|
|
|
def visit_FunctionDef(self, node): |
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) |
|
if m.matches(parent_node, m.Module()): |
|
self.current_function = node.name.value |
|
self.functions[node.name.value] = node |
|
|
|
def leave_FunctionDef(self, node): |
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) |
|
if m.matches(parent_node, m.Module()): |
|
self.current_function = None |
|
|
|
def visit_If(self, node): |
|
|
|
if self.current_function is None and self.current_class is None: |
|
for stmt in node.body.body: |
|
if m.matches(stmt, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])): |
|
self.imports.append(node) |
|
|
|
def visit_ClassDef(self, node: ClassDef) -> None: |
|
"""Record class nodes to create their dependencies at the end.""" |
|
self.classes[node.name.value] = node |
|
self.current_class = node.name.value |
|
|
|
def leave_ClassDef(self, node): |
|
self.current_class = None |
|
|
|
def visit_Name(self, node: cst.Call): |
|
"""This is used to create a mapping from module-scope functions and assignments to objects used inside them.""" |
|
if self.current_function is not None: |
|
self.object_dependency_mapping[self.current_function].add(node.value) |
|
if self.current_assignment is not None: |
|
self.object_dependency_mapping[self.current_assignment].add(node.value) |
|
|
|
def leave_Module(self, node): |
|
"""When leaving the module, we store the position of each global scoped node to allow sorting the dependencies |
|
based on their position in the code later. We use the PositionProvider metadata wrapper for this. |
|
We also make sure to update `self.object_dependency_mapping` so that it contains only names recorded in |
|
`self.global_nodes`. |
|
""" |
|
|
|
self.global_nodes = {**self.assignments, **self.classes, **self.functions} |
|
|
|
self.start_lines = {} |
|
for id, node in self.global_nodes.items(): |
|
self.start_lines[id] = self.get_metadata(cst.metadata.PositionProvider, node).start.line |
|
|
|
def _restrict_dependencies_to_known_entities(self): |
|
"""Since we added every Name as part of `self.object_dependency_mapping`, we need to remove those that |
|
are not part of the recorded objects in `self.global_nodes` (i.e. built-in variables, imports, etc). |
|
This should be called only after all merging operations have been finalized!!""" |
|
global_objects = set(self.global_nodes.keys()) |
|
for object_name, dependencies in self.object_dependency_mapping.items(): |
|
self.object_dependency_mapping[object_name] = {dep for dep in dependencies if dep in global_objects} |
|
|
|
def _compute_recursive_object_dependencies(self) -> dict[str, set]: |
|
"""Based on immediate dependency mapping, create the recursive dependency mapping. For example, given the |
|
following file: |
|
``` |
|
def foo(): |
|
pass |
|
|
|
def bar(): |
|
foo() |
|
|
|
def test(): |
|
bar() |
|
``` |
|
this visitor can only record immediate dependencies, i.e. it will record the following |
|
`self.object_dependency_mapping = {"test": {"bar"}, "bar": {"foo}}`. This function is used to create |
|
the recursive mapping, i.e. `recursive_dependencies = {"test": {"bar", "foo"}, "bar": {"foo}}`. |
|
""" |
|
recursive_dependencies = {} |
|
for object_name in self.object_dependency_mapping.keys(): |
|
all_dependencies = find_all_dependencies(self.object_dependency_mapping, start_entity=object_name) |
|
recursive_dependencies[object_name] = all_dependencies |
|
return recursive_dependencies |
|
|
|
def augment_dependencies(self, dependencies: set[str]) -> set[str]: |
|
"""For a set of `dependencies`, augment them by adding all potential dependencies of the **functions** and |
|
**assignments** present in the `dependencies`. |
|
""" |
|
new_dependencies = dependencies.copy() |
|
|
|
for dep in tuple(dependencies): |
|
if dep in self.object_recursive_dependency_mapping.keys(): |
|
new_dependencies.update(self.object_recursive_dependency_mapping[dep]) |
|
return new_dependencies |
|
|
|
def compute_class_dependencies(self): |
|
"""For each visited class, find its dependencies based on visiting the current file + potential merged dependencies.""" |
|
self.class_dependency_mapping = {} |
|
for class_name, class_node in self.classes.items(): |
|
dependencies = dependencies_for_class_node(class_node, set(self.global_nodes.keys())) |
|
|
|
self.class_dependency_mapping[class_name] = self.augment_dependencies(dependencies) |
|
|
|
@abstractmethod |
|
def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: |
|
raise NotImplementedError |
|
|
|
|
|
class ModelFileMapper(ModuleMapper): |
|
"""A mapper designed to parse modeling files (like `modeling_llama.py`). When encountering such a file |
|
in the `modular_xxx.py` file, we need to correctly visit it and merge the dependencies of the modular and current file. |
|
For this reason, this class should only be instantiated from the class method `visit_and_merge_dependencies`, which takes |
|
care of correctly merging dependencies, then finalizes all dependency graph computations. |
|
Note that we only merge functions and assignments here, as classes will be treated later on as they may be modified. |
|
For example, if you redefine `apply_rotary_pos_emb()` in the modular, the new node should be used in the dependencies |
|
of the modeling files as well. |
|
""" |
|
|
|
def __init__(self, python_module: cst.Module): |
|
super().__init__(python_module) |
|
|
|
def compute_relative_order(self, missing_dependencies: set[str]) -> dict[str, int]: |
|
"""Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that |
|
will be created based on the modular. |
|
""" |
|
relative_order = {} |
|
idx = 0 |
|
classes = sorted( |
|
[dep for dep in tuple(missing_dependencies) if dep in self.classes], key=lambda x: self.start_lines[x] |
|
) |
|
|
|
|
|
if len(classes) > 0 and not hasattr(self, "class_dependency_mapping"): |
|
raise ValueError("Cannot correctly find the relative order of the dependencies.") |
|
|
|
remaining_dependencies = missing_dependencies.copy() |
|
|
|
|
|
for class_name in classes: |
|
class_dependencies = tuple(self.class_dependency_mapping[class_name] & remaining_dependencies) |
|
original_dependencies = [] |
|
merged_dependencies = [] |
|
|
|
|
|
for class_dep in class_dependencies: |
|
if class_dep in self.start_lines: |
|
original_dependencies.append(class_dep) |
|
else: |
|
merged_dependencies.append(class_dep) |
|
|
|
|
|
original_dependencies = sorted(original_dependencies, reverse=True) |
|
|
|
original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines.get(x, 1e10)) |
|
merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) |
|
|
|
|
|
for dep in original_dependencies + merged_dependencies: |
|
remaining_dependencies.remove(dep) |
|
relative_order[dep] = idx |
|
idx += 1 |
|
|
|
|
|
if class_name in remaining_dependencies: |
|
remaining_dependencies.remove(class_name) |
|
relative_order[class_name] = idx |
|
idx += 1 |
|
|
|
|
|
remaining_dependencies = tuple(remaining_dependencies) |
|
original_dependencies = [] |
|
merged_dependencies = [] |
|
for dep in remaining_dependencies: |
|
if dep in self.modular_file_start_lines: |
|
merged_dependencies.append(dep) |
|
else: |
|
original_dependencies.append(dep) |
|
|
|
|
|
original_dependencies = sorted(original_dependencies, reverse=True) |
|
|
|
original_dependencies = sorted(original_dependencies, key=lambda x: self.start_lines.get(x, 1e10)) |
|
merged_dependencies = sorted(merged_dependencies, key=lambda x: self.modular_file_start_lines[x]) |
|
|
|
|
|
for dep in original_dependencies + merged_dependencies: |
|
relative_order[dep] = idx |
|
idx += 1 |
|
|
|
return relative_order |
|
|
|
def _merge_functions(self, functions: dict[str, cst.CSTNode], object_mapping: dict[str, set]): |
|
"""Update the global nodes and function dependency mapping with those from the modular file. |
|
|
|
Merging rule: if any function with the same name was redefined in the modular, use it and its dependencies |
|
instead of the original ones (this may mean to add new functions as well, if any redefined function uses a new one). |
|
""" |
|
|
|
self.functions.update(functions) |
|
self.object_dependency_mapping.update( |
|
{obj: dep for obj, dep in object_mapping.items() if obj in functions.keys()} |
|
) |
|
|
|
self.global_nodes.update(self.functions) |
|
|
|
def _merge_assignments(self, assignments: dict[str, cst.CSTNode], object_mapping: dict[str, set]): |
|
"""Update the global nodes with the assignment from the modular file. |
|
|
|
Merging rule: if any assignment with the same name was redefined in the modular, we use it and its dependencies ONLY if it matches |
|
a pattern in `ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE` and its value is not None, or if it matches a pattern in `ASSIGNMENTS_REGEX_TO_KEEP. |
|
Otherwise, we use the original value and dependencies. This rule was chosen to avoid having to rewrite the big docstrings. |
|
""" |
|
for assignment, node in assignments.items(): |
|
should_keep = any(re.search(pattern, assignment) for pattern in ASSIGNMENTS_REGEX_TO_KEEP) |
|
|
|
should_keep_if_not_none = any( |
|
re.search(pattern, assignment) for pattern in ASSIGNMENTS_REGEX_TO_KEEP_IF_NOT_NONE |
|
) and not (hasattr(node.body[0].value, "value") and node.body[0].value.value == "None") |
|
|
|
if should_keep or should_keep_if_not_none or assignment not in self.assignments: |
|
self.assignments[assignment] = node |
|
if assignment in object_mapping: |
|
self.object_dependency_mapping[assignment] = object_mapping[assignment] |
|
|
|
self.global_nodes.update(self.assignments) |
|
|
|
def _merge_classes(self, classes: dict[str, cst.CSTNode]): |
|
"""Update the global nodes with the new classes from the modular (i.e. classes which do not exist in current file, and |
|
are not imported). We do NOT update any dependency mapping here. This is because we only need the names of newly defined |
|
classes in the modular to be discoverable when computing dependencies for new nodes later on. For this reason, we |
|
do not add the new classes to `self.classes`, but only to `global_nodes`. |
|
""" |
|
|
|
self.global_nodes.update( |
|
{ |
|
name: node |
|
for name, node in classes.items() |
|
if name not in self.classes and name not in self.objects_imported_from_modeling |
|
} |
|
) |
|
|
|
def merge_modular_dependencies(self, classes, functions, assignments, object_mapping, start_lines): |
|
"""Merge classes, functions and assignments from the modular definitions into the current module file, |
|
then record the relative order of all nodes. |
|
Note: This function takes care of updating `global_nodes` and `object_recursive_dependency_mapping` as well after the |
|
merge with other files dependencies. |
|
""" |
|
self._merge_functions(functions, object_mapping) |
|
self._merge_assignments(assignments, object_mapping) |
|
self._merge_classes(classes) |
|
self.modular_file_start_lines = start_lines |
|
|
|
|
|
self._restrict_dependencies_to_known_entities() |
|
|
|
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() |
|
|
|
@classmethod |
|
def visit_and_merge_dependencies( |
|
cls, module: cst.Module, classes, functions, assignments, object_mapping, start_lines |
|
) -> "ModelFileMapper": |
|
wrapper = MetadataWrapper(module) |
|
mapper = cls(module) |
|
wrapper.visit(mapper) |
|
|
|
mapper.merge_modular_dependencies(classes, functions, assignments, object_mapping, start_lines) |
|
|
|
mapper.compute_class_dependencies() |
|
return mapper |
|
|
|
|
|
def common_partial_suffix(str1: str, str2: str) -> str: |
|
"""Return the biggest common suffix between 2 strings. If one string is a full suffix of the other string, |
|
we do not consider it a common suffix and return `""`""" |
|
common_suffix = "" |
|
for i in range(1, min(len(str1), len(str2)) + 1): |
|
if str1[-i] == str2[-i]: |
|
common_suffix = str1[-i] + common_suffix |
|
else: |
|
break |
|
|
|
if common_suffix == str1 or common_suffix == str2: |
|
common_suffix = "" |
|
return common_suffix |
|
|
|
|
|
def replace_class_node( |
|
mapper: ModelFileMapper, class_node: cst.ClassDef, renamed_super_class: str, original_super_class: str |
|
): |
|
""" |
|
Replace a class node which inherits from another modeling class. This function works in the following way: |
|
- start from the base class node of the inherited class (a cst.Node) |
|
- replace all methods of the base node with the methods defined in the child class |
|
- append all new methods defined in the child class |
|
- replace all calls to super() with the unravelled code |
|
|
|
| ```python | | ```python |
|
| class GemmaModel(LlamaModel): | | class GemmaModel(nn.Module): |
|
| def __init__(self): | | def __init__(self): |
|
Going from: | super().__init__() | to: | super().__init__(config) |
|
| self.dropout = 0.2 | | self.dropout = 0.2 |
|
| ``` | | self.padding_idx = config.pad_token_id |
|
| self.vocab_size = config.vocab_size |
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
| self.layers = nn.ModuleList( |
|
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
| ) |
|
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
| self.gradient_checkpointing = False |
|
| # Initialize weights and apply final processing |
|
| self.post_init() |
|
| ``` |
|
""" |
|
all_bases = [get_full_attribute_name(k.value) for k in class_node.bases] |
|
if any(base is None for base in all_bases): |
|
raise ValueError(f"Could not parse the name of the bases for {class_node.name.value}") |
|
|
|
original_node = mapper.classes[renamed_super_class] |
|
|
|
new_name = class_node.name |
|
|
|
|
|
if new_name.value != renamed_super_class: |
|
common_suffix = common_partial_suffix(new_name.value, renamed_super_class) |
|
|
|
old, new = renamed_super_class.replace(common_suffix, ""), new_name.value.replace(common_suffix, "") |
|
temp_module = cst.Module(body=[original_node]) |
|
original_node = temp_module.visit( |
|
ReplaceNameTransformer(get_lowercase_name(old), get_lowercase_name(new), only_doc=True) |
|
).body[0] |
|
|
|
|
|
|
|
additional_bases = [base for base in all_bases if base != original_super_class] |
|
new_bases = [] |
|
for original_base in original_node.bases: |
|
new_base = original_base |
|
|
|
if m.matches(original_base.value, m.Name()): |
|
original_base_name = original_base.value.value |
|
for additional_base_name in additional_bases: |
|
suffix = common_partial_suffix(original_base_name, additional_base_name) |
|
if len(suffix) > 0 and suffix[0].isupper(): |
|
new_name_node = original_base.value.with_changes(value=additional_base_name) |
|
new_base = original_base.with_changes(value=new_name_node) |
|
break |
|
new_bases.append(new_base) |
|
|
|
original_methods = { |
|
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f |
|
for f in original_node.body.body |
|
} |
|
updated_methods = { |
|
f.name.value if hasattr(f, "name") else mapper.python_module.code_for_node(f): f for f in class_node.body.body |
|
} |
|
end_meth = [] |
|
|
|
assign_targets = {} |
|
docstring_node = [] |
|
|
|
for func in original_node.body.body: |
|
name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func) |
|
if m.matches(func, m.FunctionDef()) and name in updated_methods and updated_methods[name] is not None: |
|
new_params = updated_methods[name].params |
|
|
|
kwarg_name = getattr(updated_methods[name].params, "star_kwarg", None) |
|
if kwarg_name and kwarg_name.name.value == "super_kwargs": |
|
parent_params = {k.name.value: k for k in func.params.params} |
|
parent_params.update({k.name.value: k for k in new_params.params[1:]}) |
|
new_params = new_params.with_changes( |
|
params=list(parent_params.values()), star_kwarg=func.params.star_kwarg |
|
) |
|
|
|
new_decorators = ( |
|
updated_methods[name].decorators if len(updated_methods[name].decorators) > 0 else func.decorators |
|
) |
|
|
|
|
|
new_return_annotation = updated_methods[name].returns if updated_methods[name].returns else func.returns |
|
|
|
if not re.match( |
|
r"\ndef .*\(.*\):\n raise.*Error\(.*", |
|
mapper.python_module.code_for_node(updated_methods[name]), |
|
): |
|
func = func.with_changes( |
|
body=updated_methods[name].body, |
|
params=new_params, |
|
decorators=new_decorators, |
|
returns=new_return_annotation, |
|
) |
|
else: |
|
continue |
|
|
|
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])): |
|
target = mapper.python_module.code_for_node(func.body[0].targets[0]) |
|
assign_targets[target] = func |
|
elif m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])): |
|
target = mapper.python_module.code_for_node(func.body[0].target) |
|
assign_targets[target] = func |
|
elif m.matches(func, DOCSTRING_NODE): |
|
docstring_node = [func] |
|
else: |
|
end_meth.append(func) |
|
|
|
|
|
for func in class_node.body.body: |
|
name = func.name.value if hasattr(func, "name") else mapper.python_module.code_for_node(func) |
|
if m.matches(func, DOCSTRING_NODE): |
|
|
|
updated_docstring = func.body[0].value.value |
|
if len(docstring_node) == 0: |
|
docstring_node = [ |
|
cst.SimpleStatementLine(body=[cst.Expr(value=cst.SimpleString(value=updated_docstring))]) |
|
] |
|
else: |
|
original_docstring = docstring_node[0].body[0].value.value |
|
merged_doc = merge_docstrings(original_docstring, updated_docstring) |
|
|
|
docstring_node = [ |
|
docstring_node[0].with_changes(body=[cst.Expr(value=cst.SimpleString(value=merged_doc))]) |
|
] |
|
if name not in original_methods and func is not None and isinstance(func, cst.FunctionDef): |
|
end_meth.append(func) |
|
if m.matches(func, m.SimpleStatementLine(body=[m.Assign()])): |
|
|
|
target = mapper.python_module.code_for_node(func.body[0].targets[0]) |
|
assign_targets[target] = func |
|
if m.matches(func, m.SimpleStatementLine(body=[m.AnnAssign()])): |
|
target = mapper.python_module.code_for_node(func.body[0].target) |
|
assign_targets[target] = func |
|
end_meth = docstring_node + list(assign_targets.values()) + end_meth |
|
|
|
|
|
result_node = original_node.with_changes(body=cst.IndentedBlock(body=end_meth)) |
|
temp_module = cst.Module(body=[result_node]) |
|
new_module = MetadataWrapper(temp_module) |
|
new_replacement_class = new_module.visit( |
|
SuperTransformer(temp_module, original_methods, updated_methods, all_bases) |
|
) |
|
new_replacement_body = new_replacement_class.body[0].body |
|
|
|
|
|
new_decorators = class_node.decorators if len(class_node.decorators) > 0 else original_node.decorators |
|
|
|
return original_node.with_changes( |
|
body=new_replacement_body, decorators=new_decorators, bases=new_bases, name=new_name |
|
) |
|
|
|
|
|
TYPE_TO_FILE_TYPE = { |
|
"Config": "configuration", |
|
"Tokenizer": "tokenization", |
|
"Processor": "processing", |
|
"ImageProcessor": "image_processing", |
|
"ImageProcessorFast": "image_processing*_fast", |
|
"VideoProcessor": "video_processing", |
|
"VideoProcessorInitKwargs": "video_processing", |
|
"FastImageProcessorKwargs": "image_processing*_fast", |
|
"FeatureExtractor": "feature_extractor", |
|
"ProcessorKwargs": "processing", |
|
"VideosKwargs": "processing", |
|
"ImagesKwargs": "processing", |
|
"TextKwargs": "processing", |
|
} |
|
|
|
|
|
def find_file_type(class_name: str) -> str: |
|
"""Based on a class name, find the file type corresponding to the class. |
|
If the class name is `LlamaConfig` it will return `configuration`. |
|
The list of suffixes is in `TYPE_TO_FILE_TYPE`. If there are no match, we match by default to `modeling` |
|
""" |
|
match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys()) |
|
match = re.search(rf"({match_pattern})$", class_name) |
|
if match: |
|
file_type = TYPE_TO_FILE_TYPE[match.group(1)] |
|
else: |
|
file_type = "modeling" |
|
return file_type |
|
|
|
|
|
|
|
|
|
VARIABLES_AT_THE_BEGINNING = ( |
|
"logger", |
|
"_CHECKPOINT_FOR_DOC", |
|
"_CONFIG_FOR_DOC", |
|
) |
|
|
|
|
|
IMPORTS_TO_SKIP_IN_MODULAR = ("auto.modeling_auto",) |
|
|
|
|
|
def append_new_import_node( |
|
node: cst.CSTNode, unused_imports: set[str], added_names: set, imports_to_keep: list[cst.CSTNode] |
|
): |
|
"""Insert the new `node` to the list of `imports_to_keep` in-place, if it is not part of the `unused_imports` or `added_names`. |
|
Also modifies `added_names` in-place accordingly.""" |
|
import_node = node.body[0] |
|
names_to_keep = [] |
|
for name in import_node.names: |
|
name_value = name.evaluated_alias or name.evaluated_name |
|
if name_value not in unused_imports and name_value not in added_names: |
|
names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT)) |
|
added_names.add(name_value) |
|
if len(names_to_keep) > 0: |
|
new_node = node.with_changes(body=[import_node.with_changes(names=names_to_keep)]) |
|
imports_to_keep.append(new_node) |
|
|
|
|
|
def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) -> list[cst.CSTNode]: |
|
"""Get all the imports needed in the `body`, from the list of `all_imports`. |
|
`body` is a dict with the following structure `{str: {"insert_idx": int, "node": cst.CSTNode}}`. |
|
Note: we need to use `isinstance` on scope assignments, m.matches apparently does not work here yet! |
|
""" |
|
new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] |
|
wrapper = MetadataWrapper(cst.Module(body=all_imports + new_body)) |
|
scopes = set(wrapper.resolve(ScopeProvider).values()) |
|
unused_imports = set() |
|
import_ref_count = defaultdict(lambda: 0) |
|
for scope in scopes: |
|
for assignment in scope.assignments: |
|
node = assignment.node |
|
if isinstance(assignment, cst.metadata.Assignment) and isinstance(node, (cst.Import, cst.ImportFrom)): |
|
ref_count = len(assignment.references) |
|
name = assignment.name |
|
import_ref_count[name] = max(ref_count, import_ref_count[name]) |
|
|
|
|
|
unused_imports = {name for name, count in import_ref_count.items() if count <= 0 or name in body.keys()} |
|
|
|
imports_to_keep = [] |
|
|
|
|
|
added_names = set() |
|
existing_protected_statements = set() |
|
for node in all_imports: |
|
if m.matches(node, m.If()): |
|
new_statements = [] |
|
for stmt_node in node.body.body: |
|
append_new_import_node(stmt_node, unused_imports, added_names, new_statements) |
|
new_statements = [stmt for stmt in new_statements if str(stmt) not in existing_protected_statements] |
|
if len(new_statements) > 0: |
|
new_node = node.with_changes(body=node.body.with_changes(body=new_statements)) |
|
imports_to_keep.append(new_node) |
|
existing_protected_statements.update({str(stmt) for stmt in new_statements}) |
|
else: |
|
append_new_import_node(node, unused_imports, added_names, imports_to_keep) |
|
|
|
protected_import_nodes = [node for node in imports_to_keep if m.matches(node, m.If())] |
|
usual_import_nodes = [node for node in imports_to_keep if not m.matches(node, m.If())] |
|
|
|
|
|
return usual_import_nodes + protected_import_nodes |
|
|
|
|
|
def split_all_assignment(node: cst.CSTNode) -> dict[str, cst.CSTNode]: |
|
"""Split the `__all__` assignment found in the modular between each corresponding files.""" |
|
all_all_per_file = {} |
|
assign_node = node.body[0] |
|
if isinstance(assign_node.value, cst.List): |
|
|
|
all_all_to_add = defaultdict(list) |
|
for element in assign_node.value.elements: |
|
if isinstance(element.value, cst.SimpleString): |
|
|
|
class_name = element.value.value |
|
file = find_file_type(element.value.evaluated_value) |
|
all_all_to_add[file] += [class_name] |
|
for file, new_alls in all_all_to_add.items(): |
|
new_node = assign_node.with_changes( |
|
value=cst.List(elements=[cst.Element(value=cst.SimpleString(value=k)) for k in new_alls]) |
|
) |
|
all_all_per_file[file] = node.with_changes(body=[new_node]) |
|
return all_all_per_file |
|
|
|
|
|
class ModularFileMapper(ModuleMapper): |
|
"""This is a Mapper to visit a modular file (like `modular_llama.py`). It visits the whole file, recording dependency, |
|
then visits all imported modeling files (like `modeling_llama.py`), and manages their mutual dependencies. |
|
Calling the method `create_modules()` after visit will create all modules based on this modular file. |
|
""" |
|
|
|
def __init__(self, python_module, new_name): |
|
super().__init__(python_module) |
|
|
|
self.model_name = new_name |
|
|
|
self.model_specific_imported_objects: Dict[str, str] = {} |
|
self.model_specific_modules: Dict[str, cst.Module] = {} |
|
|
|
self.all_all_to_add = {} |
|
|
|
|
|
def visit_ImportFrom(self, node: cst.ImportFrom) -> None: |
|
"""When visiting imports from modeling files (i.e. `transformers.models.xxx`) we get the code, parse it, |
|
and save it in `self.model_specific_modules` to later visit. The imported objects are saved in `self.model_specific_imported_objects`. |
|
""" |
|
import_module = self.python_module.code_for_node(node.module) |
|
import_statement = "." * len(node.relative) + import_module |
|
if any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR): |
|
return |
|
if m.matches(node.module, m.Attribute()): |
|
for imported_ in node.names: |
|
_import = re.search( |
|
rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement |
|
) |
|
if _import: |
|
source = _import.group(1) |
|
if source == "modeling" and "Config" in self.python_module.code_for_node(imported_): |
|
raise ValueError( |
|
f"You are importing {self.python_module.code_for_node(imported_)} from the modeling file. Import from the `configuration_xxxx.py` file instead" |
|
) |
|
if import_module not in self.model_specific_modules: |
|
if "models" not in import_module: |
|
import_module = "models." + import_module |
|
if "transformers" not in import_module: |
|
import_module = "transformers." + import_module |
|
source_code = get_module_source_from_name(import_module) |
|
tree = cst.parse_module(source_code) |
|
self.model_specific_modules[import_module] = tree |
|
imported_object = self.python_module.code_for_node(imported_.name) |
|
self.model_specific_imported_objects[imported_object] = import_module |
|
if m.matches(node.module, m.Name()): |
|
if "transformers" == import_module: |
|
raise ValueError( |
|
f"You are importing from {import_module} directly using global imports. Import from the correct local path" |
|
) |
|
|
|
def visit_SimpleStatementLine(self, node): |
|
"""If we visit an import statement not previously visited, record it. If we visit a module-scope assignment, |
|
simply record it or, if it is `__all__`, split it between files where we should dispatch it. |
|
""" |
|
parent_node = self.get_metadata(cst.metadata.ParentNodeProvider, node) |
|
simple_top_level_assign_structure = m.SimpleStatementLine( |
|
body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])] |
|
) |
|
if m.matches(parent_node, m.Module()): |
|
if m.matches(node, m.SimpleStatementLine(body=[m.Import()])): |
|
self.imports.append(node) |
|
elif m.matches(node, m.SimpleStatementLine(body=[m.ImportFrom()])): |
|
import_module = self.python_module.code_for_node(node.body[0].module) |
|
import_statement = "." * len(node.body[0].relative) + import_module |
|
if not ( |
|
re.search(rf"(?:transformers\.models\.)|(?:\.\.)\w+\.({self.match_patterns})_.*", import_statement) |
|
and not any(import_to_skip in import_statement for import_to_skip in IMPORTS_TO_SKIP_IN_MODULAR) |
|
): |
|
self.imports.append(node) |
|
elif m.matches(node, simple_top_level_assign_structure): |
|
assigned_variable = node.body[0].targets[0].target.value |
|
|
|
if assigned_variable == "__all__": |
|
self.all_all_to_add = split_all_assignment(node) |
|
else: |
|
self.current_assignment = assigned_variable |
|
self.assignments[assigned_variable] = node |
|
|
|
def leave_Module(self, node): |
|
"""When we leave the modular file, we do the following in order: |
|
1. for each modeling file found in the imports, rename it with the new model name, visit it, and update |
|
its dependency graph with the new function and assignment definitions found in the modular |
|
2. update the modular dependency graph with the imported functions and assignments (found when visiting the matching files) |
|
3. compute the nested (recursive) function and assignment dependencies |
|
""" |
|
|
|
super().leave_Module(node) |
|
|
|
|
|
self.visited_modules = {} |
|
self.renamers = {} |
|
name_prefixes = self.infer_new_model_name() |
|
for file, module in self.model_specific_modules.items(): |
|
file_model_name = file.split(".")[-2] |
|
new_name = name_prefixes[file] |
|
renamer = ReplaceNameTransformer(file_model_name, new_name, self.model_name) |
|
renamed_module = module.visit(renamer) |
|
self.visited_modules[file] = ModelFileMapper.visit_and_merge_dependencies( |
|
renamed_module, |
|
self.classes, |
|
self.functions, |
|
self.assignments, |
|
self.object_dependency_mapping, |
|
self.start_lines, |
|
) |
|
|
|
self.renamers[file] = renamer |
|
|
|
|
|
|
|
self.merge_model_specific_imports(self.visited_modules) |
|
|
|
|
|
self.object_recursive_dependency_mapping = self._compute_recursive_object_dependencies() |
|
|
|
|
|
|
|
self.imported_objects_per_file = defaultdict(set) |
|
for file, mapper in self.visited_modules.items(): |
|
file_type = re.search(rf"^transformers\.models\.\w+\.({self.match_patterns})_.*", file).group(1) |
|
self.imported_objects_per_file[file_type].update(mapper.objects_imported_from_modeling) |
|
|
|
def merge_model_specific_imports(self, visited_modules): |
|
"""Merge the functions and assignments imported from the modeling files to the modular nodes and dependency graph, |
|
based on the visited files.""" |
|
self.start_lines_file_mapping = {} |
|
self.added_objects_file_mapping = {} |
|
for object_name, file in self.model_specific_imported_objects.items(): |
|
visited_module = visited_modules[file] |
|
self.start_lines_file_mapping[file] = visited_module.start_lines |
|
|
|
if object_name in visited_module.functions and object_name not in self.functions: |
|
self.functions[object_name] = visited_module.functions[object_name] |
|
self.added_objects_file_mapping[object_name] = file |
|
dependencies = visited_module.object_dependency_mapping.get(object_name, None) |
|
if dependencies is not None: |
|
self.object_dependency_mapping[object_name] = dependencies |
|
for dep in dependencies: |
|
if dep not in self.global_nodes: |
|
self.added_objects_file_mapping[dep] = file |
|
self.functions[dep] = visited_module.global_nodes[dep] |
|
|
|
|
|
|
|
|
|
recursive_dependencies = visited_module.object_recursive_dependency_mapping.get(object_name, set()) |
|
node_recursive_dependencies_mapping = { |
|
dep: visited_module.global_nodes[dep] for dep in recursive_dependencies |
|
} |
|
for filename, module_mapper in self.visited_modules.items(): |
|
if filename != file: |
|
module_mapper.global_nodes[object_name] = visited_module.functions[object_name] |
|
if len(recursive_dependencies) > 0: |
|
module_mapper.object_recursive_dependency_mapping[object_name] = recursive_dependencies |
|
module_mapper.global_nodes.update(node_recursive_dependencies_mapping) |
|
|
|
|
|
elif object_name in visited_module.assignments and object_name not in self.assignments: |
|
self.assignments[object_name] = visited_module.assignments[object_name] |
|
self.added_objects_file_mapping[object_name] = file |
|
dependencies = visited_module.object_dependency_mapping.get(object_name, None) |
|
if dependencies is not None: |
|
self.object_dependency_mapping[object_name] = dependencies |
|
for dep in dependencies: |
|
if dep not in self.global_nodes: |
|
self.added_objects_file_mapping[dep] = file |
|
self.assignments[dep] = visited_module.global_nodes[dep] |
|
|
|
|
|
self.global_nodes = {**self.assignments, **self.classes, **self.functions} |
|
|
|
self._restrict_dependencies_to_known_entities() |
|
|
|
def compute_relative_order(self, missing_dependencies: set) -> dict[str, int]: |
|
"""Compute in which relative order the `missing_dependencies` should appear when the nodes are added to the final file that |
|
will be created based on the modular. |
|
""" |
|
relative_order = {} |
|
idx = 0 |
|
|
|
original_dependencies = [] |
|
other_files_dependencies = defaultdict(list) |
|
for dep in tuple(missing_dependencies): |
|
if dep in self.added_objects_file_mapping: |
|
file = self.added_objects_file_mapping[dep] |
|
other_files_dependencies[file].append(dep) |
|
else: |
|
original_dependencies.append(dep) |
|
|
|
all_dependencies = [] |
|
for file, dependencies in other_files_dependencies.items(): |
|
sorted_dependencies = sorted(dependencies, key=lambda x: self.start_lines_file_mapping[file][x]) |
|
all_dependencies += sorted_dependencies |
|
all_dependencies += sorted(original_dependencies, key=lambda x: self.start_lines[x]) |
|
|
|
|
|
for dep in all_dependencies: |
|
relative_order[dep] = idx |
|
idx += 1 |
|
|
|
return relative_order |
|
|
|
def infer_new_model_name(self) -> dict: |
|
"""Infer whether we are using a model name prefix different from the usual model name as defined from the filename. |
|
This is useful e.g. when we define a new multi-modal model, and only the text part inherits from `LlamaModel`, |
|
so we have something like: |
|
```python |
|
class NewModelNameTextDecoderLayer(LlamaDecoderLayer): |
|
pass |
|
``` |
|
with the `Text` prefix added to the model name. |
|
However, in case of multiple prefix used, we raise a warning and use the most frequent prefix, to avoid parsing |
|
the same file multiple times and inconsistencies in the objects added from dependencies. |
|
If the new prefix collides with a prefix of another class in the file where we are importing from, then we also |
|
raise a warning, and use the default prefix (model name) to avoid collisions in dependencies. |
|
""" |
|
prefix_model_name_mapping = defaultdict(Counter) |
|
cased_default_name = get_cased_name(self.model_name) |
|
|
|
for class_name, class_node in self.classes.items(): |
|
modeling_bases = [ |
|
k.value.value for k in class_node.bases if k.value.value in self.model_specific_imported_objects |
|
] |
|
if len(modeling_bases) > 1: |
|
raise ValueError( |
|
f"{class_name} was defined with more than 1 model-specific super class. This is unsupported. We found {(*modeling_bases,)}." |
|
) |
|
if len(modeling_bases) == 1: |
|
filename = self.model_specific_imported_objects[modeling_bases[0]] |
|
cased_model_name = cased_default_name |
|
suffix = common_partial_suffix(class_name, modeling_bases[0]) |
|
if len(suffix) > 0 and suffix[0].isupper(): |
|
cased_model_name = class_name.replace(suffix, "") |
|
|
|
|
|
if len(cased_model_name) < len(cased_default_name) and cased_default_name in class_name: |
|
cased_model_name = cased_default_name |
|
prefix_model_name_mapping[filename].update([cased_model_name]) |
|
|
|
|
|
final_name_mapping = {} |
|
for file, prefixes_counter in prefix_model_name_mapping.items(): |
|
if len(prefixes_counter) > 1: |
|
_, total = prefixes_counter.most_common(1)[0] |
|
most_used_entities = [name for name, count in prefixes_counter.most_common() if count == total] |
|
|
|
final_name = cased_default_name if cased_default_name in most_used_entities else most_used_entities[-1] |
|
else: |
|
final_name = list(prefixes_counter)[0] |
|
|
|
old_cased_model_name = get_cased_name(file.split(".")[-2]) |
|
old_model_name_prefix = final_name.replace(cased_default_name, old_cased_model_name) |
|
|
|
has_prefix_collision = f"\nclass {old_model_name_prefix}" in get_module_source_from_name(file) |
|
if final_name != cased_default_name and has_prefix_collision: |
|
if len(prefixes_counter) > 1: |
|
logger.warning( |
|
f"We detected multiple prefix names when inheriting from {file}: {(*set(prefixes_counter),)}. However, the " |
|
f"most used one, '{final_name}', is already present in the source file and will likely cause consistency " |
|
f"issues. For this reason we fallback to the default prefix '{cased_default_name}' when grabbing args " |
|
"and dependencies. Make sure to subclass the intermediate classes with the prefix you want (if different " |
|
f"from '{cased_default_name}') or use a single prefix in all the modular (best)." |
|
) |
|
else: |
|
logger.warning( |
|
f"We detected the use of the new default prefix {final_name} when inheriting from {file}. However, it is " |
|
"already present in the source file and will likely cause consistency issues. For this reason we fallback " |
|
f"to the default prefix '{cased_default_name}' when grabbing args and dependencies. Make sure to subclass " |
|
f"the intermediate classes with the prefix you want (if different from '{cased_default_name}')" |
|
) |
|
final_name = cased_default_name |
|
elif len(prefixes_counter) > 1: |
|
logger.warning( |
|
f"We detected multiple prefix names when inheriting from {file}: {(*set(prefixes_counter),)}. We will only " |
|
f"use the most used '{final_name}' prefix when grabbing args and dependencies. Make sure to subclass the " |
|
f"intermediate classes with the prefix you want (if different from '{final_name}') or use a single prefix " |
|
"in all the modular (best)." |
|
) |
|
final_name_mapping[file] = get_lowercase_name(final_name) |
|
|
|
|
|
for file in self.model_specific_modules.keys(): |
|
if file not in final_name_mapping.keys(): |
|
final_name_mapping[file] = self.model_name |
|
|
|
return final_name_mapping |
|
|
|
|
|
def check_dependencies_and_create_import_node( |
|
file_type: str, new_dependencies: set[str], mapper: ModuleMapper, new_name: str |
|
) -> tuple[set[str], dict[str, cst.CSTNode]]: |
|
"""Check that all class nodes in the `new_dependencies` belong to the correct `file_type`. If this is not the case, |
|
we need to remove it from the dependencies, and create a new import to it instead. |
|
This scenario may appear in the following case: |
|
If a new class in the `modular_xxx.py` file does not belong to `type_xxx.py`, but is used somewhere in `other_type_xxx.py` |
|
(e.g. as a type hint), but none of the visited files had a similar class, then it would be imported in `type_xxx.py` as |
|
part of the standard dependency graph (because we never encountered an import towards this new class in any file). |
|
For example imagine the following `modular.py`: |
|
``` |
|
from ..llama.modeling_llama import LlamaModel |
|
|
|
class NewNameTextConfig(PretrainedConfig): |
|
... |
|
|
|
class NewNameConfig(PretrainedConfig): |
|
... |
|
|
|
class NewNameModel(LlamaModel): |
|
config = NewNameConfig() |
|
text_config = NewNameTextConfig() |
|
... |
|
``` |
|
then without the help of this function, `NewNameTextConfig` would be imported in the `modeling_newname.py` as well as |
|
`configuration_newname.py`, because `modeling_llama.py` tells us to not import `NewNameConfig`, but has no |
|
knowledge of `NewNameTextConfig`. |
|
""" |
|
class_dependencies = {dep for dep in new_dependencies if m.matches(mapper.global_nodes[dep], m.ClassDef())} |
|
corrected_dependencies = new_dependencies.copy() |
|
new_imports = {} |
|
for class_name in class_dependencies: |
|
class_file_type = find_file_type(class_name) |
|
|
|
if class_file_type != file_type: |
|
corrected_dependencies.remove(class_name) |
|
import_statement = f"from .{class_file_type}_{new_name} import {class_name}" |
|
new_imports[class_name] = cst.parse_statement(import_statement) |
|
|
|
return corrected_dependencies, new_imports |
|
|
|
|
|
def get_class_node_and_dependencies( |
|
modular_mapper: ModularFileMapper, class_name: str, node: cst.CSTNode, files: dict[str, dict] |
|
) -> tuple[dict, str, dict]: |
|
"""Return a single class node (and all its dependency nodes), to be added to the `files`. It creates the new |
|
class node based on the inherited classes if needed. Also returns any new imports of a new class defined in |
|
the modular that we nay need. |
|
""" |
|
|
|
model_specific_bases = [ |
|
k.value.value for k in node.bases if k.value.value in modular_mapper.model_specific_imported_objects |
|
] |
|
super_class = model_specific_bases[0] if len(model_specific_bases) == 1 else None |
|
|
|
file_type = find_file_type(class_name) |
|
file_to_update = files[file_type] |
|
model_name = modular_mapper.model_name |
|
|
|
|
|
imported_objects = modular_mapper.imported_objects_per_file[file_type] |
|
|
|
|
|
if super_class is not None: |
|
super_file_name = modular_mapper.model_specific_imported_objects[super_class] |
|
|
|
|
|
mapper = modular_mapper.visited_modules[super_file_name] |
|
|
|
renamer = modular_mapper.renamers[super_file_name] |
|
renamed_super_class = preserve_case_replace(super_class, renamer.patterns, renamer.cased_new_name) |
|
|
|
|
|
updated_node = replace_class_node(mapper, node, renamed_super_class, super_class) |
|
|
|
|
|
new_node_dependencies = augmented_dependencies_for_class_node(updated_node, mapper, imported_objects) |
|
|
|
|
|
|
|
new_node_dependencies, new_imports = check_dependencies_and_create_import_node( |
|
file_type, new_node_dependencies, mapper, model_name |
|
) |
|
|
|
|
|
all_dependencies_to_add = find_all_dependencies( |
|
dependency_mapping=mapper.class_dependency_mapping, |
|
initial_dependencies=new_node_dependencies, |
|
initial_checked_dependencies=set(file_to_update.keys()), |
|
) |
|
|
|
relative_dependency_order = mapper.compute_relative_order(all_dependencies_to_add) |
|
nodes_to_add = { |
|
dep: (relative_dependency_order[dep], mapper.global_nodes[dep]) for dep in all_dependencies_to_add |
|
} |
|
|
|
|
|
else: |
|
updated_node = node |
|
|
|
|
|
all_dependencies_to_add = augmented_dependencies_for_class_node(updated_node, modular_mapper, imported_objects) |
|
|
|
|
|
|
|
all_dependencies_to_add, new_imports = check_dependencies_and_create_import_node( |
|
file_type, all_dependencies_to_add, modular_mapper, model_name |
|
) |
|
|
|
relative_dependency_order = modular_mapper.compute_relative_order(all_dependencies_to_add) |
|
nodes_to_add = { |
|
dep: (relative_dependency_order[dep], modular_mapper.global_nodes[dep]) |
|
for dep in all_dependencies_to_add |
|
if dep not in file_to_update.keys() |
|
} |
|
|
|
|
|
class_idx = max(relative_dependency_order.values()) + 1 if len(relative_dependency_order) > 0 else 0 |
|
nodes_to_add[class_name] = (class_idx, updated_node) |
|
|
|
return nodes_to_add, file_type, new_imports |
|
|
|
|
|
def create_modules(modular_mapper: ModularFileMapper) -> dict[str, cst.Module]: |
|
"""Create all the new modules based on visiting the modular file. It replaces all classes as necessary.""" |
|
files = defaultdict(dict) |
|
current_file_indices = defaultdict(lambda: 0) |
|
|
|
|
|
for class_name, node in modular_mapper.classes.items(): |
|
nodes_to_add, file_type, new_imports = get_class_node_and_dependencies(modular_mapper, class_name, node, files) |
|
|
|
|
|
modular_mapper.imported_objects_per_file[file_type].update(new_imports.keys()) |
|
modular_mapper.imports.extend(list(new_imports.values())) |
|
|
|
|
|
nodes_to_add = sorted(nodes_to_add.items(), key=lambda x: x[1][0]) |
|
|
|
for dependency, (_, node) in nodes_to_add: |
|
|
|
try: |
|
|
|
idx = -1000 + VARIABLES_AT_THE_BEGINNING.index(dependency) |
|
except ValueError: |
|
idx = current_file_indices[file_type] |
|
current_file_indices[file_type] += 1 |
|
files[file_type][dependency] = {"insert_idx": idx, "node": node} |
|
|
|
|
|
for file_type, node in modular_mapper.all_all_to_add.items(): |
|
idx = current_file_indices[file_type] |
|
files[file_type]["__all__"] = {"insert_idx": idx, "node": node} |
|
|
|
|
|
|
|
all_imports = modular_mapper.imports.copy() |
|
all_imports_code = {modular_mapper.python_module.code_for_node(node).strip() for node in all_imports} |
|
for file, mapper in modular_mapper.visited_modules.items(): |
|
new_imports = [ |
|
node for node in mapper.imports if mapper.python_module.code_for_node(node).strip() not in all_imports_code |
|
] |
|
new_imports_code = {mapper.python_module.code_for_node(node).strip() for node in new_imports} |
|
all_imports.extend(new_imports) |
|
all_imports_code.update(new_imports_code) |
|
|
|
|
|
for file, body in files.items(): |
|
new_body = [k[1]["node"] for k in sorted(body.items(), key=lambda x: x[1]["insert_idx"])] |
|
needed_imports = get_needed_imports(body, all_imports) |
|
full_module = needed_imports + new_body |
|
new_module = cst.Module(body=full_module, header=modular_mapper.python_module.header) |
|
files[file] = new_module |
|
|
|
return files |
|
|
|
|
|
def convert_modular_file(modular_file): |
|
pattern = re.search(r"modular_(.*)(?=\.py$)", modular_file) |
|
output = {} |
|
if pattern is not None: |
|
model_name = pattern.groups()[0] |
|
|
|
with open(modular_file, "r", encoding="utf-8") as file: |
|
code = file.read() |
|
module = cst.parse_module(code) |
|
wrapper = MetadataWrapper(module) |
|
cst_transformers = ModularFileMapper(module, model_name) |
|
wrapper.visit(cst_transformers) |
|
for file, module in create_modules(cst_transformers).items(): |
|
if module != {}: |
|
|
|
relative_path = re.search( |
|
r"(src/transformers/.*|examples/.*)", os.path.abspath(modular_file).replace("\\", "/") |
|
).group(1) |
|
|
|
header = AUTO_GENERATED_MESSAGE.format( |
|
relative_path=relative_path, short_name=os.path.basename(relative_path) |
|
) |
|
ruffed_code = run_ruff(header + module.code, True) |
|
formatted_code = run_ruff(ruffed_code, False) |
|
output[file] = [formatted_code, ruffed_code] |
|
return output |
|
else: |
|
print(f"modular pattern not found in {modular_file}, exiting") |
|
return {} |
|
|
|
|
|
def save_modeling_file(modular_file, converted_file): |
|
for file_type in converted_file.keys(): |
|
file_name_prefix = file_type.split("*")[0] |
|
file_name_suffix = file_type.split("*")[-1] if "*" in file_type else "" |
|
new_file_name = modular_file.replace("modular_", f"{file_name_prefix}_").replace( |
|
".py", f"{file_name_suffix}.py" |
|
) |
|
non_comment_lines = len( |
|
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")] |
|
) |
|
if len(converted_file[file_type][0].strip()) > 0 and non_comment_lines > 0: |
|
with open(new_file_name, "w", encoding="utf-8") as f: |
|
f.write(converted_file[file_type][0]) |
|
else: |
|
non_comment_lines = len( |
|
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")] |
|
) |
|
if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0: |
|
logger.warning("The modeling code contains errors, it's written without formatting") |
|
with open(new_file_name, "w", encoding="utf-8") as f: |
|
f.write(converted_file[file_type][1]) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--files_to_parse", |
|
default=["all"], |
|
nargs="+", |
|
help="A list of `modular_xxxx` files that should be converted to single model file", |
|
) |
|
args = parser.parse_args() |
|
if args.files_to_parse == ["all"]: |
|
args.files_to_parse = glob.glob("src/transformers/models/**/modular_*.py", recursive=True) |
|
if args.files_to_parse == ["examples"]: |
|
args.files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True) |
|
else: |
|
for i, model_name in enumerate(args.files_to_parse): |
|
if os.sep not in model_name: |
|
full_path = os.path.join("src", "transformers", "models", model_name, f"modular_{model_name}.py") |
|
|
|
if not os.path.isfile(full_path): |
|
full_path = os.path.join("examples", "modular-transformers", f"modular_{model_name}.py") |
|
|
|
if not os.path.isfile(full_path): |
|
raise ValueError(f"Cannot find a modular file for {model_name}. Please provide the full path.") |
|
args.files_to_parse[i] = full_path |
|
|
|
priority_list, _ = find_priority_list(args.files_to_parse) |
|
assert len(priority_list) == len(args.files_to_parse), "Some files will not be converted" |
|
|
|
for file_name in priority_list: |
|
print(f"Converting {file_name} to a single model single file format") |
|
module_path = file_name.replace("/", ".").replace(".py", "").replace("src.", "") |
|
converted_files = convert_modular_file(file_name) |
|
converter = save_modeling_file(file_name, converted_files) |
|
|