manuel-l01 commited on
Commit
572abf8
·
1 Parent(s): 753bd5a

Initial commit

Browse files
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9-slim
2
+
3
+ # Create dedicated user with home directory
4
+ RUN useradd -m -u 1000 user
5
+
6
+ # Set Hugging Face cache to user's writable directory
7
+ ENV HF_HOME=/home/user/.cache/huggingface
8
+ ENV TRANSFORMERS_CACHE=/home/user/.cache/huggingface
9
+
10
+ # Create cache directory with proper permissions
11
+ RUN mkdir -p ${HF_HOME} && chown -R user:user /home/user
12
+
13
+ # Set working directory (app will live here)
14
+ WORKDIR /app
15
+
16
+ # Install dependencies as root
17
+ COPY requirements.txt .
18
+ RUN pip install --no-cache-dir -r requirements.txt gunicorn
19
+
20
+ # Copy app files (maintain ownership)
21
+ COPY --chown=user:user . .
22
+
23
+ RUN rm -rf /root/.cache/pip
24
+
25
+
26
+ # Switch to non-root user
27
+ USER user
28
+
29
+
30
+ EXPOSE 7860
31
+ CMD ["gunicorn", "--workers", "1", "--timeout", "120", "--bind", "0.0.0.0:7860", "api:app"]
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: InScoreAPI
3
- emoji: 📉
4
- colorFrom: green
5
- colorTo: yellow
6
  sdk: docker
7
  pinned: false
8
  ---
 
1
  ---
2
+ title: InScoreAI
3
+ emoji: 📚
4
+ colorFrom: gray
5
+ colorTo: purple
6
  sdk: docker
7
  pinned: false
8
  ---
__pycache__/utils.cpython-311.pyc ADDED
Binary file (229 Bytes). View file
 
agents/__innit__.py ADDED
File without changes
agents/__pycache__/agents.cpython-311.pyc ADDED
Binary file (18.1 kB). View file
 
agents/__pycache__/harmonize.cpython-311.pyc ADDED
Binary file (19.2 kB). View file
 
agents/__pycache__/harmonize.cpython-312.pyc ADDED
Binary file (11 kB). View file
 
agents/__pycache__/utils.cpython-311.pyc ADDED
Binary file (691 Bytes). View file
 
agents/agents.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from anticipation import ops
2
+ from anticipation.sample import generate
3
+ from anticipation.tokenize import extract_instruments
4
+ from anticipation.convert import events_to_midi,midi_to_events, compound_to_midi
5
+ from anticipation.config import *
6
+ from anticipation.vocab import *
7
+ from anticipation.convert import midi_to_compound
8
+ import mido
9
+ from agents.utils import load_midi_metadata
10
+
11
+
12
+ SMALL_MODEL = 'stanford-crfm/music-small-800k' # faster inference, worse sample quality
13
+ MEDIUM_MODEL = 'stanford-crfm/music-medium-800k' # slower inference, better sample quality
14
+ LARGE_MODEL = 'stanford-crfm/music-large-800k' # slowest inference, best sample quality
15
+
16
+
17
+
18
+
19
+ def harmonize_midi(model, midi, start_time, end_time,original_tempo,original_time_sig,top_p):
20
+
21
+ # Turn full midi to events
22
+ events = midi_to_events(midi)
23
+
24
+ print("Midi converted to events")
25
+
26
+ # Get clip from 0 to end of full midi
27
+
28
+ segment = ops.clip(events, 0, ops.max_time(events, seconds=True))
29
+ segment = ops.translate(segment, -ops.min_time(segment, seconds=False))
30
+
31
+ # Extract melody and accompaniment
32
+ events, melody = extract_instruments(segment, [0])
33
+
34
+ print("Melody extracted")
35
+
36
+ print("Start time:", start_time)
37
+ print("End time:", end_time)
38
+
39
+ # Get initial prompt
40
+ history = ops.clip(events, 0, start_time, clip_duration=False)
41
+
42
+ anticipated = [CONTROL_OFFSET + tok for tok in ops.clip(events, end_time, ops.max_time(segment, seconds=True), clip_duration=False)]
43
+
44
+ # Generate accompaniment conditioning on melody
45
+ accompaniment = generate(model, start_time, end_time, inputs=history, controls=melody, top_p=top_p, debug=False)
46
+
47
+ # Append anticipated continuation to accompaniment
48
+ accompaniment = ops.combine(accompaniment, anticipated)
49
+
50
+ print("Accompaniment generated")
51
+
52
+ # 1) render each voice separately
53
+ mel_mid = events_to_midi(melody)
54
+ acc_mid = events_to_midi(accompaniment)
55
+
56
+ # 2) build a fresh MidiFile
57
+ combined = mido.MidiFile()
58
+ combined.ticks_per_beat = mel_mid.ticks_per_beat # or TIME_RESOLUTION//2
59
+
60
+ print("Midi built")
61
+
62
+ # 3) meta‐track with tempo & time signature
63
+ meta = mido.MidiTrack()
64
+ meta.append(mido.MetaMessage('set_tempo', tempo=original_tempo))
65
+ meta.append(mido.MetaMessage('time_signature',
66
+ numerator=original_time_sig[0],
67
+ denominator=original_time_sig[1]))
68
+ combined.tracks.append(meta)
69
+
70
+ # 4) append melody *then* accompaniment
71
+ combined.tracks.extend(mel_mid.tracks[1:]) # Skip existing meta track
72
+ combined.tracks.extend(acc_mid.tracks[1:])
73
+ # 5) save in exactly that order
74
+
75
+ for track in combined.tracks:
76
+ for msg in track:
77
+ if msg.type in ['note_on', 'note_off']:
78
+ # Ensure valid MIDI values
79
+ if hasattr(msg, 'velocity'):
80
+ msg.velocity = min(max(msg.velocity, 0), 127)
81
+ if hasattr(msg, 'note'):
82
+ msg.note = min(max(msg.note, 0), 127)
83
+
84
+ print(f"Melody tracks: {len(mel_mid.tracks)}")
85
+ print(f"Accompaniment tracks: {len(acc_mid.tracks)}")
86
+ print(f"Combined tracks before cleanup: {len(combined.tracks)}")
87
+
88
+ # Add track cleanup (keep only unique tracks):
89
+ unique_tracks = []
90
+ seen = set()
91
+ for track in combined.tracks:
92
+ track_hash = str([msg.hex() for msg in track])
93
+ if track_hash not in seen:
94
+ unique_tracks.append(track)
95
+ seen.add(track_hash)
96
+ combined.tracks = unique_tracks
97
+
98
+ print(f"Final track count: {len(combined.tracks)}")
99
+
100
+ print("Output Midi metadata added")
101
+
102
+ return combined
103
+
104
+
105
+
106
+ def harmonizer(ai_model,midi_file, start_time, end_time,top_p):
107
+ """
108
+ this function harmonizes a melody in a MIDI file
109
+ returns the harmonized MIDI
110
+
111
+ Args:
112
+ midi_file: path to the MIDI file
113
+ start_time: start time of the selected measure (melody you want to harmonize) in milliseconds
114
+ end_time: end time of the selected measure in milliseconds
115
+ """
116
+
117
+ print(f"Original MIDI tracks: {len(midi_file.tracks)}")
118
+
119
+ # Load metadata and model...
120
+
121
+ # Log original note parameters
122
+ for track in midi_file.tracks:
123
+ for msg in track:
124
+ if msg.type in ['note_on', 'note_off']:
125
+ if msg.velocity > 127 or msg.velocity < 0:
126
+ print(f"Invalid velocity: {msg.velocity}")
127
+ if msg.note > 127 or msg.note < 0:
128
+ print(f"Invalid pitch: {msg.note}")
129
+
130
+
131
+ # Load original MIDI and extract metadata
132
+ midi, original_tempo, original_time_sig = load_midi_metadata(midi_file)
133
+
134
+ print("Midi metadata loaded")
135
+
136
+ # load an anticipatory music transformer
137
+ model = ai_model # add .cuda() if you have a GPU
138
+
139
+ print("Model loaded")
140
+
141
+ harmonized_midi = harmonize_midi(model, midi, start_time, end_time, original_tempo,original_time_sig,top_p)
142
+
143
+ print("Midi generated")
144
+
145
+ print(f"Harmonized MIDI tracks: {len(harmonized_midi.tracks)}")
146
+
147
+ # Add MIDI validation
148
+ for track in harmonized_midi.tracks:
149
+ for msg in track:
150
+ if msg.type in ['note_on', 'note_off']:
151
+ # Clamp invalid values
152
+ msg.velocity = min(max(msg.velocity, 0), 127)
153
+ msg.note = min(max(msg.note, 0), 127)
154
+
155
+ print("Midi saved")
156
+
157
+ return harmonized_midi
158
+
159
+ def infill_midi(model, midi, start_time, end_time,original_tempo,original_time_sig,top_p):
160
+
161
+ # Turn full midi to events
162
+ events = midi_to_events(midi)
163
+
164
+ print("Midi converted to events")
165
+
166
+ # Get clip from 0 to end of full midi
167
+
168
+ segment = ops.clip(events, 0, ops.max_time(events, seconds=True))
169
+ segment = ops.translate(segment, -ops.min_time(segment, seconds=False))
170
+
171
+ # Get initial prompt
172
+ history = ops.clip(events, 0, start_time, clip_duration=False)
173
+
174
+ anticipated = [CONTROL_OFFSET + tok for tok in ops.clip(events, end_time, ops.max_time(segment, seconds=True), clip_duration=False)]
175
+
176
+ # Generate accompaniment conditioning on melody
177
+ infilling = generate(model, start_time, end_time, inputs=history, controls=anticipated, top_p=top_p, debug=False)
178
+
179
+ # Append anticipated continuation to accompaniment
180
+ full_events = ops.combine(infilling, anticipated)
181
+
182
+ print("Accompaniment generated")
183
+
184
+ # 1) render each voice separately
185
+ full_mid = events_to_midi(full_events)
186
+
187
+ # 2) build a fresh MidiFile
188
+ combined = mido.MidiFile()
189
+ combined.ticks_per_beat = full_mid.ticks_per_beat # or TIME_RESOLUTION//2
190
+
191
+ print("Midi built")
192
+
193
+ # 3) meta‐track with tempo & time signature
194
+ meta = mido.MidiTrack()
195
+ meta.append(mido.MetaMessage('set_tempo', tempo=original_tempo))
196
+ meta.append(mido.MetaMessage('time_signature',
197
+ numerator=original_time_sig[0],
198
+ denominator=original_time_sig[1]))
199
+ combined.tracks.append(meta)
200
+
201
+ # 4) append melody *then* accompaniment
202
+ combined.tracks.extend(full_mid.tracks[:]) # Skip existing meta track
203
+
204
+ # 5) save in exactly that order
205
+
206
+ for track in combined.tracks:
207
+ for msg in track:
208
+ if msg.type in ['note_on', 'note_off']:
209
+ # Ensure valid MIDI values
210
+ if hasattr(msg, 'velocity'):
211
+ msg.velocity = min(max(msg.velocity, 0), 127)
212
+ if hasattr(msg, 'note'):
213
+ msg.note = min(max(msg.note, 0), 127)
214
+
215
+ print(f"Melody tracks: {len(full_mid.tracks)}")
216
+ print(f"Accompaniment tracks: {len(full_mid.tracks)}")
217
+ print(f"Combined tracks before cleanup: {len(combined.tracks)}")
218
+
219
+ # Add track cleanup (keep only unique tracks):
220
+ unique_tracks = []
221
+ seen = set()
222
+ for track in combined.tracks:
223
+ track_hash = str([msg.hex() for msg in track])
224
+ if track_hash not in seen:
225
+ unique_tracks.append(track)
226
+ seen.add(track_hash)
227
+ combined.tracks = unique_tracks
228
+
229
+ print(f"Final track count: {len(combined.tracks)}")
230
+
231
+ print("Output Midi metadata added")
232
+
233
+ return combined
234
+
235
+
236
+
237
+ def infiller(ai_model,midi_file, start_time, end_time,top_p):
238
+ """
239
+ this function harmonizes a melody in a MIDI file
240
+ returns the harmonized MIDI
241
+
242
+ Args:
243
+ midi_file: path to the MIDI file
244
+ start_time: start time of the selected measure (melody you want to harmonize) in milliseconds
245
+ end_time: end time of the selected measure in milliseconds
246
+ """
247
+
248
+ print(f"Original MIDI tracks: {len(midi_file.tracks)}")
249
+
250
+ # Load metadata and model...
251
+
252
+ # Log original note parameters
253
+ for track in midi_file.tracks:
254
+ for msg in track:
255
+ if msg.type in ['note_on', 'note_off']:
256
+ if msg.velocity > 127 or msg.velocity < 0:
257
+ print(f"Invalid velocity: {msg.velocity}")
258
+ if msg.note > 127 or msg.note < 0:
259
+ print(f"Invalid pitch: {msg.note}")
260
+
261
+
262
+ # Load original MIDI and extract metadata
263
+ midi, original_tempo, original_time_sig = load_midi_metadata(midi_file)
264
+
265
+ print("Midi metadata loaded")
266
+
267
+ # load an anticipatory music transformer
268
+ model = ai_model # add .cuda() if you have a GPU
269
+
270
+ print("Model loaded")
271
+
272
+ infilled_midi = infill_midi(model, midi, start_time, end_time, original_tempo,original_time_sig,top_p)
273
+
274
+ print("Midi generated")
275
+
276
+ print(f"Harmonized MIDI tracks: {len(infilled_midi.tracks)}")
277
+
278
+ # Add MIDI validation
279
+ for track in infilled_midi.tracks:
280
+ for msg in track:
281
+ if msg.type in ['note_on', 'note_off']:
282
+ # Clamp invalid values
283
+ msg.velocity = min(max(msg.velocity, 0), 127)
284
+ msg.note = min(max(msg.note, 0), 127)
285
+
286
+ print("Midi saved")
287
+
288
+ return infilled_midi
289
+
290
+ def change_melody_midi(model, midi, start_time, end_time,original_tempo,original_time_sig,top_p):
291
+
292
+ events = midi_to_events(midi)
293
+ segment = ops.clip(events, 0, ops.max_time(events, seconds=True))
294
+ segment = ops.translate(segment, -ops.min_time(segment, seconds=False))
295
+
296
+ # Extract melody (instrument 0) as events and accompaniment as controls
297
+ instruments = list(ops.get_instruments(segment).keys())
298
+ accompaniment_instruments = [instr for instr in instruments if instr != 0]
299
+ melody_events, accompaniment_controls = extract_instruments(segment, accompaniment_instruments)
300
+
301
+ # Get initial prompt (melody before start_time)
302
+ history = ops.clip(melody_events, 0, start_time, clip_duration=False)
303
+
304
+ # Include accompaniment controls for the entire duration
305
+ controls = accompaniment_controls # Full accompaniment as controls
306
+
307
+ # Generate new melody conditioned on accompaniment
308
+ infilling = generate(model, start_time, end_time, inputs=history, controls=controls, top_p=top_p, debug=False)
309
+
310
+ # Append anticipated continuation
311
+ anticipated_melody = [CONTROL_OFFSET + tok for tok in ops.clip(melody_events, end_time, ops.max_time(segment, seconds=True), clip_duration=False)]
312
+ full_events = ops.combine(infilling, anticipated_melody)
313
+
314
+ acc_mid = events_to_midi(accompaniment_controls)
315
+
316
+
317
+ # Render and combine MIDI tracks
318
+ full_mid = events_to_midi(full_events)
319
+ combined = mido.MidiFile()
320
+ combined.ticks_per_beat = full_mid.ticks_per_beat # or TIME_RESOLUTION//2
321
+
322
+ print("Midi built")
323
+
324
+ # 3) meta‐track with tempo & time signature
325
+ meta = mido.MidiTrack()
326
+ meta.append(mido.MetaMessage('set_tempo', tempo=original_tempo))
327
+ meta.append(mido.MetaMessage('time_signature',
328
+ numerator=original_time_sig[0],
329
+ denominator=original_time_sig[1]))
330
+ combined.tracks.append(meta)
331
+
332
+ # 4) append melody *then* accompaniment
333
+ combined.tracks.extend(full_mid.tracks[:]) # Skip existing meta track
334
+ combined.tracks.extend(acc_mid.tracks[:]) # Skip existing meta track
335
+
336
+ # 5) save in exactly that order
337
+
338
+ for track in combined.tracks:
339
+ for msg in track:
340
+ if msg.type in ['note_on', 'note_off']:
341
+ # Ensure valid MIDI values
342
+ if hasattr(msg, 'velocity'):
343
+ msg.velocity = min(max(msg.velocity, 0), 127)
344
+ if hasattr(msg, 'note'):
345
+ msg.note = min(max(msg.note, 0), 127)
346
+
347
+ print(f"Melody tracks: {len(full_mid.tracks)}")
348
+ print(f"Accompaniment tracks: {len(full_mid.tracks)}")
349
+ print(f"Combined tracks before cleanup: {len(combined.tracks)}")
350
+
351
+ # Add track cleanup (keep only unique tracks):
352
+ unique_tracks = []
353
+ seen = set()
354
+ for track in combined.tracks:
355
+ track_hash = str([msg.hex() for msg in track])
356
+ if track_hash not in seen:
357
+ unique_tracks.append(track)
358
+ seen.add(track_hash)
359
+ combined.tracks = unique_tracks
360
+
361
+ print(f"Final track count: {len(combined.tracks)}")
362
+
363
+ print("Output Midi metadata added")
364
+
365
+ return combined
366
+
367
+
368
+
369
+ def change_melody(ai_model,midi_file, start_time, end_time,top_p):
370
+ """
371
+ this function harmonizes a melody in a MIDI file
372
+ returns the harmonized MIDI
373
+
374
+ Args:
375
+ midi_file: path to the MIDI file
376
+ start_time: start time of the selected measure (melody you want to harmonize) in milliseconds
377
+ end_time: end time of the selected measure in milliseconds
378
+ """
379
+
380
+ print(f"Original MIDI tracks: {len(midi_file.tracks)}")
381
+
382
+ # Load metadata and model...
383
+
384
+ # Log original note parameters
385
+ for track in midi_file.tracks:
386
+ for msg in track:
387
+ if msg.type in ['note_on', 'note_off']:
388
+ if msg.velocity > 127 or msg.velocity < 0:
389
+ print(f"Invalid velocity: {msg.velocity}")
390
+ if msg.note > 127 or msg.note < 0:
391
+ print(f"Invalid pitch: {msg.note}")
392
+
393
+
394
+ # Load original MIDI and extract metadata
395
+ midi, original_tempo, original_time_sig = load_midi_metadata(midi_file)
396
+
397
+ print("Midi metadata loaded")
398
+
399
+ # load an anticipatory music transformer
400
+ model = ai_model # add .cuda() if you have a GPU
401
+
402
+ print("Model loaded")
403
+
404
+ change_melody_gen_midi = change_melody_midi(model, midi, start_time, end_time, original_tempo,original_time_sig,top_p)
405
+
406
+ print("Midi generated")
407
+
408
+ print(f"Harmonized MIDI tracks: {len(change_melody_gen_midi.tracks)}")
409
+
410
+ # Add MIDI validation
411
+ for track in change_melody_gen_midi.tracks:
412
+ for msg in track:
413
+ if msg.type in ['note_on', 'note_off']:
414
+ # Clamp invalid values
415
+ msg.velocity = min(max(msg.velocity, 0), 127)
416
+ msg.note = min(max(msg.note, 0), 127)
417
+
418
+ print("Midi saved")
419
+
420
+
421
+
422
+ return change_melody_gen_midi
agents/utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ def load_midi_metadata(midi_file):
4
+
5
+ original_tempo = 500000 # default tempo (120 BPM)
6
+ original_time_sig = (4, 4) # default time signature
7
+
8
+ for msg in midi_file:
9
+ if msg.type == 'set_tempo':
10
+ original_tempo = msg.tempo
11
+ elif msg.type == 'time_signature':
12
+ original_time_sig = (msg.numerator, msg.denominator)
13
+
14
+ return midi_file, original_tempo, original_time_sig
15
+
anticipation/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """ Infrastructure for constructing anticipatory infilling models.
2
+
3
+ This model provides infrastructure to preprocess Midi music datasets
4
+ for training anticipatory music infilling models. For more context, see:
5
+
6
+ Anticipatory Music Transformer
7
+ John Thickstun, David Hall, Chris Donahue, Percy Liang
8
+ Preprint Report, 2023
9
+ """
anticipation/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (598 Bytes). View file
 
