JAM / utils.py
renhang
update code without examples
8e872fa
raw
history blame
4.67 kB
def regroup_words(
words: list[dict],
max_len: float = 15.0,
gap: float = 0.50,
) -> list[dict]:
"""
Returns a list of segments with keys:
'start', 'end', 'text', 'words'
"""
if not words:
return []
segs, seg_words = [], []
seg_start = words[0]["start"]
last_end = seg_start
for w in words:
over_max = (w["end"] - seg_start) > max_len
long_gap = (w["start"] - last_end) > gap
if (seg_words and (over_max or long_gap)):
segs.append({
"start": seg_start,
"end": last_end,
"segment": " ".join(x["word"] for x in seg_words),
})
seg_words = []
seg_start = w["start"]
seg_words.append(w)
last_end = w["end"]
# flush final segment
segs.append({
"start": seg_start,
"end": last_end,
"segment": " ".join(x["word"] for x in seg_words),
})
return segs
def text_to_words(text: str) -> list[dict]:
"""
Convert text format like "word[start:end] word[start:end]..." to word list.
Args:
text: String in format "It's[4.96:5.52] a[5.52:5.84] long[5.84:6.16]..."
Returns:
List of word dictionaries with keys: 'word', 'start', 'end'
"""
import re
if not text.strip():
return []
# Pattern to match word[start:end] format
pattern = r'(\S+?)\[([^:]+):([^\]]+)\]'
matches = re.findall(pattern, text)
words = []
for word, start_str, end_str in matches:
try:
start = float(start_str) if start_str != 'xxx' else 0.0
end = float(end_str) if end_str != 'xxx' else 0.0
words.append({
'word': word,
'start': start,
'end': end
})
except ValueError:
# Skip invalid entries
continue
return words
def words_to_text(words: list[dict]) -> str:
"""
Convert word list to text format "word[start:end] word[start:end]...".
Args:
words: List of word dictionaries with keys: 'word', 'start', 'end'
Returns:
String in format "It's[4.96:5.52] a[5.52:5.84] long[5.84:6.16]..."
"""
if not words:
return ""
text_parts = []
for word in words:
word_text = word.get('word', '')
start = word.get('start', 0.0)
end = word.get('end', 0.0)
# Format timestamps to max 2 decimal places
start_str = f"{start:.2f}".rstrip('0').rstrip('.')
end_str = f"{end:.2f}".rstrip('0').rstrip('.')
text_parts.append(f"{word_text}[{start_str}:{end_str}]")
return " ".join(text_parts)
def json_to_text(json_data: dict) -> str:
"""
Convert JSON lyrics data to text format for display.
Only uses the 'word' layer from the JSON structure.
Groups words into sentences/lines for better readability.
Args:
json_data: Dictionary with 'word' key containing list of word objects
Returns:
String with words grouped into lines: "word[start:end] word[start:end]...\nword[start:end]..."
"""
if not isinstance(json_data, dict) or 'word' not in json_data:
return ""
words = json_data['word']
# Group words into segments using the existing regroup_words function
segments = regroup_words(words, max_len=5, gap=0.50)
# Convert each segment to text format
segment_lines = []
for seg in segments:
# Extract words for this segment based on time range
seg_words = []
for word in words:
if seg['start'] <= word['start'] < seg['end'] or (
word['start'] <= seg['start'] < word['end']
):
seg_words.append(word)
if seg_words:
segment_text = words_to_text(seg_words)
segment_lines.append(segment_text)
return '\n\n'.join(segment_lines)
def text_to_json(text: str) -> dict:
"""
Convert text format to JSON structure expected by the model.
Creates the 'word' layer that the model needs.
Handles multi-line input by joining lines.
Args:
text: String in format "word[start:end] word[start:end]..." (can be multi-line)
Returns:
Dictionary with 'word' key containing list of word objects
"""
# Join multiple lines into single line for parsing
single_line_text = ' '.join(line.strip() for line in text.split('\n') if line.strip())
words = text_to_words(single_line_text)
return {"word": words}