diff --git a/README.md b/README.md index 00acb6e..2f3c6d9 100644 --- a/README.md +++ b/README.md @@ -105,13 +105,18 @@ basic-pitch --help **predict()** -Import `basic-pitch` into your own Python code and run the [`predict`](basic_pitch/inference.py) functions directly, providing an `` and returning the model's prediction results: +Import `basic-pitch` into your own Python code and run the [`predict`](basic_pitch/inference.py) functions directly, providing an `` or an `` and returning the model's prediction results: ```python from basic_pitch.inference import predict -from basic_pitch import ICASSP_2022_MODEL_PATH +# get model predictions given an audio file model_output, midi_data, note_events = predict() + +# or alternatively, provide an array of samples +audio_array, sample_rate = librosa.load(, mono=True, duration=10.0, offset=5.0) +model_output, midi_data, note_events = predict(audio_array, sample_rate) + ``` - `` & `` (*float*s) set the maximum and minimum allowed note frequency, in Hz, returned by the model. Pitch events with frequencies outside of this range will be excluded from the prediction results. diff --git a/basic_pitch/inference.py b/basic_pitch/inference.py index 25b062c..2475cee 100644 --- a/basic_pitch/inference.py +++ b/basic_pitch/inference.py @@ -213,7 +213,7 @@ def window_audio_file( def get_audio_input( - audio_path: Union[pathlib.Path, str], overlap_len: int, hop_size: int + audio_path_or_array: Union[pathlib.Path, str, np.ndarray], sample_rate: int, overlap_len: int, hop_size: int ) -> Iterable[Tuple[npt.NDArray[np.float32], Dict[str, float], int]]: """ Read wave file (as mono), pad appropriately, and return as @@ -228,8 +228,20 @@ def get_audio_input( """ assert overlap_len % 2 == 0, f"overlap_length must be even, got {overlap_len}" - - audio_original, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True) + # if a numpy array of samples is provided, use it directly + if isinstance(audio_path_or_array, np.ndarray): + audio_original = audio_path_or_array + if sample_rate is None: + raise ValueError("Sample rate must be provided when input is an array of audio samples.") + # resample audio if required + elif sample_rate != AUDIO_SAMPLE_RATE: + audio_original = librosa.resample(audio_original, orig_sr=sample_rate, target_sr=AUDIO_SAMPLE_RATE) + # convert to mono if necessary + if audio_original.ndim != 1: + audio_original = librosa.to_mono(audio_path_or_array) + # load audio file + else: + audio_original, _ = librosa.load(str(audio_path_or_array), sr=AUDIO_SAMPLE_RATE, mono=True) original_length = audio_original.shape[0] audio_original = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original]) @@ -267,14 +279,16 @@ def unwrap_output( def run_inference( - audio_path: Union[pathlib.Path, str], + audio_path_or_array: Union[pathlib.Path, str, np.ndarray], + sample_rate: None, model_or_model_path: Union[Model, pathlib.Path, str], debug_file: Optional[pathlib.Path] = None, ) -> Dict[str, np.array]: """Run the model on the input audio path. Args: - audio_path: The audio to run inference on. + audio_path_or_array: The audio to run inference on. Can be either the path to an audio file or a numpy array of audio samples. + sample_rate: Sample rate of the audio file. Only used if audio_path_or_array is a np array. model_or_model_path: A loaded Model or path to a serialized model to load. debug_file: An optional path to output debug data to. Useful for testing/verification. @@ -292,7 +306,7 @@ def run_inference( hop_size = AUDIO_N_SAMPLES - overlap_len output: Dict[str, Any] = {"note": [], "onset": [], "contour": []} - for audio_windowed, _, audio_original_length in get_audio_input(audio_path, overlap_len, hop_size): + for audio_windowed, _, audio_original_length in get_audio_input(audio_path_or_array, sample_rate, overlap_len, hop_size): for k, v in model.predict(audio_windowed).items(): output[k].append(v) @@ -415,7 +429,8 @@ def save_note_events( def predict( - audio_path: Union[pathlib.Path, str], + audio_path_or_array: Union[pathlib.Path, str, np.ndarray], + sample_rate: int = None, model_or_model_path: Union[Model, pathlib.Path, str] = ICASSP_2022_MODEL_PATH, onset_threshold: float = 0.5, frame_threshold: float = 0.3, @@ -426,6 +441,7 @@ def predict( melodia_trick: bool = True, debug_file: Optional[pathlib.Path] = None, midi_tempo: float = 120, + verbose: bool = False ) -> Tuple[ Dict[str, np.array], pretty_midi.PrettyMIDI, @@ -434,7 +450,8 @@ def predict( """Run a single prediction. Args: - audio_path: File path for the audio to run inference on. + audio_path_or_array: File path for the audio to run inference on or array of audio samples. + sample_rate: Sample rate of the audio file. Only used if audio_path_or_array is a np array. model_or_model_path: A loaded Model or path to a serialized model to load. onset_threshold: Minimum energy required for an onset to be considered present. frame_threshold: Minimum energy requirement for a frame to be considered present. @@ -449,9 +466,12 @@ def predict( """ with no_tf_warnings(): - print(f"Predicting MIDI for {audio_path}...") + if isinstance(audio_path_or_array, np.ndarray) and verbose: + print("Predicting MIDI ...") + elif verbose: + print(f"Predicting MIDI for {audio_path_or_array}...") - model_output = run_inference(audio_path, model_or_model_path, debug_file) + model_output = run_inference(audio_path_or_array, sample_rate, model_or_model_path, debug_file) min_note_len = int(np.round(minimum_note_length / 1000 * (AUDIO_SAMPLE_RATE / FFT_HOP))) midi_data, note_events = infer.model_output_to_notes( model_output,