anticipation/__pycache__/config.cpython-311.pyc ADDED
Binary file (2.08 kB). View file
 
anticipation/__pycache__/convert.cpython-311.pyc ADDED
Binary file (19 kB). View file
 
anticipation/__pycache__/ops.cpython-311.pyc ADDED
Binary file (12.1 kB). View file
 
anticipation/__pycache__/sample.cpython-311.pyc ADDED
Binary file (13.6 kB). View file
 
anticipation/__pycache__/tokenize.cpython-311.pyc ADDED
Binary file (12.8 kB). View file
 
anticipation/__pycache__/visuals.cpython-311.pyc ADDED
Binary file (4.02 kB). View file
 
anticipation/__pycache__/vocab.cpython-311.pyc ADDED
Binary file (2.62 kB). View file
 
anticipation/config-original.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Global configuration for anticipatory infilling models.
3
+ """
4
+
5
+ # model hyper-parameters
6
+
7
+ CONTEXT_SIZE = 1024 # model context
8
+ EVENT_SIZE = 3 # each event/control is encoded as 3 tokens
9
+ M = 341 # model context (1024 = 1 + EVENT_SIZE*M)
10
+ DELTA = 5 # anticipation time in seconds
11
+
12
+ assert CONTEXT_SIZE == 1+EVENT_SIZE*M
13
+
14
+ # vocabulary constants
15
+
16
+ MAX_TIME_IN_SECONDS = 100 # exclude very long training sequences
17
+ MAX_DURATION_IN_SECONDS = 10 # maximum duration of a note
18
+ TIME_RESOLUTION = 100 # 10ms time resolution = 100 bins/second
19
+
20
+ MAX_PITCH = 128 # 128 MIDI pitches
21
+ MAX_INSTR = 129 # 129 MIDI instruments (128 + drums)
22
+ MAX_NOTE = MAX_PITCH*MAX_INSTR # note = pitch x instrument
23
+
24
+ MAX_INTERARRIVAL_IN_SECONDS = 10 # maximum interarrival time (for MIDI-like encoding)
25
+
26
+ # preprocessing settings
27
+
28
+ PREPROC_WORKERS = 16
29
+
30
+ COMPOUND_SIZE = 5 # event size in the intermediate compound tokenization
31
+ MAX_TRACK_INSTR = 16 # exclude tracks with large numbers of instruments
32
+ MAX_TRACK_TIME_IN_SECONDS = 3600 # exclude very long tracks (longer than 1 hour)
33
+ MIN_TRACK_TIME_IN_SECONDS = 10 # exclude very short tracks (less than 10 seconds)
34
+ MIN_TRACK_EVENTS = 100 # exclude very short tracks (less than 100 events)
35
+
36
+ # LakhMIDI dataset splits
37
+
38
+ LAKH_SPLITS = ['0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f']
39
+ LAKH_VALID = ['e']
40
+ LAKH_TEST = ['f']
41
+
42
+ # derived quantities
43
+
44
+ MAX_TIME = TIME_RESOLUTION*MAX_TIME_IN_SECONDS
45
+ MAX_DUR = TIME_RESOLUTION*MAX_DURATION_IN_SECONDS
46
+
47
+ MAX_INTERARRIVAL = TIME_RESOLUTION*MAX_INTERARRIVAL_IN_SECONDS
48
+
49
+
50
+ if __name__ == '__main__':
51
+ print('Model constants:')
52
+ print(f' -> anticipation interval: {DELTA}s')
53
+ print('Vocabulary constants:')
54
+ print(f' -> maximum time of a sequence: {MAX_TIME_IN_SECONDS}s')
55
+ print(f' -> maximum duration of a note: {MAX_DURATION_IN_SECONDS}s')
56
+ print(f' -> time resolution: {TIME_RESOLUTION}bins/s ({1000//TIME_RESOLUTION}ms)')
57
+ print(f' -> maximum interarrival-time (MIDI-like encoding): {MAX_INTERARRIVAL_IN_SECONDS}s')
58
+ print('Preprocessing constants:')
59
+ print(f' -> maximum time of a track: {MAX_TRACK_TIME_IN_SECONDS}s')
60
+ print(f' -> minimum events in a track: {MIN_TRACK_EVENTS}s')
anticipation/config.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Global configuration for anticipatory infilling models.
3
+ """
4
+
5
+ # model hyper-parameters
6
+
7
+ CONTEXT_SIZE = 1024 # model context
8
+ EVENT_SIZE = 3 # each event/control is encoded as 3 tokens
9
+ M = 341 # model context (1024 = 1 + EVENT_SIZE*M)
10
+ DELTA = 5 # anticipation time in seconds
11
+
12
+ assert CONTEXT_SIZE == 1+EVENT_SIZE*M
13
+
14
+ # vocabulary constants
15
+
16
+ MAX_TIME_IN_SECONDS = 100 # exclude very long training sequences
17
+ MAX_DURATION_IN_SECONDS = 10 # maximum duration of a note
18
+ TIME_RESOLUTION = 100 # 10ms time resolution = 100 bins/second
19
+
20
+ MAX_PITCH = 128 # 128 MIDI pitches
21
+ MAX_INSTR = 129 # 129 MIDI instruments (128 + drums)
22
+ MAX_NOTE = MAX_PITCH*MAX_INSTR # note = pitch x instrument
23
+
24
+ MAX_INTERARRIVAL_IN_SECONDS = 10 # maximum interarrival time (for MIDI-like encoding)
25
+
26
+ # preprocessing settings
27
+
28
+ PREPROC_WORKERS = 16
29
+
30
+ COMPOUND_SIZE = 5 # event size in the intermediate compound tokenization
31
+ MAX_TRACK_INSTR = 16 # exclude tracks with large numbers of instruments
32
+ MAX_TRACK_TIME_IN_SECONDS = 3600 # exclude very long tracks (longer than 1 hour)
33
+ MIN_TRACK_TIME_IN_SECONDS = 10 # exclude very short tracks (less than 10 seconds)
34
+ MIN_TRACK_EVENTS = 100 # exclude very short tracks (less than 100 events)
35
+
36
+ # LakhMIDI dataset splits
37
+
38
+ LAKH_SPLITS = ['0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f']
39
+ LAKH_VALID = ['e']
40
+ LAKH_TEST = ['f']
41
+
42
+ # derived quantities
43
+
44
+ MAX_TIME = TIME_RESOLUTION*MAX_TIME_IN_SECONDS
45
+ MAX_DUR = TIME_RESOLUTION*MAX_DURATION_IN_SECONDS
46
+
47
+ MAX_INTERARRIVAL = TIME_RESOLUTION*MAX_INTERARRIVAL_IN_SECONDS
48
+
49
+
50
+ if __name__ == '__main__':
51
+ print('Model constants:')
52
+ print(f' -> anticipation interval: {DELTA}s')
53
+ print('Vocabulary constants:')
54
+ print(f' -> maximum time of a sequence: {MAX_TIME_IN_SECONDS}s')
55
+ print(f' -> maximum duration of a note: {MAX_DURATION_IN_SECONDS}s')
56
+ print(f' -> time resolution: {TIME_RESOLUTION}bins/s ({1000//TIME_RESOLUTION}ms)')
57
+ print(f' -> maximum interarrival-time (MIDI-like encoding): {MAX_INTERARRIVAL_IN_SECONDS}s')
58
+ print('Preprocessing constants:')
59
+ print(f' -> maximum time of a track: {MAX_TRACK_TIME_IN_SECONDS}s')
60
+ print(f' -> minimum events in a track: {MIN_TRACK_EVENTS}s')
anticipation/convert-original.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for converting to and from Midi data and encoded/tokenized data.
3
+ """
4
+
5
+ from collections import defaultdict
6
+
7
+ import mido
8
+
9
+ from anticipation.config import *
10
+ from anticipation.vocab import *
11
+ from anticipation.ops import unpad
12
+
13
+
14
+ def midi_to_interarrival(midifile, debug=False, stats=False):
15
+ midi = mido.MidiFile(midifile)
16
+
17
+ tokens = []
18
+ dt = 0
19
+
20
+ instruments = defaultdict(int) # default to code 0 = piano
21
+ tempo = 500000 # default tempo: 500000 microseconds per beat
22
+ truncations = 0
23
+ for message in midi:
24
+ dt += message.time
25
+
26
+ # sanity check: negative time?
27
+ if message.time < 0:
28
+ raise ValueError
29
+
30
+ if message.type == 'program_change':
31
+ instruments[message.channel] = message.program
32
+ elif message.type in ['note_on', 'note_off']:
33
+ delta_ticks = min(round(TIME_RESOLUTION*dt), MAX_INTERARRIVAL-1)
34
+ if delta_ticks != round(TIME_RESOLUTION*dt):
35
+ truncations += 1
36
+
37
+ if delta_ticks > 0: # if time elapsed since last token
38
+ tokens.append(MIDI_TIME_OFFSET + delta_ticks) # add a time step event
39
+
40
+ # special case: channel 9 is drums!
41
+ inst = 128 if message.channel == 9 else instruments[message.channel]
42
+ offset = MIDI_START_OFFSET if message.type == 'note_on' and message.velocity > 0 else MIDI_END_OFFSET
43
+ tokens.append(offset + (2**7)*inst + message.note)
44
+ dt = 0
45
+ elif message.type == 'set_tempo':
46
+ tempo = message.tempo
47
+ elif message.type == 'time_signature':
48
+ pass # we use real time
49
+ elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
50
+ pass # we don't attempt to model these
51
+ elif message.type == 'control_change':
52
+ pass # this includes pedal and per-track volume: ignore for now
53
+ elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
54
+ 'copyright', 'marker', 'instrument_name', 'cue_marker',
55
+ 'device_name', 'sequence_number']:
56
+ pass # possibly useful metadata but ignore for now
57
+ elif message.type == 'channel_prefix':
58
+ pass # relatively common, but can we ignore this?
59
+ elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
60
+ pass # I have no idea what this is
61
+ else:
62
+ if debug:
63
+ print('UNHANDLED MESSAGE', message.type, message)
64
+
65
+ if stats:
66
+ return tokens, truncations
67
+
68
+ return tokens
69
+
70
+
71
+ def interarrival_to_midi(tokens, debug=False):
72
+ mid = mido.MidiFile()
73
+ mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120
74
+
75
+ track_idx = {} # maps instrument to (track number, current time)
76
+ time_in_ticks = 0
77
+ num_tracks = 0
78
+ for token in tokens:
79
+ if token == MIDI_SEPARATOR:
80
+ continue
81
+
82
+ if token < MIDI_START_OFFSET:
83
+ time_in_ticks += token - MIDI_TIME_OFFSET
84
+ elif token < MIDI_END_OFFSET:
85
+ token -= MIDI_START_OFFSET
86
+ instrument = token // 2**7
87
+ pitch = token - (2**7)*instrument
88
+
89
+ try:
90
+ track, previous_time, idx = track_idx[instrument]
91
+ except KeyError:
92
+ idx = num_tracks
93
+ previous_time = 0
94
+ track = mido.MidiTrack()
95
+ mid.tracks.append(track)
96
+ if instrument == 128: # drums always go on channel 9
97
+ idx = 9
98
+ message = mido.Message('program_change', channel=idx, program=0)
99
+ else:
100
+ message = mido.Message('program_change', channel=idx, program=instrument)
101
+ track.append(message)
102
+ num_tracks += 1
103
+ if num_tracks == 9:
104
+ num_tracks += 1 # skip the drums track
105
+
106
+ track.append(mido.Message('note_on', note=pitch, channel=idx, velocity=96, time=time_in_ticks-previous_time))
107
+ track_idx[instrument] = (track, time_in_ticks, idx)
108
+ else:
109
+ token -= MIDI_END_OFFSET
110
+ instrument = token // 2**7
111
+ pitch = token - (2**7)*instrument
112
+
113
+ try:
114
+ track, previous_time, idx = track_idx[instrument]
115
+ except KeyError:
116
+ # shouldn't happen because we should have a corresponding onset
117
+ if debug:
118
+ print('IGNORING bad offset')
119
+
120
+ continue
121
+
122
+ track.append(mido.Message('note_off', note=pitch, channel=idx, time=time_in_ticks-previous_time))
123
+ track_idx[instrument] = (track, time_in_ticks, idx)
124
+
125
+ return mid
126
+
127
+
128
+ def midi_to_compound(midifile, debug=False):
129
+ if type(midifile) == str:
130
+ midi = mido.MidiFile(midifile)
131
+ else:
132
+ midi = midifile
133
+
134
+ tokens = []
135
+ note_idx = 0
136
+ open_notes = defaultdict(list)
137
+
138
+ time = 0
139
+ instruments = defaultdict(int) # default to code 0 = piano
140
+ tempo = 500000 # default tempo: 500000 microseconds per beat
141
+ for message in midi:
142
+ time += message.time
143
+
144
+ # sanity check: negative time?
145
+ if message.time < 0:
146
+ raise ValueError
147
+
148
+ if message.type == 'program_change':
149
+ instruments[message.channel] = message.program
150
+ elif message.type in ['note_on', 'note_off']:
151
+ # special case: channel 9 is drums!
152
+ instr = 128 if message.channel == 9 else instruments[message.channel]
153
+
154
+ if message.type == 'note_on' and message.velocity > 0: # onset
155
+ # time quantization
156
+ time_in_ticks = round(TIME_RESOLUTION*time)
157
+
158
+ # Our compound word is: (time, duration, note, instr, velocity)
159
+ tokens.append(time_in_ticks) # 5ms resolution
160
+ tokens.append(-1) # placeholder (we'll fill this in later)
161
+ tokens.append(message.note)
162
+ tokens.append(instr)
163
+ tokens.append(message.velocity)
164
+
165
+ open_notes[(instr,message.note,message.channel)].append((note_idx, time))
166
+ note_idx += 1
167
+ else: # offset
168
+ try:
169
+ open_idx, onset_time = open_notes[(instr,message.note,message.channel)].pop(0)
170
+ except IndexError:
171
+ if debug:
172
+ print('WARNING: ignoring bad offset')
173
+ else:
174
+ duration_ticks = round(TIME_RESOLUTION*(time-onset_time))
175
+ tokens[5*open_idx + 1] = duration_ticks
176
+ #del open_notes[(instr,message.note,message.channel)]
177
+ elif message.type == 'set_tempo':
178
+ tempo = message.tempo
179
+ elif message.type == 'time_signature':
180
+ pass # we use real time
181
+ elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
182
+ pass # we don't attempt to model these
183
+ elif message.type == 'control_change':
184
+ pass # this includes pedal and per-track volume: ignore for now
185
+ elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
186
+ 'copyright', 'marker', 'instrument_name', 'cue_marker',
187
+ 'device_name', 'sequence_number']:
188
+ pass # possibly useful metadata but ignore for now
189
+ elif message.type == 'channel_prefix':
190
+ pass # relatively common, but can we ignore this?
191
+ elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
192
+ pass # I have no idea what this is
193
+ else:
194
+ if debug:
195
+ print('UNHANDLED MESSAGE', message.type, message)
196
+
197
+ unclosed_count = 0
198
+ for _,v in open_notes.items():
199
+ unclosed_count += len(v)
200
+
201
+ if debug and unclosed_count > 0:
202
+ print(f'WARNING: {unclosed_count} unclosed notes')
203
+ print(' ', midifile)
204
+
205
+ return tokens
206
+
207
+
208
+ def compound_to_midi(tokens, debug=False):
209
+ mid = mido.MidiFile()
210
+ mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120
211
+
212
+ it = iter(tokens)
213
+ time_index = defaultdict(list)
214
+ for _, (time_in_ticks,duration,note,instrument,velocity) in enumerate(zip(it,it,it,it,it)):
215
+ time_index[(time_in_ticks,0)].append((note, instrument, velocity)) # 0 = onset
216
+ time_index[(time_in_ticks+duration,1)].append((note, instrument, velocity)) # 1 = offset
217
+
218
+ track_idx = {} # maps instrument to (track number, current time)
219
+ num_tracks = 0
220
+ for time_in_ticks, event_type in sorted(time_index.keys()):
221
+ for (note, instrument, velocity) in time_index[(time_in_ticks, event_type)]:
222
+ if event_type == 0: # onset
223
+ try:
224
+ track, previous_time, idx = track_idx[instrument]
225
+ except KeyError:
226
+ idx = num_tracks
227
+ previous_time = 0
228
+ track = mido.MidiTrack()
229
+ mid.tracks.append(track)
230
+ if instrument == 128: # drums always go on channel 9
231
+ idx = 9
232
+ message = mido.Message('program_change', channel=idx, program=0)
233
+ else:
234
+ message = mido.Message('program_change', channel=idx, program=instrument)
235
+ track.append(message)
236
+ num_tracks += 1
237
+ if num_tracks == 9:
238
+ num_tracks += 1 # skip the drums track
239
+
240
+ track.append(mido.Message(
241
+ 'note_on', note=note, channel=idx, velocity=velocity,
242
+ time=time_in_ticks-previous_time))
243
+ track_idx[instrument] = (track, time_in_ticks, idx)
244
+ else: # offset
245
+ try:
246
+ track, previous_time, idx = track_idx[instrument]
247
+ except KeyError:
248
+ # shouldn't happen because we should have a corresponding onset
249
+ if debug:
250
+ print('IGNORING bad offset')
251
+
252
+ continue
253
+
254
+ track.append(mido.Message(
255
+ 'note_off', note=note, channel=idx,
256
+ time=time_in_ticks-previous_time))
257
+ track_idx[instrument] = (track, time_in_ticks, idx)
258
+
259
+ return mid
260
+
261
+
262
+ def compound_to_events(tokens, stats=False):
263
+ assert len(tokens) % 5 == 0
264
+ tokens = tokens.copy()
265
+
266
+ # remove velocities
267
+ del tokens[4::5]
268
+
269
+ # combine (note, instrument)
270
+ assert all(-1 <= tok < 2**7 for tok in tokens[2::4])
271
+ assert all(-1 <= tok < 129 for tok in tokens[3::4])
272
+ tokens[2::4] = [SEPARATOR if note == -1 else MAX_PITCH*instr + note
273
+ for note, instr in zip(tokens[2::4],tokens[3::4])]
274
+ tokens[2::4] = [NOTE_OFFSET + tok for tok in tokens[2::4]]
275
+ del tokens[3::4]
276
+
277
+ # max duration cutoff and set unknown durations to 250ms
278
+ truncations = sum([1 for tok in tokens[1::3] if tok >= MAX_DUR])
279
+ tokens[1::3] = [TIME_RESOLUTION//4 if tok == -1 else min(tok, MAX_DUR-1)
280
+ for tok in tokens[1::3]]
281
+ tokens[1::3] = [DUR_OFFSET + tok for tok in tokens[1::3]]
282
+
283
+ assert min(tokens[0::3]) >= 0
284
+ tokens[0::3] = [TIME_OFFSET + tok for tok in tokens[0::3]]
285
+
286
+ assert len(tokens) % 3 == 0
287
+
288
+ if stats:
289
+ return tokens, truncations
290
+
291
+ return tokens
292
+
293
+
294
+ def events_to_compound(tokens, debug=False):
295
+ tokens = unpad(tokens)
296
+
297
+ # move all tokens to zero-offset for synthesis
298
+ tokens = [tok - CONTROL_OFFSET if tok >= CONTROL_OFFSET and tok != SEPARATOR else tok
299
+ for tok in tokens]
300
+
301
+ # remove type offsets
302
+ tokens[0::3] = [tok - TIME_OFFSET if tok != SEPARATOR else tok for tok in tokens[0::3]]
303
+ tokens[1::3] = [tok - DUR_OFFSET if tok != SEPARATOR else tok for tok in tokens[1::3]]
304
+ tokens[2::3] = [tok - NOTE_OFFSET if tok != SEPARATOR else tok for tok in tokens[2::3]]
305
+
306
+ offset = 0 # add max time from previous track for synthesis
307
+ track_max = 0 # keep track of max time in track
308
+ for j, (time,dur,note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])):
309
+ if note == SEPARATOR:
310
+ offset += track_max
311
+ track_max = 0
312
+ if debug:
313
+ print('Sequence Boundary')
314
+ else:
315
+ track_max = max(track_max, time+dur)
316
+ tokens[3*j] += offset
317
+
318
+ # strip sequence separators
319
+ assert len([tok for tok in tokens if tok == SEPARATOR]) % 3 == 0
320
+ tokens = [tok for tok in tokens if tok != SEPARATOR]
321
+
322
+ assert len(tokens) % 3 == 0
323
+ out = 5*(len(tokens)//3)*[0]
324
+ out[0::5] = tokens[0::3]
325
+ out[1::5] = tokens[1::3]
326
+ out[2::5] = [tok - (2**7)*(tok//2**7) for tok in tokens[2::3]]
327
+ out[3::5] = [tok//2**7 for tok in tokens[2::3]]
328
+ out[4::5] = (len(tokens)//3)*[72] # default velocity
329
+
330
+ assert max(out[1::5]) < MAX_DUR
331
+ assert max(out[2::5]) < MAX_PITCH
332
+ assert max(out[3::5]) < MAX_INSTR
333
+ assert all(tok >= 0 for tok in out)
334
+
335
+ return out
336
+
337
+
338
+ def events_to_midi(tokens, debug=False):
339
+ return compound_to_midi(events_to_compound(tokens, debug=debug), debug=debug)
340
+
341
+ def midi_to_events(midifile, debug=False):
342
+ return compound_to_events(midi_to_compound(midifile, debug=debug))
anticipation/convert.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for converting to and from Midi data and encoded/tokenized data.
3
+ """
4
+
5
+ from collections import defaultdict
6
+
7
+ import mido
8
+
9
+ from anticipation.config import *
10
+ from anticipation.vocab import *
11
+ from anticipation.ops import unpad
12
+
13
+
14
+ def midi_to_interarrival(midifile, debug=False, stats=False):
15
+ midi = mido.MidiFile(midifile)
16
+
17
+ tokens = []
18
+ dt = 0
19
+
20
+ instruments = defaultdict(int) # default to code 0 = piano
21
+ tempo = 500000 # default tempo: 500000 microseconds per beat
22
+ truncations = 0
23
+ for message in midi:
24
+ dt += message.time
25
+
26
+ # sanity check: negative time?
27
+ if message.time < 0:
28
+ raise ValueError
29
+
30
+ if message.type == 'program_change':
31
+ instruments[message.channel] = message.program
32
+ elif message.type in ['note_on', 'note_off']:
33
+ delta_ticks = min(round(TIME_RESOLUTION*dt), MAX_INTERARRIVAL-1)
34
+ if delta_ticks != round(TIME_RESOLUTION*dt):
35
+ truncations += 1
36
+
37
+ if delta_ticks > 0: # if time elapsed since last token
38
+ tokens.append(MIDI_TIME_OFFSET + delta_ticks) # add a time step event
39
+
40
+ # special case: channel 9 is drums!
41
+ inst = 128 if message.channel == 9 else instruments[message.channel]
42
+ offset = MIDI_START_OFFSET if message.type == 'note_on' and message.velocity > 0 else MIDI_END_OFFSET
43
+ tokens.append(offset + (2**7)*inst + message.note)
44
+ dt = 0
45
+ elif message.type == 'set_tempo':
46
+ tempo = message.tempo
47
+ elif message.type == 'time_signature':
48
+ pass # we use real time
49
+ elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
50
+ pass # we don't attempt to model these
51
+ elif message.type == 'control_change':
52
+ pass # this includes pedal and per-track volume: ignore for now
53
+ elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
54
+ 'copyright', 'marker', 'instrument_name', 'cue_marker',
55
+ 'device_name', 'sequence_number']:
56
+ pass # possibly useful metadata but ignore for now
57
+ elif message.type == 'channel_prefix':
58
+ pass # relatively common, but can we ignore this?
59
+ elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
60
+ pass # I have no idea what this is
61
+ else:
62
+ if debug:
63
+ print('UNHANDLED MESSAGE', message.type, message)
64
+
65
+ if stats:
66
+ return tokens, truncations
67
+
68
+ return tokens
69
+
70
+
71
+ def interarrival_to_midi(tokens, debug=False):
72
+ mid = mido.MidiFile()
73
+ mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120
74
+
75
+ track_idx = {} # maps instrument to (track number, current time)
76
+ time_in_ticks = 0
77
+ num_tracks = 0
78
+ for token in tokens:
79
+ if token == MIDI_SEPARATOR:
80
+ continue
81
+
82
+ if token < MIDI_START_OFFSET:
83
+ time_in_ticks += token - MIDI_TIME_OFFSET
84
+ elif token < MIDI_END_OFFSET:
85
+ token -= MIDI_START_OFFSET
86
+ instrument = token // 2**7
87
+ pitch = token - (2**7)*instrument
88
+
89
+ try:
90
+ track, previous_time, idx = track_idx[instrument]
91
+ except KeyError:
92
+ idx = num_tracks
93
+ previous_time = 0
94
+ track = mido.MidiTrack()
95
+ mid.tracks.append(track)
96
+ if instrument == 128: # drums always go on channel 9
97
+ idx = 9
98
+ message = mido.Message('program_change', channel=idx, program=0)
99
+ else:
100
+ message = mido.Message('program_change', channel=idx, program=instrument)
101
+ track.append(message)
102
+ num_tracks += 1
103
+ if num_tracks == 9:
104
+ num_tracks += 1 # skip the drums track
105
+
106
+ track.append(mido.Message('note_on', note=pitch, channel=idx, velocity=96, time=time_in_ticks-previous_time))
107
+ track_idx[instrument] = (track, time_in_ticks, idx)
108
+ else:
109
+ token -= MIDI_END_OFFSET
110
+ instrument = token // 2**7
111
+ pitch = token - (2**7)*instrument
112
+
113
+ try:
114
+ track, previous_time, idx = track_idx[instrument]
115
+ except KeyError:
116
+ # shouldn't happen because we should have a corresponding onset
117
+ if debug:
118
+ print('IGNORING bad offset')
119
+
120
+ continue
121
+
122
+ track.append(mido.Message('note_off', note=pitch, channel=idx, time=time_in_ticks-previous_time))
123
+ track_idx[instrument] = (track, time_in_ticks, idx)
124
+
125
+ return mid
126
+
127
+
128
+ def midi_to_compound(midifile, debug=False):
129
+ if type(midifile) == str:
130
+ midi = mido.MidiFile(midifile)
131
+ else:
132
+ midi = midifile
133
+
134
+ tokens = []
135
+ note_idx = 0
136
+ open_notes = defaultdict(list)
137
+
138
+ time = 0
139
+ instruments = defaultdict(lambda: {'program': 0, 'channel': None}) # Track channel assignments
140
+ next_channel = 0
141
+
142
+ tempo = 500000 # default tempo: 500000 microseconds per beat
143
+ for message in midi:
144
+ time += message.time
145
+
146
+ # sanity check: negative time?
147
+ if message.time < 0:
148
+ raise ValueError
149
+
150
+ if message.type == 'program_change':
151
+ # Reserve channels 0-8, 10-15 (skip 9 for drums)
152
+ if message.channel != 9 and message.channel not in instruments:
153
+ instruments[message.channel]['program'] = message.program
154
+ instruments[message.channel]['channel'] = next_channel
155
+ next_channel += 1
156
+ if next_channel == 9: # Skip channel 9 (drums)
157
+ next_channel = 10
158
+ elif message.type in ['note_on', 'note_off']:
159
+ # special case: channel 9 is drums!
160
+ instr = 128 if message.channel == 9 else instruments[message.channel]['program']
161
+ channel = 9 if message.channel == 9 else instruments[message.channel]['channel']
162
+ compound_instr = (instr << 4) | channel
163
+ if message.type == 'note_on' and message.velocity > 0: # onset
164
+ # time quantization
165
+ time_in_ticks = round(TIME_RESOLUTION*time)
166
+
167
+ # Our compound word is: (time, duration, note, instr, velocity)
168
+ tokens.append(time_in_ticks) # 5ms resolution
169
+ tokens.append(-1) # placeholder (we'll fill this in later)
170
+ tokens.append(message.note)
171
+ tokens.append(compound_instr)
172
+ tokens.append(message.velocity)
173
+
174
+ open_notes[(instr,message.note,message.channel)].append((note_idx, time))
175
+ note_idx += 1
176
+ else: # offset
177
+ try:
178
+ open_idx, onset_time = open_notes[(instr,message.note,message.channel)].pop(0)
179
+ except IndexError:
180
+ if debug:
181
+ print('WARNING: ignoring bad offset')
182
+ else:
183
+ duration_ticks = round(TIME_RESOLUTION*(time-onset_time))
184
+ tokens[5*open_idx + 1] = duration_ticks
185
+ #del open_notes[(instr,message.note,message.channel)]
186
+ elif message.type == 'set_tempo':
187
+ tempo = message.tempo
188
+ elif message.type == 'time_signature':
189
+ pass # we use real time
190
+ elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
191
+ pass # we don't attempt to model these
192
+ elif message.type == 'control_change':
193
+ pass # this includes pedal and per-track volume: ignore for now
194
+ elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
195
+ 'copyright', 'marker', 'instrument_name', 'cue_marker',
196
+ 'device_name', 'sequence_number']:
197
+ pass # possibly useful metadata but ignore for now
198
+ elif message.type == 'channel_prefix':
199
+ pass # relatively common, but can we ignore this?
200
+ elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
201
+ pass # I have no idea what this is
202
+ else:
203
+ if debug:
204
+ print('UNHANDLED MESSAGE', message.type, message)
205
+
206
+ unclosed_count = 0
207
+ for _,v in open_notes.items():
208
+ unclosed_count += len(v)
209
+
210
+ if debug and unclosed_count > 0:
211
+ print(f'WARNING: {unclosed_count} unclosed notes')
212
+ print(' ', midifile)
213
+
214
+ return tokens
215
+
216
+
217
+ def compound_to_midi(tokens, debug=False):
218
+ mid = mido.MidiFile()
219
+ mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120
220
+
221
+ tracks = {}
222
+ for token in tokens:
223
+ # Decode program and channel
224
+ program = (token >> 4) & 0x7F
225
+ channel = token & 0x0F
226
+
227
+ if (program, channel) not in tracks:
228
+ track = mido.MidiTrack()
229
+ mid.tracks.append(track)
230
+ tracks[(program, channel)] = track
231
+ track.append(mido.Message('program_change',
232
+ program=program,
233
+ channel=channel))
234
+
235
+ it = iter(tokens)
236
+ time_index = defaultdict(list)
237
+ for _, (time_in_ticks,duration,note,instrument,velocity) in enumerate(zip(it,it,it,it,it)):
238
+ time_index[(time_in_ticks,0)].append((note, instrument, velocity)) # 0 = onset
239
+ time_index[(time_in_ticks+duration,1)].append((note, instrument, velocity)) # 1 = offset
240
+
241
+ track_idx = {} # maps instrument to (track number, current time)
242
+ num_tracks = 0
243
+ for time_in_ticks, event_type in sorted(time_index.keys()):
244
+ for (note, instrument, velocity) in time_index[(time_in_ticks, event_type)]:
245
+ if event_type == 0: # onset
246
+ try:
247
+ track, previous_time, idx = track_idx[instrument]
248
+ except KeyError:
249
+ idx = num_tracks
250
+ previous_time = 0
251
+ track = mido.MidiTrack()
252
+ mid.tracks.append(track)
253
+ if instrument == 128: # drums always go on channel 9
254
+ idx = 9
255
+ message = mido.Message('program_change', channel=idx, program=0)
256
+ else:
257
+ message = mido.Message('program_change', channel=idx, program=instrument)
258
+ track.append(message)
259
+ num_tracks += 1
260
+ if num_tracks == 9:
261
+ num_tracks += 1 # skip the drums track
262
+
263
+ track.append(mido.Message(
264
+ 'note_on', note=note, channel=idx, velocity=velocity,
265
+ time=time_in_ticks-previous_time))
266
+ track_idx[instrument] = (track, time_in_ticks, idx)
267
+ else: # offset
268
+ try:
269
+ track, previous_time, idx = track_idx[instrument]
270
+ except KeyError:
271
+ # shouldn't happen because we should have a corresponding onset
272
+ if debug:
273
+ print('IGNORING bad offset')
274
+
275
+ continue
276
+
277
+ track.append(mido.Message(
278
+ 'note_off', note=note, channel=idx,
279
+ time=time_in_ticks-previous_time))
280
+ track_idx[instrument] = (track, time_in_ticks, idx)
281
+
282
+ return mid
283
+
284
+
285
+ def compound_to_events(tokens, stats=False):
286
+ assert len(tokens) % 5 == 0
287
+ tokens = tokens.copy()
288
+
289
+ # remove velocities
290
+ del tokens[4::5]
291
+
292
+ # combine (note, instrument)
293
+ assert all(-1 <= tok < 2**7 for tok in tokens[2::4])
294
+ assert all(-1 <= tok < 129 for tok in tokens[3::4])
295
+ tokens[2::4] = [SEPARATOR if note == -1 else MAX_PITCH*instr + note
296
+ for note, instr in zip(tokens[2::4],tokens[3::4])]
297
+ tokens[2::4] = [NOTE_OFFSET + tok for tok in tokens[2::4]]
298
+ del tokens[3::4]
299
+
300
+ # max duration cutoff and set unknown durations to 250ms
301
+ truncations = sum([1 for tok in tokens[1::3] if tok >= MAX_DUR])
302
+ tokens[1::3] = [TIME_RESOLUTION//4 if tok == -1 else min(tok, MAX_DUR-1)
303
+ for tok in tokens[1::3]]
304
+ tokens[1::3] = [DUR_OFFSET + tok for tok in tokens[1::3]]
305
+
306
+ assert min(tokens[0::3]) >= 0
307
+ tokens[0::3] = [TIME_OFFSET + tok for tok in tokens[0::3]]
308
+
309
+ assert len(tokens) % 3 == 0
310
+
311
+ if stats:
312
+ return tokens, truncations
313
+
314
+ return tokens
315
+
316
+
317
+ def events_to_compound(tokens, debug=False):
318
+ tokens = unpad(tokens)
319
+
320
+ # move all tokens to zero-offset for synthesis
321
+ tokens = [tok - CONTROL_OFFSET if tok >= CONTROL_OFFSET and tok != SEPARATOR else tok
322
+ for tok in tokens]
323
+
324
+ # remove type offsets
325
+ tokens[0::3] = [tok - TIME_OFFSET if tok != SEPARATOR else tok for tok in tokens[0::3]]
326
+ tokens[1::3] = [tok - DUR_OFFSET if tok != SEPARATOR else tok for tok in tokens[1::3]]
327
+ tokens[2::3] = [tok - NOTE_OFFSET if tok != SEPARATOR else tok for tok in tokens[2::3]]
328
+
329
+ offset = 0 # add max time from previous track for synthesis
330
+ track_max = 0 # keep track of max time in track
331
+ for j, (time,dur,note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])):
332
+ if note == SEPARATOR:
333
+ offset += track_max
334
+ track_max = 0
335
+ if debug:
336
+ print('Sequence Boundary')
337
+ else:
338
+ track_max = max(track_max, time+dur)
339
+ tokens[3*j] += offset
340
+
341
+ # strip sequence separators
342
+ assert len([tok for tok in tokens if tok == SEPARATOR]) % 3 == 0
343
+ tokens = [tok for tok in tokens if tok != SEPARATOR]
344
+
345
+ assert len(tokens) % 3 == 0
346
+ out = 5*(len(tokens)//3)*[0]
347
+ out[0::5] = tokens[0::3]
348
+ out[1::5] = tokens[1::3]
349
+ out[2::5] = [tok - (2**7)*(tok//2**7) for tok in tokens[2::3]]
350
+ out[3::5] = [tok//2**7 for tok in tokens[2::3]]
351
+ out[4::5] = (len(tokens)//3)*[72] # default velocity
352
+
353
+ assert max(out[1::5]) < MAX_DUR
354
+ assert max(out[2::5]) < MAX_PITCH
355
+ assert max(out[3::5]) < MAX_INSTR
356
+ assert all(tok >= 0 for tok in out)
357
+
358
+ return out
359
+
360
+
361
+ def events_to_midi(tokens, debug=False):
362
+ return compound_to_midi(events_to_compound(tokens, debug=debug), debug=debug)
363
+
364
+ def midi_to_events(midifile, debug=False):
365
+ return compound_to_events(midi_to_compound(midifile, debug=debug))
anticipation/ops.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for operating on encoded Midi sequences.
3
+ """
4
+
5
+ from collections import defaultdict
6
+
7
+ from anticipation.config import *
8
+ from anticipation.vocab import *
9
+
10
+
11
+ def print_tokens(tokens):
12
+ print('---------------------')
13
+ for j, (tm, dur, note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])):
14
+ if note == SEPARATOR:
15
+ assert tm == SEPARATOR and dur == SEPARATOR
16
+ print(j, 'SEPARATOR')
17
+ continue
18
+
19
+ if note == REST:
20
+ assert tm < CONTROL_OFFSET
21
+ assert dur == DUR_OFFSET+0
22
+ print(j, tm, 'REST')
23
+ continue
24
+
25
+ if note < CONTROL_OFFSET:
26
+ tm = tm - TIME_OFFSET
27
+ dur = dur - DUR_OFFSET
28
+ note = note - NOTE_OFFSET
29
+ instr = note//2**7
30
+ pitch = note - (2**7)*instr
31
+ print(j, tm, dur, instr, pitch)
32
+ else:
33
+ tm = tm - ATIME_OFFSET
34
+ dur = dur - ADUR_OFFSET
35
+ note = note - ANOTE_OFFSET
36
+ instr = note//2**7
37
+ pitch = note - (2**7)*instr
38
+ print(j, tm, dur, instr, pitch, '(A)')
39
+
40
+
41
+ def clip(tokens, start, end, clip_duration=True, seconds=True):
42
+ if seconds:
43
+ start = int(TIME_RESOLUTION*start)
44
+ end = int(TIME_RESOLUTION*end)
45
+
46
+ new_tokens = []
47
+ for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
48
+ if note < CONTROL_OFFSET:
49
+ this_time = time - TIME_OFFSET
50
+ this_dur = dur - DUR_OFFSET
51
+ else:
52
+ this_time = time - ATIME_OFFSET
53
+ this_dur = dur - ADUR_OFFSET
54
+
55
+ if this_time < start or end < this_time:
56
+ continue
57
+
58
+ # truncate extended notes
59
+ if clip_duration and end < this_time + this_dur:
60
+ dur -= this_time + this_dur - end
61
+
62
+ new_tokens.extend([time, dur, note])
63
+
64
+ return new_tokens
65
+
66
+
67
+ def mask(tokens, start, end):
68
+ new_tokens = []
69
+ for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
70
+ if note < CONTROL_OFFSET:
71
+ this_time = (time - TIME_OFFSET)/float(TIME_RESOLUTION)
72
+ else:
73
+ this_time = (time - ATIME_OFFSET)/float(TIME_RESOLUTION)
74
+
75
+ if start < this_time < end:
76
+ continue
77
+
78
+ new_tokens.extend([time, dur, note])
79
+
80
+ return new_tokens
81
+
82
+
83
+ def delete(tokens, criterion):
84
+ new_tokens = []
85
+ for token in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
86
+ if criterion(token):
87
+ continue
88
+
89
+ new_tokens.extend(token)
90
+
91
+ return new_tokens
92
+
93
+
94
+ def sort(tokens):
95
+ """ sort sequence of events or controls (but not both) """
96
+
97
+ times = tokens[0::3]
98
+ indices = sorted(range(len(times)), key=times.__getitem__)
99
+
100
+ sorted_tokens = []
101
+ for idx in indices:
102
+ sorted_tokens.extend(tokens[3*idx:3*(idx+1)])
103
+
104
+ return sorted_tokens
105
+
106
+
107
+ def split(tokens):
108
+ """ split a sequence into events and controls """
109
+
110
+ events = []
111
+ controls = []
112
+ for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
113
+ if note < CONTROL_OFFSET:
114
+ events.extend([time, dur, note])
115
+ else:
116
+ controls.extend([time, dur, note])
117
+
118
+ return events, controls
119
+
120
+
121
+ def pad(tokens, end_time=None, density=TIME_RESOLUTION):
122
+ end_time = TIME_OFFSET+(end_time if end_time else max_time(tokens, seconds=False))
123
+ new_tokens = []
124
+ previous_time = TIME_OFFSET+0
125
+ for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
126
+ # must pad before separation, anticipation
127
+ assert note < CONTROL_OFFSET
128
+
129
+ # insert pad tokens to ensure the desired density
130
+ while time > previous_time + density:
131
+ new_tokens.extend([previous_time+density, DUR_OFFSET+0, REST])
132
+ previous_time += density
133
+
134
+ new_tokens.extend([time, dur, note])
135
+ previous_time = time
136
+
137
+ while end_time > previous_time + density:
138
+ new_tokens.extend([previous_time+density, DUR_OFFSET+0, REST])
139
+ previous_time += density
140
+
141
+ return new_tokens
142
+
143
+
144
+ def unpad(tokens):
145
+ new_tokens = []
146
+ for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
147
+ if note == REST: continue
148
+
149
+ new_tokens.extend([time, dur, note])
150
+
151
+ return new_tokens
152
+
153
+
154
+ def anticipate(events, controls, delta=DELTA*TIME_RESOLUTION):
155
+ """
156
+ Interleave a sequence of events with anticipated controls.
157
+
158
+ Inputs:
159
+ events : a sequence of events
160
+ controls : a sequence of time-localized controls
161
+ delta : the anticipation interval
162
+
163
+ Returns:
164
+ tokens : interleaved events and anticipated controls
165
+ controls : unconsumed controls (control time > max_time(events) + delta)
166
+ """
167
+
168
+ if len(controls) == 0:
169
+ return events, controls
170
+
171
+ tokens = []
172
+ event_time = 0
173
+ control_time = controls[0] - ATIME_OFFSET
174
+ for time, dur, note in zip(events[0::3],events[1::3],events[2::3]):
175
+ while event_time >= control_time - delta:
176
+ tokens.extend(controls[0:3])
177
+ controls = controls[3:] # consume this control
178
+ control_time = controls[0] - ATIME_OFFSET if len(controls) > 0 else float('inf')
179
+
180
+ assert note < CONTROL_OFFSET
181
+ event_time = time - TIME_OFFSET
182
+ tokens.extend([time, dur, note])
183
+
184
+ return tokens, controls
185
+
186
+
187
+ def sparsity(tokens):
188
+ max_dt = 0
189
+ previous_time = TIME_OFFSET+0
190
+ for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
191
+ if note == SEPARATOR: continue
192
+ assert note < CONTROL_OFFSET # don't operate on interleaved sequences
193
+
194
+ max_dt = max(max_dt, time - previous_time)
195
+ previous_time = time
196
+
197
+ return max_dt
198
+
199
+
200
+ def min_time(tokens, seconds=True, instr=None):
201
+ mt = None
202
+ for time, dur, note in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
203
+ # stop calculating at sequence separator
204
+ if note == SEPARATOR: break
205
+
206
+ if note < CONTROL_OFFSET:
207
+ time -= TIME_OFFSET
208
+ note -= NOTE_OFFSET
209
+ else:
210
+ time -= ATIME_OFFSET
211
+ note -= ANOTE_OFFSET
212
+
213
+ # min time of a particular instrument
214
+ if instr is not None and instr != note//2**7:
215
+ continue
216
+
217
+ mt = time if mt is None else min(mt, time)
218
+
219
+ if mt is None: mt = 0
220
+ return mt/float(TIME_RESOLUTION) if seconds else mt
221
+
222
+
223
+ def max_time(tokens, seconds=True, instr=None):
224
+ mt = 0
225
+ for time, dur, note in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
226
+ # keep checking for max_time, even if it appears after a separator
227
+ # (this is important because we use this check for vocab overflow in tokenization)
228
+ if note == SEPARATOR: continue
229
+
230
+ if note < CONTROL_OFFSET:
231
+ time -= TIME_OFFSET
232
+ note -= NOTE_OFFSET
233
+ else:
234
+ time -= ATIME_OFFSET
235
+ note -= ANOTE_OFFSET
236
+
237
+ # max time of a particular instrument
238
+ if instr is not None and instr != note//2**7:
239
+ continue
240
+
241
+ mt = max(mt, time)
242
+
243
+ return mt/float(TIME_RESOLUTION) if seconds else mt
244
+
245
+
246
+ def get_instruments(tokens):
247
+ instruments = defaultdict(int)
248
+ for time, dur, note in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
249
+ if note >= SPECIAL_OFFSET: continue
250
+
251
+ if note < CONTROL_OFFSET:
252
+ note -= NOTE_OFFSET
253
+ else:
254
+ note -= ANOTE_OFFSET
255
+
256
+ instr = note//2**7
257
+ instruments[instr] += 1
258
+
259
+ return instruments
260
+
261
+
262
+ def translate(tokens, dt, seconds=False):
263
+ if seconds:
264
+ dt = int(TIME_RESOLUTION*dt)
265
+
266
+ new_tokens = []
267
+ for (time, dur, note) in zip(tokens[0::3],tokens[1::3],tokens[2::3]):
268
+ # stop translating after EOT
269
+ if note == SEPARATOR:
270
+ new_tokens.extend([time, dur, note])
271
+ dt = 0
272
+ continue
273
+
274
+ if note < CONTROL_OFFSET:
275
+ this_time = time - TIME_OFFSET
276
+ else:
277
+ this_time = time - ATIME_OFFSET
278
+
279
+ assert 0 <= this_time + dt
280
+ new_tokens.extend([time+dt, dur, note])
281
+
282
+ return new_tokens
283
+
284
+ def combine(events, controls):
285
+ return sort(events + [token - CONTROL_OFFSET for token in controls])
anticipation/sample.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API functions for sampling from anticipatory infilling models.
3
+ """
4
+
5
+ import math
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from tqdm import tqdm
11
+
12
+ from anticipation import ops
13
+ from anticipation.config import *
14
+ from anticipation.vocab import *
15
+
16
+
17
+ def safe_logits(logits, idx):
18
+ logits[CONTROL_OFFSET:SPECIAL_OFFSET] = -float('inf') # don't generate controls
19
+ logits[SPECIAL_OFFSET:] = -float('inf') # don't generate special tokens
20
+
21
+ # don't generate stuff in the wrong time slot
22
+ if idx % 3 == 0:
23
+ logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')
24
+ logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')
25
+ elif idx % 3 == 1:
26
+ logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')
27
+ logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')
28
+ elif idx % 3 == 2:
29
+ logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')
30
+ logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')
31
+
32
+ return logits
33
+
34
+
35
+ def nucleus(logits, top_p):
36
+ # from HF implementation
37
+ if top_p < 1.0:
38
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
39
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
40
+
41
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
42
+ sorted_indices_to_remove = cumulative_probs > top_p
43
+
44
+ # Shift the indices to the right to keep also the first token above the threshold
45
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
46
+ sorted_indices_to_remove[..., 0] = 0
47
+
48
+ # scatter sorted tensors to original indexing
49
+ indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
50
+ logits[indices_to_remove] = -float("inf")
51
+
52
+ return logits
53
+
54
+
55
+ def future_logits(logits, curtime):
56
+ """ don't sample events in the past """
57
+ if curtime > 0:
58
+ logits[TIME_OFFSET:TIME_OFFSET+curtime] = -float('inf')
59
+
60
+ return logits
61
+
62
+
63
+ def instr_logits(logits, full_history):
64
+ """ don't sample more than 16 instruments """
65
+ instrs = ops.get_instruments(full_history)
66
+ if len(instrs) < 15: # 16 - 1 to account for the reserved drum track
67
+ return logits
68
+
69
+ for instr in range(MAX_INSTR):
70
+ if instr not in instrs:
71
+ logits[NOTE_OFFSET+instr*MAX_PITCH:NOTE_OFFSET+(instr+1)*MAX_PITCH] = -float('inf')
72
+
73
+ return logits
74
+
75
+
76
+ def add_token(model, z, tokens, top_p, current_time, debug=False):
77
+ assert len(tokens) % 3 == 0
78
+
79
+ history = tokens.copy()
80
+ lookback = max(len(tokens) - 1017, 0)
81
+ history = history[lookback:] # Markov window
82
+ offset = ops.min_time(history, seconds=False)
83
+ history[::3] = [tok - offset for tok in history[::3]] # relativize time in the history buffer
84
+
85
+ new_token = []
86
+ with torch.no_grad():
87
+ for i in range(3):
88
+ input_tokens = torch.tensor(z + history + new_token).unsqueeze(0).to(model.device)
89
+ logits = model(input_tokens).logits[0,-1]
90
+
91
+ idx = input_tokens.shape[1]-1
92
+ logits = safe_logits(logits, idx)
93
+ if i == 0:
94
+ logits = future_logits(logits, current_time - offset)
95
+ elif i == 2:
96
+ logits = instr_logits(logits, tokens)
97
+ logits = nucleus(logits, top_p)
98
+
99
+ probs = F.softmax(logits, dim=-1)
100
+ token = torch.multinomial(probs, 1)
101
+ new_token.append(int(token))
102
+
103
+ new_token[0] += offset # revert to full sequence timing
104
+ if debug:
105
+ print(f' OFFSET = {offset}, LEN = {len(history)}, TIME = {tokens[::3][-5:]}')
106
+
107
+ return new_token
108
+
109
+
110
+ def generate(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION):
111
+ if inputs is None:
112
+ inputs = []
113
+
114
+ if controls is None:
115
+ controls = []
116
+
117
+ start_time = int(TIME_RESOLUTION*start_time)
118
+ end_time = int(TIME_RESOLUTION*end_time)
119
+
120
+ # prompt is events up to start_time
121
+ prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time)
122
+
123
+ # treat events beyond start_time as controls
124
+ future = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False)
125
+ if debug:
126
+ print('Future')
127
+ ops.print_tokens(future)
128
+
129
+ # clip controls that preceed the sequence
130
+ controls = ops.clip(controls, DELTA, ops.max_time(controls, seconds=False), clip_duration=False, seconds=False)
131
+
132
+ if debug:
133
+ print('Controls')
134
+ ops.print_tokens(controls)
135
+
136
+ z = [ANTICIPATE] if len(controls) > 0 or len(future) > 0 else [AUTOREGRESS]
137
+ if debug:
138
+ print('AR Mode' if z[0] == AUTOREGRESS else 'AAR Mode')
139
+
140
+ # interleave the controls with the events
141
+ tokens, controls = ops.anticipate(prompt, ops.sort(controls + [CONTROL_OFFSET+token for token in future]))
142
+
143
+ if debug:
144
+ print('Prompt')
145
+ ops.print_tokens(tokens)
146
+
147
+ current_time = ops.max_time(prompt, seconds=False)
148
+ if debug:
149
+ print('Current time:', current_time)
150
+
151
+ with tqdm(range(end_time-start_time)) as progress:
152
+ if controls:
153
+ atime, adur, anote = controls[0:3]
154
+ anticipated_tokens = controls[3:]
155
+ anticipated_time = atime - ATIME_OFFSET
156
+ else:
157
+ # nothing to anticipate
158
+ anticipated_time = math.inf
159
+
160
+ while True:
161
+ while current_time >= anticipated_time - delta:
162
+ tokens.extend([atime, adur, anote])
163
+ if debug:
164
+ note = anote - ANOTE_OFFSET
165
+ instr = note//2**7
166
+ print('A', atime - ATIME_OFFSET, adur - ADUR_OFFSET, instr, note - (2**7)*instr)
167
+
168
+ if len(anticipated_tokens) > 0:
169
+ atime, adur, anote = anticipated_tokens[0:3]
170
+ anticipated_tokens = anticipated_tokens[3:]
171
+ anticipated_time = atime - ATIME_OFFSET
172
+ else:
173
+ # nothing more to anticipate
174
+ anticipated_time = math.inf
175
+
176
+ new_token = add_token(model, z, tokens, top_p, max(start_time,current_time))
177
+ new_time = new_token[0] - TIME_OFFSET
178
+ if new_time >= end_time:
179
+ break
180
+
181
+ if debug:
182
+ new_note = new_token[2] - NOTE_OFFSET
183
+ new_instr = new_note//2**7
184
+ new_pitch = new_note - (2**7)*new_instr
185
+ print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch)
186
+
187
+ tokens.extend(new_token)
188
+ dt = new_time - current_time
189
+ assert dt >= 0
190
+ current_time = new_time
191
+ progress.update(dt)
192
+
193
+ events, _ = ops.split(tokens)
194
+ return ops.sort(ops.unpad(events) + future)
195
+
196
+
197
+ def generate_ar(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION):
198
+ if inputs is None:
199
+ inputs = []
200
+
201
+ if controls is None:
202
+ controls = []
203
+ else:
204
+ # treat controls as ordinary tokens
205
+ controls = [token-CONTROL_OFFSET for token in controls]
206
+
207
+ start_time = int(TIME_RESOLUTION*start_time)
208
+ end_time = int(TIME_RESOLUTION*end_time)
209
+
210
+ inputs = ops.sort(inputs + controls)
211
+
212
+ # prompt is events up to start_time
213
+ prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time)
214
+ if debug:
215
+ print('Prompt')
216
+ ops.print_tokens(prompt)
217
+
218
+ # treat events beyond start_time as controls
219
+ controls = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False)
220
+ if debug:
221
+ print('Future')
222
+ ops.print_tokens(controls)
223
+
224
+ z = [AUTOREGRESS]
225
+ if debug:
226
+ print('AR Mode')
227
+
228
+ current_time = ops.max_time(prompt, seconds=False)
229
+ if debug:
230
+ print('Current time:', current_time)
231
+
232
+ tokens = prompt
233
+ with tqdm(range(end_time-start_time)) as progress:
234
+ if controls:
235
+ atime, adur, anote = controls[0:3]
236
+ anticipated_tokens = controls[3:]
237
+ anticipated_time = atime - TIME_OFFSET
238
+ else:
239
+ # nothing to anticipate
240
+ anticipated_time = math.inf
241
+
242
+ while True:
243
+ new_token = add_token(model, z, tokens, top_p, max(start_time,current_time))
244
+ new_time = new_token[0] - TIME_OFFSET
245
+ if new_time >= end_time:
246
+ break
247
+
248
+ dt = new_time - current_time
249
+ assert dt >= 0
250
+ current_time = new_time
251
+
252
+ # backfill anything that should have come before the new token
253
+ while current_time >= anticipated_time:
254
+ tokens.extend([atime, adur, anote])
255
+ if debug:
256
+ note = anote - NOTE_OFFSET
257
+ instr = note//2**7
258
+ print('A', atime - TIME_OFFSET, adur - DUR_OFFSET, instr, note - (2**7)*instr)
259
+
260
+ if len(anticipated_tokens) > 0:
261
+ atime, adur, anote = anticipated_tokens[0:3]
262
+ anticipated_tokens = anticipated_tokens[3:]
263
+ anticipated_time = atime - TIME_OFFSET
264
+ else:
265
+ # nothing more to anticipate
266
+ anticipated_time = math.inf
267
+
268
+ if debug:
269
+ new_note = new_token[2] - NOTE_OFFSET
270
+ new_instr = new_note//2**7
271
+ new_pitch = new_note - (2**7)*new_instr
272
+ print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch)
273
+
274
+ tokens.extend(new_token)
275
+ progress.update(dt)
276
+
277
+ if anticipated_time != math.inf:
278
+ tokens.extend([atime, adur, anote])
279
+
280
+ return ops.sort(ops.unpad(tokens) + controls)
anticipation/tokenize.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Top-level functions for preprocessing data to be used for training.
3
+ """
4
+
5
+ from tqdm import tqdm
6
+
7
+ import numpy as np
8
+
9
+ from anticipation import ops
10
+ from anticipation.config import *
11
+ from anticipation.vocab import *
12
+ from anticipation.convert import compound_to_events, midi_to_interarrival
13
+
14
+
15
+ def extract_spans(all_events, rate):
16
+ events = []
17
+ controls = []
18
+ span = True
19
+ next_span = end_span = TIME_OFFSET+0
20
+ for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]):
21
+ assert(note not in [SEPARATOR, REST]) # shouldn't be in the sequence yet
22
+
23
+ # end of an anticipated span; decide when to do it again (next_span)
24
+ if span and time >= end_span:
25
+ span = False
26
+ next_span = time+int(TIME_RESOLUTION*np.random.exponential(1./rate))
27
+
28
+ # anticipate a 3-second span
29
+ if (not span) and time >= next_span:
30
+ span = True
31
+ end_span = time + DELTA*TIME_RESOLUTION
32
+
33
+ if span:
34
+ # mark this event as a control
35
+ controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note])
36
+ else:
37
+ events.extend([time, dur, note])
38
+
39
+ return events, controls
40
+
41
+
42
+ ANTICIPATION_RATES = 10
43
+ def extract_random(all_events, rate):
44
+ events = []
45
+ controls = []
46
+ for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]):
47
+ assert(note not in [SEPARATOR, REST]) # shouldn't be in the sequence yet
48
+
49
+ if np.random.random() < rate/float(ANTICIPATION_RATES):
50
+ # mark this event as a control
51
+ controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note])
52
+ else:
53
+ events.extend([time, dur, note])
54
+
55
+ return events, controls
56
+
57
+
58
+ def extract_instruments(all_events, instruments):
59
+ events = []
60
+ controls = []
61
+ for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]):
62
+ assert note < CONTROL_OFFSET # shouldn't be in the sequence yet
63
+ assert note not in [SEPARATOR, REST] # these shouldn't either
64
+
65
+ instr = (note-NOTE_OFFSET)//2**7
66
+ if instr in instruments:
67
+ # mark this event as a control
68
+ controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note])
69
+ else:
70
+ events.extend([time, dur, note])
71
+
72
+ return events, controls
73
+
74
+
75
+ def maybe_tokenize(compound_tokens):
76
+ # skip sequences with very few events
77
+ if len(compound_tokens) < COMPOUND_SIZE*MIN_TRACK_EVENTS:
78
+ return None, None, 1 # short track
79
+
80
+ events, truncations = compound_to_events(compound_tokens, stats=True)
81
+ end_time = ops.max_time(events, seconds=False)
82
+
83
+ # don't want to deal with extremely short tracks
84
+ if end_time < TIME_RESOLUTION*MIN_TRACK_TIME_IN_SECONDS:
85
+ return None, None, 1 # short track
86
+
87
+ # don't want to deal with extremely long tracks
88
+ if end_time > TIME_RESOLUTION*MAX_TRACK_TIME_IN_SECONDS:
89
+ return None, None, 2 # long track
90
+
91
+ # skip sequences more instruments than MIDI channels (16)
92
+ if len(ops.get_instruments(events)) > MAX_TRACK_INSTR:
93
+ return None, None, 3 # too many instruments
94
+
95
+ return events, truncations, 0
96
+
97
+
98
+ def tokenize_ia(datafiles, output, augment_factor, idx=0, debug=False):
99
+ assert augment_factor == 1 # can't augment interarrival-tokenized data
100
+
101
+ all_truncations = 0
102
+ seqcount = rest_count = 0
103
+ stats = 4*[0] # (short, long, too many instruments, inexpressible)
104
+ np.random.seed(0)
105
+
106
+ with open(output, 'w') as outfile:
107
+ concatenated_tokens = []
108
+ for j, filename in tqdm(list(enumerate(datafiles)), desc=f'#{idx}', position=idx+1, leave=True):
109
+ with open(filename, 'r') as f:
110
+ _, _, status = maybe_tokenize([int(token) for token in f.read().split()])
111
+
112
+ if status > 0:
113
+ stats[status-1] += 1
114
+ continue
115
+
116
+ filename = filename[:-len('.compound.txt')] # get the original MIDI
117
+
118
+ # already parsed; shouldn't raise an exception
119
+ tokens, truncations = midi_to_interarrival(filename, stats=True)
120
+ tokens[0:0] = [MIDI_SEPARATOR]
121
+ concatenated_tokens.extend(tokens)
122
+ all_truncations += truncations
123
+
124
+ # write out full sequences to file
125
+ while len(concatenated_tokens) >= CONTEXT_SIZE:
126
+ seq = concatenated_tokens[0:CONTEXT_SIZE]
127
+ concatenated_tokens = concatenated_tokens[CONTEXT_SIZE:]
128
+ outfile.write(' '.join([str(tok) for tok in seq]) + '\n')
129
+ seqcount += 1
130
+
131
+ if debug:
132
+ fmt = 'Processed {} sequences (discarded {} tracks, discarded {} seqs, added {} rest tokens)'
133
+ print(fmt.format(seqcount, stats[0]+stats[1]+stats[2], stats[3], rest_count))
134
+
135
+ return (seqcount, rest_count, stats[0], stats[1], stats[2], stats[3], all_truncations)
136
+
137
+
138
+ def tokenize(datafiles, output, augment_factor, idx=0, debug=False):
139
+ tokens = []
140
+ all_truncations = 0
141
+ seqcount = rest_count = 0
142
+ stats = 4*[0] # (short, long, too many instruments, inexpressible)
143
+ np.random.seed(0)
144
+
145
+ with open(output, 'w') as outfile:
146
+ concatenated_tokens = []
147
+ for j, filename in tqdm(list(enumerate(datafiles)), desc=f'#{idx}', position=idx+1, leave=True):
148
+ with open(filename, 'r') as f:
149
+ all_events, truncations, status = maybe_tokenize([int(token) for token in f.read().split()])
150
+
151
+ if status > 0:
152
+ stats[status-1] += 1
153
+ continue
154
+
155
+ instruments = list(ops.get_instruments(all_events).keys())
156
+ end_time = ops.max_time(all_events, seconds=False)
157
+
158
+ # different random augmentations
159
+ for k in range(augment_factor):
160
+ if k % 10 == 0:
161
+ # no augmentation
162
+ events = all_events.copy()
163
+ controls = []
164
+ elif k % 10 == 1:
165
+ # span augmentation
166
+ lmbda = .05
167
+ events, controls = extract_spans(all_events, lmbda)
168
+ elif k % 10 < 6:
169
+ # random augmentation
170
+ r = np.random.randint(1,ANTICIPATION_RATES)
171
+ events, controls = extract_random(all_events, r)
172
+ else:
173
+ if len(instruments) > 1:
174
+ # instrument augmentation: at least one, but not all instruments
175
+ u = 1+np.random.randint(len(instruments)-1)
176
+ subset = np.random.choice(instruments, u, replace=False)
177
+ events, controls = extract_instruments(all_events, subset)
178
+ else:
179
+ # no augmentation
180
+ events = all_events.copy()
181
+ controls = []
182
+
183
+ if len(concatenated_tokens) == 0:
184
+ z = ANTICIPATE if k % 10 != 0 else AUTOREGRESS
185
+
186
+ all_truncations += truncations
187
+ events = ops.pad(events, end_time)
188
+ rest_count += sum(1 if tok == REST else 0 for tok in events[2::3])
189
+ tokens, controls = ops.anticipate(events, controls)
190
+ assert len(controls) == 0 # should have consumed all controls (because of padding)
191
+ tokens[0:0] = [SEPARATOR, SEPARATOR, SEPARATOR]
192
+ concatenated_tokens.extend(tokens)
193
+
194
+ # write out full sequences to file
195
+ while len(concatenated_tokens) >= EVENT_SIZE*M:
196
+ seq = concatenated_tokens[0:EVENT_SIZE*M]
197
+ concatenated_tokens = concatenated_tokens[EVENT_SIZE*M:]
198
+
199
+ # relativize time to the context
200
+ seq = ops.translate(seq, -ops.min_time(seq, seconds=False), seconds=False)
201
+ assert ops.min_time(seq, seconds=False) == 0
202
+ if ops.max_time(seq, seconds=False) >= MAX_TIME:
203
+ stats[3] += 1
204
+ continue
205
+
206
+ # if seq contains SEPARATOR, global controls describe the first sequence
207
+ seq.insert(0, z)
208
+
209
+ outfile.write(' '.join([str(tok) for tok in seq]) + '\n')
210
+ seqcount += 1
211
+
212
+ # grab the current augmentation controls if we didn't already
213
+ z = ANTICIPATE if k % 10 != 0 else AUTOREGRESS
214
+
215
+ if debug:
216
+ fmt = 'Processed {} sequences (discarded {} tracks, discarded {} seqs, added {} rest tokens)'
217
+ print(fmt.format(seqcount, stats[0]+stats[1]+stats[2], stats[3], rest_count))
218
+
219
+ return (seqcount, rest_count, stats[0], stats[1], stats[2], stats[3], all_truncations)
anticipation/visuals.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for inspecting encoded music data.
3
+ """
4
+
5
+ import numpy as np
6
+
7
+ import matplotlib
8
+ import matplotlib.pyplot as plt
9
+
10
+ import anticipation.ops as ops
11
+ from anticipation.config import *
12
+ from anticipation.vocab import *
13
+
14
+ def visualize(tokens, output, selected=None):
15
+ #colors = ['white', 'silver', 'red', 'sienna', 'darkorange', 'gold', 'yellow', 'palegreen', 'seagreen', 'cyan',
16
+ # 'dodgerblue', 'slategray', 'navy', 'mediumpurple', 'mediumorchid', 'magenta', 'lightpink']
17
+ colors = ['white', '#426aa0', '#b26789', '#de9283', '#eac29f', 'silver', 'red', 'sienna', 'darkorange', 'gold', 'yellow', 'palegreen', 'seagreen', 'cyan', 'dodgerblue', 'slategray', 'navy']
18
+
19
+ plt.rcParams['figure.dpi'] = 300
20
+ plt.rcParams['savefig.dpi'] = 300
21
+
22
+ max_time = ops.max_time(tokens, seconds=False)
23
+ grid = np.zeros([max_time, MAX_PITCH])
24
+ instruments = list(sorted(list(ops.get_instruments(tokens).keys())))
25
+ if 128 in instruments:
26
+ instruments.remove(128)
27
+
28
+ for j, (tm, dur, note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])):
29
+ if note == SEPARATOR:
30
+ assert tm == SEPARATOR and dur == SEPARATOR
31
+ print(j, 'SEPARATOR')
32
+ continue
33
+
34
+ if note == REST:
35
+ continue
36
+
37
+ assert note < CONTROL_OFFSET
38
+
39
+ tm = tm - TIME_OFFSET
40
+ dur = dur - DUR_OFFSET
41
+ note = note - NOTE_OFFSET
42
+ instr = note//2**7
43
+ pitch = note - (2**7)*instr
44
+
45
+ if instr == 128: # drums
46
+ continue # we don't visualize this
47
+
48
+ if selected and instr not in selected:
49
+ continue
50
+
51
+ grid[tm:tm+dur, pitch] = 1+instruments.index(instr)
52
+
53
+ plt.clf()
54
+ plt.axis('off')
55
+ cmap = matplotlib.colors.ListedColormap(colors)
56
+ bounds = list(range(MAX_TRACK_INSTR)) + [16]
57
+ norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N)
58
+ plt.imshow(np.flipud(grid.T), aspect=16, cmap=cmap, norm=norm, interpolation='none')
59
+
60
+ patches = [matplotlib.patches.Patch(color=colors[i+1], label=f"{instruments[i]}")
61
+ for i in range(len(instruments))]
62
+ plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0. )
63
+
64
+ plt.tight_layout()
65
+ plt.savefig(output)
anticipation/vocab.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The vocabularies used for arrival-time and interarrival-time encodings.
3
+ """
4
+
5
+ # training sequence vocab
6
+
7
+ from anticipation.config import *
8
+
9
+ # the event block
10
+ EVENT_OFFSET = 0
11
+ TIME_OFFSET = EVENT_OFFSET
12
+ DUR_OFFSET = TIME_OFFSET + MAX_TIME
13
+ NOTE_OFFSET = DUR_OFFSET + MAX_DUR
14
+ REST = NOTE_OFFSET + MAX_NOTE
15
+
16
+ # the control block
17
+ CONTROL_OFFSET = NOTE_OFFSET + MAX_NOTE + 1
18
+ ATIME_OFFSET = CONTROL_OFFSET + 0
19
+ ADUR_OFFSET = ATIME_OFFSET + MAX_TIME
20
+ ANOTE_OFFSET = ADUR_OFFSET + MAX_DUR
21
+
22
+ # the special block
23
+ SPECIAL_OFFSET = ANOTE_OFFSET + MAX_NOTE
24
+ SEPARATOR = SPECIAL_OFFSET
25
+ AUTOREGRESS = SPECIAL_OFFSET + 1
26
+ ANTICIPATE = SPECIAL_OFFSET + 2
27
+ VOCAB_SIZE = ANTICIPATE+1
28
+
29
+ # interarrival-time (MIDI-like) vocab
30
+ MIDI_TIME_OFFSET = 0
31
+ MIDI_START_OFFSET = MIDI_TIME_OFFSET + MAX_INTERARRIVAL
32
+ MIDI_END_OFFSET = MIDI_START_OFFSET + MAX_NOTE
33
+ MIDI_SEPARATOR = MIDI_END_OFFSET + MAX_NOTE
34
+ MIDI_VOCAB_SIZE = MIDI_SEPARATOR + 1
35
+
36
+ if __name__ == '__main__':
37
+ print('Arrival-Time Training Sequence Format:')
38
+ print('Event Offset: ', EVENT_OFFSET)
39
+ print(' -> time offset :', TIME_OFFSET)
40
+ print(' -> duration offset :', DUR_OFFSET)
41
+ print(' -> note offset :', NOTE_OFFSET)
42
+ print(' -> rest token: ', REST)
43
+ print('Anticipated Control Offset: ', CONTROL_OFFSET)
44
+ print(' -> anticipated time offset :', ATIME_OFFSET)
45
+ print(' -> anticipated duration offset :', ADUR_OFFSET)
46
+ print(' -> anticipated note offset :', ANOTE_OFFSET)
47
+ print('Special Token Offset: ', SPECIAL_OFFSET)
48
+ print(' -> separator token: ', SEPARATOR)
49
+ print(' -> autoregression flag: ', AUTOREGRESS)
50
+ print(' -> anticipation flag: ', ANTICIPATE)
51
+ print('Arrival Encoding Vocabulary Size: ', VOCAB_SIZE)
52
+ print('')
53
+ print('Interarrival-Time Training Sequence Format:')
54
+ print(' -> time offset: ', MIDI_TIME_OFFSET)
55
+ print(' -> note-on offset: ', MIDI_START_OFFSET)
56
+ print(' -> note-off offset: ', MIDI_END_OFFSET)
57
+ print(' -> separator token: ', MIDI_SEPARATOR)
58
+ print('Interarrival Encoding Vocabulary Size: ', MIDI_VOCAB_SIZE)
api.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from agents.agents import harmonizer, infiller, change_melody
2
+ from flask import Flask, request, jsonify
3
+ from flask_cors import CORS
4
+ import mido
5
+ import tempfile
6
+ import os
7
+ import music21
8
+ import traceback
9
+ from uuid import uuid4
10
+ import threading
11
+ from transformers import AutoModelForCausalLM
12
+
13
+
14
+ app = Flask(__name__)
15
+ CORS(app)
16
+
17
+ @app.after_request
18
+ def add_cors_headers(response):
19
+ # Allow only your domain
20
+ response.headers['Access-Control-Allow-Origin'] = 'https://https://inscoreai.netlify.app/.com'
21
+ response.headers['Access-Control-Allow-Methods'] = 'GET, POST'
22
+ response.headers['Access-Control-Allow-Headers'] = 'Content-Type'
23
+ return response
24
+
25
+ def midi_to_musicxml(midi_file_path):
26
+ """Convert MIDI file to MusicXML string with absolute safety"""
27
+ try:
28
+ midi_path_str = str(midi_file_path)
29
+
30
+ # Parse and convert to MusicXML
31
+ score = music21.converter.parse(midi_path_str)
32
+
33
+ # Create temporary output file path
34
+ temp_output = os.path.join(tempfile.gettempdir(), f"output_{uuid4().hex}.musicxml")
35
+
36
+ # Write to temporary file
37
+ score.write('musicxml', temp_output)
38
+
39
+ # Read back as string
40
+ with open(temp_output, 'r') as f:
41
+ musicxml_str = f.read()
42
+
43
+ # Clean up
44
+ os.unlink(temp_output)
45
+
46
+ return musicxml_str
47
+ except Exception as e:
48
+ print(f"Conversion error: {str(e)}")
49
+ traceback.print_exc()
50
+ raise
51
+
52
+ def load_model():
53
+ global MODEL
54
+ with MODEL_LOCK:
55
+ if MODEL is None:
56
+ print("⏳ Loading music generation model...")
57
+ MODEL = AutoModelForCausalLM.from_pretrained('stanford-crfm/music-small-800k',local_files_only=True, force_download=False) # Prevent re-downloads
58
+ # Add .cuda() here if using GPU
59
+ print("✅ Model loaded successfully!")
60
+ return MODEL
61
+
62
+ # Model loading setup
63
+ MODEL = None
64
+ MODEL_LOCK = threading.Lock()
65
+
66
+ # Initialize model when app starts
67
+ with app.app_context():
68
+ load_model()
69
+
70
+ @app.route('/upload', methods=['POST'])
71
+ def handle_upload():
72
+ temp_midi_path = None
73
+ top_p = float(request.form.get('top_p', '0.95'))
74
+
75
+ try:
76
+ # Validate input
77
+ if 'midi_file' not in request.files:
78
+ return jsonify({"status": "error", "message": "No MIDI file provided"}), 400
79
+
80
+ midi_file = request.files['midi_file']
81
+ start_time = request.form.get('start_time', '0')
82
+ end_time = request.form.get('end_time', '0')
83
+
84
+ # Create temporary MIDI file with random name
85
+ temp_dir = tempfile.gettempdir()
86
+ temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid")
87
+
88
+ # Save uploaded MIDI to temp file
89
+ midi_file.save(temp_midi_path)
90
+
91
+ # Process MIDI
92
+ midi = mido.MidiFile(temp_midi_path)
93
+ model = load_model()
94
+ harmonized_midi = harmonizer(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p)
95
+
96
+ # Save harmonized MIDI (overwriting temp file)
97
+ harmonized_midi.save(temp_midi_path)
98
+
99
+ # Convert to MusicXML string
100
+ musicxml_str = midi_to_musicxml(temp_midi_path)
101
+
102
+ # Final type verification
103
+ if not isinstance(musicxml_str, str):
104
+ raise TypeError(f"Expected string but got {type(musicxml_str)}")
105
+
106
+ return jsonify({
107
+ "status": "success",
108
+ "musicxml": musicxml_str
109
+ })
110
+
111
+ except Exception as e:
112
+ print(f"Error processing request: {str(e)}")
113
+ traceback.print_exc()
114
+ return jsonify({
115
+ "status": "error",
116
+ "message": str(e)
117
+ }), 400
118
+ finally:
119
+ # Clean up temp file
120
+ if temp_midi_path and os.path.exists(temp_midi_path):
121
+ try:
122
+ os.unlink(temp_midi_path)
123
+ except Exception as e:
124
+ print(f"Warning: Could not remove {temp_midi_path}: {str(e)}")
125
+
126
+ @app.route('/uploadinfill', methods=['POST'])
127
+ def handle_upload_infilling():
128
+ temp_midi_path = None
129
+ top_p = float(request.form.get('top_p', '0.95'))
130
+
131
+ try:
132
+ # Validate input
133
+ if 'midi_file' not in request.files:
134
+ return jsonify({"status": "error", "message": "No MIDI file provided"}), 400
135
+
136
+ midi_file = request.files['midi_file']
137
+ start_time = request.form.get('start_time', '0')
138
+ end_time = request.form.get('end_time', '0')
139
+
140
+ # Create temporary MIDI file with random name
141
+ temp_dir = tempfile.gettempdir()
142
+ temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid")
143
+
144
+ # Save uploaded MIDI to temp file
145
+ midi_file.save(temp_midi_path)
146
+
147
+ # Process MIDI
148
+ midi = mido.MidiFile(temp_midi_path)
149
+ model = load_model()
150
+ infilled_midi = infiller(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p)
151
+
152
+ # Save harmonized MIDI (overwriting temp file)
153
+ infilled_midi.save(temp_midi_path)
154
+
155
+ # Convert to MusicXML string
156
+ musicxml_str = midi_to_musicxml(temp_midi_path)
157
+
158
+ # Final type verification
159
+ if not isinstance(musicxml_str, str):
160
+ raise TypeError(f"Expected string but got {type(musicxml_str)}")
161
+
162
+ return jsonify({
163
+ "status": "success",
164
+ "musicxml": musicxml_str
165
+ })
166
+
167
+ except Exception as e:
168
+ print(f"Error processing request: {str(e)}")
169
+ traceback.print_exc()
170
+ return jsonify({
171
+ "status": "error",
172
+ "message": str(e)
173
+ }), 400
174
+ finally:
175
+ # Clean up temp file
176
+ if temp_midi_path and os.path.exists(temp_midi_path):
177
+ try:
178
+ os.unlink(temp_midi_path)
179
+ except Exception as e:
180
+ print(f"Warning: Could not remove {temp_midi_path}: {str(e)}")
181
+
182
+ @app.route('/uploadchangemelody', methods=['POST'])
183
+ def handle_upload_changemelody():
184
+ temp_midi_path = None
185
+ top_p = float(request.form.get('top_p', '0.95'))
186
+
187
+ try:
188
+ # Validate input
189
+ if 'midi_file' not in request.files:
190
+ return jsonify({"status": "error", "message": "No MIDI file provided"}), 400
191
+
192
+ midi_file = request.files['midi_file']
193
+ start_time = request.form.get('start_time', '0')
194
+ end_time = request.form.get('end_time', '0')
195
+
196
+ # Create temporary MIDI file with random name
197
+ temp_dir = tempfile.gettempdir()
198
+ temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid")
199
+
200
+ # Save uploaded MIDI to temp file
201
+ midi_file.save(temp_midi_path)
202
+
203
+ # Process MIDI
204
+ midi = mido.MidiFile(temp_midi_path)
205
+ model = load_model()
206
+ changed_melody_midi = change_melody(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p)
207
+
208
+ # Save harmonized MIDI (overwriting temp file)
209
+ changed_melody_midi.save(temp_midi_path)
210
+
211
+ # Convert to MusicXML string
212
+ musicxml_str = midi_to_musicxml(temp_midi_path)
213
+
214
+ # Final type verification
215
+ if not isinstance(musicxml_str, str):
216
+ raise TypeError(f"Expected string but got {type(musicxml_str)}")
217
+
218
+ return jsonify({
219
+ "status": "success",
220
+ "musicxml": musicxml_str
221
+ })
222
+
223
+ except Exception as e:
224
+ print(f"Error processing request: {str(e)}")
225
+ traceback.print_exc()
226
+ return jsonify({
227
+ "status": "error",
228
+ "message": str(e)
229
+ }), 400
230
+ finally:
231
+ # Clean up temp file
232
+ if temp_midi_path and os.path.exists(temp_midi_path):
233
+ try:
234
+ os.unlink(temp_midi_path)
235
+ except Exception as e:
236
+ print(f"Warning: Could not remove {temp_midi_path}: {str(e)}")
237
+
238
+
239
+ if __name__ == '__main__':
240
+ app.run(debug=True, port=5000)
examples/full-score3.mid ADDED
Binary file (1.36 kB). View file
 
examples/strawberry.mid ADDED
Binary file (24.2 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib == 3.7.1
2
+ midi2audio == 0.1.1
3
+ mido == 1.2.10
4
+ numpy >= 1.22.4
5
+ torch >= 2.0.1
6
+ transformers == 4.29.2
7
+ tqdm == 4.65.0
8
+ flask==3.1.1
9
+ flask-cors==5.0.1
10
+ music21
11
+ gunicorn