Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support longer audio files reducing memory usage with chunking #2256

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

def test_audio():
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
audio = load_audio(audio_path)
audio = next(load_audio(audio_path))
assert audio.ndim == 1
assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12
assert 0 < audio.std() < 1

mel_from_audio = log_mel_spectrogram(audio)
mel_from_file = log_mel_spectrogram(audio_path)
mel_from_audio = next(log_mel_spectrogram(audio))
mel_from_file = next(log_mel_spectrogram(audio_path))

assert np.allclose(mel_from_audio, mel_from_file)
assert mel_from_audio.max() - mel_from_audio.min() <= 2.0
45 changes: 32 additions & 13 deletions whisper/audio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import subprocess
from functools import lru_cache
from subprocess import CalledProcessError, run
from typing import Optional, Union
from typing import Generator, Optional, Union

import numpy as np
import torch
Expand All @@ -21,6 +22,7 @@
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token

MAX_CHUNK_DURATION = 2 * 60 * 60 # 2 hour maximum chunk duration

def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Expand Down Expand Up @@ -55,12 +57,16 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

while True:
out = process.stdout.read(MAX_CHUNK_DURATION * sr * 2)
if not out:
break
yield np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
except Exception as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0


def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Expand Down Expand Up @@ -108,7 +114,7 @@ def mel_filters(device, n_mels: int) -> torch.Tensor:


def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
audio: Union[str, np.ndarray, torch.Tensor, Generator[np.ndarray, None, None]],
n_mels: int = 80,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
Expand All @@ -135,13 +141,26 @@ def log_mel_spectrogram(
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)

if device is not None:
audio = audio.to(device)
if isinstance(audio, str):
audio = load_audio(audio)
elif isinstance(audio, np.ndarray):
audio = [audio]
elif isinstance(audio, torch.Tensor):
audio = [audio]

for chunk in audio:
if not isinstance(chunk, torch.Tensor):
chunk = torch.from_numpy(chunk)
if device is not None:
chunk = chunk.to(device)
yield _log_mel_spectrogram(chunk, n_mels, padding)


def _log_mel_spectrogram(
audio: torch.Tensor,
n_mels: int = 80,
padding: int = 0,
):
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
Expand Down
Loading