manuel-l01's picture
Initial commit
572abf8
"""
API functions for sampling from anticipatory infilling models.
"""
import math
import torch
import torch.nn.functional as F
from tqdm import tqdm
from anticipation import ops
from anticipation.config import *
from anticipation.vocab import *
def safe_logits(logits, idx):
logits[CONTROL_OFFSET:SPECIAL_OFFSET] = -float('inf') # don't generate controls
logits[SPECIAL_OFFSET:] = -float('inf') # don't generate special tokens
# don't generate stuff in the wrong time slot
if idx % 3 == 0:
logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')
logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')
elif idx % 3 == 1:
logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')
logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf')
elif idx % 3 == 2:
logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf')
logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf')
return logits
def nucleus(logits, top_p):
# from HF implementation
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float("inf")
return logits
def future_logits(logits, curtime):
""" don't sample events in the past """
if curtime > 0:
logits[TIME_OFFSET:TIME_OFFSET+curtime] = -float('inf')
return logits
def instr_logits(logits, full_history):
""" don't sample more than 16 instruments """
instrs = ops.get_instruments(full_history)
if len(instrs) < 15: # 16 - 1 to account for the reserved drum track
return logits
for instr in range(MAX_INSTR):
if instr not in instrs:
logits[NOTE_OFFSET+instr*MAX_PITCH:NOTE_OFFSET+(instr+1)*MAX_PITCH] = -float('inf')
return logits
def add_token(model, z, tokens, top_p, current_time, debug=False):
assert len(tokens) % 3 == 0
history = tokens.copy()
lookback = max(len(tokens) - 1017, 0)
history = history[lookback:] # Markov window
offset = ops.min_time(history, seconds=False)
history[::3] = [tok - offset for tok in history[::3]] # relativize time in the history buffer
new_token = []
with torch.no_grad():
for i in range(3):
input_tokens = torch.tensor(z + history + new_token).unsqueeze(0).to(model.device)
logits = model(input_tokens).logits[0,-1]
idx = input_tokens.shape[1]-1
logits = safe_logits(logits, idx)
if i == 0:
logits = future_logits(logits, current_time - offset)
elif i == 2:
logits = instr_logits(logits, tokens)
logits = nucleus(logits, top_p)
probs = F.softmax(logits, dim=-1)
token = torch.multinomial(probs, 1)
new_token.append(int(token))
new_token[0] += offset # revert to full sequence timing
if debug:
print(f' OFFSET = {offset}, LEN = {len(history)}, TIME = {tokens[::3][-5:]}')
return new_token
def generate(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION):
if inputs is None:
inputs = []
if controls is None:
controls = []
start_time = int(TIME_RESOLUTION*start_time)
end_time = int(TIME_RESOLUTION*end_time)
# prompt is events up to start_time
prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time)
# treat events beyond start_time as controls
future = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False)
if debug:
print('Future')
ops.print_tokens(future)
# clip controls that preceed the sequence
controls = ops.clip(controls, DELTA, ops.max_time(controls, seconds=False), clip_duration=False, seconds=False)
if debug:
print('Controls')
ops.print_tokens(controls)
z = [ANTICIPATE] if len(controls) > 0 or len(future) > 0 else [AUTOREGRESS]
if debug:
print('AR Mode' if z[0] == AUTOREGRESS else 'AAR Mode')
# interleave the controls with the events
tokens, controls = ops.anticipate(prompt, ops.sort(controls + [CONTROL_OFFSET+token for token in future]))
if debug:
print('Prompt')
ops.print_tokens(tokens)
current_time = ops.max_time(prompt, seconds=False)
if debug:
print('Current time:', current_time)
with tqdm(range(end_time-start_time)) as progress:
if controls:
atime, adur, anote = controls[0:3]
anticipated_tokens = controls[3:]
anticipated_time = atime - ATIME_OFFSET
else:
# nothing to anticipate
anticipated_time = math.inf
while True:
while current_time >= anticipated_time - delta:
tokens.extend([atime, adur, anote])
if debug:
note = anote - ANOTE_OFFSET
instr = note//2**7
print('A', atime - ATIME_OFFSET, adur - ADUR_OFFSET, instr, note - (2**7)*instr)
if len(anticipated_tokens) > 0:
atime, adur, anote = anticipated_tokens[0:3]
anticipated_tokens = anticipated_tokens[3:]
anticipated_time = atime - ATIME_OFFSET
else:
# nothing more to anticipate
anticipated_time = math.inf
new_token = add_token(model, z, tokens, top_p, max(start_time,current_time))
new_time = new_token[0] - TIME_OFFSET
if new_time >= end_time:
break
if debug:
new_note = new_token[2] - NOTE_OFFSET
new_instr = new_note//2**7
new_pitch = new_note - (2**7)*new_instr
print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch)
tokens.extend(new_token)
dt = new_time - current_time
assert dt >= 0
current_time = new_time
progress.update(dt)
events, _ = ops.split(tokens)
return ops.sort(ops.unpad(events) + future)
def generate_ar(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION):
if inputs is None:
inputs = []
if controls is None:
controls = []
else:
# treat controls as ordinary tokens
controls = [token-CONTROL_OFFSET for token in controls]
start_time = int(TIME_RESOLUTION*start_time)
end_time = int(TIME_RESOLUTION*end_time)
inputs = ops.sort(inputs + controls)
# prompt is events up to start_time
prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time)
if debug:
print('Prompt')
ops.print_tokens(prompt)
# treat events beyond start_time as controls
controls = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False)
if debug:
print('Future')
ops.print_tokens(controls)
z = [AUTOREGRESS]
if debug:
print('AR Mode')
current_time = ops.max_time(prompt, seconds=False)
if debug:
print('Current time:', current_time)
tokens = prompt
with tqdm(range(end_time-start_time)) as progress:
if controls:
atime, adur, anote = controls[0:3]
anticipated_tokens = controls[3:]
anticipated_time = atime - TIME_OFFSET
else:
# nothing to anticipate
anticipated_time = math.inf
while True:
new_token = add_token(model, z, tokens, top_p, max(start_time,current_time))
new_time = new_token[0] - TIME_OFFSET
if new_time >= end_time:
break
dt = new_time - current_time
assert dt >= 0
current_time = new_time
# backfill anything that should have come before the new token
while current_time >= anticipated_time:
tokens.extend([atime, adur, anote])
if debug:
note = anote - NOTE_OFFSET
instr = note//2**7
print('A', atime - TIME_OFFSET, adur - DUR_OFFSET, instr, note - (2**7)*instr)
if len(anticipated_tokens) > 0:
atime, adur, anote = anticipated_tokens[0:3]
anticipated_tokens = anticipated_tokens[3:]
anticipated_time = atime - TIME_OFFSET
else:
# nothing more to anticipate
anticipated_time = math.inf
if debug:
new_note = new_token[2] - NOTE_OFFSET
new_instr = new_note//2**7
new_pitch = new_note - (2**7)*new_instr
print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch)
tokens.extend(new_token)
progress.update(dt)
if anticipated_time != math.inf:
tokens.extend([atime, adur, anote])
return ops.sort(ops.unpad(tokens) + controls)