InScoreAPI / anticipation /tokenize.py
manuel-l01's picture
Initial commit
572abf8
"""
Top-level functions for preprocessing data to be used for training.
"""
from tqdm import tqdm
import numpy as np
from anticipation import ops
from anticipation.config import *
from anticipation.vocab import *
from anticipation.convert import compound_to_events, midi_to_interarrival
def extract_spans(all_events, rate):
events = []
controls = []
span = True
next_span = end_span = TIME_OFFSET+0
for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]):
assert(note not in [SEPARATOR, REST]) # shouldn't be in the sequence yet
# end of an anticipated span; decide when to do it again (next_span)
if span and time >= end_span:
span = False
next_span = time+int(TIME_RESOLUTION*np.random.exponential(1./rate))
# anticipate a 3-second span
if (not span) and time >= next_span:
span = True
end_span = time + DELTA*TIME_RESOLUTION
if span:
# mark this event as a control
controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note])
else:
events.extend([time, dur, note])
return events, controls
ANTICIPATION_RATES = 10
def extract_random(all_events, rate):
events = []
controls = []
for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]):
assert(note not in [SEPARATOR, REST]) # shouldn't be in the sequence yet
if np.random.random() < rate/float(ANTICIPATION_RATES):
# mark this event as a control
controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note])
else:
events.extend([time, dur, note])
return events, controls
def extract_instruments(all_events, instruments):
events = []
controls = []
for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]):
assert note < CONTROL_OFFSET # shouldn't be in the sequence yet
assert note not in [SEPARATOR, REST] # these shouldn't either
instr = (note-NOTE_OFFSET)//2**7
if instr in instruments:
# mark this event as a control
controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note])
else:
events.extend([time, dur, note])
return events, controls
def maybe_tokenize(compound_tokens):
# skip sequences with very few events
if len(compound_tokens) < COMPOUND_SIZE*MIN_TRACK_EVENTS:
return None, None, 1 # short track
events, truncations = compound_to_events(compound_tokens, stats=True)
end_time = ops.max_time(events, seconds=False)
# don't want to deal with extremely short tracks
if end_time < TIME_RESOLUTION*MIN_TRACK_TIME_IN_SECONDS:
return None, None, 1 # short track
# don't want to deal with extremely long tracks
if end_time > TIME_RESOLUTION*MAX_TRACK_TIME_IN_SECONDS:
return None, None, 2 # long track
# skip sequences more instruments than MIDI channels (16)
if len(ops.get_instruments(events)) > MAX_TRACK_INSTR:
return None, None, 3 # too many instruments
return events, truncations, 0
def tokenize_ia(datafiles, output, augment_factor, idx=0, debug=False):
assert augment_factor == 1 # can't augment interarrival-tokenized data
all_truncations = 0
seqcount = rest_count = 0
stats = 4*[0] # (short, long, too many instruments, inexpressible)
np.random.seed(0)
with open(output, 'w') as outfile:
concatenated_tokens = []
for j, filename in tqdm(list(enumerate(datafiles)), desc=f'#{idx}', position=idx+1, leave=True):
with open(filename, 'r') as f:
_, _, status = maybe_tokenize([int(token) for token in f.read().split()])
if status > 0:
stats[status-1] += 1
continue
filename = filename[:-len('.compound.txt')] # get the original MIDI
# already parsed; shouldn't raise an exception
tokens, truncations = midi_to_interarrival(filename, stats=True)
tokens[0:0] = [MIDI_SEPARATOR]
concatenated_tokens.extend(tokens)
all_truncations += truncations
# write out full sequences to file
while len(concatenated_tokens) >= CONTEXT_SIZE:
seq = concatenated_tokens[0:CONTEXT_SIZE]
concatenated_tokens = concatenated_tokens[CONTEXT_SIZE:]
outfile.write(' '.join([str(tok) for tok in seq]) + '\n')
seqcount += 1
if debug:
fmt = 'Processed {} sequences (discarded {} tracks, discarded {} seqs, added {} rest tokens)'
print(fmt.format(seqcount, stats[0]+stats[1]+stats[2], stats[3], rest_count))
return (seqcount, rest_count, stats[0], stats[1], stats[2], stats[3], all_truncations)
def tokenize(datafiles, output, augment_factor, idx=0, debug=False):
tokens = []
all_truncations = 0
seqcount = rest_count = 0
stats = 4*[0] # (short, long, too many instruments, inexpressible)
np.random.seed(0)
with open(output, 'w') as outfile:
concatenated_tokens = []
for j, filename in tqdm(list(enumerate(datafiles)), desc=f'#{idx}', position=idx+1, leave=True):
with open(filename, 'r') as f:
all_events, truncations, status = maybe_tokenize([int(token) for token in f.read().split()])
if status > 0:
stats[status-1] += 1
continue
instruments = list(ops.get_instruments(all_events).keys())
end_time = ops.max_time(all_events, seconds=False)
# different random augmentations
for k in range(augment_factor):
if k % 10 == 0:
# no augmentation
events = all_events.copy()
controls = []
elif k % 10 == 1:
# span augmentation
lmbda = .05
events, controls = extract_spans(all_events, lmbda)
elif k % 10 < 6:
# random augmentation
r = np.random.randint(1,ANTICIPATION_RATES)
events, controls = extract_random(all_events, r)
else:
if len(instruments) > 1:
# instrument augmentation: at least one, but not all instruments
u = 1+np.random.randint(len(instruments)-1)
subset = np.random.choice(instruments, u, replace=False)
events, controls = extract_instruments(all_events, subset)
else:
# no augmentation
events = all_events.copy()
controls = []
if len(concatenated_tokens) == 0:
z = ANTICIPATE if k % 10 != 0 else AUTOREGRESS
all_truncations += truncations
events = ops.pad(events, end_time)
rest_count += sum(1 if tok == REST else 0 for tok in events[2::3])
tokens, controls = ops.anticipate(events, controls)
assert len(controls) == 0 # should have consumed all controls (because of padding)
tokens[0:0] = [SEPARATOR, SEPARATOR, SEPARATOR]
concatenated_tokens.extend(tokens)
# write out full sequences to file
while len(concatenated_tokens) >= EVENT_SIZE*M:
seq = concatenated_tokens[0:EVENT_SIZE*M]
concatenated_tokens = concatenated_tokens[EVENT_SIZE*M:]
# relativize time to the context
seq = ops.translate(seq, -ops.min_time(seq, seconds=False), seconds=False)
assert ops.min_time(seq, seconds=False) == 0
if ops.max_time(seq, seconds=False) >= MAX_TIME:
stats[3] += 1
continue
# if seq contains SEPARATOR, global controls describe the first sequence
seq.insert(0, z)
outfile.write(' '.join([str(tok) for tok in seq]) + '\n')
seqcount += 1
# grab the current augmentation controls if we didn't already
z = ANTICIPATE if k % 10 != 0 else AUTOREGRESS
if debug:
fmt = 'Processed {} sequences (discarded {} tracks, discarded {} seqs, added {} rest tokens)'
print(fmt.format(seqcount, stats[0]+stats[1]+stats[2], stats[3], rest_count))
return (seqcount, rest_count, stats[0], stats[1], stats[2], stats[3], all_truncations)