InScoreAPI / agents /agents.py
manuel-l01's picture
Initial commit
572abf8
from anticipation import ops
from anticipation.sample import generate
from anticipation.tokenize import extract_instruments
from anticipation.convert import events_to_midi,midi_to_events, compound_to_midi
from anticipation.config import *
from anticipation.vocab import *
from anticipation.convert import midi_to_compound
import mido
from agents.utils import load_midi_metadata
SMALL_MODEL = 'stanford-crfm/music-small-800k' # faster inference, worse sample quality
MEDIUM_MODEL = 'stanford-crfm/music-medium-800k' # slower inference, better sample quality
LARGE_MODEL = 'stanford-crfm/music-large-800k' # slowest inference, best sample quality
def harmonize_midi(model, midi, start_time, end_time,original_tempo,original_time_sig,top_p):
# Turn full midi to events
events = midi_to_events(midi)
print("Midi converted to events")
# Get clip from 0 to end of full midi
segment = ops.clip(events, 0, ops.max_time(events, seconds=True))
segment = ops.translate(segment, -ops.min_time(segment, seconds=False))
# Extract melody and accompaniment
events, melody = extract_instruments(segment, [0])
print("Melody extracted")
print("Start time:", start_time)
print("End time:", end_time)
# Get initial prompt
history = ops.clip(events, 0, start_time, clip_duration=False)
anticipated = [CONTROL_OFFSET + tok for tok in ops.clip(events, end_time, ops.max_time(segment, seconds=True), clip_duration=False)]
# Generate accompaniment conditioning on melody
accompaniment = generate(model, start_time, end_time, inputs=history, controls=melody, top_p=top_p, debug=False)
# Append anticipated continuation to accompaniment
accompaniment = ops.combine(accompaniment, anticipated)
print("Accompaniment generated")
# 1) render each voice separately
mel_mid = events_to_midi(melody)
acc_mid = events_to_midi(accompaniment)
# 2) build a fresh MidiFile
combined = mido.MidiFile()
combined.ticks_per_beat = mel_mid.ticks_per_beat # or TIME_RESOLUTION//2
print("Midi built")
# 3) meta‐track with tempo & time signature
meta = mido.MidiTrack()
meta.append(mido.MetaMessage('set_tempo', tempo=original_tempo))
meta.append(mido.MetaMessage('time_signature',
numerator=original_time_sig[0],
denominator=original_time_sig[1]))
combined.tracks.append(meta)
# 4) append melody *then* accompaniment
combined.tracks.extend(mel_mid.tracks[1:]) # Skip existing meta track
combined.tracks.extend(acc_mid.tracks[1:])
# 5) save in exactly that order
for track in combined.tracks:
for msg in track:
if msg.type in ['note_on', 'note_off']:
# Ensure valid MIDI values
if hasattr(msg, 'velocity'):
msg.velocity = min(max(msg.velocity, 0), 127)
if hasattr(msg, 'note'):
msg.note = min(max(msg.note, 0), 127)
print(f"Melody tracks: {len(mel_mid.tracks)}")
print(f"Accompaniment tracks: {len(acc_mid.tracks)}")
print(f"Combined tracks before cleanup: {len(combined.tracks)}")
# Add track cleanup (keep only unique tracks):
unique_tracks = []
seen = set()
for track in combined.tracks:
track_hash = str([msg.hex() for msg in track])
if track_hash not in seen:
unique_tracks.append(track)
seen.add(track_hash)
combined.tracks = unique_tracks
print(f"Final track count: {len(combined.tracks)}")
print("Output Midi metadata added")
return combined
def harmonizer(ai_model,midi_file, start_time, end_time,top_p):
"""
this function harmonizes a melody in a MIDI file
returns the harmonized MIDI
Args:
midi_file: path to the MIDI file
start_time: start time of the selected measure (melody you want to harmonize) in milliseconds
end_time: end time of the selected measure in milliseconds
"""
print(f"Original MIDI tracks: {len(midi_file.tracks)}")
# Load metadata and model...
# Log original note parameters
for track in midi_file.tracks:
for msg in track:
if msg.type in ['note_on', 'note_off']:
if msg.velocity > 127 or msg.velocity < 0:
print(f"Invalid velocity: {msg.velocity}")
if msg.note > 127 or msg.note < 0:
print(f"Invalid pitch: {msg.note}")
# Load original MIDI and extract metadata
midi, original_tempo, original_time_sig = load_midi_metadata(midi_file)
print("Midi metadata loaded")
# load an anticipatory music transformer
model = ai_model # add .cuda() if you have a GPU
print("Model loaded")
harmonized_midi = harmonize_midi(model, midi, start_time, end_time, original_tempo,original_time_sig,top_p)
print("Midi generated")
print(f"Harmonized MIDI tracks: {len(harmonized_midi.tracks)}")
# Add MIDI validation
for track in harmonized_midi.tracks:
for msg in track:
if msg.type in ['note_on', 'note_off']:
# Clamp invalid values
msg.velocity = min(max(msg.velocity, 0), 127)
msg.note = min(max(msg.note, 0), 127)
print("Midi saved")
return harmonized_midi
def infill_midi(model, midi, start_time, end_time,original_tempo,original_time_sig,top_p):
# Turn full midi to events
events = midi_to_events(midi)
print("Midi converted to events")
# Get clip from 0 to end of full midi
segment = ops.clip(events, 0, ops.max_time(events, seconds=True))
segment = ops.translate(segment, -ops.min_time(segment, seconds=False))
# Get initial prompt
history = ops.clip(events, 0, start_time, clip_duration=False)
anticipated = [CONTROL_OFFSET + tok for tok in ops.clip(events, end_time, ops.max_time(segment, seconds=True), clip_duration=False)]
# Generate accompaniment conditioning on melody
infilling = generate(model, start_time, end_time, inputs=history, controls=anticipated, top_p=top_p, debug=False)
# Append anticipated continuation to accompaniment
full_events = ops.combine(infilling, anticipated)
print("Accompaniment generated")
# 1) render each voice separately
full_mid = events_to_midi(full_events)
# 2) build a fresh MidiFile
combined = mido.MidiFile()
combined.ticks_per_beat = full_mid.ticks_per_beat # or TIME_RESOLUTION//2
print("Midi built")
# 3) meta‐track with tempo & time signature
meta = mido.MidiTrack()
meta.append(mido.MetaMessage('set_tempo', tempo=original_tempo))
meta.append(mido.MetaMessage('time_signature',
numerator=original_time_sig[0],
denominator=original_time_sig[1]))
combined.tracks.append(meta)
# 4) append melody *then* accompaniment
combined.tracks.extend(full_mid.tracks[:]) # Skip existing meta track
# 5) save in exactly that order
for track in combined.tracks:
for msg in track:
if msg.type in ['note_on', 'note_off']:
# Ensure valid MIDI values
if hasattr(msg, 'velocity'):
msg.velocity = min(max(msg.velocity, 0), 127)
if hasattr(msg, 'note'):
msg.note = min(max(msg.note, 0), 127)
print(f"Melody tracks: {len(full_mid.tracks)}")
print(f"Accompaniment tracks: {len(full_mid.tracks)}")
print(f"Combined tracks before cleanup: {len(combined.tracks)}")
# Add track cleanup (keep only unique tracks):
unique_tracks = []
seen = set()
for track in combined.tracks:
track_hash = str([msg.hex() for msg in track])
if track_hash not in seen:
unique_tracks.append(track)
seen.add(track_hash)
combined.tracks = unique_tracks
print(f"Final track count: {len(combined.tracks)}")
print("Output Midi metadata added")
return combined
def infiller(ai_model,midi_file, start_time, end_time,top_p):
"""
this function harmonizes a melody in a MIDI file
returns the harmonized MIDI
Args:
midi_file: path to the MIDI file
start_time: start time of the selected measure (melody you want to harmonize) in milliseconds
end_time: end time of the selected measure in milliseconds
"""
print(f"Original MIDI tracks: {len(midi_file.tracks)}")
# Load metadata and model...
# Log original note parameters
for track in midi_file.tracks:
for msg in track:
if msg.type in ['note_on', 'note_off']:
if msg.velocity > 127 or msg.velocity < 0:
print(f"Invalid velocity: {msg.velocity}")
if msg.note > 127 or msg.note < 0:
print(f"Invalid pitch: {msg.note}")
# Load original MIDI and extract metadata
midi, original_tempo, original_time_sig = load_midi_metadata(midi_file)
print("Midi metadata loaded")
# load an anticipatory music transformer
model = ai_model # add .cuda() if you have a GPU
print("Model loaded")
infilled_midi = infill_midi(model, midi, start_time, end_time, original_tempo,original_time_sig,top_p)
print("Midi generated")
print(f"Harmonized MIDI tracks: {len(infilled_midi.tracks)}")
# Add MIDI validation
for track in infilled_midi.tracks:
for msg in track:
if msg.type in ['note_on', 'note_off']:
# Clamp invalid values
msg.velocity = min(max(msg.velocity, 0), 127)
msg.note = min(max(msg.note, 0), 127)
print("Midi saved")
return infilled_midi
def change_melody_midi(model, midi, start_time, end_time,original_tempo,original_time_sig,top_p):
events = midi_to_events(midi)
segment = ops.clip(events, 0, ops.max_time(events, seconds=True))
segment = ops.translate(segment, -ops.min_time(segment, seconds=False))
# Extract melody (instrument 0) as events and accompaniment as controls
instruments = list(ops.get_instruments(segment).keys())
accompaniment_instruments = [instr for instr in instruments if instr != 0]
melody_events, accompaniment_controls = extract_instruments(segment, accompaniment_instruments)
# Get initial prompt (melody before start_time)
history = ops.clip(melody_events, 0, start_time, clip_duration=False)
# Include accompaniment controls for the entire duration
controls = accompaniment_controls # Full accompaniment as controls
# Generate new melody conditioned on accompaniment
infilling = generate(model, start_time, end_time, inputs=history, controls=controls, top_p=top_p, debug=False)
# Append anticipated continuation
anticipated_melody = [CONTROL_OFFSET + tok for tok in ops.clip(melody_events, end_time, ops.max_time(segment, seconds=True), clip_duration=False)]
full_events = ops.combine(infilling, anticipated_melody)
acc_mid = events_to_midi(accompaniment_controls)
# Render and combine MIDI tracks
full_mid = events_to_midi(full_events)
combined = mido.MidiFile()
combined.ticks_per_beat = full_mid.ticks_per_beat # or TIME_RESOLUTION//2
print("Midi built")
# 3) meta‐track with tempo & time signature
meta = mido.MidiTrack()
meta.append(mido.MetaMessage('set_tempo', tempo=original_tempo))
meta.append(mido.MetaMessage('time_signature',
numerator=original_time_sig[0],
denominator=original_time_sig[1]))
combined.tracks.append(meta)
# 4) append melody *then* accompaniment
combined.tracks.extend(full_mid.tracks[:]) # Skip existing meta track
combined.tracks.extend(acc_mid.tracks[:]) # Skip existing meta track
# 5) save in exactly that order
for track in combined.tracks:
for msg in track:
if msg.type in ['note_on', 'note_off']:
# Ensure valid MIDI values
if hasattr(msg, 'velocity'):
msg.velocity = min(max(msg.velocity, 0), 127)
if hasattr(msg, 'note'):
msg.note = min(max(msg.note, 0), 127)
print(f"Melody tracks: {len(full_mid.tracks)}")
print(f"Accompaniment tracks: {len(full_mid.tracks)}")
print(f"Combined tracks before cleanup: {len(combined.tracks)}")
# Add track cleanup (keep only unique tracks):
unique_tracks = []
seen = set()
for track in combined.tracks:
track_hash = str([msg.hex() for msg in track])
if track_hash not in seen:
unique_tracks.append(track)
seen.add(track_hash)
combined.tracks = unique_tracks
print(f"Final track count: {len(combined.tracks)}")
print("Output Midi metadata added")
return combined
def change_melody(ai_model,midi_file, start_time, end_time,top_p):
"""
this function harmonizes a melody in a MIDI file
returns the harmonized MIDI
Args:
midi_file: path to the MIDI file
start_time: start time of the selected measure (melody you want to harmonize) in milliseconds
end_time: end time of the selected measure in milliseconds
"""
print(f"Original MIDI tracks: {len(midi_file.tracks)}")
# Load metadata and model...
# Log original note parameters
for track in midi_file.tracks:
for msg in track:
if msg.type in ['note_on', 'note_off']:
if msg.velocity > 127 or msg.velocity < 0:
print(f"Invalid velocity: {msg.velocity}")
if msg.note > 127 or msg.note < 0:
print(f"Invalid pitch: {msg.note}")
# Load original MIDI and extract metadata
midi, original_tempo, original_time_sig = load_midi_metadata(midi_file)
print("Midi metadata loaded")
# load an anticipatory music transformer
model = ai_model # add .cuda() if you have a GPU
print("Model loaded")
change_melody_gen_midi = change_melody_midi(model, midi, start_time, end_time, original_tempo,original_time_sig,top_p)
print("Midi generated")
print(f"Harmonized MIDI tracks: {len(change_melody_gen_midi.tracks)}")
# Add MIDI validation
for track in change_melody_gen_midi.tracks:
for msg in track:
if msg.type in ['note_on', 'note_off']:
# Clamp invalid values
msg.velocity = min(max(msg.velocity, 0), 127)
msg.note = min(max(msg.note, 0), 127)
print("Midi saved")
return change_melody_gen_midi