manuel-l01's picture
Initial commit
572abf8
"""
Global configuration for anticipatory infilling models.
"""
# model hyper-parameters
CONTEXT_SIZE = 1024 # model context
EVENT_SIZE = 3 # each event/control is encoded as 3 tokens
M = 341 # model context (1024 = 1 + EVENT_SIZE*M)
DELTA = 5 # anticipation time in seconds
assert CONTEXT_SIZE == 1+EVENT_SIZE*M
# vocabulary constants
MAX_TIME_IN_SECONDS = 100 # exclude very long training sequences
MAX_DURATION_IN_SECONDS = 10 # maximum duration of a note
TIME_RESOLUTION = 100 # 10ms time resolution = 100 bins/second
MAX_PITCH = 128 # 128 MIDI pitches
MAX_INSTR = 129 # 129 MIDI instruments (128 + drums)
MAX_NOTE = MAX_PITCH*MAX_INSTR # note = pitch x instrument
MAX_INTERARRIVAL_IN_SECONDS = 10 # maximum interarrival time (for MIDI-like encoding)
# preprocessing settings
PREPROC_WORKERS = 16
COMPOUND_SIZE = 5 # event size in the intermediate compound tokenization
MAX_TRACK_INSTR = 16 # exclude tracks with large numbers of instruments
MAX_TRACK_TIME_IN_SECONDS = 3600 # exclude very long tracks (longer than 1 hour)
MIN_TRACK_TIME_IN_SECONDS = 10 # exclude very short tracks (less than 10 seconds)
MIN_TRACK_EVENTS = 100 # exclude very short tracks (less than 100 events)
# LakhMIDI dataset splits
LAKH_SPLITS = ['0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f']
LAKH_VALID = ['e']
LAKH_TEST = ['f']
# derived quantities
MAX_TIME = TIME_RESOLUTION*MAX_TIME_IN_SECONDS
MAX_DUR = TIME_RESOLUTION*MAX_DURATION_IN_SECONDS
MAX_INTERARRIVAL = TIME_RESOLUTION*MAX_INTERARRIVAL_IN_SECONDS
if __name__ == '__main__':
print('Model constants:')
print(f' -> anticipation interval: {DELTA}s')
print('Vocabulary constants:')
print(f' -> maximum time of a sequence: {MAX_TIME_IN_SECONDS}s')
print(f' -> maximum duration of a note: {MAX_DURATION_IN_SECONDS}s')
print(f' -> time resolution: {TIME_RESOLUTION}bins/s ({1000//TIME_RESOLUTION}ms)')
print(f' -> maximum interarrival-time (MIDI-like encoding): {MAX_INTERARRIVAL_IN_SECONDS}s')
print('Preprocessing constants:')
print(f' -> maximum time of a track: {MAX_TRACK_TIME_IN_SECONDS}s')
print(f' -> minimum events in a track: {MIN_TRACK_EVENTS}s')