Spaces:
Running
Running
""" | |
Top-level functions for preprocessing data to be used for training. | |
""" | |
from tqdm import tqdm | |
import numpy as np | |
from anticipation import ops | |
from anticipation.config import * | |
from anticipation.vocab import * | |
from anticipation.convert import compound_to_events, midi_to_interarrival | |
def extract_spans(all_events, rate): | |
events = [] | |
controls = [] | |
span = True | |
next_span = end_span = TIME_OFFSET+0 | |
for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]): | |
assert(note not in [SEPARATOR, REST]) # shouldn't be in the sequence yet | |
# end of an anticipated span; decide when to do it again (next_span) | |
if span and time >= end_span: | |
span = False | |
next_span = time+int(TIME_RESOLUTION*np.random.exponential(1./rate)) | |
# anticipate a 3-second span | |
if (not span) and time >= next_span: | |
span = True | |
end_span = time + DELTA*TIME_RESOLUTION | |
if span: | |
# mark this event as a control | |
controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note]) | |
else: | |
events.extend([time, dur, note]) | |
return events, controls | |
ANTICIPATION_RATES = 10 | |
def extract_random(all_events, rate): | |
events = [] | |
controls = [] | |
for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]): | |
assert(note not in [SEPARATOR, REST]) # shouldn't be in the sequence yet | |
if np.random.random() < rate/float(ANTICIPATION_RATES): | |
# mark this event as a control | |
controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note]) | |
else: | |
events.extend([time, dur, note]) | |
return events, controls | |
def extract_instruments(all_events, instruments): | |
events = [] | |
controls = [] | |
for time, dur, note in zip(all_events[0::3],all_events[1::3],all_events[2::3]): | |
assert note < CONTROL_OFFSET # shouldn't be in the sequence yet | |
assert note not in [SEPARATOR, REST] # these shouldn't either | |
instr = (note-NOTE_OFFSET)//2**7 | |
if instr in instruments: | |
# mark this event as a control | |
controls.extend([CONTROL_OFFSET+time, CONTROL_OFFSET+dur, CONTROL_OFFSET+note]) | |
else: | |
events.extend([time, dur, note]) | |
return events, controls | |
def maybe_tokenize(compound_tokens): | |
# skip sequences with very few events | |
if len(compound_tokens) < COMPOUND_SIZE*MIN_TRACK_EVENTS: | |
return None, None, 1 # short track | |
events, truncations = compound_to_events(compound_tokens, stats=True) | |
end_time = ops.max_time(events, seconds=False) | |
# don't want to deal with extremely short tracks | |
if end_time < TIME_RESOLUTION*MIN_TRACK_TIME_IN_SECONDS: | |
return None, None, 1 # short track | |
# don't want to deal with extremely long tracks | |
if end_time > TIME_RESOLUTION*MAX_TRACK_TIME_IN_SECONDS: | |
return None, None, 2 # long track | |
# skip sequences more instruments than MIDI channels (16) | |
if len(ops.get_instruments(events)) > MAX_TRACK_INSTR: | |
return None, None, 3 # too many instruments | |
return events, truncations, 0 | |
def tokenize_ia(datafiles, output, augment_factor, idx=0, debug=False): | |
assert augment_factor == 1 # can't augment interarrival-tokenized data | |
all_truncations = 0 | |
seqcount = rest_count = 0 | |
stats = 4*[0] # (short, long, too many instruments, inexpressible) | |
np.random.seed(0) | |
with open(output, 'w') as outfile: | |
concatenated_tokens = [] | |
for j, filename in tqdm(list(enumerate(datafiles)), desc=f'#{idx}', position=idx+1, leave=True): | |
with open(filename, 'r') as f: | |
_, _, status = maybe_tokenize([int(token) for token in f.read().split()]) | |
if status > 0: | |
stats[status-1] += 1 | |
continue | |
filename = filename[:-len('.compound.txt')] # get the original MIDI | |
# already parsed; shouldn't raise an exception | |
tokens, truncations = midi_to_interarrival(filename, stats=True) | |
tokens[0:0] = [MIDI_SEPARATOR] | |
concatenated_tokens.extend(tokens) | |
all_truncations += truncations | |
# write out full sequences to file | |
while len(concatenated_tokens) >= CONTEXT_SIZE: | |
seq = concatenated_tokens[0:CONTEXT_SIZE] | |
concatenated_tokens = concatenated_tokens[CONTEXT_SIZE:] | |
outfile.write(' '.join([str(tok) for tok in seq]) + '\n') | |
seqcount += 1 | |
if debug: | |
fmt = 'Processed {} sequences (discarded {} tracks, discarded {} seqs, added {} rest tokens)' | |
print(fmt.format(seqcount, stats[0]+stats[1]+stats[2], stats[3], rest_count)) | |
return (seqcount, rest_count, stats[0], stats[1], stats[2], stats[3], all_truncations) | |
def tokenize(datafiles, output, augment_factor, idx=0, debug=False): | |
tokens = [] | |
all_truncations = 0 | |
seqcount = rest_count = 0 | |
stats = 4*[0] # (short, long, too many instruments, inexpressible) | |
np.random.seed(0) | |
with open(output, 'w') as outfile: | |
concatenated_tokens = [] | |
for j, filename in tqdm(list(enumerate(datafiles)), desc=f'#{idx}', position=idx+1, leave=True): | |
with open(filename, 'r') as f: | |
all_events, truncations, status = maybe_tokenize([int(token) for token in f.read().split()]) | |
if status > 0: | |
stats[status-1] += 1 | |
continue | |
instruments = list(ops.get_instruments(all_events).keys()) | |
end_time = ops.max_time(all_events, seconds=False) | |
# different random augmentations | |
for k in range(augment_factor): | |
if k % 10 == 0: | |
# no augmentation | |
events = all_events.copy() | |
controls = [] | |
elif k % 10 == 1: | |
# span augmentation | |
lmbda = .05 | |
events, controls = extract_spans(all_events, lmbda) | |
elif k % 10 < 6: | |
# random augmentation | |
r = np.random.randint(1,ANTICIPATION_RATES) | |
events, controls = extract_random(all_events, r) | |
else: | |
if len(instruments) > 1: | |
# instrument augmentation: at least one, but not all instruments | |
u = 1+np.random.randint(len(instruments)-1) | |
subset = np.random.choice(instruments, u, replace=False) | |
events, controls = extract_instruments(all_events, subset) | |
else: | |
# no augmentation | |
events = all_events.copy() | |
controls = [] | |
if len(concatenated_tokens) == 0: | |
z = ANTICIPATE if k % 10 != 0 else AUTOREGRESS | |
all_truncations += truncations | |
events = ops.pad(events, end_time) | |
rest_count += sum(1 if tok == REST else 0 for tok in events[2::3]) | |
tokens, controls = ops.anticipate(events, controls) | |
assert len(controls) == 0 # should have consumed all controls (because of padding) | |
tokens[0:0] = [SEPARATOR, SEPARATOR, SEPARATOR] | |
concatenated_tokens.extend(tokens) | |
# write out full sequences to file | |
while len(concatenated_tokens) >= EVENT_SIZE*M: | |
seq = concatenated_tokens[0:EVENT_SIZE*M] | |
concatenated_tokens = concatenated_tokens[EVENT_SIZE*M:] | |
# relativize time to the context | |
seq = ops.translate(seq, -ops.min_time(seq, seconds=False), seconds=False) | |
assert ops.min_time(seq, seconds=False) == 0 | |
if ops.max_time(seq, seconds=False) >= MAX_TIME: | |
stats[3] += 1 | |
continue | |
# if seq contains SEPARATOR, global controls describe the first sequence | |
seq.insert(0, z) | |
outfile.write(' '.join([str(tok) for tok in seq]) + '\n') | |
seqcount += 1 | |
# grab the current augmentation controls if we didn't already | |
z = ANTICIPATE if k % 10 != 0 else AUTOREGRESS | |
if debug: | |
fmt = 'Processed {} sequences (discarded {} tracks, discarded {} seqs, added {} rest tokens)' | |
print(fmt.format(seqcount, stats[0]+stats[1]+stats[2], stats[3], rest_count)) | |
return (seqcount, rest_count, stats[0], stats[1], stats[2], stats[3], all_truncations) | |