""" API functions for sampling from anticipatory infilling models. """ import math import torch import torch.nn.functional as F from tqdm import tqdm from anticipation import ops from anticipation.config import * from anticipation.vocab import * def safe_logits(logits, idx): logits[CONTROL_OFFSET:SPECIAL_OFFSET] = -float('inf') # don't generate controls logits[SPECIAL_OFFSET:] = -float('inf') # don't generate special tokens # don't generate stuff in the wrong time slot if idx % 3 == 0: logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf') logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf') elif idx % 3 == 1: logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf') logits[NOTE_OFFSET:NOTE_OFFSET+MAX_NOTE] = -float('inf') elif idx % 3 == 2: logits[TIME_OFFSET:TIME_OFFSET+MAX_TIME] = -float('inf') logits[DUR_OFFSET:DUR_OFFSET+MAX_DUR] = -float('inf') return logits def nucleus(logits, top_p): # from HF implementation if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative probability above the threshold (token with 0 are kept) sorted_indices_to_remove = cumulative_probs > top_p # Shift the indices to the right to keep also the first token above the threshold sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 # scatter sorted tensors to original indexing indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = -float("inf") return logits def future_logits(logits, curtime): """ don't sample events in the past """ if curtime > 0: logits[TIME_OFFSET:TIME_OFFSET+curtime] = -float('inf') return logits def instr_logits(logits, full_history): """ don't sample more than 16 instruments """ instrs = ops.get_instruments(full_history) if len(instrs) < 15: # 16 - 1 to account for the reserved drum track return logits for instr in range(MAX_INSTR): if instr not in instrs: logits[NOTE_OFFSET+instr*MAX_PITCH:NOTE_OFFSET+(instr+1)*MAX_PITCH] = -float('inf') return logits def add_token(model, z, tokens, top_p, current_time, debug=False): assert len(tokens) % 3 == 0 history = tokens.copy() lookback = max(len(tokens) - 1017, 0) history = history[lookback:] # Markov window offset = ops.min_time(history, seconds=False) history[::3] = [tok - offset for tok in history[::3]] # relativize time in the history buffer new_token = [] with torch.no_grad(): for i in range(3): input_tokens = torch.tensor(z + history + new_token).unsqueeze(0).to(model.device) logits = model(input_tokens).logits[0,-1] idx = input_tokens.shape[1]-1 logits = safe_logits(logits, idx) if i == 0: logits = future_logits(logits, current_time - offset) elif i == 2: logits = instr_logits(logits, tokens) logits = nucleus(logits, top_p) probs = F.softmax(logits, dim=-1) token = torch.multinomial(probs, 1) new_token.append(int(token)) new_token[0] += offset # revert to full sequence timing if debug: print(f' OFFSET = {offset}, LEN = {len(history)}, TIME = {tokens[::3][-5:]}') return new_token def generate(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION): if inputs is None: inputs = [] if controls is None: controls = [] start_time = int(TIME_RESOLUTION*start_time) end_time = int(TIME_RESOLUTION*end_time) # prompt is events up to start_time prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time) # treat events beyond start_time as controls future = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False) if debug: print('Future') ops.print_tokens(future) # clip controls that preceed the sequence controls = ops.clip(controls, DELTA, ops.max_time(controls, seconds=False), clip_duration=False, seconds=False) if debug: print('Controls') ops.print_tokens(controls) z = [ANTICIPATE] if len(controls) > 0 or len(future) > 0 else [AUTOREGRESS] if debug: print('AR Mode' if z[0] == AUTOREGRESS else 'AAR Mode') # interleave the controls with the events tokens, controls = ops.anticipate(prompt, ops.sort(controls + [CONTROL_OFFSET+token for token in future])) if debug: print('Prompt') ops.print_tokens(tokens) current_time = ops.max_time(prompt, seconds=False) if debug: print('Current time:', current_time) with tqdm(range(end_time-start_time)) as progress: if controls: atime, adur, anote = controls[0:3] anticipated_tokens = controls[3:] anticipated_time = atime - ATIME_OFFSET else: # nothing to anticipate anticipated_time = math.inf while True: while current_time >= anticipated_time - delta: tokens.extend([atime, adur, anote]) if debug: note = anote - ANOTE_OFFSET instr = note//2**7 print('A', atime - ATIME_OFFSET, adur - ADUR_OFFSET, instr, note - (2**7)*instr) if len(anticipated_tokens) > 0: atime, adur, anote = anticipated_tokens[0:3] anticipated_tokens = anticipated_tokens[3:] anticipated_time = atime - ATIME_OFFSET else: # nothing more to anticipate anticipated_time = math.inf new_token = add_token(model, z, tokens, top_p, max(start_time,current_time)) new_time = new_token[0] - TIME_OFFSET if new_time >= end_time: break if debug: new_note = new_token[2] - NOTE_OFFSET new_instr = new_note//2**7 new_pitch = new_note - (2**7)*new_instr print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch) tokens.extend(new_token) dt = new_time - current_time assert dt >= 0 current_time = new_time progress.update(dt) events, _ = ops.split(tokens) return ops.sort(ops.unpad(events) + future) def generate_ar(model, start_time, end_time, inputs=None, controls=None, top_p=1.0, debug=False, delta=DELTA*TIME_RESOLUTION): if inputs is None: inputs = [] if controls is None: controls = [] else: # treat controls as ordinary tokens controls = [token-CONTROL_OFFSET for token in controls] start_time = int(TIME_RESOLUTION*start_time) end_time = int(TIME_RESOLUTION*end_time) inputs = ops.sort(inputs + controls) # prompt is events up to start_time prompt = ops.pad(ops.clip(inputs, 0, start_time, clip_duration=False, seconds=False), start_time) if debug: print('Prompt') ops.print_tokens(prompt) # treat events beyond start_time as controls controls = ops.clip(inputs, start_time+1, ops.max_time(inputs, seconds=False), clip_duration=False, seconds=False) if debug: print('Future') ops.print_tokens(controls) z = [AUTOREGRESS] if debug: print('AR Mode') current_time = ops.max_time(prompt, seconds=False) if debug: print('Current time:', current_time) tokens = prompt with tqdm(range(end_time-start_time)) as progress: if controls: atime, adur, anote = controls[0:3] anticipated_tokens = controls[3:] anticipated_time = atime - TIME_OFFSET else: # nothing to anticipate anticipated_time = math.inf while True: new_token = add_token(model, z, tokens, top_p, max(start_time,current_time)) new_time = new_token[0] - TIME_OFFSET if new_time >= end_time: break dt = new_time - current_time assert dt >= 0 current_time = new_time # backfill anything that should have come before the new token while current_time >= anticipated_time: tokens.extend([atime, adur, anote]) if debug: note = anote - NOTE_OFFSET instr = note//2**7 print('A', atime - TIME_OFFSET, adur - DUR_OFFSET, instr, note - (2**7)*instr) if len(anticipated_tokens) > 0: atime, adur, anote = anticipated_tokens[0:3] anticipated_tokens = anticipated_tokens[3:] anticipated_time = atime - TIME_OFFSET else: # nothing more to anticipate anticipated_time = math.inf if debug: new_note = new_token[2] - NOTE_OFFSET new_instr = new_note//2**7 new_pitch = new_note - (2**7)*new_instr print('C', new_time, new_token[1] - DUR_OFFSET, new_instr, new_pitch) tokens.extend(new_token) progress.update(dt) if anticipated_time != math.inf: tokens.extend([atime, adur, anote]) return ops.sort(ops.unpad(tokens) + controls)