from agents.agents import harmonizer, infiller, change_melody from flask import Flask, request, jsonify from flask_cors import CORS import mido import tempfile import os import music21 import traceback from uuid import uuid4 import threading from transformers import AutoModelForCausalLM app = Flask(__name__) CORS(app) @app.after_request def add_cors_headers(response): # Allow only your domain response.headers['Access-Control-Allow-Origin'] = 'https://inscoreai.netlify.app/' response.headers['Access-Control-Allow-Methods'] = 'GET, POST' response.headers['Access-Control-Allow-Headers'] = 'Content-Type' return response def midi_to_musicxml(midi_file_path): """Convert MIDI file to MusicXML string with absolute safety""" try: midi_path_str = str(midi_file_path) # Parse and convert to MusicXML score = music21.converter.parse(midi_path_str) # Create temporary output file path temp_output = os.path.join(tempfile.gettempdir(), f"output_{uuid4().hex}.musicxml") # Write to temporary file score.write('musicxml', temp_output) # Read back as string with open(temp_output, 'r') as f: musicxml_str = f.read() # Clean up os.unlink(temp_output) return musicxml_str except Exception as e: print(f"Conversion error: {str(e)}") traceback.print_exc() raise def load_model(): cache_dir = os.environ.get('HF_HOME', '/home/user/.cache/huggingface') print(f"Using cache directory: {cache_dir}") # Verify permissions try: test_file = os.path.join(cache_dir, "test.txt") with open(test_file, "w") as f: f.write("test") print("✅ Cache directory is writable") except Exception as e: print(f"❌ Cache directory not writable: {e}") # Load model return AutoModelForCausalLM.from_pretrained( 'stanford-crfm/music-small-800k', cache_dir=cache_dir, local_files_only=False, force_download=False ) # Model loading setup MODEL = None MODEL_LOCK = threading.Lock() # Initialize model when app starts with app.app_context(): load_model() @app.route('/upload', methods=['POST']) def handle_upload(): temp_midi_path = None top_p = float(request.form.get('top_p', '0.95')) try: # Validate input if 'midi_file' not in request.files: return jsonify({"status": "error", "message": "No MIDI file provided"}), 400 midi_file = request.files['midi_file'] start_time = request.form.get('start_time', '0') end_time = request.form.get('end_time', '0') # Create temporary MIDI file with random name temp_dir = tempfile.gettempdir() temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid") # Save uploaded MIDI to temp file midi_file.save(temp_midi_path) # Process MIDI midi = mido.MidiFile(temp_midi_path) model = load_model() harmonized_midi = harmonizer(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p) # Save harmonized MIDI (overwriting temp file) harmonized_midi.save(temp_midi_path) # Convert to MusicXML string musicxml_str = midi_to_musicxml(temp_midi_path) # Final type verification if not isinstance(musicxml_str, str): raise TypeError(f"Expected string but got {type(musicxml_str)}") return jsonify({ "status": "success", "musicxml": musicxml_str }) except Exception as e: print(f"Error processing request: {str(e)}") traceback.print_exc() return jsonify({ "status": "error", "message": str(e) }), 400 finally: # Clean up temp file if temp_midi_path and os.path.exists(temp_midi_path): try: os.unlink(temp_midi_path) except Exception as e: print(f"Warning: Could not remove {temp_midi_path}: {str(e)}") @app.route('/uploadinfill', methods=['POST']) def handle_upload_infilling(): temp_midi_path = None top_p = float(request.form.get('top_p', '0.95')) try: # Validate input if 'midi_file' not in request.files: return jsonify({"status": "error", "message": "No MIDI file provided"}), 400 midi_file = request.files['midi_file'] start_time = request.form.get('start_time', '0') end_time = request.form.get('end_time', '0') # Create temporary MIDI file with random name temp_dir = tempfile.gettempdir() temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid") # Save uploaded MIDI to temp file midi_file.save(temp_midi_path) # Process MIDI midi = mido.MidiFile(temp_midi_path) model = load_model() infilled_midi = infiller(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p) # Save harmonized MIDI (overwriting temp file) infilled_midi.save(temp_midi_path) # Convert to MusicXML string musicxml_str = midi_to_musicxml(temp_midi_path) # Final type verification if not isinstance(musicxml_str, str): raise TypeError(f"Expected string but got {type(musicxml_str)}") return jsonify({ "status": "success", "musicxml": musicxml_str }) except Exception as e: print(f"Error processing request: {str(e)}") traceback.print_exc() return jsonify({ "status": "error", "message": str(e) }), 400 finally: # Clean up temp file if temp_midi_path and os.path.exists(temp_midi_path): try: os.unlink(temp_midi_path) except Exception as e: print(f"Warning: Could not remove {temp_midi_path}: {str(e)}") @app.route('/uploadchangemelody', methods=['POST']) def handle_upload_changemelody(): temp_midi_path = None top_p = float(request.form.get('top_p', '0.95')) try: # Validate input if 'midi_file' not in request.files: return jsonify({"status": "error", "message": "No MIDI file provided"}), 400 midi_file = request.files['midi_file'] start_time = request.form.get('start_time', '0') end_time = request.form.get('end_time', '0') # Create temporary MIDI file with random name temp_dir = tempfile.gettempdir() temp_midi_path = os.path.join(temp_dir, f"temp_{uuid4().hex}.mid") # Save uploaded MIDI to temp file midi_file.save(temp_midi_path) # Process MIDI midi = mido.MidiFile(temp_midi_path) model = load_model() changed_melody_midi = change_melody(model,midi, int(start_time)/1000, int(end_time)/1000,top_p=top_p) # Save harmonized MIDI (overwriting temp file) changed_melody_midi.save(temp_midi_path) # Convert to MusicXML string musicxml_str = midi_to_musicxml(temp_midi_path) # Final type verification if not isinstance(musicxml_str, str): raise TypeError(f"Expected string but got {type(musicxml_str)}") return jsonify({ "status": "success", "musicxml": musicxml_str }) except Exception as e: print(f"Error processing request: {str(e)}") traceback.print_exc() return jsonify({ "status": "error", "message": str(e) }), 400 finally: # Clean up temp file if temp_midi_path and os.path.exists(temp_midi_path): try: os.unlink(temp_midi_path) except Exception as e: print(f"Warning: Could not remove {temp_midi_path}: {str(e)}") if __name__ == '__main__': app.run(debug=True, port=5000)