import gradio as gr import os import datetime import pytz from pathlib import Path def current_time(): # Format the time string in English current = datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y-%m-%d %H:%M:%S") return current print(f"[{current_time()}] Starting to deploy space...") print(f"[{current_time()}] Log: Installing - gsutil") os.system("pip install gsutil") print(f"[{current_time()}] Log: Git - Cloning Github's T5X training framework to the current directory") os.system("git clone --branch=main https://github.com/google-research/t5x") print(f"[{current_time()}] Log: File - Moving t5x to the current directory, renaming it to t5x_tmp, and then deleting the temp directory") os.system("mv t5x t5x_tmp; mv t5x_tmp/* .; rm -r t5x_tmp") print(f"[{current_time()}] Log: Edit - Replacing the text 'jax[tpu]' with 'jax' in setup.py") os.system("sed -i 's:jax\\[tpu\\]:jax:' setup.py") print(f"[{current_time()}] Log: Python - Using pip to install the Python package in the current directory") os.system("python3 -m pip install -e .") print(f"[{current_time()}] Log: Python - Upgrading the Python package manager pip") os.system("python3 -m pip install --upgrade pip") print(f"[{current_time()}] Log: Installing - langchain") os.system("pip install langchain") print(f"[{current_time()}] Log: Installing - sentence-transformers") os.system("pip install sentence-transformers") print(f"[{current_time()}] Log: Git - Cloning Github's airio to the current directory") os.system("git clone --branch=main https://github.com/google/airio") print(f"[{current_time()}] Log: File - Moving airio to the current directory, renaming it to airio_tmp, and then deleting the temp directory") os.system("mv airio airio_tmp; mv airio_tmp/* .; rm -r airio_tmp") print(f"[{current_time()}] Log: Python - Using pip to install the Python package in the current directory") os.system("python3 -m pip install -e .") print(f"[{current_time()}] Log: Git - Cloning Github's MT3 model to the current directory") os.system("git clone --branch=main https://github.com/magenta/mt3") print(f"[{current_time()}] Log: File - Renaming the mt3 directory to mt3_tmp...") os.system("mv mt3 mt3_tmp") print(f"[{current_time()}] Log: File - Moving all files from the mt3_tmp directory to the current directory") os.system("mv mt3_tmp/* .") print(f"[{current_time()}] Log: File - Deleting the mt3_tmp directory") os.system("rm -r mt3_tmp") print(f"[{current_time()}] Log: Importing - Necessary tools from the mt3 directory") from mt3 import metrics_utils from mt3 import models from mt3 import network from mt3 import note_sequences from mt3 import preprocessors from mt3 import spectrograms from mt3 import vocabularies print(f"[{current_time()}] Log: Python - Using pip to install jax[cuda12_local] nest-asyncio pyfluidsynth from storage.googleapis.com") os.system("python3 -m pip install jax[cuda12_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html") print(f"[{current_time()}] Log: Installing - Upgrading jaxlib") os.system("pip install --upgrade jaxlib") print(f"[{current_time()}] Log: Python - Using pip to install the Python package in the current directory") os.system("python3 -m pip install -e .") print(f"[{current_time()}] Log: Installing - TensorFlow CPU") os.system("pip install tensorflow_cpu") print(f"[{current_time()}] Log: gsutil - Copying MT3 checkpoints to the current directory") os.system("gsutil -q -m cp -r gs://mt3/checkpoints .") print(f"[{current_time()}] Log: gsutil - Copying SoundFont file to the current directory") os.system("gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .") print(f"[{current_time()}] Log: Importing - Necessary tools") import functools import os import numpy as np import tensorflow.compat.v2 as tf import gin import jax import librosa import note_seq import seqio import t5 import t5x import nest_asyncio nest_asyncio.apply() SAMPLE_RATE = 16000 SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2' def upload_audio(audio, sample_rate): return note_seq.audio_io.wav_data_to_samples_librosa( audio, sample_rate=sample_rate) print(f"[{current_time()}] Log: Starting to wrap the model...") class InferenceModel(object): """A T5X model wrapper for music transcription.""" def __init__(self, checkpoint_path, model_type='mt3'): if model_type == 'ismir2021': num_velocity_bins = 127 self.encoding_spec = note_sequences.NoteEncodingSpec self.inputs_length = 512 elif model_type == 'mt3': num_velocity_bins = 1 self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec self.inputs_length = 256 else: raise ValueError('unknown model_type: %s' % model_type) gin_files = ['/home/user/app/mt3/gin/model.gin', '/home/user/app/mt3/gin/mt3.gin'] self.batch_size = 8 self.outputs_length = 1024 self.sequence_length = {'inputs': self.inputs_length, 'targets': self.outputs_length} self.partitioner = t5x.partitioning.PjitPartitioner( model_parallel_submesh=None, num_partitions=1) print(f"[{current_time()}] Log: Building codec") self.spectrogram_config = spectrograms.SpectrogramConfig() self.codec = vocabularies.build_codec( vocab_config=vocabularies.VocabularyConfig( num_velocity_bins=num_velocity_bins) ) self.vocabulary = vocabularies.vocabulary_from_codec(self.codec) self.output_features = { 'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2), 'targets': seqio.Feature(vocabulary=self.vocabulary), } print(f"[{current_time()}] Log: Creating T5X model") self._parse_gin(gin_files) self.model = self._load_model() print(f"[{current_time()}] Log: Restoring model checkpoint") self.restore_from_checkpoint(checkpoint_path) @property def input_shapes(self): return { 'encoder_input_tokens': (self.batch_size, self.inputs_length), 'decoder_input_tokens': (self.batch_size, self.outputs_length) } def _parse_gin(self, gin_files): """Parses the gin files used to train the model.""" print(f"[{current_time()}] Log: Parsing gin files") gin_bindings = [ 'from __gin__ import dynamic_registration', 'from mt3 import vocabularies', 'VOCAB_CONFIG=@vocabularies.VocabularyConfig()', 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS' ] with gin.unlock_config(): gin.parse_config_files_and_bindings(gin_files, gin_bindings, finalize_config=False) def _load_model(self): """Loads the T5X `Model` after parsing the training gin config.""" print(f"[{current_time()}] Log: Loading T5X model") model_config = gin.get_configurable(network.T5Config)() module = network.Transformer(config=model_config) return models.ContinuousInputsEncoderDecoderModel( module=module, input_vocabulary=self.output_features['inputs'].vocabulary, output_vocabulary=self.output_features['targets'].vocabulary, optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0), input_depth=spectrograms.input_depth(self.spectrogram_config)) def restore_from_checkpoint(self, checkpoint_path): """Restores the training state from a checkpoint, resetting self._predict_fn().""" print(f"[{current_time()}] Log: Restoring training state from checkpoint") train_state_initializer = t5x.utils.TrainStateInitializer( optimizer_def=self.model.optimizer_def, init_fn=self.model.get_initial_variables, input_shapes=self.input_shapes, partitioner=self.partitioner) restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig( path=checkpoint_path, mode='specific', dtype='float32') train_state_axes = train_state_initializer.train_state_axes self._predict_fn = self._get_predict_fn(train_state_axes) self._train_state = train_state_initializer.from_checkpoint_or_scratch( [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0)) @functools.lru_cache() def _get_predict_fn(self, train_state_axes): """Generates a partitioned prediction function for decoding.""" print(f"[{current_time()}] Log: Generating prediction function for decoding") def partial_predict_fn(params, batch, decode_rng): return self.model.predict_batch_with_aux(params, batch, decoder_params={'decode_rng': None}) return self.partitioner.partition( partial_predict_fn, in_axis_resources=( train_state_axes.params, t5x.partitioning.PartitionSpec('data',), None), out_axis_resources=t5x.partitioning.PartitionSpec('data',) ) def predict_tokens(self, batch, seed=0): """Predicts tokens from a preprocessed dataset batch.""" print(f"[{current_time()}] Running: Predicting note sequence from preprocessed dataset (seed: {seed})") prediction, _ = self._predict_fn(self._train_state.params, batch, jax.random.PRNGKey(seed)) return self.vocabulary.decode_tf(prediction).numpy() def __call__(self, audio): """Infers a note sequence from audio samples. Args: audio: 1-D numpy array of a single audio sample at 16kHz. Returns: A note sequence of the transcribed audio. """ print(f"[{current_time()}] Running: Inferring note sequence from audio samples") ds = self.audio_to_dataset(audio) ds = self.preprocess(ds) model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)( ds, task_feature_lengths=self.sequence_length) model_ds = model_ds.batch(self.batch_size) inferences = (tokens for batch in model_ds.as_numpy_iterator() for tokens in self.predict_tokens(batch)) predictions = [] for example, tokens in zip(ds.as_numpy_iterator(), inferences): predictions.append(self.postprocess(tokens, example)) result = metrics_utils.event_predictions_to_ns( predictions, codec=self.codec, encoding_spec=self.encoding_spec) return result['est_ns'] def audio_to_dataset(self, audio): """Creates a TF Dataset with spectrograms from input audio.""" print(f"[{current_time()}] Running: Creating TF Dataset with spectrograms from audio") frames, frame_times = self._audio_to_frames(audio) return tf.data.Dataset.from_tensors({ 'inputs': frames, 'input_times': frame_times, }) def _audio_to_frames(self, audio): """Computes spectrogram frames from audio.""" print(f"[{current_time()}] Running: Computing spectrogram frames from audio") frame_size = self.spectrogram_config.hop_width padding = [0, frame_size - len(audio) % frame_size] audio = np.pad(audio, padding, mode='constant') frames = spectrograms.split_audio(audio, self.spectrogram_config) num_frames = len(audio) // frame_size times = np.arange(num_frames) / self.spectrogram_config.frames_per_second return frames, times def preprocess(self, ds): pp_chain = [ functools.partial( t5.data.preprocessors.split_tokens_to_inputs_length, sequence_length=self.sequence_length, output_features=self.output_features, feature_key='inputs', additional_feature_keys=['input_times']), # Cache during training. preprocessors.add_dummy_targets, functools.partial( preprocessors.compute_spectrograms, spectrogram_config=self.spectrogram_config) ] for pp in pp_chain: ds = pp(ds) return ds def postprocess(self, tokens, example): tokens = self._trim_eos(tokens) start_time = example['input_times'][0] # Round down to the nearest tokenized time step. start_time -= start_time % (1 / self.codec.steps_per_second) return { 'est_tokens': tokens, 'start_time': start_time, # The internal MT3 code expects the raw inputs, which are not used here. 'raw_inputs': [] } @staticmethod def _trim_eos(tokens): tokens = np.array(tokens, np.int32) if vocabularies.DECODED_EOS_ID in tokens: tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)] return tokens inference_model = InferenceModel('/home/user/app/checkpoints/mt3/', 'mt3') def inference(audio): filename = os.path.basename(audio) print(f"[{current_time()}] Running: Input file: {filename}") with open(audio, 'rb') as fd: contents = fd.read() audio_data = upload_audio(contents, sample_rate=16000) est_ns = inference_model(audio_data) note_seq.sequence_proto_to_midi_file(est_ns, './transcribed.mid') return './transcribed.mid' with gr.Blocks(title="MT3", theme="Thatguy099/Sonix") as demo: gr.HTML( """

🤝 Community resource — please use responsibly to keep this service available for everyone

""" ) gr.HTML('''

MT3: Multi-Task Multitrack Music Transcription

To use it, simply upload an audio file, or click on an example to see how it works. See the links below for more information.

''') with gr.Row(): with gr.Column(): audio_input = gr.Audio(type="filepath", label="Input", sources=['upload'], interactive=True) with gr.Column(): midi_output = gr.File(label="Output") with gr.Row(): gr.Examples( # The first example filename translates to "Good Morning, Great Forest.mp3" examples=[['早安大森林.mp3'], ['canon.flac'], ['download.wav']], label="Examples", fn=inference, inputs=audio_input, outputs=midi_output, cache_examples=True ) with gr.Row(): submit_btn = gr.Button("Run", variant="primary") submit_btn.click( fn=inference, inputs=audio_input, outputs=midi_output ) gr.HTML('''''') demo.launch( server_name="0.0.0.0", share=True, mcp_server=True, show_api=True )