manuel-l01's picture
Initial commit
572abf8
"""
Utilities for converting to and from Midi data and encoded/tokenized data.
"""
from collections import defaultdict
import mido
from anticipation.config import *
from anticipation.vocab import *
from anticipation.ops import unpad
def midi_to_interarrival(midifile, debug=False, stats=False):
midi = mido.MidiFile(midifile)
tokens = []
dt = 0
instruments = defaultdict(int) # default to code 0 = piano
tempo = 500000 # default tempo: 500000 microseconds per beat
truncations = 0
for message in midi:
dt += message.time
# sanity check: negative time?
if message.time < 0:
raise ValueError
if message.type == 'program_change':
instruments[message.channel] = message.program
elif message.type in ['note_on', 'note_off']:
delta_ticks = min(round(TIME_RESOLUTION*dt), MAX_INTERARRIVAL-1)
if delta_ticks != round(TIME_RESOLUTION*dt):
truncations += 1
if delta_ticks > 0: # if time elapsed since last token
tokens.append(MIDI_TIME_OFFSET + delta_ticks) # add a time step event
# special case: channel 9 is drums!
inst = 128 if message.channel == 9 else instruments[message.channel]
offset = MIDI_START_OFFSET if message.type == 'note_on' and message.velocity > 0 else MIDI_END_OFFSET
tokens.append(offset + (2**7)*inst + message.note)
dt = 0
elif message.type == 'set_tempo':
tempo = message.tempo
elif message.type == 'time_signature':
pass # we use real time
elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
pass # we don't attempt to model these
elif message.type == 'control_change':
pass # this includes pedal and per-track volume: ignore for now
elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
'copyright', 'marker', 'instrument_name', 'cue_marker',
'device_name', 'sequence_number']:
pass # possibly useful metadata but ignore for now
elif message.type == 'channel_prefix':
pass # relatively common, but can we ignore this?
elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
pass # I have no idea what this is
else:
if debug:
print('UNHANDLED MESSAGE', message.type, message)
if stats:
return tokens, truncations
return tokens
def interarrival_to_midi(tokens, debug=False):
mid = mido.MidiFile()
mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120
track_idx = {} # maps instrument to (track number, current time)
time_in_ticks = 0
num_tracks = 0
for token in tokens:
if token == MIDI_SEPARATOR:
continue
if token < MIDI_START_OFFSET:
time_in_ticks += token - MIDI_TIME_OFFSET
elif token < MIDI_END_OFFSET:
token -= MIDI_START_OFFSET
instrument = token // 2**7
pitch = token - (2**7)*instrument
try:
track, previous_time, idx = track_idx[instrument]
except KeyError:
idx = num_tracks
previous_time = 0
track = mido.MidiTrack()
mid.tracks.append(track)
if instrument == 128: # drums always go on channel 9
idx = 9
message = mido.Message('program_change', channel=idx, program=0)
else:
message = mido.Message('program_change', channel=idx, program=instrument)
track.append(message)
num_tracks += 1
if num_tracks == 9:
num_tracks += 1 # skip the drums track
track.append(mido.Message('note_on', note=pitch, channel=idx, velocity=96, time=time_in_ticks-previous_time))
track_idx[instrument] = (track, time_in_ticks, idx)
else:
token -= MIDI_END_OFFSET
instrument = token // 2**7
pitch = token - (2**7)*instrument
try:
track, previous_time, idx = track_idx[instrument]
except KeyError:
# shouldn't happen because we should have a corresponding onset
if debug:
print('IGNORING bad offset')
continue
track.append(mido.Message('note_off', note=pitch, channel=idx, time=time_in_ticks-previous_time))
track_idx[instrument] = (track, time_in_ticks, idx)
return mid
def midi_to_compound(midifile, debug=False):
if type(midifile) == str:
midi = mido.MidiFile(midifile)
else:
midi = midifile
tokens = []
note_idx = 0
open_notes = defaultdict(list)
time = 0
instruments = defaultdict(lambda: {'program': 0, 'channel': None}) # Track channel assignments
next_channel = 0
tempo = 500000 # default tempo: 500000 microseconds per beat
for message in midi:
time += message.time
# sanity check: negative time?
if message.time < 0:
raise ValueError
if message.type == 'program_change':
# Reserve channels 0-8, 10-15 (skip 9 for drums)
if message.channel != 9 and message.channel not in instruments:
instruments[message.channel]['program'] = message.program
instruments[message.channel]['channel'] = next_channel
next_channel += 1
if next_channel == 9: # Skip channel 9 (drums)
next_channel = 10
elif message.type in ['note_on', 'note_off']:
# special case: channel 9 is drums!
instr = 128 if message.channel == 9 else instruments[message.channel]['program']
channel = 9 if message.channel == 9 else instruments[message.channel]['channel']
compound_instr = (instr << 4) | channel
if message.type == 'note_on' and message.velocity > 0: # onset
# time quantization
time_in_ticks = round(TIME_RESOLUTION*time)
# Our compound word is: (time, duration, note, instr, velocity)
tokens.append(time_in_ticks) # 5ms resolution
tokens.append(-1) # placeholder (we'll fill this in later)
tokens.append(message.note)
tokens.append(compound_instr)
tokens.append(message.velocity)
open_notes[(instr,message.note,message.channel)].append((note_idx, time))
note_idx += 1
else: # offset
try:
open_idx, onset_time = open_notes[(instr,message.note,message.channel)].pop(0)
except IndexError:
if debug:
print('WARNING: ignoring bad offset')
else:
duration_ticks = round(TIME_RESOLUTION*(time-onset_time))
tokens[5*open_idx + 1] = duration_ticks
#del open_notes[(instr,message.note,message.channel)]
elif message.type == 'set_tempo':
tempo = message.tempo
elif message.type == 'time_signature':
pass # we use real time
elif message.type in ['aftertouch', 'polytouch', 'pitchwheel', 'sequencer_specific']:
pass # we don't attempt to model these
elif message.type == 'control_change':
pass # this includes pedal and per-track volume: ignore for now
elif message.type in ['track_name', 'text', 'end_of_track', 'lyrics', 'key_signature',
'copyright', 'marker', 'instrument_name', 'cue_marker',
'device_name', 'sequence_number']:
pass # possibly useful metadata but ignore for now
elif message.type == 'channel_prefix':
pass # relatively common, but can we ignore this?
elif message.type in ['midi_port', 'smpte_offset', 'sysex']:
pass # I have no idea what this is
else:
if debug:
print('UNHANDLED MESSAGE', message.type, message)
unclosed_count = 0
for _,v in open_notes.items():
unclosed_count += len(v)
if debug and unclosed_count > 0:
print(f'WARNING: {unclosed_count} unclosed notes')
print(' ', midifile)
return tokens
def compound_to_midi(tokens, debug=False):
mid = mido.MidiFile()
mid.ticks_per_beat = TIME_RESOLUTION // 2 # 2 beats/second at quarter=120
tracks = {}
for token in tokens:
# Decode program and channel
program = (token >> 4) & 0x7F
channel = token & 0x0F
if (program, channel) not in tracks:
track = mido.MidiTrack()
mid.tracks.append(track)
tracks[(program, channel)] = track
track.append(mido.Message('program_change',
program=program,
channel=channel))
it = iter(tokens)
time_index = defaultdict(list)
for _, (time_in_ticks,duration,note,instrument,velocity) in enumerate(zip(it,it,it,it,it)):
time_index[(time_in_ticks,0)].append((note, instrument, velocity)) # 0 = onset
time_index[(time_in_ticks+duration,1)].append((note, instrument, velocity)) # 1 = offset
track_idx = {} # maps instrument to (track number, current time)
num_tracks = 0
for time_in_ticks, event_type in sorted(time_index.keys()):
for (note, instrument, velocity) in time_index[(time_in_ticks, event_type)]:
if event_type == 0: # onset
try:
track, previous_time, idx = track_idx[instrument]
except KeyError:
idx = num_tracks
previous_time = 0
track = mido.MidiTrack()
mid.tracks.append(track)
if instrument == 128: # drums always go on channel 9
idx = 9
message = mido.Message('program_change', channel=idx, program=0)
else:
message = mido.Message('program_change', channel=idx, program=instrument)
track.append(message)
num_tracks += 1
if num_tracks == 9:
num_tracks += 1 # skip the drums track
track.append(mido.Message(
'note_on', note=note, channel=idx, velocity=velocity,
time=time_in_ticks-previous_time))
track_idx[instrument] = (track, time_in_ticks, idx)
else: # offset
try:
track, previous_time, idx = track_idx[instrument]
except KeyError:
# shouldn't happen because we should have a corresponding onset
if debug:
print('IGNORING bad offset')
continue
track.append(mido.Message(
'note_off', note=note, channel=idx,
time=time_in_ticks-previous_time))
track_idx[instrument] = (track, time_in_ticks, idx)
return mid
def compound_to_events(tokens, stats=False):
assert len(tokens) % 5 == 0
tokens = tokens.copy()
# remove velocities
del tokens[4::5]
# combine (note, instrument)
assert all(-1 <= tok < 2**7 for tok in tokens[2::4])
assert all(-1 <= tok < 129 for tok in tokens[3::4])
tokens[2::4] = [SEPARATOR if note == -1 else MAX_PITCH*instr + note
for note, instr in zip(tokens[2::4],tokens[3::4])]
tokens[2::4] = [NOTE_OFFSET + tok for tok in tokens[2::4]]
del tokens[3::4]
# max duration cutoff and set unknown durations to 250ms
truncations = sum([1 for tok in tokens[1::3] if tok >= MAX_DUR])
tokens[1::3] = [TIME_RESOLUTION//4 if tok == -1 else min(tok, MAX_DUR-1)
for tok in tokens[1::3]]
tokens[1::3] = [DUR_OFFSET + tok for tok in tokens[1::3]]
assert min(tokens[0::3]) >= 0
tokens[0::3] = [TIME_OFFSET + tok for tok in tokens[0::3]]
assert len(tokens) % 3 == 0
if stats:
return tokens, truncations
return tokens
def events_to_compound(tokens, debug=False):
tokens = unpad(tokens)
# move all tokens to zero-offset for synthesis
tokens = [tok - CONTROL_OFFSET if tok >= CONTROL_OFFSET and tok != SEPARATOR else tok
for tok in tokens]
# remove type offsets
tokens[0::3] = [tok - TIME_OFFSET if tok != SEPARATOR else tok for tok in tokens[0::3]]
tokens[1::3] = [tok - DUR_OFFSET if tok != SEPARATOR else tok for tok in tokens[1::3]]
tokens[2::3] = [tok - NOTE_OFFSET if tok != SEPARATOR else tok for tok in tokens[2::3]]
offset = 0 # add max time from previous track for synthesis
track_max = 0 # keep track of max time in track
for j, (time,dur,note) in enumerate(zip(tokens[0::3],tokens[1::3],tokens[2::3])):
if note == SEPARATOR:
offset += track_max
track_max = 0
if debug:
print('Sequence Boundary')
else:
track_max = max(track_max, time+dur)
tokens[3*j] += offset
# strip sequence separators
assert len([tok for tok in tokens if tok == SEPARATOR]) % 3 == 0
tokens = [tok for tok in tokens if tok != SEPARATOR]
assert len(tokens) % 3 == 0
out = 5*(len(tokens)//3)*[0]
out[0::5] = tokens[0::3]
out[1::5] = tokens[1::3]
out[2::5] = [tok - (2**7)*(tok//2**7) for tok in tokens[2::3]]
out[3::5] = [tok//2**7 for tok in tokens[2::3]]
out[4::5] = (len(tokens)//3)*[72] # default velocity
assert max(out[1::5]) < MAX_DUR
assert max(out[2::5]) < MAX_PITCH
assert max(out[3::5]) < MAX_INSTR
assert all(tok >= 0 for tok in out)
return out
def events_to_midi(tokens, debug=False):
return compound_to_midi(events_to_compound(tokens, debug=debug), debug=debug)
def midi_to_events(midifile, debug=False):
return compound_to_events(midi_to_compound(midifile, debug=debug))