Spaces:
Running
Running
Commit
·
572abf8
1
Parent(s):
753bd5a
Initial commit
Browse files- Dockerfile +31 -0
- README.md +4 -4
- __pycache__/utils.cpython-311.pyc +0 -0
- agents/__innit__.py +0 -0
- agents/__pycache__/agents.cpython-311.pyc +0 -0
- agents/__pycache__/harmonize.cpython-311.pyc +0 -0
- agents/__pycache__/harmonize.cpython-312.pyc +0 -0
- agents/__pycache__/utils.cpython-311.pyc +0 -0
- agents/agents.py +422 -0
- agents/utils.py +15 -0
- anticipation/__init__.py +9 -0
- anticipation/__pycache__/__init__.cpython-311.pyc +0 -0
- anticipation/__pycache__/config.cpython-311.pyc +0 -0
- anticipation/__pycache__/convert.cpython-311.pyc +0 -0
- anticipation/__pycache__/ops.cpython-311.pyc +0 -0
- anticipation/__pycache__/sample.cpython-311.pyc +0 -0
- anticipation/__pycache__/tokenize.cpython-311.pyc +0 -0
- anticipation/__pycache__/visuals.cpython-311.pyc +0 -0
- anticipation/__pycache__/vocab.cpython-311.pyc +0 -0
- anticipation/config-original.py +60 -0
- anticipation/config.py +60 -0
- anticipation/convert-original.py +342 -0
- anticipation/convert.py +365 -0
- anticipation/ops.py +285 -0
- anticipation/sample.py +280 -0
- anticipation/tokenize.py +219 -0
- anticipation/visuals.py +65 -0
- anticipation/vocab.py +58 -0
- api.py +240 -0
- examples/full-score3.mid +0 -0
- examples/strawberry.mid +0 -0
- requirements.txt +11 -0
Dockerfile
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9-slim
|
2 |
+
|
3 |
+
# Create dedicated user with home directory
|
4 |
+
RUN useradd -m -u 1000 user
|
5 |
+
|
6 |
+
# Set Hugging Face cache to user's writable directory
|
7 |
+
ENV HF_HOME=/home/user/.cache/huggingface
|
8 |
+
ENV TRANSFORMERS_CACHE=/home/user/.cache/huggingface
|
9 |
+
|
10 |
+
# Create cache directory with proper permissions
|
11 |
+
RUN mkdir -p ${HF_HOME} && chown -R user:user /home/user
|
12 |
+
|
13 |
+
# Set working directory (app will live here)
|
14 |
+
WORKDIR /app
|
15 |
+
|
16 |
+
# Install dependencies as root
|
17 |
+
COPY requirements.txt .
|
18 |
+
RUN pip install --no-cache-dir -r requirements.txt gunicorn
|
19 |
+
|
20 |
+
# Copy app files (maintain ownership)
|
21 |
+
COPY --chown=user:user . .
|
22 |
+
|
23 |
+
RUN rm -rf /root/.cache/pip
|
24 |
+
|
25 |
+
|
26 |
+
# Switch to non-root user
|
27 |
+
USER user
|
28 |
+
|
29 |
+
|
30 |
+
EXPOSE 7860
|
31 |
+
CMD ["gunicorn", "--workers", "1", "--timeout", "120", "--bind", "0.0.0.0:7860", "api:app"]
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
---
|
|
|
1 |
---
|
2 |
+
title: InScoreAI
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: purple
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
---
|
__pycache__/utils.cpython-311.pyc
ADDED
Binary file (229 Bytes). View file
|
|
agents/__innit__.py
ADDED
File without changes
|
agents/__pycache__/agents.cpython-311.pyc
ADDED
Binary file (18.1 kB). View file
|
|
agents/__pycache__/harmonize.cpython-311.pyc
ADDED
Binary file (19.2 kB). View file
|
|
agents/__pycache__/harmonize.cpython-312.pyc
ADDED
Binary file (11 kB). View file
|
|
agents/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (691 Bytes). View file
|
|
agents/agents.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from anticipation import ops
|
2 |
+
from anticipation.sample import generate
|
3 |
+
from anticipation.tokenize import extract_instruments
|
4 |
+
from anticipation.convert import events_to_midi,midi_to_events, compound_to_midi
|
5 |
+
from anticipation.config import *
|
6 |
+
from anticipation.vocab import *
|
7 |
+
from anticipation.convert import midi_to_compound
|
8 |
+
import mido
|
9 |
+
from agents.utils import load_midi_metadata
|
10 |
+
|
11 |
+
|
12 |
+
SMALL_MODEL = 'stanford-crfm/music-small-800k' # faster inference, worse sample quality
|
13 |
+
MEDIUM_MODEL = 'stanford-crfm/music-medium-800k' # slower inference, better sample quality
|
14 |
+
LARGE_MODEL = 'stanford-crfm/music-large-800k' # slowest inference, best sample quality
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
def harmonize_midi(model, midi, start_time, end_time,original_tempo,original_time_sig,top_p):
|
20 |
+
|
21 |
+
# Turn full midi to events
|
22 |
+
events = midi_to_events(midi)
|
23 |
+
|
24 |
+
print("Midi converted to events")
|
25 |
+
|
26 |
+
# Get clip from 0 to end of full midi
|
27 |
+
|
28 |
+
segment = ops.clip(events, 0, ops.max_time(events, seconds=True))
|
29 |
+
segment = ops.translate(segment, -ops.min_time(segment, seconds=False))
|
30 |
+
|
31 |
+
# Extract melody and accompaniment
|
32 |
+
events, melody = extract_instruments(segment, [0])
|
33 |
+
|
34 |
+
print("Melody extracted")
|
35 |
+
|
36 |
+
print("Start time:", start_time)
|
37 |
+
print("End time:", end_time)
|
38 |
+
|
39 |
+
# Get initial prompt
|
40 |
+
history = ops.clip(events, 0, start_time, clip_duration=False)
|
41 |
+
|
42 |
+
anticipated = [CONTROL_OFFSET + tok for tok in ops.clip(events, end_time, ops.max_time(segment, seconds=True), clip_duration=False)]
|
43 |
+
|
44 |
+
# Generate accompaniment conditioning on melody
|
45 |
+
accompaniment = generate(model, start_time, end_time, inputs=history, controls=melody, top_p=top_p, debug=False)
|
46 |
+
|
47 |
+
# Append anticipated continuation to accompaniment
|
48 |
+
accompaniment = ops.combine(accompaniment, anticipated)
|
49 |
+
|
50 |
+
print("Accompaniment generated")
|
51 |
+
|
52 |
+
# 1) render each voice separately
|
53 |
+
mel_mid = events_to_midi(melody)
|
54 |
+
acc_mid = events_to_midi(accompaniment)
|
55 |
+
|
56 |
+
# 2) build a fresh MidiFile
|
57 |
+
combined = mido.MidiFile()
|
58 |
+
combined.ticks_per_beat = mel_mid.ticks_per_beat # or TIME_RESOLUTION//2
|
59 |
+
|
60 |
+
print("Midi built")
|
61 |
+
|
62 |
+
# 3) meta‐track with tempo & time signature
|
63 |
+
meta = mido.MidiTrack()
|
64 |
+
meta.append(mido.MetaMessage('set_tempo', tempo=original_tempo))
|
65 |
+
meta.append(mido.MetaMessage('time_signature',
|
66 |
+
numerator=original_time_sig[0],
|
67 |
+
denominator=original_time_sig[1]))
|
68 |
+
combined.tracks.append(meta)
|
69 |
+
|
70 |
+
# 4) append melody *then* accompaniment
|
71 |
+
combined.tracks.extend(mel_mid.tracks[1:]) # Skip existing meta track
|
72 |
+
combined.tracks.extend(acc_mid.tracks[1:])
|
73 |
+
# 5) save in exactly that order
|
74 |
+
|
75 |
+
for track in combined.tracks:
|
76 |
+
for msg in track:
|
77 |
+
if msg.type in ['note_on', 'note_off']:
|
78 |
+
# Ensure valid MIDI values
|
79 |
+
if hasattr(msg, 'velocity'):
|
80 |
+
msg.velocity = min(max(msg.velocity, 0), 127)
|
81 |
+
if hasattr(msg, 'note'):
|
82 |
+
msg.note = min(max(msg.note, 0), 127)
|
83 |
+
|
84 |
+
print(f"Melody tracks: {len(mel_mid.tracks)}")
|
85 |
+
print(f"Accompaniment tracks: {len(acc_mid.tracks)}")
|
86 |
+
print(f"Combined tracks before cleanup: {len(combined.tracks)}")
|
87 |
+
|
88 |
+
# Add track cleanup (keep only unique tracks):
|
89 |
+
unique_tracks = []
|
90 |
+
seen = set()
|
91 |
+
for track in combined.tracks:
|
92 |
+
track_hash = str([msg.hex() for msg in track])
|
93 |
+
if track_hash not in seen:
|
94 |
+
unique_tracks.append(track)
|
95 |
+
seen.add(track_hash)
|
96 |
+
combined.tracks = unique_tracks
|
97 |
+
|
98 |
+
print(f"Final track count: {len(combined.tracks)}")
|
99 |
+
|
100 |
+
print("Output Midi metadata added")
|
101 |
+
|
102 |
+
return combined
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
def harmonizer(ai_model,midi_file, start_time, end_time,top_p):
|
107 |
+
"""
|
108 |
+
this function harmonizes a melody in a MIDI file
|
109 |
+
returns the harmonized MIDI
|
110 |
+
|
111 |
+
Args:
|
112 |
+
midi_file: path to the MIDI file
|
113 |
+
start_time: start time of the selected measure (melody you want to harmonize) in milliseconds
|
114 |
+
end_time: end time of the selected measure in milliseconds
|
115 |
+
"""
|
116 |
+
|
117 |
+
print(f"Original MIDI tracks: {len(midi_file.tracks)}")
|
118 |
+
|
119 |
+
# Load metadata and model...
|
120 |
+
|
121 |
+
# Log original note parameters
|
122 |
+
for track in midi_file.tracks:
|
123 |
+
for msg in track:
|
124 |
+
if msg.type in ['note_on', 'note_off']:
|
125 |
+
if msg.velocity > 127 or msg.velocity < 0:
|
126 |
+
print(f"Invalid velocity: {msg.velocity}")
|
127 |
+
if msg.note > 127 or msg.note < 0:
|
128 |
+
print(f"Invalid pitch: {msg.note}")
|
129 |
+
|
130 |
+
|
131 |
+
# Load original MIDI and extract metadata
|
132 |
+
midi, original_tempo, original_time_sig = load_midi_metadata(midi_file)
|
133 |
+
|
134 |
+
print("Midi metadata loaded")
|
135 |
+
|
136 |
+
# load an anticipatory music transformer
|
137 |
+
model = ai_model # add .cuda() if you have a GPU
|
138 |
+
|
139 |
+
print("Model loaded")
|
140 |
+
|
141 |
+
harmonized_midi = harmonize_midi(model, midi, start_time, end_time, original_tempo,original_time_sig,top_p)
|
142 |
+
|
143 |
+
print("Midi generated")
|
144 |
+
|
145 |
+
print(f"Harmonized MIDI tracks: {len(harmonized_midi.tracks)}")
|
146 |
+
|
147 |
+
# Add MIDI validation
|
148 |
+
for track in harmonized_midi.tracks:
|
149 |
+
for msg in track:
|
150 |
+
if msg.type in ['note_on', 'note_off']:
|
151 |
+
# Clamp invalid values
|
152 |
+
msg.velocity = min(max(msg.velocity, 0), 127)
|
153 |
+
msg.note = min(max(msg.note, 0), 127)
|
154 |
+
|
155 |
+
print("Midi saved")
|
156 |
+
|
157 |
+
return harmonized_midi
|
158 |
+
|
159 |
+
def infill_midi(model, midi, start_time, end_time,original_tempo,original_time_sig,top_p):
|
160 |
+
|
161 |
+
# Turn full midi to events
|
162 |
+
events = midi_to_events(midi)
|
163 |
+
|
164 |
+
print("Midi converted to events")
|
165 |
+
|
166 |
+
# Get clip from 0 to end of full midi
|
167 |
+
|
168 |
+
segment = ops.clip(events, 0, ops.max_time(events, seconds=True))
|
169 |
+
segment = ops.translate(segment, -ops.min_time(segment, seconds=False))
|
170 |
+
|
171 |
+
# Get initial prompt
|
172 |
+
history = ops.clip(events, 0, start_time, clip_duration=False)
|
173 |
+
|
174 |
+
anticipated = [CONTROL_OFFSET + tok for tok in ops.clip(events, end_time, ops.max_time(segment, seconds=True), clip_duration=False)]
|
175 |
+
|
176 |
+
# Generate accompaniment conditioning on melody
|
177 |
+
infilling = generate(model, start_time, end_time, inputs=history, controls=anticipated, top_p=top_p, debug=False)
|
178 |
+
|
179 |
+
# Append anticipated continuation to accompaniment
|
180 |
+
full_events = ops.combine(infilling, anticipated)
|
181 |
+
|
182 |
+
print("Accompaniment generated")
|
183 |
+
|
184 |
+
# 1) render each voice separately
|
185 |
+
full_mid = events_to_midi(full_events)
|
186 |
+
|
187 |
+
# 2) build a fresh MidiFile
|
188 |
+
combined = mido.MidiFile()
|
189 |
+
combined.ticks_per_beat = full_mid.ticks_per_beat # or TIME_RESOLUTION//2
|
190 |
+
|
191 |
+
print("Midi built")
|
192 |
+
|
193 |
+
# 3) meta‐track with tempo & time signature
|
194 |
+
meta = mido.MidiTrack()
|
195 |
+
meta.append(mido.MetaMessage('set_tempo', tempo=original_tempo))
|
196 |
+
meta.append(mido.MetaMessage('time_signature',
|
197 |
+
numerator=original_time_sig[0],
|
198 |
+
denominator=original_time_sig[1]))
|
199 |
+
combined.tracks.append(meta)
|
200 |
+
|
201 |
+
# 4) append melody *then* accompaniment
|
202 |
+
combined.tracks.extend(full_mid.tracks[:]) # Skip existing meta track
|
203 |
+
|
204 |
+
# 5) save in exactly that order
|
205 |
+
|
206 |
+
for track in combined.tracks:
|
207 |
+
for msg in track:
|
208 |
+
if msg.type in ['note_on', 'note_off']:
|
209 |
+
# Ensure valid MIDI values
|
210 |
+
if hasattr(msg, 'velocity'):
|
211 |
+
msg.velocity = min(max(msg.velocity, 0), 127)
|
212 |
+
if hasattr(msg, 'note'):
|
213 |
+
msg.note = min(max(msg.note, 0), 127)
|
214 |
+
|
215 |
+
print(f"Melody tracks: {len(full_mid.tracks)}")
|
216 |
+
print(f"Accompaniment tracks: {len(full_mid.tracks)}")
|
217 |
+
print(f"Combined tracks before cleanup: {len(combined.tracks)}")
|
218 |
+
|
219 |
+
# Add track cleanup (keep only unique tracks):
|
220 |
+
unique_tracks = []
|
221 |
+
seen = set()
|
222 |
+
for track in combined.tracks:
|
223 |
+
track_hash = str([msg.hex() for msg in track])
|
224 |
+
if track_hash not in seen:
|
225 |
+
unique_tracks.append(track)
|
226 |
+
seen.add(track_hash)
|
227 |
+
combined.tracks = unique_tracks
|
228 |
+
|
229 |
+
print(f"Final track count: {len(combined.tracks)}")
|
230 |
+
|
231 |
+
print("Output Midi metadata added")
|
232 |
+
|
233 |
+
return combined
|
234 |
+
|
235 |
+
|
236 |
+
|
237 |
+
def infiller(ai_model,midi_file, start_time, end_time,top_p):
|
238 |
+
"""
|
239 |
+
this function harmonizes a melody in a MIDI file
|
240 |
+
returns the harmonized MIDI
|
241 |
+
|
242 |
+
Args:
|
243 |
+
midi_file: path to the MIDI file
|
244 |
+
start_time: start time of the selected measure (melody you want to harmonize) in milliseconds
|
245 |
+
end_time: end time of the selected measure in milliseconds
|
246 |
+
"""
|
247 |
+
|
248 |
+
print(f"Original MIDI tracks: {len(midi_file.tracks)}")
|
249 |
+
|
250 |
+
# Load metadata and model...
|
251 |
+
|
252 |
+
# Log original note parameters
|
253 |
+
for track in midi_file.tracks:
|
254 |
+
for msg in track:
|
255 |
+
if msg.type in ['note_on', 'note_off']:
|
256 |
+
if msg.velocity > 127 or msg.velocity < 0:
|
257 |
+
print(f"Invalid velocity: {msg.velocity}")
|
258 |
+
if msg.note > 127 or msg.note < 0:
|
259 |
+
print(f"Invalid pitch: {msg.note}")
|
260 |
+
|
261 |
+
|
262 |
+
# Load original MIDI and extract metadata
|
263 |
+
midi, original_tempo, original_time_sig = load_midi_metadata(midi_file)
|
264 |
+
|
265 |
+
print("Midi metadata loaded")
|
266 |
+
|
267 |
+
# load an anticipatory music transformer
|
268 |
+
model = ai_model # add .cuda() if you have a GPU
|
269 |
+
|
270 |
+
print("Model loaded")
|
271 |
+
|
272 |
+
infilled_midi = infill_midi(model, midi, start_time, end_time, original_tempo,original_time_sig,top_p)
|
273 |
+
|
274 |
+
print("Midi generated")
|
275 |
+
|
276 |
+
print(f"Harmonized MIDI tracks: {len(infilled_midi.tracks)}")
|
277 |
+
|
278 |
+
# Add MIDI validation
|
279 |
+
for track in infilled_midi.tracks:
|
280 |
+
for msg in track:
|
281 |
+
if msg.type in ['note_on', 'note_off']:
|
282 |
+
# Clamp invalid values
|
283 |
+
msg.velocity = min(max(msg.velocity, 0), 127)
|
284 |
+
msg.note = min(max(msg.note, 0), 127)
|
285 |
+
|
286 |
+
print("Midi saved")
|
287 |
+
|
288 |
+
return infilled_midi
|
289 |
+
|
290 |
+
def change_melody_midi(model, midi, start_time, end_time,original_tempo,original_time_sig,top_p):
|
291 |
+
|
292 |
+
events = midi_to_events(midi)
|
293 |
+
segment = ops.clip(events, 0, ops.max_time(events, seconds=True))
|
294 |
+
segment = ops.translate(segment, -ops.min_time(segment, seconds=False))
|
295 |
+
|
296 |
+
# Extract melody (instrument 0) as events and accompaniment as controls
|
297 |
+
instruments = list(ops.get_instruments(segment).keys())
|
298 |
+
accompaniment_instruments = [instr for instr in instruments if instr != 0]
|
299 |
+
melody_events, accompaniment_controls = extract_instruments(segment, accompaniment_instruments)
|
300 |
+
|
301 |
+
# Get initial prompt (melody before start_time)
|
302 |
+
history = ops.clip(melody_events, 0, start_time, clip_duration=False)
|
303 |
+
|
304 |
+
# Include accompaniment controls for the entire duration
|
305 |
+
controls = accompaniment_controls # Full accompaniment as controls
|
306 |
+
|
307 |
+
# Generate new melody conditioned on accompaniment
|
308 |
+
infilling = generate(model, start_time, end_time, inputs=history, controls=controls, top_p=top_p, debug=False)
|
309 |
+
|
310 |
+
# Append anticipated continuation
|
311 |
+
anticipated_melody = [CONTROL_OFFSET + tok for tok in ops.clip(melody_events, end_time, ops.max_time(segment, seconds=True), clip_duration=False)]
|
312 |
+
full_events = ops.combine(infilling, anticipated_melody)
|
313 |
+
|
314 |
+
acc_mid = events_to_midi(accompaniment_controls)
|
315 |
+
|
316 |
+
|
317 |
+
# Render and combine MIDI tracks
|
318 |
+
full_mid = events_to_midi(full_events)
|
319 |
+
combined = mido.MidiFile()
|
320 |
+
combined.ticks_per_beat = full_mid.ticks_per_beat # or TIME_RESOLUTION//2
|
321 |
+
|
322 |
+
print("Midi built")
|
323 |
+
|
324 |
+
# 3) meta‐track with tempo & time signature
|
325 |
+
meta = mido.MidiTrack()
|
326 |
+
meta.append(mido.MetaMessage('set_tempo', tempo=original_tempo))
|
327 |
+
meta.append(mido.MetaMessage('time_signature',
|
328 |
+
numerator=original_time_sig[0],
|
329 |
+
denominator=original_time_sig[1]))
|
330 |
+
combined.tracks.append(meta)
|
331 |
+
|
332 |
+
# 4) append melody *then* accompaniment
|
333 |
+
combined.tracks.extend(full_mid.tracks[:]) # Skip existing meta track
|
334 |
+
combined.tracks.extend(acc_mid.tracks[:]) # Skip existing meta track
|
335 |
+
|
336 |
+
# 5) save in exactly that order
|
337 |
+
|
338 |
+
for track in combined.tracks:
|
339 |
+
for msg in track:
|
340 |
+
if msg.type in ['note_on', 'note_off']:
|
341 |
+
# Ensure valid MIDI values
|
342 |
+
if hasattr(msg, 'velocity'):
|
343 |
+
msg.velocity = min(max(msg.velocity, 0), 127)
|
344 |
+
if hasattr(msg, 'note'):
|
345 |
+
msg.note = min(max(msg.note, 0), 127)
|
346 |
+
|
347 |
+
print(f"Melody tracks: {len(full_mid.tracks)}")
|
348 |
+
print(f"Accompaniment tracks: {len(full_mid.tracks)}")
|
349 |
+
print(f"Combined tracks before cleanup: {len(combined.tracks)}")
|
350 |
+
|
351 |
+
# Add track cleanup (keep only unique tracks):
|
352 |
+
unique_tracks = []
|
353 |
+
seen = set()
|
354 |
+
for track in combined.tracks:
|
355 |
+
track_hash = str([msg.hex() for msg in track])
|
356 |
+
if track_hash not in seen:
|
357 |
+
unique_tracks.append(track)
|
358 |
+
seen.add(track_hash)
|
359 |
+
combined.tracks = unique_tracks
|
360 |
+
|
361 |
+
print(f"Final track count: {len(combined.tracks)}")
|
362 |
+
|
363 |
+
print("Output Midi metadata added")
|
364 |
+
|
365 |
+
return combined
|
366 |
+
|
367 |
+
|
368 |
+
|
369 |
+
def change_melody(ai_model,midi_file, start_time, end_time,top_p):
|
370 |
+
"""
|
371 |
+
this function harmonizes a melody in a MIDI file
|
372 |
+
returns the harmonized MIDI
|
373 |
+
|
374 |
+
Args:
|
375 |
+
midi_file: path to the MIDI file
|
376 |
+
start_time: start time of the selected measure (melody you want to harmonize) in milliseconds
|
377 |
+
end_time: end time of the selected measure in milliseconds
|
378 |
+
"""
|
379 |
+
|
380 |
+
print(f"Original MIDI tracks: {len(midi_file.tracks)}")
|
381 |
+
|
382 |
+
# Load metadata and model...
|
383 |
+
|
384 |
+
# Log original note parameters
|
385 |
+
for track in midi_file.tracks:
|
386 |
+
for msg in track:
|
387 |
+
if msg.type in ['note_on', 'note_off']:
|
388 |
+
if msg.velocity > 127 or msg.velocity < 0:
|
389 |
+
print(f"Invalid velocity: {msg.velocity}")
|
390 |
+
if msg.note > 127 or msg.note < 0:
|
391 |
+
print(f"Invalid pitch: {msg.note}")
|
392 |
+
|
393 |
+
|
394 |
+
# Load original MIDI and extract metadata
|
395 |
+
midi, original_tempo, original_time_sig = load_midi_metadata(midi_file)
|
396 |
+
|
397 |
+
print("Midi metadata loaded")
|
398 |
+
|
399 |
+
# load an anticipatory music transformer
|
400 |
+
model = ai_model # add .cuda() if you have a GPU
|
401 |
+
|
402 |
+
print("Model loaded")
|
403 |
+
|
404 |
+
change_melody_gen_midi = change_melody_midi(model, midi, start_time, end_time, original_tempo,original_time_sig,top_p)
|
405 |
+
|
406 |
+
print("Midi generated")
|
407 |
+
|
408 |
+
print(f"Harmonized MIDI tracks: {len(change_melody_gen_midi.tracks)}")
|
409 |
+
|
410 |
+
# Add MIDI validation
|
411 |
+
for track in change_melody_gen_midi.tracks:
|
412 |
+
for msg in track:
|
413 |
+
if msg.type in ['note_on', 'note_off']:
|
414 |
+
# Clamp invalid values
|
415 |
+
msg.velocity = min(max(msg.velocity, 0), 127)
|
416 |
+
msg.note = min(max(msg.note, 0), 127)
|
417 |
+
|
418 |
+
print("Midi saved")
|
419 |
+
|
420 |
+
|
421 |
+
|
422 |
+
return change_melody_gen_midi
|
agents/utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
def load_midi_metadata(midi_file):
|
4 |
+
|
5 |
+
original_tempo = 500000 # default tempo (120 BPM)
|
6 |
+
original_time_sig = (4, 4) # default time signature
|
7 |
+
|
8 |
+
for msg in midi_file:
|
9 |
+
if msg.type == 'set_tempo':
|
10 |
+
original_tempo = msg.tempo
|
11 |
+
elif msg.type == 'time_signature':
|
12 |
+
original_time_sig = (msg.numerator, msg.denominator)
|
13 |
+
|
14 |
+
return midi_file, original_tempo, original_time_sig
|
15 |
+
|
anticipation/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Infrastructure for constructing anticipatory infilling models.
|
2 |
+
|
3 |
+
This model provides infrastructure to preprocess Midi music datasets
|
4 |
+
for training anticipatory music infilling models. For more context, see:
|
5 |
+
|
6 |
+
Anticipatory Music Transformer
|
7 |
+
John Thickstun, David Hall, Chris Donahue, Percy Liang
|
8 |
+
Preprint Report, 2023
|
9 |
+
"""
|
anticipation/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (598 Bytes). View file
|
|
anticipation/__pycache__/config.cpython-311.pyc
ADDED
Binary file (2.08 kB). View file
|
|
anticipation/__pycache__/convert.cpython-311.pyc
ADDED
Binary file (19 kB). View file
|
|
anticipation/__pycache__/ops.cpython-311.pyc
ADDED
Binary file (12.1 kB). View file
|
|
anticipation/__pycache__/sample.cpython-311.pyc
ADDED
Binary file (13.6 kB). View file
|
|
anticipation/__pycache__/tokenize.cpython-311.pyc
ADDED
Binary file (12.8 kB). View file
|
|
anticipation/__pycache__/visuals.cpython-311.pyc
ADDED
Binary file (4.02 kB). View file
|
|
anticipation/__pycache__/vocab.cpython-311.pyc
ADDED
Binary file (2.62 kB). View file
|
|
anticipation/config-original.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Global configuration for anticipatory infilling models.
|
3 |
+
"""
|
4 |
+
|
5 |
+
# model hyper-parameters
|
6 |
+
|
7 |
+
CONTEXT_SIZE = 1024 # model context
|
8 |
+
EVENT_SIZE = 3 # each event/control is encoded as 3 tokens
|
9 |
+
M = 341 # model context (1024 = 1 + EVENT_SIZE*M)
|
10 |
+
DELTA = 5 # anticipation time in seconds
|
11 |
+
|
12 |
+
assert CONTEXT_SIZE == 1+EVENT_SIZE*M
|
13 |
+
|
14 |
+
# vocabulary constants
|
15 |
+
|
16 |
+
MAX_TIME_IN_SECONDS = 100 # exclude very long training sequences
|
17 |
+
MAX_DURATION_IN_SECONDS = 10 # maximum duration of a note
|
18 |
+
TIME_RESOLUTION = 100 # 10ms time resolution = 100 bins/second
|
19 |
+
|
20 |
+
MAX_PITCH = 128 # 128 MIDI pitches
|
21 |
+
MAX_INSTR = 129 # 129 MIDI instruments (128 + drums)
|
22 |
+
MAX_NOTE = MAX_PITCH*MAX_INSTR # note = pitch x instrument
|
23 |
+
|
24 |
+
MAX_INTERARRIVAL_IN_SECONDS = 10 # maximum interarrival time (for MIDI-like encoding)
|
25 |
+
|
26 |
+
# preprocessing settings
|
27 |
+
|
28 |
+
PREPROC_WORKERS = 16
|
29 |
+
|
30 |
+
COMPOUND_SIZE = 5 # event size in the intermediate compound tokenization
|
31 |
+
MAX_TRACK_INSTR = 16 # exclude tracks with large numbers of instruments
|
32 |
+
MAX_TRACK_TIME_IN_SECONDS = 3600 # exclude very long tracks (longer than 1 hour)
|
33 |
+
MIN_TRACK_TIME_IN_SECONDS = 10 # exclude very short tracks (less than 10 seconds)
|
34 |
+
MIN_TRACK_EVENTS = 100 # exclude very short tracks (less than 100 events)
|
35 |
+
|
36 |
+
# LakhMIDI dataset splits
|
37 |
+
|
38 |
+
LAKH_SPLITS = ['0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f']
|
39 |
+
LAKH_VALID = ['e']
|
40 |
+
LAKH_TEST = ['f']
|
41 |
+
|
42 |
+
# derived quantities
|
43 |
+
|
44 |
+
MAX_TIME = TIME_RESOLUTION*MAX_TIME_IN_SECONDS
|
45 |
+
MAX_DUR = TIME_RESOLUTION*MAX_DURATION_IN_SECONDS
|
46 |
+
|
47 |
+
MAX_INTERARRIVAL = TIME_RESOLUTION*MAX_INTERARRIVAL_IN_SECONDS
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == '__main__':
|
51 |
+
print('Model constants:')
|
52 |
+
print(f' -> anticipation interval: {DELTA}s')
|
53 |
+
print('Vocabulary constants:')
|
54 |
+
print(f' -> maximum time of a sequence: {MAX_TIME_IN_SECONDS}s')
|
55 |
+
print(f' -> maximum duration of a note: {MAX_DURATION_IN_SECONDS}s')
|
56 |
+
print(f' -> time resolution: {TIME_RESOLUTION}bins/s ({1000//TIME_RESOLUTION}ms)')
|
57 |
+
print(f' -> maximum interarrival-time (MIDI-like encoding): {MAX_INTERARRIVAL_IN_SECONDS}s')
|
58 |
+
print('Preprocessing constants:')
|
59 |
+
print(f' -> maximum time of a track: {MAX_TRACK_TIME_IN_SECONDS}s')
|
60 |
+
print(f' -> minimum events in a track: {MIN_TRACK_EVENTS}s')
|
anticipation/config.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Global configuration for anticipatory infilling models.
|
3 |
+
"""
|
4 |
+
|
5 |
+
# model hyper-parameters
|
6 |
+
|
7 |
+
CONTEXT_SIZE = 1024 # model context
|
8 |
+
EVENT_SIZE = 3 # each event/control is encoded as 3 tokens
|
9 |
+
M = 341 # model context (1024 = 1 + EVENT_SIZE*M)
|
10 |
+
DELTA = 5 # anticipation time in seconds
|
11 |
+
|
12 |
+
assert CONTEXT_SIZE == 1+EVENT_SIZE*M
|
13 |
+
|
14 |
+
# vocabulary constants
|
15 |
+
|
16 |
+
MAX_TIME_IN_SECONDS = 100 # exclude very long training sequences
|
17 |
+
MAX_DURATION_IN_SECONDS = 10 # maximum duration of a note
|
18 |
+
TIME_RESOLUTION = 100 # 10ms time resolution = 100 bins/second
|
19 |
+
|
20 |
+
MAX_PITCH = 128 # 128 MIDI pitches
|
21 |
+
MAX_INSTR = 129 # 129 MIDI instruments (128 + drums)
|
22 |
+
MAX_NOTE = MAX_PITCH*MAX_INSTR # note = pitch x instrument
|
23 |
+
|
24 |
+
MAX_INTERARRIVAL_IN_SECONDS = 10 # maximum interarrival time (for MIDI-like encoding)
|
25 |
+
|
26 |
+
# preprocessing settings
|
27 |
+
|
28 |
+
PREPROC_WORKERS = 16
|
29 |
+
|
30 |
+
COMPOUND_SIZE = 5 # event size in the intermediate compound tokenization
|
31 |
+
MAX_TRACK_INSTR = 16 # exclude tracks with large numbers of instruments
|
32 |
+
MAX_TRACK_TIME_IN_SECONDS = 3600 # exclude very long tracks (longer than 1 hour)
|
33 |
+
MIN_TRACK_TIME_IN_SECONDS = 10 # exclude very short tracks (less than 10 seconds)
|
34 |
+
MIN_TRACK_EVENTS = 100 # exclude very short tracks (less than 100 events)
|
35 |
+
|
36 |
+
# LakhMIDI dataset splits
|
37 |
+
|
38 |
+
LAKH_SPLITS = ['0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f']
|
39 |
+
LAKH_VALID = ['e']
|
40 |
+
LAKH_TEST = ['f']
|
41 |
+
|
42 |
+
# derived quantities
|
43 |
+
|
44 |
+
MAX_TIME = TIME_RESOLUTION*MAX_TIME_IN_SECONDS
|
45 |
+
MAX_DUR = TIME_RESOLUTION*MAX_DURATION_IN_SECONDS
|
46 |
+
|
47 |
+
MAX_INTERARRIVAL = TIME_RESOLUTION*MAX_INTERARRIVAL_IN_SECONDS
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == '__main__':
|
51 |
+
print('Model constants:')
|
52 |
+
print(f' -> anticipation interval: {DELTA}s')
|
53 |
+
print('Vocabulary constants:')
|
54 |
+
print(f' -> maximum time of a sequence: {MAX_TIME_IN_SECONDS}s')
|
55 |
+
print(f' -> maximum duration of a note: {MAX_DURATION_IN_SECONDS}s')
|
56 |
+
print(f' -> time resolution: {TIME_RESOLUTION}bins/s ({1000//TIME_RESOLUTION}ms)')
|
57 |
+
print(f' -> maximum interarrival-time (MIDI-like encoding): {MAX_INTERARRIVAL_IN_SECONDS}s')
|
58 |
+
print('Preprocessing constants:')
|
59 |
+
print(f' -> maximum time of a track: {MAX_TRACK_TIME_IN_SECONDS}s')
|
60 |
+
print(f' -> minimum events in a track: {MIN_TRACK_EVENTS}s')
|
anticipation/convert-original.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for converting to and from Midi data and encoded/tokenized data.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
import mido
|
8 |
+
|
9 |
+
from anticipation.config import *
|
10 |
+
from anticipation.vocab import *
|
11 |
+
from anticipation.ops import unpad
|
12 |
+
|
13 |
+
|
14 |
+
def midi_to_interarrival(midifile, debug=False, stats=False):
|
15 |
+
midi = mido.MidiFile(midifile)
|
16 |
+
|
17 |
+
tokens = []
|
18 |
+
dt = 0
|
19 |
+
|
20 |
+
instruments = defaultdict(int) # default to code 0 = piano
|
21 |
+
tempo = 500000 # default tempo: 500000 microseconds per beat
|
22 |
+
truncations = 0
|
23 |
+
for message in midi:
|
24 |
+
dt += message.time
|
25 |
+
|
26 |
+
# sanity check: negative time?
|
27 |
+
if message.time < 0:
|
28 |
+
raise ValueError
|
29 |
+
|
30 |
+
if message.type == 'program_change':
|
31 |
+
instruments[message.channel] = message.program
|
32 |
+
elif message.type in ['note_on', 'note_off']:
|
33 |
+
delta_ticks = min(round(TIME_RESOLUTION*dt), MAX_INTERARRIVAL-1)
|
34 |
+
if delta_ticks != round(TIME_RESOLUTION*dt):
|
35 |
+
truncations += 1
|
36 |
+
|
37 |
+
if delta_ticks > 0: # if time elapsed since last token
|
38 |
+
tokens.append(MIDI_TIME_OFFSET + delta_ticks) # add a time step event
|
39 |
+
|
40 |
+
# special case: channel 9 is drums!
|
41 |
+
inst = 128 if message.channel == 9 else instruments[message.channel]
|
42 |
+
offset = MIDI_START_OFFSET if message.type == 'note_on' and message.velocity > 0 else MIDI_END_OFFSET
|
43 |
+
tokens.append(offset + (2**7)*inst + message.note)
|
44 |
+
dt = 0
|
45 |
+
elif message.type == 'set_tempo':
|
46 |
+
tempo = message.tempo
|
47 |
+
elif message.type == 'time_signature':
|
48 |
+
pass # we use real time
|
49 |
+
elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
|
50 |
+
pass # we don't attempt to model these
|
51 |
+
elif message.type == 'control_change':
|
52 |
+
pass # this includes pedal and per-track volume: ignore for now
|
53 |
+
elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
|
54 |
+
'copyright', 'marker', 'instrument_name', 'cue_marker',
|
55 |
+
'device_name', 'sequence_number']:
|
56 |
+
pass # possibly useful metadata but ignore for now
|
57 |
+
elif message.type == 'channel_prefix':
|
58 |
+
pass # relatively common, but can we ignore this?
|
59 |
+
elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
|
60 |
+
pass # I have no idea what this is
|
61 |
+
else:
|
62 |
+
if debug:
|
63 |
+
print('UNHANDLED MESSAGE', message.type, message)
|
64 |
+
|
65 |
+
if stats:
|
66 |
+
return tokens, truncations
|
67 |
+
|
68 |
+
return tokens
|
69 |
+
|
70 |
+
|
71 |
+
def interarrival_to_midi(tokens, debug=False):
|
72 |
+
mid = mido.MidiFile()
|
73 |
+
mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120
|
74 |
+
|
75 |
+
track_idx = {} # maps instrument to (track number, current time)
|
76 |
+
time_in_ticks = 0
|
77 |
+
num_tracks = 0
|
78 |
+
for token in tokens:
|
79 |
+
if token == MIDI_SEPARATOR:
|
80 |
+
continue
|
81 |
+
|
82 |
+
if token < MIDI_START_OFFSET:
|
83 |
+
time_in_ticks += token - MIDI_TIME_OFFSET
|
84 |
+
elif token < MIDI_END_OFFSET:
|
85 |
+
token -= MIDI_START_OFFSET
|
86 |
+
instrument = token // 2**7
|
87 |
+
pitch = token - (2**7)*instrument
|
88 |
+
|
89 |
+
try:
|
90 |
+
track, previous_time, idx = track_idx[instrument]
|
91 |
+
except KeyError:
|
92 |
+
idx = num_tracks
|
93 |
+
previous_time = 0
|
94 |
+
track = mido.MidiTrack()
|
95 |
+
mid.tracks.append(track)
|
96 |
+
if instrument == 128: # drums always go on channel 9
|
97 |
+
idx = 9
|
98 |
+
message = mido.Message('program_change', channel=idx, program=0)
|
99 |
+
else:
|
100 |
+
message = mido.Message('program_change', channel=idx, program=instrument)
|
101 |
+
track.append(message)
|
102 |
+
num_tracks += 1
|
103 |
+
if num_tracks == 9:
|
104 |
+
num_tracks += 1 # skip the drums track
|
105 |
+
|
106 |
+
track.append(mido.Message('note_on', note=pitch, channel=idx, velocity=96, time=time_in_ticks-previous_time))
|
107 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
108 |
+
else:
|
109 |
+
token -= MIDI_END_OFFSET
|
110 |
+
instrument = token // 2**7
|
111 |
+
pitch = token - (2**7)*instrument
|
112 |
+
|
113 |
+
try:
|
114 |
+
track, previous_time, idx = track_idx[instrument]
|
115 |
+
except KeyError:
|
116 |
+
# shouldn't happen because we should have a corresponding onset
|
117 |
+
if debug:
|
118 |
+
print('IGNORING bad offset')
|
119 |
+
|
120 |
+
continue
|
121 |
+
|
122 |
+
track.append(mido.Message('note_off', note=pitch, channel=idx, time=time_in_ticks-previous_time))
|
123 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
124 |
+
|
125 |
+
return mid
|
126 |
+
|
127 |
+
|
128 |
+
def midi_to_compound(midifile, debug=False):
|
129 |
+
if type(midifile) == str:
|
130 |
+
midi = mido.MidiFile(midifile)
|
131 |
+
else:
|
132 |
+
midi = midifile
|
133 |
+
|
134 |
+
tokens = []
|
135 |
+
note_idx = 0
|
136 |
+
open_notes = defaultdict(list)
|
137 |
+
|
138 |
+
time = 0
|
139 |
+
instruments = defaultdict(int) # default to code 0 = piano
|
140 |
+
tempo = 500000 # default tempo: 500000 microseconds per beat
|
141 |
+
for message in midi:
|
142 |
+
time += message.time
|
143 |
+
|
144 |
+
# sanity check: negative time?
|
145 |
+
if message.time < 0:
|
146 |
+
raise ValueError
|
147 |
+
|
148 |
+
if message.type == 'program_change':
|
149 |
+
instruments[message.channel] = message.program
|
150 |
+
elif message.type in ['note_on', 'note_off']:
|
151 |
+
# special case: channel 9 is drums!
|
152 |
+
instr = 128 if message.channel == 9 else instruments[message.channel]
|
153 |
+
|
154 |
+
if message.type == 'note_on' and message.velocity > 0: # onset
|
155 |
+
# time quantization
|
156 |
+
time_in_ticks = round(TIME_RESOLUTION*time)
|
157 |
+
|
158 |
+
# Our compound word is: (time, duration, note, instr, velocity)
|
159 |
+
tokens.append(time_in_ticks) # 5ms resolution
|
160 |
+
tokens.append(-1) # placeholder (we'll fill this in later)
|
161 |
+
tokens.append(message.note)
|
162 |
+
tokens.append(instr)
|
163 |
+
tokens.append(message.velocity)
|
164 |
+
|
165 |
+
open_notes[(instr,message.note,message.channel)].append((note_idx, time))
|
166 |
+
note_idx += 1
|
167 |
+
else: # offset
|
168 |
+
try:
|
169 |
+
open_idx, onset_time = open_notes[(instr,message.note,message.channel)].pop(0)
|
170 |
+
except IndexError:
|
171 |
+
if debug:
|
172 |
+
print('WARNING: ignoring bad offset')
|
173 |
+
else:
|
174 |
+
duration_ticks = round(TIME_RESOLUTION*(time-onset_time))
|
175 |
+
tokens[5*open_idx + 1] = duration_ticks
|
176 |
+
#del open_notes[(instr,message.note,message.channel)]
|
177 |
+
elif message.type == 'set_tempo':
|
178 |
+
tempo = message.tempo
|
179 |
+
elif message.type == 'time_signature':
|
180 |
+
pass # we use real time
|
181 |
+
elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
|
182 |
+
pass # we don't attempt to model these
|
183 |
+
elif message.type == 'control_change':
|
184 |
+
pass # this includes pedal and per-track volume: ignore for now
|
185 |
+
elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
|
186 |
+
'copyright', 'marker', 'instrument_name', 'cue_marker',
|
187 |
+
'device_name', 'sequence_number']:
|
188 |
+
pass # possibly useful metadata but ignore for now
|
189 |
+
elif message.type == 'channel_prefix':
|
190 |
+
pass # relatively common, but can we ignore this?
|
191 |
+
elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
|
192 |
+
pass # I have no idea what this is
|
193 |
+
else:
|
194 |
+
if debug:
|
195 |
+
print('UNHANDLED MESSAGE', message.type, message)
|
196 |
+
|
197 |
+
unclosed_count = 0
|
198 |
+
for _,v in open_notes.items():
|
199 |
+
unclosed_count += len(v)
|
200 |
+
|
201 |
+
if debug and unclosed_count > 0:
|
202 |
+
print(f'WARNING: {unclosed_count} unclosed notes')
|
203 |
+
print(' ', midifile)
|
204 |
+
|
205 |
+
return tokens
|
206 |
+
|
207 |
+
|
208 |
+
def compound_to_midi(tokens, debug=False):
|
209 |
+
mid = mido.MidiFile()
|
210 |
+
mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120
|
211 |
+
|
212 |
+
it = iter(tokens)
|
213 |
+
time_index = defaultdict(list)
|
214 |
+
for _, (time_in_ticks,duration,note,instrument,velocity) in enumerate(zip(it,it,it,it,it)):
|
215 |
+
time_index[(time_in_ticks,0)].append((note, instrument, velocity)) # 0 = onset
|
216 |
+
time_index[(time_in_ticks+duration,1)].append((note, instrument, velocity)) # 1 = offset
|
217 |
+
|
218 |
+
track_idx = {} # maps instrument to (track number, current time)
|
219 |
+
num_tracks = 0
|
220 |
+
for time_in_ticks, event_type in sorted(time_index.keys()):
|
221 |
+
for (note, instrument, velocity) in time_index[(time_in_ticks, event_type)]:
|
222 |
+
if event_type == 0: # onset
|
223 |
+
try:
|
224 |
+
track, previous_time, idx = track_idx[instrument]
|
225 |
+
except KeyError:
|
226 |
+
idx = num_tracks
|
227 |
+
previous_time = 0
|
228 |
+
track = mido.MidiTrack()
|
229 |
+
mid.tracks.append(track)
|
230 |
+
if instrument == 128: # drums always go on channel 9
|
231 |
+
idx = 9
|
232 |
+
message = mido.Message('program_change', channel=idx, program=0)
|
233 |
+
else:
|
234 |
+
message = mido.Message('program_change', channel=idx, program=instrument)
|
235 |
+
track.append(message)
|
236 |
+
num_tracks += 1
|
237 |
+
if num_tracks == 9:
|
238 |
+
num_tracks += 1 # skip the drums track
|
239 |
+
|
240 |
+
track.append(mido.Message(
|
241 |
+
'note_on', note=note, channel=idx, velocity=velocity,
|
242 |
+
time=time_in_ticks-previous_time))
|
243 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
244 |
+
else: # offset
|
245 |
+
try:
|
246 |
+
track, previous_time, idx = track_idx[instrument]
|
247 |
+
except KeyError:
|
248 |
+
# shouldn't happen because we should have a corresponding onset
|
249 |
+
if debug:
|
250 |
+
print('IGNORING bad offset')
|
251 |
+
|
252 |
+
continue
|
253 |
+
|
254 |
+
track.append(mido.Message(
|
255 |
+
'note_off', note=note, channel=idx,
|
256 |
+
time=time_in_ticks-previous_time))
|
257 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
258 |
+
|
259 |
+
return mid
|
260 |
+
|
261 |
+
|
262 |
+
def compound_to_events(tokens, stats=False):
|
263 |
+
assert len(tokens) % 5 == 0
|
264 |
+
tokens = tokens.copy()
|
265 |
+
|
266 |
+
# remove velocities
|
267 |
+
del tokens[4::5]
|
268 |
+
|
269 |
+
# combine (note, instrument)
|
270 |
+
assert all(-1 <= tok < 2**7 for tok in tokens[2::4])
|
271 |
+
assert all(-1 <= tok < 129 for tok in tokens[3::4])
|
272 |
+
tokens[2::4] = [SEPARATOR if note == -1 else MAX_PITCH*instr + note
|
273 |
+
for note, instr in zip(tokens[2::4],tokens[3::4])]
|
274 |
+
tokens[2::4] = [NOTE_OFFSET + tok for tok in tokens[2::4]]
|
275 |
+
del tokens[3::4]
|
276 |
+
|
277 |
+
# max duration cutoff and set unknown durations to 250ms
|
278 |
+
truncations = sum([1 for tok in tokens[1::3] if tok >= MAX_DUR])
|
279 |
+
tokens[1::3] = [TIME_RESOLUTION//4 if tok == -1 else min(tok, MAX_DUR-1)
|
280 |
+
for tok in tokens[1::3]]
|
281 |
+
tokens[1::3] = [DUR_OFFSET + tok for tok in tokens[1::3]]
|
282 |
+
|
283 |
+
assert min(tokens[0::3]) >= 0
|
284 |
+
tokens[0::3] = [TIME_OFFSET + tok for tok in tokens[0::3]]
|
285 |
+
|
286 |
+
assert len(tokens) % 3 == 0
|
287 |
+
|
288 |
+
if stats:
|
289 |
+
return tokens, truncations
|
290 |
+
|
291 |
+
return tokens
|
292 |
+
|
293 |
+
|
294 |
+
def events_to_compound(tokens, debug=False):
|
295 |
+
tokens = unpad(tokens)
|
296 |
+
|
297 |
+
# move all tokens to zero-offset for synthesis
|
298 |
+
tokens = [tok - CONTROL_OFFSET if tok >= CONTROL_OFFSET and tok != SEPARATOR else tok
|
299 |
+
for tok in tokens]
|
300 |
+
|
301 |
+
# remove type offsets
|
302 |
+
tokens[0::3] = [tok - TIME_OFFSET if tok != SEPARATOR else tok for tok in tokens[0::3]]
|
303 |
+
tokens[1::3] = [tok - DUR_OFFSET if tok != SEPARATOR else tok for tok in tokens[1::3]]
|
304 |
+
tokens[2::3] = [tok - NOTE_OFFSET if tok != SEPARATOR else tok for tok in tokens[2::3]]
|
305 |
+
|
306 |
+
offset = 0 # add max time from previous track for synthesis
|
307 |
+
track_max = 0 # keep track of max time in track
|
308 |
+
for j, (time,dur,note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])):
|
309 |
+
if note == SEPARATOR:
|
310 |
+
offset += track_max
|
311 |
+
track_max = 0
|
312 |
+
if debug:
|
313 |
+
print('Sequence Boundary')
|
314 |
+
else:
|
315 |
+
track_max = max(track_max, time+dur)
|
316 |
+
tokens[3*j] += offset
|
317 |
+
|
318 |
+
# strip sequence separators
|
319 |
+
assert len([tok for tok in tokens if tok == SEPARATOR]) % 3 == 0
|
320 |
+
tokens = [tok for tok in tokens if tok != SEPARATOR]
|
321 |
+
|
322 |
+
assert len(tokens) % 3 == 0
|
323 |
+
out = 5*(len(tokens)//3)*[0]
|
324 |
+
out[0::5] = tokens[0::3]
|
325 |
+
out[1::5] = tokens[1::3]
|
326 |
+
out[2::5] = [tok - (2**7)*(tok//2**7) for tok in tokens[2::3]]
|
327 |
+
out[3::5] = [tok//2**7 for tok in tokens[2::3]]
|
328 |
+
out[4::5] = (len(tokens)//3)*[72] # default velocity
|
329 |
+
|
330 |
+
assert max(out[1::5]) < MAX_DUR
|
331 |
+
assert max(out[2::5]) < MAX_PITCH
|
332 |
+
assert max(out[3::5]) < MAX_INSTR
|
333 |
+
assert all(tok >= 0 for tok in out)
|
334 |
+
|
335 |
+
return out
|
336 |
+
|
337 |
+
|
338 |
+
def events_to_midi(tokens, debug=False):
|
339 |
+
return compound_to_midi(events_to_compound(tokens, debug=debug), debug=debug)
|
340 |
+
|
341 |
+
def midi_to_events(midifile, debug=False):
|
342 |
+
return compound_to_events(midi_to_compound(midifile, debug=debug))
|
anticipation/convert.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for converting to and from Midi data and encoded/tokenized data.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
import mido
|
8 |
+
|
9 |
+
from anticipation.config import *
|
10 |
+
from anticipation.vocab import *
|
11 |
+
from anticipation.ops import unpad
|
12 |
+
|
13 |
+
|
14 |
+
def midi_to_interarrival(midifile, debug=False, stats=False):
|
15 |
+
midi = mido.MidiFile(midifile)
|
16 |
+
|
17 |
+
tokens = []
|
18 |
+
dt = 0
|
19 |
+
|
20 |
+
instruments = defaultdict(int) # default to code 0 = piano
|
21 |
+
tempo = 500000 # default tempo: 500000 microseconds per beat
|
22 |
+
truncations = 0
|
23 |
+
for message in midi:
|
24 |
+
dt += message.time
|
25 |
+
|
26 |
+
# sanity check: negative time?
|
27 |
+
if message.time < 0:
|
28 |
+
raise ValueError
|
29 |
+
|
30 |
+
if message.type == 'program_change':
|
31 |
+
instruments[message.channel] = message.program
|
32 |
+
elif message.type in ['note_on', 'note_off']:
|
33 |
+
delta_ticks = min(round(TIME_RESOLUTION*dt), MAX_INTERARRIVAL-1)
|
34 |
+
if delta_ticks != round(TIME_RESOLUTION*dt):
|
35 |
+
truncations += 1
|
36 |
+
|
37 |
+
if delta_ticks > 0: # if time elapsed since last token
|
38 |
+
tokens.append(MIDI_TIME_OFFSET + delta_ticks) # add a time step event
|
39 |
+
|
40 |
+
# special case: channel 9 is drums!
|
41 |
+
inst = 128 if message.channel == 9 else instruments[message.channel]
|
42 |
+
offset = MIDI_START_OFFSET if message.type == 'note_on' and message.velocity > 0 else MIDI_END_OFFSET
|
43 |
+
tokens.append(offset + (2**7)*inst + message.note)
|
44 |
+
dt = 0
|
45 |
+
elif message.type == 'set_tempo':
|
46 |
+
tempo = message.tempo
|
47 |
+
elif message.type == 'time_signature':
|
48 |
+
pass # we use real time
|
49 |
+
elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
|
50 |
+
pass # we don't attempt to model these
|
51 |
+
elif message.type == 'control_change':
|
52 |
+
pass # this includes pedal and per-track volume: ignore for now
|
53 |
+
elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
|
54 |
+
'copyright', 'marker', 'instrument_name', 'cue_marker',
|
55 |
+
'device_name', 'sequence_number']:
|
56 |
+
pass # possibly useful metadata but ignore for now
|
57 |
+
elif message.type == 'channel_prefix':
|
58 |
+
pass # relatively common, but can we ignore this?
|
59 |
+
elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
|
60 |
+
pass # I have no idea what this is
|
61 |
+
else:
|
62 |
+
if debug:
|
63 |
+
print('UNHANDLED MESSAGE', message.type, message)
|
64 |
+
|
65 |
+
if stats:
|
66 |
+
return tokens, truncations
|
67 |
+
|
68 |
+
return tokens
|
69 |
+
|
70 |
+
|
71 |
+
def interarrival_to_midi(tokens, debug=False):
|
72 |
+
mid = mido.MidiFile()
|
73 |
+
mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120
|
74 |
+
|
75 |
+
track_idx = {} # maps instrument to (track number, current time)
|
76 |
+
time_in_ticks = 0
|
77 |
+
num_tracks = 0
|
78 |
+
for token in tokens:
|
79 |
+
if token == MIDI_SEPARATOR:
|
80 |
+
continue
|
81 |
+
|
82 |
+
if token < MIDI_START_OFFSET:
|
83 |
+
time_in_ticks += token - MIDI_TIME_OFFSET
|
84 |
+
elif token < MIDI_END_OFFSET:
|
85 |
+
token -= MIDI_START_OFFSET
|
86 |
+
instrument = token // 2**7
|
87 |
+
pitch = token - (2**7)*instrument
|
88 |
+
|
89 |
+
try:
|
90 |
+
track, previous_time, idx = track_idx[instrument]
|
91 |
+
except KeyError:
|
92 |
+
idx = num_tracks
|
93 |
+
previous_time = 0
|
94 |
+
track = mido.MidiTrack()
|
95 |
+
mid.tracks.append(track)
|
96 |
+
if instrument == 128: # drums always go on channel 9
|
97 |
+
idx = 9
|
98 |
+
message = mido.Message('program_change', channel=idx, program=0)
|
99 |
+
else:
|
100 |
+
message = mido.Message('program_change', channel=idx, program=instrument)
|
101 |
+
track.append(message)
|
102 |
+
num_tracks += 1
|
103 |
+
if num_tracks == 9:
|
104 |
+
num_tracks += 1 # skip the drums track
|
105 |
+
|
106 |
+
track.append(mido.Message('note_on', note=pitch, channel=idx, velocity=96, time=time_in_ticks-previous_time))
|
107 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
108 |
+
else:
|
109 |
+
token -= MIDI_END_OFFSET
|
110 |
+
instrument = token // 2**7
|
111 |
+
pitch = token - (2**7)*instrument
|
112 |
+
|
113 |
+
try:
|
114 |
+
track, previous_time, idx = track_idx[instrument]
|
115 |
+
except KeyError:
|
116 |
+
# shouldn't happen because we should have a corresponding onset
|
117 |
+
if debug:
|
118 |
+
print('IGNORING bad offset')
|
119 |
+
|
120 |
+
continue
|
121 |
+
|
122 |
+
track.append(mido.Message('note_off', note=pitch, channel=idx, time=time_in_ticks-previous_time))
|
123 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
124 |
+
|
125 |
+
return mid
|
126 |
+
|
127 |
+
|
128 |
+
def midi_to_compound(midifile, debug=False):
|
129 |
+
if type(midifile) == str:
|
130 |
+
midi = mido.MidiFile(midifile)
|
131 |
+
else:
|
132 |
+
midi = midifile
|
133 |
+
|
134 |
+
tokens = []
|
135 |
+
note_idx = 0
|
136 |
+
open_notes = defaultdict(list)
|
137 |
+
|
138 |
+
time = 0
|
139 |
+
instruments = defaultdict(lambda: {'program': 0, 'channel': None}) # Track channel assignments
|
140 |
+
next_channel = 0
|
141 |
+
|
142 |
+
tempo = 500000 # default tempo: 500000 microseconds per beat
|
143 |
+
for message in midi:
|
144 |
+
time += message.time
|
145 |
+
|
146 |
+
# sanity check: negative time?
|
147 |
+
if message.time < 0:
|
148 |
+
raise ValueError
|
149 |
+
|
150 |
+
if message.type == 'program_change':
|
151 |
+
# Reserve channels 0-8, 10-15 (skip 9 for drums)
|
152 |
+
if message.channel != 9 and message.channel not in instruments:
|
153 |
+
instruments[message.channel]['program'] = message.program
|
154 |
+
instruments[message.channel]['channel'] = next_channel
|
155 |
+
next_channel += 1
|
156 |
+
if next_channel == 9: # Skip channel 9 (drums)
|
157 |
+
next_channel = 10
|
158 |
+
elif message.type in ['note_on', 'note_off']:
|
159 |
+
# special case: channel 9 is drums!
|
160 |
+
instr = 128 if message.channel == 9 else instruments[message.channel]['program']
|
161 |
+
channel = 9 if message.channel == 9 else instruments[message.channel]['channel']
|
162 |
+
compound_instr = (instr << 4) | channel
|
163 |
+
if message.type == 'note_on' and message.velocity > 0: # onset
|
164 |
+
# time quantization
|
165 |
+
time_in_ticks = round(TIME_RESOLUTION*time)
|
166 |
+
|
167 |
+
# Our compound word is: (time, duration, note, instr, velocity)
|
168 |
+
tokens.append(time_in_ticks) # 5ms resolution
|
169 |
+
tokens.append(-1) # placeholder (we'll fill this in later)
|
170 |
+
tokens.append(message.note)
|
171 |
+
tokens.append(compound_instr)
|
172 |
+
tokens.append(message.velocity)
|
173 |
+
|
174 |
+
open_notes[(instr,message.note,message.channel)].append((note_idx, time))
|
175 |
+
note_idx += 1
|
176 |
+
else: # offset
|
177 |
+
try:
|
178 |
+
open_idx, onset_time = open_notes[(instr,message.note,message.channel)].pop(0)
|
179 |
+
except IndexError:
|
180 |
+
if debug:
|
181 |
+
print('WARNING: ignoring bad offset')
|
182 |
+
else:
|
183 |
+
duration_ticks = round(TIME_RESOLUTION*(time-onset_time))
|
184 |
+
tokens[5*open_idx + 1] = duration_ticks
|
185 |
+
#del open_notes[(instr,message.note,message.channel)]
|
186 |
+
elif message.type == 'set_tempo':
|
187 |
+
tempo = message.tempo
|
188 |
+
elif message.type == 'time_signature':
|
189 |
+
pass # we use real time
|
190 |
+
elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
|
191 |
+
pass # we don't attempt to model these
|
192 |
+
elif message.type == 'control_change':
|
193 |
+
pass # this includes pedal and per-track volume: ignore for now
|
194 |
+
elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
|
195 |
+
'copyright', 'marker', 'instrument_name', 'cue_marker',
|
196 |
+
'device_name', 'sequence_number']:
|
197 |
+
pass # possibly useful metadata but ignore for now
|
198 |
+
elif message.type == 'channel_prefix':
|
199 |
+
pass # relatively common, but can we ignore this?
|
200 |
+
elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
|
201 |
+
pass # I have no idea what this is
|
202 |
+
else:
|
203 |
+
if debug:
|
204 |
+
print('UNHANDLED MESSAGE', message.type, message)
|
205 |
+
|
206 |
+
unclosed_count = 0
|
207 |
+
for _,v in open_notes.items():
|
208 |
+
unclosed_count += len(v)
|
209 |
+
|
210 |
+
if debug and unclosed_count > 0:
|
211 |
+
print(f'WARNING: {unclosed_count} unclosed notes')
|
212 |
+
print(' ', midifile)
|
213 |
+
|
214 |
+
return tokens
|
215 |
+
|
216 |
+
|
217 |
+
def compound_to_midi(tokens, debug=False):
|
218 |
+
mid = mido.MidiFile()
|
219 |
+
mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120
|
220 |
+
|
221 |
+
tracks = {}
|
222 |
+
for token in tokens:
|
223 |
+
# Decode program and channel
|
224 |
+
program = (token >> 4) & 0x7F
|
225 |
+
channel = token & 0x0F
|
226 |
+
|
227 |
+
if (program, channel) not in tracks:
|
228 |
+
track = mido.MidiTrack()
|
229 |
+
mid.tracks.append(track)
|
230 |
+
tracks[(program, channel)] = track
|
231 |
+
track.append(mido.Message('program_change',
|
232 |
+
program=program,
|
233 |
+
channel=channel))
|
234 |
+
|
235 |
+
it = iter(tokens)
|
236 |
+
time_index = defaultdict(list)
|
237 |
+
for _, (time_in_ticks,duration,note,instrument,velocity) in enumerate(zip(it,it,it,it,it)):
|
238 |
+
time_index[(time_in_ticks,0)].append((note, instrument, velocity)) # 0 = onset
|
239 |
+
time_index[(time_in_ticks+duration,1)].append((note, instrument, velocity)) # 1 = offset
|
240 |
+
|
241 |
+
track_idx = {} # maps instrument to (track number, current time)
|
242 |
+
num_tracks = 0
|
243 |
+
for time_in_ticks, event_type in sorted(time_index.keys()):
|
244 |
+
for (note, instrument, velocity) in time_index[(time_in_ticks, event_type)]:
|
245 |
+
if event_type == 0: # onset
|
246 |
+
try:
|
247 |
+
track, previous_time, idx = track_idx[instrument]
|
248 |
+
except KeyError:
|
249 |
+
idx = num_tracks
|
250 |
+
previous_time = 0
|
251 |
+
track = mido.MidiTrack()
|
252 |
+
mid.tracks.append(track)
|
253 |
+
if instrument == 128: # drums always go on channel 9
|
254 |
+
idx = 9
|
255 |
+
message = mido.Message('program_change', channel=idx, program=0)
|
256 |
+
else:
|
257 |
+
message = mido.Message('program_change', channel=idx, program=instrument)
|
258 |
+
track.append(message)
|
259 |
+
num_tracks += 1
|
260 |
+
if num_tracks == 9:
|
261 |
+
num_tracks += 1 # skip the drums track
|
262 |
+
|
263 |
+
track.append(mido.Message(
|
264 |
+
'note_on', note=note, channel=idx, velocity=velocity,
|
265 |
+
time=time_in_ticks-previous_time))
|
266 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
267 |
+
else: # offset
|
268 |
+
try:
|
269 |
+
track, previous_time, idx = track_idx[instrument]
|
270 |
+
except KeyError:
|
271 |
+
# shouldn't happen because we should have a corresponding onset
|
272 |
+
if debug:
|
273 |
+
print('IGNORING bad offset')
|
274 |
+
|
275 |
+
continue
|
276 |
+
|
277 |
+
track.append(mido.Message(
|
278 |
+
'note_off', note=note, channel=idx,
|
279 |
+
time=time_in_ticks-previous_time))
|
280 |
+
track_idx[instrument] = (track, time_in_ticks, idx)
|
281 |
+
|
282 |
+
return mid
|
283 |
+
|
284 |
+
|
285 |
+
def compound_to_events(tokens, stats=False):
|
286 |
+
assert len(tokens) % 5 == 0
|
287 |
+
tokens = tokens.copy()
|
288 |
+
|
289 |
+
# remove velocities
|
290 |
+
del tokens[4::5]
|
291 |
+
|
292 |
+
# combine (note, instrument)
|
293 |
+
assert all(-1 <= tok < 2**7 for tok in tokens[2::4])
|
294 |
+
assert all(-1 <= tok < 129 for tok in tokens[3::4])
|
295 |
+
tokens[2::4] = [SEPARATOR if note == -1 else MAX_PITCH*instr + note
|
296 |
+
for note, instr in zip(tokens[2::4],tokens[3::4])]
|
297 |
+
tokens[2::4] = [NOTE_OFFSET + tok for tok in tokens[2::4]]
|
298 |
+
del tokens[3::4]
|
299 |
+
|
300 |
+
# max duration cutoff and set unknown durations to 250ms
|
301 |
+
truncations = sum([1 for tok in tokens[1::3] if tok >= MAX_DUR])
|
302 |
+
tokens[1::3] = [TIME_RESOLUTION//4 if tok == -1 else min(tok, MAX_DUR-1)
|
303 |
+
for tok in tokens[1::3]]
|
304 |
+
tokens[1::3] = [DUR_OFFSET + tok for tok in tokens[1::3]]
|
305 |
+
|
306 |
+
assert min(tokens[0::3]) >= 0
|
307 |
+
tokens[0::3] = [TIME_OFFSET + tok for tok in tokens[0::3]]
|
308 |
+
|
309 |
+
assert len(tokens) % 3 == 0
|
310 |
+
|
311 |
+
if stats:
|
312 |
+
return tokens, truncations
|
313 |
+
|
314 |
+
return tokens
|
315 |
+
|
316 |
+
|
317 |
+
def events_to_compound(tokens, debug=False):
|
318 |
+
tokens = unpad(tokens)
|
319 |
+
|
320 |
+
# move all tokens to zero-offset for synthesis
|
321 |
+
tokens = [tok - CONTROL_OFFSET if tok >= CONTROL_OFFSET and tok != SEPARATOR else tok
|
322 |
+
for tok in tokens]
|
323 |
+
|
324 |
+
# remove type offsets
|
325 |
+
tokens[0::3] = [tok - TIME_OFFSET if tok != SEPARATOR else tok for tok in tokens[0::3]]
|
326 |
+
tokens[1::3] = [tok - DUR_OFFSET if tok != SEPARATOR else tok for tok in tokens[1::3]]
|
327 |
+
tokens[2::3] = [tok - NOTE_OFFSET if tok != SEPARATOR else tok for tok in tokens[2::3]]
|
328 |
+
|
329 |
+
offset = 0 # add max time from previous track for synthesis
|
330 |
+
track_max = 0 # keep track of max time in track
|
331 |
+
for j, (time,dur,note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])):
|
332 |
+
if note == SEPARATOR:
|
333 |
+
offset += track_max
|
334 |
+
track_max = 0
|
335 |
+
if debug:
|
336 |
+
print('Sequence Boundary')
|
337 |
+
else:
|
338 |
+
track_max = max(track_max, time+dur)
|
339 |
+
tokens[3*j] += offset
|
340 |
+
|
341 |
+
# strip sequence separators
|
342 |
+
assert len([tok for tok in tokens if tok == SEPARATOR]) % 3 == 0
|
343 |
+
tokens = [tok for tok in tokens if tok != SEPARATOR]
|
344 |
+
|
345 |
+
assert len(tokens) % 3 == 0
|
346 |
+
out = 5*(len(tokens)//3)*[0]
|
347 |
+
out[0::5] = tokens[0::3]
|
348 |
+
out[1::5] = tokens[1::3]
|
349 |
+
out[2::5] = [tok - (2**7)*(tok//2**7) for tok in tokens[2::3]]
|
350 |
+
out[3::5] = [tok//2**7 for tok in tokens[2::3]]
|
351 |
+
out[4::5] = (len(tokens)//3)*[72] # default velocity
|
352 |
+
|
353 |
+
assert max(out[1::5]) < MAX_DUR
|
354 |
+
assert max(out[2::5]) < MAX_PITCH
|
355 |
+
assert max(out[3::5]) < MAX_INSTR
|
356 |
+
assert all(tok >= 0 for tok in out)
|
357 |
+
|
358 |
+
return out
|
359 |
+
|
360 |
+
|
361 |
+
def events_to_midi(tokens, debug=False):
|
362 |
+
return compound_to_midi(events_to_compound(tokens, debug=debug), debug=debug)
|
363 |
+
|
364 |
+
def midi_to_events(midifile, debug=False):
|
365 |
+
return compound_to_events(midi_to_compound(midifile, debug=debug))
|
anticipation/ops.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for operating on encoded Midi sequences.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
from anticipation.config import *
|
8 |
+
from anticipation.vocab import *
|
9 |
+
|
10 |
+
|
11 |
+
def print_tokens(tokens):
|
12 |
+
print('---------------------')
|
13 |
+
for j, (tm, dur, note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])):
|
14 |
+
if note == SEPARATOR:
|
15 |
+
assert tm == SEPARATOR and dur == SEPARATOR
|
16 |
+
print(j, 'SEPARATOR')
|
17 |
+
continue
|
18 |
+
|
19 |
+
if note == REST:
|
20 |
+
assert tm < CONTROL_OFFSET
|
21 |
+
assert dur == DUR_OFFSET+0
|
22 |
+
print(j, tm, 'REST')
|
23 |
+
continue
|
24 |
+
|
25 |
+
if note < CONTROL_OFFSET:
|
26 |
+
tm = tm - TIME_OFFSET
|
27 |
+
dur = dur - DUR_OFFSET
|
28 |
+
note = note - NOTE_OFFSET
|
29 |
+
instr = note//2**7
|
30 |
+
pitch = note - (2**7)*instr
|
31 |
+
print(j, tm, dur, instr, pitch)
|
32 |
+
else:
|
33 |
+
tm = tm - ATIME_OFFSET
|
34 |
+
dur = dur - ADUR_OFFSET
|
35 |
+
note = note - ANOTE_OFFSET
|
36 |
+
instr = note//2**7
|
37 |
+
pitch = note - (2**7)*instr
|
38 |
+
print(j, tm, dur, instr, pitch, '(A)')
|
39 |
+
|
40 |
+
|
41 |
+
def clip(tokens, start, end, clip_duration=True, seconds=True):
|
42 |
+
if seconds:
|
43 |
+
start = int(TIME_RESOLUTION*start)
|
44 |
+
end = int(TIME_RESOLUTION*end)
|
45 |
+
|
46 |
+
new_tokens = []
|
47 |
+
for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
48 |
+
if note < CONTROL_OFFSET:
|
49 |
+
this_time = time - TIME_OFFSET
|
50 |
+
this_dur = dur - DUR_OFFSET
|
51 |
+
else:
|
52 |
+
this_time = time - ATIME_OFFSET
|
53 |
+
this_dur = dur - ADUR_OFFSET
|
54 |
+
|
55 |
+
if this_time < start or end < this_time:
|
56 |
+
continue
|
57 |
+
|
58 |
+
# truncate extended notes
|
59 |
+
if clip_duration and end < this_time + this_dur:
|
60 |
+
dur -= this_time + this_dur - end
|
61 |
+
|
62 |
+
new_tokens.extend([time, dur, note])
|
63 |
+
|
64 |
+
return new_tokens
|
65 |
+
|
66 |
+
|
67 |
+
def mask(tokens, start, end):
|
68 |
+
new_tokens = []
|
69 |
+
for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
70 |
+
if note < CONTROL_OFFSET:
|
71 |
+
this_time = (time - TIME_OFFSET)/float(TIME_RESOLUTION)
|
72 |
+
else:
|
73 |
+
this_time = (time - ATIME_OFFSET)/float(TIME_RESOLUTION)
|
74 |
+
|
75 |
+
if start < this_time < end:
|
76 |
+
continue
|
77 |
+
|
78 |
+
new_tokens.extend([time, dur, note])
|
79 |
+
|
80 |
+
return new_tokens
|
81 |
+
|
82 |
+
|
83 |
+
def delete(tokens, criterion):
|
84 |
+
new_tokens = []
|
85 |
+
for token in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
86 |
+
if criterion(token):
|
87 |
+
continue
|
88 |
+
|
89 |
+
new_tokens.extend(token)
|
90 |
+
|
91 |
+
return new_tokens
|
92 |
+
|
93 |
+
|
94 |
+
def sort(tokens):
|
95 |
+
""" sort sequence of events or controls (but not both) """
|
96 |
+
|
97 |
+
times = tokens[0::3]
|
98 |
+
indices = sorted(range(len(times)), key=times.__getitem__)
|
99 |
+
|
100 |
+
sorted_tokens = []
|
101 |
+
for idx in indices:
|
102 |
+
sorted_tokens.extend(tokens[3*idx:3*(idx+1)])
|
103 |
+
|
104 |
+
return sorted_tokens
|
105 |
+
|
106 |
+
|
107 |
+
def split(tokens):
|
108 |
+
""" split a sequence into events and controls """
|
109 |
+
|
110 |
+
events = []
|
111 |
+
controls = []
|
112 |
+
for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
113 |
+
if note < CONTROL_OFFSET:
|
114 |
+
events.extend([time, dur, note])
|
115 |
+
else:
|
116 |
+
controls.extend([time, dur, note])
|
117 |
+
|
118 |
+
return events, controls
|
119 |
+
|
120 |
+
|
121 |
+
def pad(tokens, end_time=None, density=TIME_RESOLUTION):
|
122 |
+
end_time = TIME_OFFSET+(end_time if end_time else max_time(tokens, seconds=False))
|
123 |
+
new_tokens = []
|
124 |
+
previous_time = TIME_OFFSET+0
|
125 |
+
for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
126 |
+
# must pad before separation, anticipation
|
127 |
+
assert note < CONTROL_OFFSET
|
128 |
+
|
129 |
+
# insert pad tokens to ensure the desired density
|
130 |
+
while time > previous_time + density:
|
131 |
+
new_tokens.extend([previous_time+density, DUR_OFFSET+0, REST])
|
132 |
+
previous_time += density
|
133 |
+
|
134 |
+
new_tokens.extend([time, dur, note])
|
135 |
+
previous_time = time
|
136 |
+
|
137 |
+
while end_time > previous_time + density:
|
138 |
+
new_tokens.extend([previous_time+density, DUR_OFFSET+0, REST])
|
139 |
+
previous_time += density
|
140 |
+
|
141 |
+
return new_tokens
|
142 |
+
|
143 |
+
|
144 |
+
def unpad(tokens):
|
145 |
+
new_tokens = []
|
146 |
+
for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
147 |
+
if note == REST: continue
|
148 |
+
|
149 |
+
new_tokens.extend([time, dur, note])
|
150 |
+
|
151 |
+
return new_tokens
|
152 |
+
|
153 |
+
|
154 |
+
def anticipate(events, controls, delta=DELTA*TIME_RESOLUTION):
|
155 |
+
"""
|
156 |
+
Interleave a sequence of events with anticipated controls.
|
157 |
+
|
158 |
+
Inputs:
|
159 |
+
events : a sequence of events
|
160 |
+
controls : a sequence of time-localized controls
|
161 |
+
delta : the anticipation interval
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
tokens : interleaved events and anticipated controls
|
165 |
+
controls : unconsumed controls (control time > max_time(events) + delta)
|
166 |
+
"""
|
167 |
+
|
168 |
+
if len(controls) == 0:
|
169 |
+
return events, controls
|
170 |
+
|
171 |
+
tokens = []
|
172 |
+
event_time = 0
|
173 |
+
control_time = controls[0] - ATIME_OFFSET
|
174 |
+
for time, dur, note in zip(events[0::3],events[1::3],events[2::3]):
|
175 |
+
while event_time >= control_time - delta:
|
176 |
+
tokens.extend(controls[0:3])
|
177 |
+
controls = controls[3:] # consume this control
|
178 |
+
control_time = controls[0] - ATIME_OFFSET if len(controls) > 0 else float('inf')
|
179 |
+
|
180 |
+
assert note < CONTROL_OFFSET
|
181 |
+
event_time = time - TIME_OFFSET
|
182 |
+
tokens.extend([time, dur, note])
|
183 |
+
|
184 |
+
return tokens, controls
|
185 |
+
|
186 |
+
|
187 |
+
def sparsity(tokens):
|
188 |
+
max_dt = 0
|
189 |
+
previous_time = TIME_OFFSET+0
|
190 |
+
for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
191 |
+
if note == SEPARATOR: continue
|
192 |
+
assert note < CONTROL_OFFSET # don't operate on interleaved sequences
|
193 |
+
|
194 |
+
max_dt = max(max_dt, time - previous_time)
|
195 |
+
previous_time = time
|
196 |
+
|
197 |
+
return max_dt
|
198 |
+
|
199 |
+
|
200 |
+
def min_time(tokens, seconds=True, instr=None):
|
201 |
+
mt = None
|
202 |
+
for time, dur, note in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
203 |
+
# stop calculating at sequence separator
|
204 |
+
if note == SEPARATOR: break
|
205 |
+
|
206 |
+
if note < CONTROL_OFFSET:
|
207 |
+
time -= TIME_OFFSET
|
208 |
+
note -= NOTE_OFFSET
|
209 |
+
else:
|
210 |
+
time -= ATIME_OFFSET
|
211 |
+
note -= ANOTE_OFFSET
|
212 |
+
|
213 |
+
# min time of a particular instrument
|
214 |
+
if instr is not None and instr != note//2**7:
|
215 |
+
continue
|
216 |
+
|
217 |
+
mt = time if mt is None else min(mt, time)
|
218 |
+
|
219 |
+
if mt is None: mt = 0
|
220 |
+
return mt/float(TIME_RESOLUTION) if seconds else mt
|
221 |
+
|
222 |
+
|
223 |
+
def max_time(tokens, seconds=True, instr=None):
|
224 |
+
mt = 0
|
225 |
+
for time, dur, note in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
226 |
+
# keep checking for max_time, even if it appears after a separator
|
227 |
+
# (this is important because we use this check for vocab overflow in tokenization)
|
228 |
+
if note == SEPARATOR: continue
|
229 |
+
|
230 |
+
if note < CONTROL_OFFSET:
|
231 |
+
time -= TIME_OFFSET
|
232 |
+
note -= NOTE_OFFSET
|
233 |
+
else:
|
234 |
+
time -= ATIME_OFFSET
|
235 |
+
note -= ANOTE_OFFSET
|
236 |
+
|
237 |
+
# max time of a particular instrument
|
238 |
+
if instr is not None and instr != note//2**7:
|
239 |
+
continue
|
240 |
+
|
241 |
+
mt = max(mt, time)
|
242 |
+
|
243 |
+
return mt/float(TIME_RESOLUTION) if seconds else mt
|
244 |
+
|
245 |
+
|
246 |
+
def get_instruments(tokens):
|
247 |
+
instruments = defaultdict(int)
|
248 |
+
for time, dur, note in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
249 |
+
if note >= SPECIAL_OFFSET: continue
|
250 |
+
|
251 |
+
if note < CONTROL_OFFSET:
|
252 |
+
note -= NOTE_OFFSET
|
253 |
+
else:
|
254 |
+
note -= ANOTE_OFFSET
|
255 |
+
|
256 |
+
instr = note//2**7
|
257 |
+
instruments[instr] += 1
|
258 |
+
|
259 |
+
return instruments
|
260 |
+
|
261 |
+
|
262 |
+
def translate(tokens, dt, seconds=False):
|
263 |
+
if seconds:
|
264 |
+
dt = int(TIME_RESOLUTION*dt)
|
265 |
+
|
266 |
+
new_tokens = []
|
267 |
+
for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
|
268 |
+
# stop translating after EOT
|
269 |
+
if note == SEPARATOR:
|
270 |
+
new_tokens.extend([time, dur, note])
|
271 |
+
dt = 0
|
272 |
+
continue
|
273 |
+
|
274 |
+
if note < CONTROL_OFFSET:
|
275 |
+
this_time = time - TIME_OFFSET
|
276 |
+
else:
|
277 |
+
this_time = time - ATIME_OFFSET
|
278 |
+
|
279 |
+
assert 0 <= this_time + dt
|
280 |
+
new_tokens.extend([time+dt, dur, note])
|
281 |
+
|
282 |
+
return new_tokens
|
283 |
+
|
284 |
+
def combine(events, controls):
|
285 |
+
return sort(events + [token - CONTROL_OFFSET for token in controls])
|
anticipation/sample.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
API functions for sampling from anticipatory infilling models.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from anticipation import ops
|
13 |
+
from anticipation.config import *
|
14 |
+
from anticipation.vocab import *
|
15 |
+
|
16 |
+
|
17 |
+
def safe_logits(logits, idx):
|
18 |
+
logits[CONTROL_OFFSET:SPECIAL_OFFSET] = -float('inf') # don't generate controls
|
19 |
+
logits[SPECIAL_OFFSET:] = -float('inf') # don't generate special tokens
|
20 |
+
|
21 |
+
# don't generate stuff in the wrong time slot
|
22 |
+
if idx % 3 == 0:
|
23 |
+
logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')
|
24 |
+
logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')
|
25 |
+
elif idx % 3 == 1:
|
26 |
+
logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')
|
27 |
+
logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')
|
28 |
+
elif idx % 3 == 2:
|
29 |
+
logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')
|
30 |
+
logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')
|
31 |
+
|
32 |
+
return logits
|
33 |
+
|
34 |
+
|
35 |
+
def nucleus(logits, top_p):
|
36 |
+
# from HF implementation
|
37 |
+
if top_p < 1.0:
|
38 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
39 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
40 |
+
|
41 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
42 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
43 |
+
|
44 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
45 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
46 |
+
sorted_indices_to_remove[..., 0] = 0
|
47 |
+
|
48 |
+
# scatter sorted tensors to original indexing
|
49 |
+
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
|
50 |
+
logits[indices_to_remove] = -float("inf")
|
51 |
+
|
52 |
+
return logits
|
53 |
+
|
54 |
+
|
55 |
+
def future_logits(logits, curtime):
|
56 |
+
""" don't sample events in the past """
|
57 |
+
if curtime > 0:
|
58 |
+
logits[TIME_OFFSET:TIME_OFFSET+curtime] = -float('inf')
|
59 |
+
|
60 |
+
return logits
|
61 |
+
|
62 |
+
|
63 |
+
def instr_logits(logits, full_history):
|
64 |
+
""" don't sample more than 16 instruments """
|
65 |
+
instrs = ops.get_instruments(full_history)
|
66 |
+
if len(instrs) < 15: # 16 - 1 to account for the reserved drum track
|
67 |
+
return logits
|
68 |
+
|
69 |
+
for instr in range(MAX_INSTR):
|
70 |
+
if instr not in instrs:
|
71 |
+
logits[NOTE_OFFSET+instr*MAX_PITCH:NOTE_OFFSET+(instr+1)*MAX_PITCH] = -float('inf')
|
72 |
+
|
73 |
+
return logits
|
74 |
+
|
75 |
+
|
76 |
+
def add_token(model, z, tokens, top_p, current_time, debug=False):
|
77 |
+
assert len(tokens) % 3 == 0
|
78 |
+
|
79 |
+
history = tokens.copy()
|
80 |
+
lookback = max(len(tokens) - 1017, 0)
|
81 |
+
history = history[lookback:] # Markov window
|
82 |
+
offset = ops.min_time(history, seconds=False)
|
83 |
+
history[::3] = [tok - offset for tok in history[::3]] # relativize time in the history buffer
|
84 |
+
|
85 |
+
new_token = []
|
86 |
+
with torch.no_grad():
|
87 |
+
for i in range(3):
|
88 |
+
input_tokens = torch.tensor(z + history + new_token).unsqueeze(0).to(model.device)
|
89 |
+
logits = model(input_tokens).logits[0,-1]
|
90 |
+
|
91 |
+
idx = input_tokens.shape[1]-1
|
92 |
+
logits = safe_logits(logits, idx)
|
93 |
+
if i == 0:
|
94 |
+
logits = future_logits(logits, current_time - offset)
|
95 |
+
elif i == 2:
|
96 |
+
logits = instr_logits(logits, tokens)
|
97 |
+
logits = nucleus(logits, top_p)
|
98 |
+
|
99 |
+
probs = F.softmax(logits, dim=-1)
|
100 |
+
token = torch.multinomial(probs, 1)
|
101 |
+
new_token.append(int(token))
|
102 |
+
|
103 |
+
new_token[0] += offset # revert to full sequence timing
|
104 |
+
if debug:
|
105 |
+
print(f' OFFSET = {offset}, LEN = {len(history)}, TIME = {tokens[::3][-5:]}')
|
106 |
+
|
107 |
+
return new_token
|
108 |
+
|
109 |
+
|
110 |
+
def generate(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION):
|
111 |
+
if inputs is None:
|
112 |
+
inputs = []
|
113 |
+
|
114 |
+
if controls is None:
|
115 |
+
controls = []
|
116 |
+
|
117 |
+
start_time = int(TIME_RESOLUTION*start_time)
|
118 |
+
end_time = int(TIME_RESOLUTION*end_time)
|
119 |
+
|
120 |
+
# prompt is events up to start_time
|
121 |
+
prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time)
|
122 |
+
|
123 |
+
# treat events beyond start_time as controls
|
124 |
+
future = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False)
|
125 |
+
if debug:
|
126 |
+
print('Future')
|
127 |
+
ops.print_tokens(future)
|
128 |
+
|
129 |
+
# clip controls that preceed the sequence
|
130 |
+
controls = ops.clip(controls, DELTA, ops.max_time(controls, seconds=False), clip_duration=False, seconds=False)
|
131 |
+
|
132 |
+
if debug:
|
133 |
+
print('Controls')
|
134 |
+
ops.print_tokens(controls)
|
135 |
+
|
136 |
+
z = [ANTICIPATE] if len(controls) > 0 or len(future) > 0 else [AUTOREGRESS]
|
137 |
+
if debug:
|
138 |
+
print('AR Mode' if z[0] == AUTOREGRESS else 'AAR Mode')
|
139 |
+
|
140 |
+
# interleave the controls with the events
|
141 |
+
tokens, controls = ops.anticipate(prompt, ops.sort(controls + [CONTROL_OFFSET+token for token in future]))
|
142 |
+
|
143 |
+
if debug:
|
144 |
+
print('Prompt')
|
145 |
+
ops.print_tokens(tokens)
|
146 |
+
|
147 |
+
current_time = ops.max_time(prompt, seconds=False)
|
148 |
+
if debug:
|
149 |
+
print('Current time:', current_time)
|
150 |
+
|
151 |
+
with tqdm(range(end_time-start_time)) as progress:
|
152 |
+
if controls:
|
153 |
+
atime, adur, anote = controls[0:3]
|
154 |
+
anticipated_tokens = controls[3:]
|
155 |
+
anticipated_time = atime - ATIME_OFFSET
|
156 |
+
else:
|
157 |
+
# nothing to anticipate
|
158 |
+
anticipated_time = math.inf
|
159 |
+
|
160 |
+
while True:
|
161 |
+
while current_time >= anticipated_time - delta:
|
162 |
+
tokens.extend([atime, adur, anote])
|
163 |
+
if debug:
|
164 |
+
note = anote - ANOTE_OFFSET
|
165 |
+
instr = note//2**7
|
166 |
+
print('A', atime - ATIME_OFFSET, adur - ADUR_OFFSET, instr, note - (2**7)*instr)
|
167 |
+
|
168 |
+
if len(anticipated_tokens) > 0:
|
169 |
+
atime, adur, anote = anticipated_tokens[0:3]
|
170 |
+
anticipated_tokens = anticipated_tokens[3:]
|
171 |
+
anticipated_time = atime - ATIME_OFFSET
|
172 |
+
else:
|
173 |
+
# nothing more to anticipate
|
174 |
+
anticipated_time = math.inf
|
175 |
+
|
176 |
+
new_token = add_token(model, z, tokens, top_p, max(start_time,current_time))
|
177 |
+
new_time = new_token[0] - TIME_OFFSET
|
178 |
+
if new_time >= end_time:
|
179 |
+
break
|
180 |
+
|
181 |
+
if debug:
|
182 |
+
new_note = new_token[2] - NOTE_OFFSET
|
183 |
+
new_instr = new_note//2**7
|
184 |
+
new_pitch = new_note - (2**7)*new_instr
|
185 |
+
print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch)
|
186 |
+
|
187 |
+
tokens.extend(new_token)
|
188 |
+
dt = new_time - current_time
|
189 |
+
assert dt >= 0
|
190 |
+
current_time = new_time
|
191 |
+
progress.update(dt)
|
192 |
+
|
193 |
+
events, _ = ops.split(tokens)
|
194 |
+
return ops.sort(ops.unpad(events) + future)
|
195 |
+
|
196 |
+
|
197 |
+
def generate_ar(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION):
|
198 |
+
if inputs is None:
|
199 |
+
inputs = []
|
200 |
+
|
201 |
+
if controls is None:
|
202 |
+
controls = []
|
203 |
+
else:
|
204 |
+
# treat controls as ordinary tokens
|
205 |
+
controls = [token-CONTROL_OFFSET for token in controls]
|
206 |
+
|
207 |
+
start_time = int(TIME_RESOLUTION*start_time)
|
208 |
+
end_time = int(TIME_RESOLUTION*end_time)
|
209 |
+
|
210 |
+
inputs = ops.sort(inputs + controls)
|
211 |
+
|
212 |
+
# prompt is events up to start_time
|
213 |
+
prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time)
|
214 |
+
if debug:
|
215 |
+
print('Prompt')
|
216 |
+
ops.print_tokens(prompt)
|
217 |
+
|
218 |
+
# treat events beyond start_time as controls
|
219 |
+
controls = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False)
|
220 |
+
if debug:
|
221 |
+
print('Future')
|
222 |
+
ops.print_tokens(controls)
|
223 |
+
|
224 |
+
z = [AUTOREGRESS]
|
225 |
+
if debug:
|
226 |
+
print('AR Mode')
|
227 |
+
|
228 |
+
current_time = ops.max_time(prompt, seconds=False)
|
229 |
+
if debug:
|
230 |
+
print('Current time:', current_time)
|
231 |
+
|
232 |
+
tokens = prompt
|
233 |
+
with tqdm(range(end_time-start_time)) as progress:
|
234 |
+
if controls:
|
235 |
+
atime, adur, anote = controls[0:3]
|
236 |
+
anticipated_tokens = controls[3:]
|
237 |
+
anticipated_time = atime - TIME_OFFSET
|
238 |
+
else:
|
239 |
+
# nothing to anticipate
|
240 |
+
anticipated_time = math.inf
|
241 |
+
|
242 |
+
while True:
|
243 |
+
new_token = add_token(model, z, tokens, top_p, max(start_time,current_time))
|
244 |
+
new_time = new_token[0] - TIME_OFFSET
|
245 |
+
if new_time >= end_time:
|
246 |
+
break
|
247 |
+
|
248 |
+
dt = new_time - current_time
|
249 |
+
assert dt >= 0
|
250 |
+
current_time = new_time
|
251 |
+
|
252 |
+
# backfill anything that should have come before the new token
|
253 |
+
while current_time >= anticipated_time:
|
254 |
+
tokens.extend([atime, adur, anote])
|
255 |
+
if debug:
|
256 |
+
note = anote - NOTE_OFFSET
|
257 |
+
instr = note//2**7
|
258 |
+
print('A', atime - TIME_OFFSET, adur - DUR_OFFSET, instr, note - (2**7)*instr)
|
259 |
+
|
260 |
+
if len(anticipated_tokens) > 0:
|
261 |
+
atime, adur, anote = anticipated_tokens[0:3]
|
262 |
+
anticipated_tokens = anticipated_tokens[3:]
|
263 |
+
anticipated_time = atime - TIME_OFFSET
|
264 |
+
else:
|
265 |
+
# nothing more to anticipate
|
266 |
+
anticipated_time = math.inf
|
267 |
+
|
268 |
+
if debug:
|
269 |
+
new_note = new_token[2] - NOTE_OFFSET
|
270 |
+
new_instr = new_note//2**7
|
271 |
+
new_pitch = new_note - (2**7)*new_instr
|
272 |
+
print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch)
|
273 |
+
|
274 |
+
tokens.extend(new_token)
|
275 |
+
progress.update(dt)
|
276 |
+
|
277 |
+
if anticipated_time != math.inf:
|
278 |
+
tokens.extend([atime, adur, anote])
|
279 |
+
|
280 |
+
return ops.sort(ops.unpad(tokens) + controls)
|
anticipation/tokenize.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Top-level functions for preprocessing data to be used for training.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from anticipation import ops
|
10 |
+
from anticipation.config import *
|
11 |
+
from anticipation.vocab import *
|
12 |
+
from anticipation.convert import compound_to_events, midi_to_interarrival
|
13 |
+
|
14 |
+
|
15 |
+
def extract_spans(all_events, rate):
|
16 |
+
events = []
|
17 |
+
controls = []
|
18 |
+
span = True
|
19 |
+
next_span = end_span = TIME_OFFSET+0
|
20 |
+
for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]):
|
21 |
+
assert(note not in [SEPARATOR, REST]) # shouldn't be in the sequence yet
|
22 |
+
|
23 |
+
# end of an anticipated span; decide when to do it again (next_span)
|
24 |
+
if span and time >= end_span:
|
25 |
+
span = False
|
26 |
+
next_span = time+int(TIME_RESOLUTION*np.random.exponential(1./rate))
|
27 |
+
|
28 |
+
# anticipate a 3-second span
|
29 |
+
if (not span) and time >= next_span:
|
30 |
+
span = True
|
31 |
+
end_span = time + DELTA*TIME_RESOLUTION
|
32 |
+
|
33 |
+
if span:
|
34 |
+
# mark this event as a control
|
35 |
+
controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note])
|
36 |
+
else:
|
37 |
+
events.extend([time, dur, note])
|
38 |
+
|
39 |
+
return events, controls
|
40 |
+
|
41 |
+
|
42 |
+
ANTICIPATION_RATES = 10
|
43 |
+
def extract_random(all_events, rate):
|
44 |
+
events = []
|
45 |
+
controls = []
|
46 |
+
for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]):
|
47 |
+
assert(note not in [SEPARATOR, REST]) # shouldn't be in the sequence yet
|
48 |
+
|
49 |
+
if np.random.random() < rate/float(ANTICIPATION_RATES):
|
50 |
+
# mark this event as a control
|
51 |
+
controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note])
|
52 |
+
else:
|
53 |
+
events.extend([time, dur, note])
|
54 |
+
|
55 |
+
return events, controls
|
56 |
+
|
57 |
+
|
58 |
+
def extract_instruments(all_events, instruments):
|
59 |
+
events = []
|
60 |
+
controls = []
|
61 |
+
for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]):
|
62 |
+
assert note < CONTROL_OFFSET # shouldn't be in the sequence yet
|
63 |
+
assert note not in [SEPARATOR, REST] # these shouldn't either
|
64 |
+
|
65 |
+
instr = (note-NOTE_OFFSET)//2**7
|
66 |
+
if instr in instruments:
|
67 |
+
# mark this event as a control
|
68 |
+
controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note])
|
69 |
+
else:
|
70 |
+
events.extend([time, dur, note])
|
71 |
+
|
72 |
+
return events, controls
|
73 |
+
|
74 |
+
|
75 |
+
def maybe_tokenize(compound_tokens):
|
76 |
+
# skip sequences with very few events
|
77 |
+
if len(compound_tokens) < COMPOUND_SIZE*MIN_TRACK_EVENTS:
|
78 |
+
return None, None, 1 # short track
|
79 |
+
|
80 |
+
events, truncations = compound_to_events(compound_tokens, stats=True)
|
81 |
+
end_time = ops.max_time(events, seconds=False)
|
82 |
+
|
83 |
+
# don't want to deal with extremely short tracks
|
84 |
+
if end_time < TIME_RESOLUTION*MIN_TRACK_TIME_IN_SECONDS:
|
85 |
+
return None, None, 1 # short track
|
86 |
+
|
87 |
+
# don't want to deal with extremely long tracks
|
88 |
+
if end_time > TIME_RESOLUTION*MAX_TRACK_TIME_IN_SECONDS:
|
89 |
+
return None, None, 2 # long track
|
90 |
+
|
91 |
+
# skip sequences more instruments than MIDI channels (16)
|
92 |
+
if len(ops.get_instruments(events)) > MAX_TRACK_INSTR:
|
93 |
+
return None, None, 3 # too many instruments
|
94 |
+
|
95 |
+
return events, truncations, 0
|
96 |
+
|
97 |
+
|
98 |
+
def tokenize_ia(datafiles, output, augment_factor, idx=0, debug=False):
|
99 |
+
assert augment_factor == 1 # can't augment interarrival-tokenized data
|
100 |
+
|
101 |
+
all_truncations = 0
|
102 |
+
seqcount = rest_count = 0
|
103 |
+
stats = 4*[0] # (short, long, too many instruments, inexpressible)
|
104 |
+
np.random.seed(0)
|
105 |
+
|
106 |
+
with open(output, 'w') as outfile:
|
107 |
+
concatenated_tokens = []
|
108 |
+
for j, filename in tqdm(list(enumerate(datafiles)), desc=f'#{idx}', position=idx+1, leave=True):
|
109 |
+
with open(filename, 'r') as f:
|
110 |
+
_, _, status = maybe_tokenize([int(token) for token in f.read().split()])
|
111 |
+
|
112 |
+
if status > 0:
|
113 |
+
stats[status-1] += 1
|
114 |
+
continue
|
115 |
+
|
116 |
+
filename = filename[:-len('.compound.txt')] # get the original MIDI
|
117 |
+
|
118 |
+
# already parsed; shouldn't raise an exception
|
119 |
+
tokens, truncations = midi_to_interarrival(filename, stats=True)
|
120 |
+
tokens[0:0] = [MIDI_SEPARATOR]
|
121 |
+
concatenated_tokens.extend(tokens)
|
122 |
+
all_truncations += truncations
|
123 |
+
|
124 |
+
# write out full sequences to file
|
125 |
+
while len(concatenated_tokens) >= CONTEXT_SIZE:
|
126 |
+
seq = concatenated_tokens[0:CONTEXT_SIZE]
|
127 |
+
concatenated_tokens = concatenated_tokens[CONTEXT_SIZE:]
|
128 |
+
outfile.write(' '.join([str(tok) for tok in seq]) + '\n')
|
129 |
+
seqcount += 1
|
130 |
+
|
131 |
+
if debug:
|
132 |
+
fmt = 'Processed {} sequences (discarded {} tracks, discarded {} seqs, added {} rest tokens)'
|
133 |
+
print(fmt.format(seqcount, stats[0]+stats[1]+stats[2], stats[3], rest_count))
|
134 |
+
|
135 |
+
return (seqcount, rest_count, stats[0], stats[1], stats[2], stats[3], all_truncations)
|
136 |
+
|
137 |
+
|
138 |
+
def tokenize(datafiles, output, augment_factor, idx=0, debug=False):
|
139 |
+
tokens = []
|
140 |
+
all_truncations = 0
|
141 |
+
seqcount = rest_count = 0
|
142 |
+
stats = 4*[0] # (short, long, too many instruments, inexpressible)
|
143 |
+
np.random.seed(0)
|
144 |
+
|
145 |
+
with open(output, 'w') as outfile:
|
146 |
+
concatenated_tokens = []
|
147 |
+
for j, filename in tqdm(list(enumerate(datafiles)), desc=f'#{idx}', position=idx+1, leave=True):
|
148 |
+
with open(filename, 'r') as f:
|
149 |
+
all_events, truncations, status = maybe_tokenize([int(token) for token in f.read().split()])
|
150 |
+
|
151 |
+
if status > 0:
|
152 |
+
stats[status-1] += 1
|
153 |
+
continue
|
154 |
+
|
155 |
+
instruments = list(ops.get_instruments(all_events).keys())
|
156 |
+
end_time = ops.max_time(all_events, seconds=False)
|
157 |
+
|
158 |
+
# different random augmentations
|
159 |
+
for k in range(augment_factor):
|
160 |
+
if k % 10 == 0:
|
161 |
+
# no augmentation
|
162 |
+
events = all_events.copy()
|
163 |
+
controls = []
|
164 |
+
elif k % 10 == 1:
|
165 |
+
# span augmentation
|
166 |
+
lmbda = .05
|
167 |
+
events, controls = extract_spans(all_events, lmbda)
|
168 |
+
elif k % 10 < 6:
|
169 |
+
# random augmentation
|
170 |
+
r = np.random.randint(1,ANTICIPATION_RATES)
|
171 |
+
events, controls = extract_random(all_events, r)
|
172 |
+
else:
|
173 |
+
if len(instruments) > 1:
|
174 |
+
# instrument augmentation: at least one, but not all instruments
|
175 |
+
u = 1+np.random.randint(len(instruments)-1)
|
176 |
+
subset = np.random.choice(instruments, u, replace=False)
|
177 |
+
events, controls = extract_instruments(all_events, subset)
|
178 |
+
else:
|
179 |
+
# no augmentation
|
180 |
+
events = all_events.copy()
|
181 |
+
controls = []
|
182 |
+
|
183 |
+
if len(concatenated_tokens) == 0:
|
184 |
+
z = ANTICIPATE if k % 10 != 0 else AUTOREGRESS
|
185 |
+
|
186 |
+
all_truncations += truncations
|
187 |
+
events = ops.pad(events, end_time)
|
188 |
+
rest_count += sum(1 if tok == REST else 0 for tok in events[2::3])
|
189 |
+
tokens, controls = ops.anticipate(events, controls)
|
190 |
+
assert len(controls) == 0 # should have consumed all controls (because of padding)
|
191 |
+
tokens[0:0] = [SEPARATOR, SEPARATOR, SEPARATOR]
|
192 |
+
concatenated_tokens.extend(tokens)
|
193 |
+
|
194 |
+
# write out full sequences to file
|
195 |
+
while len(concatenated_tokens) >= EVENT_SIZE*M:
|
196 |
+
seq = concatenated_tokens[0:EVENT_SIZE*M]
|
197 |
+
concatenated_tokens = concatenated_tokens[EVENT_SIZE*M:]
|
198 |
+
|
199 |
+
# relativize time to the context
|
200 |
+
seq = ops.translate(seq, -ops.min_time(seq, seconds=False), seconds=False)
|
201 |
+
assert ops.min_time(seq, seconds=False) == 0
|
202 |
+
if ops.max_time(seq, seconds=False) >= MAX_TIME:
|
203 |
+
stats[3] += 1
|
204 |
+
continue
|
205 |
+
|
206 |
+
# if seq contains SEPARATOR, global controls describe the first sequence
|
207 |
+
seq.insert(0, z)
|
208 |
+
|
209 |
+
outfile.write(' '.join([str(tok) for tok in seq]) + '\n')
|
210 |
+
seqcount += 1
|
211 |
+
|
212 |
+
# grab the current augmentation controls if we didn't already
|
213 |
+
z = ANTICIPATE if k % 10 != 0 else AUTOREGRESS
|
214 |
+
|
215 |
+
if debug:
|
216 |
+
fmt = 'Processed {} sequences (discarded {} tracks, discarded {} seqs, added {} rest tokens)'
|
217 |
+
print(fmt.format(seqcount, stats[0]+stats[1]+stats[2], stats[3], rest_count))
|
218 |
+
|
219 |
+
return (seqcount, rest_count, stats[0], stats[1], stats[2], stats[3], all_truncations)
|
anticipation/visuals.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for inspecting encoded music data.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import matplotlib
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
import anticipation.ops as ops
|
11 |
+
from anticipation.config import *
|
12 |
+
from anticipation.vocab import *
|
13 |
+
|
14 |
+
def visualize(tokens, output, selected=None):
|
15 |
+
#colors = ['white', 'silver', 'red', 'sienna', 'darkorange', 'gold', 'yellow', 'palegreen', 'seagreen', 'cyan',
|
16 |
+
# 'dodgerblue', 'slategray', 'navy', 'mediumpurple', 'mediumorchid', 'magenta', 'lightpink']
|
17 |
+
colors = ['white', '#426aa0', '#b26789', '#de9283', '#eac29f', 'silver', 'red', 'sienna', 'darkorange', 'gold', 'yellow', 'palegreen', 'seagreen', 'cyan', 'dodgerblue', 'slategray', 'navy']
|
18 |
+
|
19 |
+
plt.rcParams['figure.dpi'] = 300
|
20 |
+
plt.rcParams['savefig.dpi'] = 300
|
21 |
+
|
22 |
+
max_time = ops.max_time(tokens, seconds=False)
|
23 |
+
grid = np.zeros([max_time, MAX_PITCH])
|
24 |
+
instruments = list(sorted(list(ops.get_instruments(tokens).keys())))
|
25 |
+
if 128 in instruments:
|
26 |
+
instruments.remove(128)
|
27 |
+
|
28 |
+
for j, (tm, dur, note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])):
|
29 |
+
if note == SEPARATOR:
|
30 |
+
assert tm == SEPARATOR and dur == SEPARATOR
|
31 |
+
print(j, 'SEPARATOR')
|
32 |
+
continue
|
33 |
+
|
34 |
+
if note == REST:
|
35 |
+
continue
|
36 |
+
|
37 |
+
assert note < CONTROL_OFFSET
|
38 |
+
|
39 |
+
tm = tm - TIME_OFFSET
|
40 |
+
dur = dur - DUR_OFFSET
|
41 |
+
note = note - NOTE_OFFSET
|
42 |
+
instr = note//2**7
|
43 |
+
pitch = note - (2**7)*instr
|
44 |
+
|
45 |
+
if instr == 128: # drums
|
46 |
+
continue # we don't visualize this
|
47 |
+
|
48 |
+
if selected and instr not in selected:
|
49 |
+
continue
|
50 |
+
|
51 |
+
grid[tm:tm+dur, pitch] = 1+instruments.index(instr)
|
52 |
+
|
53 |
+
plt.clf()
|
54 |
+
plt.axis('off')
|
55 |
+
cmap = matplotlib.colors.ListedColormap(colors)
|
56 |
+
bounds = list(range(MAX_TRACK_INSTR)) + [16]
|
57 |
+
norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N)
|
58 |
+
plt.imshow(np.flipud(grid.T), aspect=16, cmap=cmap, norm=norm, interpolation='none')
|
59 |
+
|
60 |
+
patches = [matplotlib.patches.Patch(color=colors[i+1], label=f"{instruments[i]}")
|
61 |
+
for i in range(len(instruments))]
|
62 |
+
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0. )
|
63 |
+
|
64 |
+
plt.tight_layout()
|
65 |
+
plt.savefig(output)
|
anticipation/vocab.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The vocabularies used for arrival-time and interarrival-time encodings.
|
3 |
+
"""
|
4 |
+
|
5 |
+
# training sequence vocab
|
6 |
+
|
7 |
+
from anticipation.config import *
|
8 |
+
|
9 |
+
# the event block
|
10 |
+
EVENT_OFFSET = 0
|
11 |
+
TIME_OFFSET = EVENT_OFFSET
|
12 |
+
DUR_OFFSET = TIME_OFFSET + MAX_TIME
|
13 |
+
NOTE_OFFSET = DUR_OFFSET + MAX_DUR
|
14 |
+
REST = NOTE_OFFSET + MAX_NOTE
|
15 |
+
|
16 |
+
# the control block
|
17 |
+
CONTROL_OFFSET = NOTE_OFFSET + MAX_NOTE + 1
|
18 |
+
ATIME_OFFSET = CONTROL_OFFSET + 0
|
19 |
+
ADUR_OFFSET = ATIME_OFFSET + MAX_TIME
|
20 |
+
ANOTE_OFFSET = ADUR_OFFSET + MAX_DUR
|
21 |
+
|
22 |
+
# the special block
|
23 |
+
SPECIAL_OFFSET = ANOTE_OFFSET + MAX_NOTE
|
24 |
+
SEPARATOR = SPECIAL_OFFSET
|
25 |
+
AUTOREGRESS = SPECIAL_OFFSET + 1
|
26 |
+
ANTICIPATE = SPECIAL_OFFSET + 2
|
27 |
+
VOCAB_SIZE = ANTICIPATE+1
|
28 |
+
|
29 |
+
# interarrival-time (MIDI-like) vocab
|
30 |
+
MIDI_TIME_OFFSET = 0
|
31 |
+
MIDI_START_OFFSET = MIDI_TIME_OFFSET + MAX_INTERARRIVAL
|
32 |
+
MIDI_END_OFFSET = MIDI_START_OFFSET + MAX_NOTE
|
33 |
+
MIDI_SEPARATOR = MIDI_END_OFFSET + MAX_NOTE
|
34 |
+
MIDI_VOCAB_SIZE = MIDI_SEPARATOR + 1
|
35 |
+
|
36 |
+
if __name__ == '__main__':
|
37 |
+
print('Arrival-Time Training Sequence Format:')
|
38 |
+
print('Event Offset: ', EVENT_OFFSET)
|
39 |
+
print(' -> time offset :', TIME_OFFSET)
|
40 |
+
print(' -> duration offset :', DUR_OFFSET)
|
41 |
+
print(' -> note offset :', NOTE_OFFSET)
|
42 |
+
print(' -> rest token: ', REST)
|
43 |
+
print('Anticipated Control Offset: ', CONTROL_OFFSET)
|
44 |
+
print(' -> anticipated time offset :', ATIME_OFFSET)
|
45 |
+
print(' -> anticipated duration offset :', ADUR_OFFSET)
|
46 |
+
print(' -> anticipated note offset :', ANOTE_OFFSET)
|
47 |
+
print('Special Token Offset: ', SPECIAL_OFFSET)
|
48 |
+
print(' -> separator token: ', SEPARATOR)
|
49 |
+
print(' -> autoregression flag: ', AUTOREGRESS)
|
50 |
+
print(' -> anticipation flag: ', ANTICIPATE)
|
51 |
+
print('Arrival Encoding Vocabulary Size: ', VOCAB_SIZE)
|
52 |
+
print('')
|
53 |
+
print('Interarrival-Time Training Sequence Format:')
|
54 |
+
print(' -> time offset: ', MIDI_TIME_OFFSET)
|
55 |
+
print(' -> note-on offset: ', MIDI_START_OFFSET)
|
56 |
+
print(' -> note-off offset: ', MIDI_END_OFFSET)
|
57 |
+
print(' -> separator token: ', MIDI_SEPARATOR)
|
58 |
+
print('Interarrival Encoding Vocabulary Size: ', MIDI_VOCAB_SIZE)
|
api.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from agents.agents import harmonizer, infiller, change_melody
|
2 |
+
from flask import Flask, request, jsonify
|
3 |
+
from flask_cors import CORS
|
4 |
+
import mido
|
5 |
+
import tempfile
|
6 |
+
import os
|
7 |
+
import music21
|
8 |
+
import traceback
|
9 |
+
from uuid import uuid4
|
10 |
+
import threading
|
11 |
+
from transformers import AutoModelForCausalLM
|
12 |
+
|
13 |
+
|
14 |
+
app = Flask(__name__)
|
15 |
+
CORS(app)
|
16 |
+
|
17 |
+
@app.after_request
|
18 |
+
def add_cors_headers(response):
|
19 |
+
# Allow only your domain
|
20 |
+
response.headers['Access-Control-Allow-Origin'] = 'https://https://inscoreai.netlify.app/.com'
|
21 |
+
response.headers['Access-Control-Allow-Methods'] = 'GET, POST'
|
22 |
+
response.headers['Access-Control-Allow-Headers'] = 'Content-Type'
|
23 |
+
return response
|
24 |
+
|
25 |
+
def midi_to_musicxml(midi_file_path):
|
26 |
+
"""Convert MIDI file to MusicXML string with absolute safety"""
|
27 |
+
try:
|
28 |
+
midi_path_str = str(midi_file_path)
|
29 |
+
|
30 |
+
# Parse and convert to MusicXML
|
31 |
+
score = music21.converter.parse(midi_path_str)
|
32 |
+
|
33 |
+
# Create temporary output file path
|
34 |
+
temp_output = os.path.join(tempfile.gettempdir(), f"output_{uuid4().hex}.musicxml")
|
35 |
+
|
36 |
+
# Write to temporary file
|
37 |
+
score.write('musicxml', temp_output)
|
38 |
+
|
39 |
+
# Read back as string
|
40 |
+
with open(temp_output, 'r') as f:
|
41 |
+
musicxml_str = f.read()
|
42 |
+
|
43 |
+
# Clean up
|
44 |
+
os.unlink(temp_output)
|
45 |
+
|
46 |
+
return musicxml_str
|
47 |
+
except Exception as e:
|
48 |
+
print(f"Conversion error: {str(e)}")
|
49 |
+
traceback.print_exc()
|
50 |
+
raise
|
51 |
+
|
52 |
+
def load_model():
|
53 |
+
global MODEL
|
54 |
+
with MODEL_LOCK:
|
55 |
+
if MODEL is None:
|
56 |
+
print("⏳ Loading music generation model...")
|
57 |
+
MODEL = AutoModelForCausalLM.from_pretrained('stanford-crfm/music-small-800k',local_files_only=True, force_download=False) # Prevent re-downloads
|
58 |
+
# Add .cuda() here if using GPU
|
59 |
+
print("✅ Model loaded successfully!")
|
60 |
+
return MODEL
|
61 |
+
|
62 |
+
# Model loading setup
|
63 |
+
MODEL = None
|
64 |
+
MODEL_LOCK = threading.Lock()
|
65 |
+
|
66 |
+
# Initialize model when app starts
|
67 |
+
with app.app_context():
|
68 |
+
load_model()
|
69 |
+
|
70 |
+
@app.route('/upload', methods=['POST'])
|
71 |
+
def handle_upload():
|
72 |
+
temp_midi_path = None
|
73 |
+
top_p = float(request.form.get('top_p', '0.95'))
|
74 |
+
|
75 |
+
try:
|
76 |
+
# Validate input
|
77 |
+
if 'midi_file' not in request.files:
|
78 |
+
return jsonify({"status": "error", "message": "No MIDI file provided"}), 400
|
79 |
+
|
80 |
+
midi_file = request.files['midi_file']
|
81 |
+
start_time = request.form.get('start_time', '0')
|
82 |
+
end_time = request.form.get('end_time', '0')
|
83 |
+
|
84 |
+
# Create temporary MIDI file with random name
|
85 |
+
temp_dir = tempfile.gettempdir()
|
86 |
+
temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid")
|
87 |
+
|
88 |
+
# Save uploaded MIDI to temp file
|
89 |
+
midi_file.save(temp_midi_path)
|
90 |
+
|
91 |
+
# Process MIDI
|
92 |
+
midi = mido.MidiFile(temp_midi_path)
|
93 |
+
model = load_model()
|
94 |
+
harmonized_midi = harmonizer(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p)
|
95 |
+
|
96 |
+
# Save harmonized MIDI (overwriting temp file)
|
97 |
+
harmonized_midi.save(temp_midi_path)
|
98 |
+
|
99 |
+
# Convert to MusicXML string
|
100 |
+
musicxml_str = midi_to_musicxml(temp_midi_path)
|
101 |
+
|
102 |
+
# Final type verification
|
103 |
+
if not isinstance(musicxml_str, str):
|
104 |
+
raise TypeError(f"Expected string but got {type(musicxml_str)}")
|
105 |
+
|
106 |
+
return jsonify({
|
107 |
+
"status": "success",
|
108 |
+
"musicxml": musicxml_str
|
109 |
+
})
|
110 |
+
|
111 |
+
except Exception as e:
|
112 |
+
print(f"Error processing request: {str(e)}")
|
113 |
+
traceback.print_exc()
|
114 |
+
return jsonify({
|
115 |
+
"status": "error",
|
116 |
+
"message": str(e)
|
117 |
+
}), 400
|
118 |
+
finally:
|
119 |
+
# Clean up temp file
|
120 |
+
if temp_midi_path and os.path.exists(temp_midi_path):
|
121 |
+
try:
|
122 |
+
os.unlink(temp_midi_path)
|
123 |
+
except Exception as e:
|
124 |
+
print(f"Warning: Could not remove {temp_midi_path}: {str(e)}")
|
125 |
+
|
126 |
+
@app.route('/uploadinfill', methods=['POST'])
|
127 |
+
def handle_upload_infilling():
|
128 |
+
temp_midi_path = None
|
129 |
+
top_p = float(request.form.get('top_p', '0.95'))
|
130 |
+
|
131 |
+
try:
|
132 |
+
# Validate input
|
133 |
+
if 'midi_file' not in request.files:
|
134 |
+
return jsonify({"status": "error", "message": "No MIDI file provided"}), 400
|
135 |
+
|
136 |
+
midi_file = request.files['midi_file']
|
137 |
+
start_time = request.form.get('start_time', '0')
|
138 |
+
end_time = request.form.get('end_time', '0')
|
139 |
+
|
140 |
+
# Create temporary MIDI file with random name
|
141 |
+
temp_dir = tempfile.gettempdir()
|
142 |
+
temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid")
|
143 |
+
|
144 |
+
# Save uploaded MIDI to temp file
|
145 |
+
midi_file.save(temp_midi_path)
|
146 |
+
|
147 |
+
# Process MIDI
|
148 |
+
midi = mido.MidiFile(temp_midi_path)
|
149 |
+
model = load_model()
|
150 |
+
infilled_midi = infiller(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p)
|
151 |
+
|
152 |
+
# Save harmonized MIDI (overwriting temp file)
|
153 |
+
infilled_midi.save(temp_midi_path)
|
154 |
+
|
155 |
+
# Convert to MusicXML string
|
156 |
+
musicxml_str = midi_to_musicxml(temp_midi_path)
|
157 |
+
|
158 |
+
# Final type verification
|
159 |
+
if not isinstance(musicxml_str, str):
|
160 |
+
raise TypeError(f"Expected string but got {type(musicxml_str)}")
|
161 |
+
|
162 |
+
return jsonify({
|
163 |
+
"status": "success",
|
164 |
+
"musicxml": musicxml_str
|
165 |
+
})
|
166 |
+
|
167 |
+
except Exception as e:
|
168 |
+
print(f"Error processing request: {str(e)}")
|
169 |
+
traceback.print_exc()
|
170 |
+
return jsonify({
|
171 |
+
"status": "error",
|
172 |
+
"message": str(e)
|
173 |
+
}), 400
|
174 |
+
finally:
|
175 |
+
# Clean up temp file
|
176 |
+
if temp_midi_path and os.path.exists(temp_midi_path):
|
177 |
+
try:
|
178 |
+
os.unlink(temp_midi_path)
|
179 |
+
except Exception as e:
|
180 |
+
print(f"Warning: Could not remove {temp_midi_path}: {str(e)}")
|
181 |
+
|
182 |
+
@app.route('/uploadchangemelody', methods=['POST'])
|
183 |
+
def handle_upload_changemelody():
|
184 |
+
temp_midi_path = None
|
185 |
+
top_p = float(request.form.get('top_p', '0.95'))
|
186 |
+
|
187 |
+
try:
|
188 |
+
# Validate input
|
189 |
+
if 'midi_file' not in request.files:
|
190 |
+
return jsonify({"status": "error", "message": "No MIDI file provided"}), 400
|
191 |
+
|
192 |
+
midi_file = request.files['midi_file']
|
193 |
+
start_time = request.form.get('start_time', '0')
|
194 |
+
end_time = request.form.get('end_time', '0')
|
195 |
+
|
196 |
+
# Create temporary MIDI file with random name
|
197 |
+
temp_dir = tempfile.gettempdir()
|
198 |
+
temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid")
|
199 |
+
|
200 |
+
# Save uploaded MIDI to temp file
|
201 |
+
midi_file.save(temp_midi_path)
|
202 |
+
|
203 |
+
# Process MIDI
|
204 |
+
midi = mido.MidiFile(temp_midi_path)
|
205 |
+
model = load_model()
|
206 |
+
changed_melody_midi = change_melody(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p)
|
207 |
+
|
208 |
+
# Save harmonized MIDI (overwriting temp file)
|
209 |
+
changed_melody_midi.save(temp_midi_path)
|
210 |
+
|
211 |
+
# Convert to MusicXML string
|
212 |
+
musicxml_str = midi_to_musicxml(temp_midi_path)
|
213 |
+
|
214 |
+
# Final type verification
|
215 |
+
if not isinstance(musicxml_str, str):
|
216 |
+
raise TypeError(f"Expected string but got {type(musicxml_str)}")
|
217 |
+
|
218 |
+
return jsonify({
|
219 |
+
"status": "success",
|
220 |
+
"musicxml": musicxml_str
|
221 |
+
})
|
222 |
+
|
223 |
+
except Exception as e:
|
224 |
+
print(f"Error processing request: {str(e)}")
|
225 |
+
traceback.print_exc()
|
226 |
+
return jsonify({
|
227 |
+
"status": "error",
|
228 |
+
"message": str(e)
|
229 |
+
}), 400
|
230 |
+
finally:
|
231 |
+
# Clean up temp file
|
232 |
+
if temp_midi_path and os.path.exists(temp_midi_path):
|
233 |
+
try:
|
234 |
+
os.unlink(temp_midi_path)
|
235 |
+
except Exception as e:
|
236 |
+
print(f"Warning: Could not remove {temp_midi_path}: {str(e)}")
|
237 |
+
|
238 |
+
|
239 |
+
if __name__ == '__main__':
|
240 |
+
app.run(debug=True, port=5000)
|
examples/full-score3.mid
ADDED
Binary file (1.36 kB). View file
|
|
examples/strawberry.mid
ADDED
Binary file (24.2 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib == 3.7.1
|
2 |
+
midi2audio == 0.1.1
|
3 |
+
mido == 1.2.10
|
4 |
+
numpy >= 1.22.4
|
5 |
+
torch >= 2.0.1
|
6 |
+
transformers == 4.29.2
|
7 |
+
tqdm == 4.65.0
|
8 |
+
flask==3.1.1
|
9 |
+
flask-cors==5.0.1
|
10 |
+
music21
|
11 |
+
gunicorn
|