librosax.stft¶
- stft(waveform: Array, n_fft: int, hop_length: int = None, win_length: int = None, window: str = 'hann', center: bool = True, pad_mode: str = 'constant')[source]¶
Compute the Short-Time Fourier Transform (STFT) of a waveform.
This function computes the STFT of the given waveform using JAX’s
scipy.signal.stftimplementation.Note
For JAX JIT compilation, the following arguments should be marked as static:
n_fft,hop_length,win_length,window,center,pad_mode- Parameters:
waveform –
Input signal waveform. The last axis must be time (samples). Supported shapes:
(T,)- single waveform with T samples(B, T)- batch of B waveforms(B, C, T)- batch of B waveforms with C channels
n_fft – FFT size. Determines the number of frequency bins in the output:
n_fft // 2 + 1.hop_length – Number of samples between successive frames. Default is
win_length // 4.win_length – Window size. Default is
n_fft.window – Window function type. Default is
"hann". Also supports"sqrt_hann".center – If
True, the waveform is padded so that frames are centered. Default isTrue.pad_mode – Padding mode for the waveform. Must be one of
["constant", "reflect"]. Default is"constant".
- Returns:
Complex STFT matrix. The second-to-last axis is frequency (
n_fft // 2 + 1bins) and the last axis is time frames. Output shapes correspond to inputs:(T,)→(F, N)(B, T)→(B, F, N)(B, C, T)→(B, C, F, N)
where F =
n_fft // 2 + 1and N = number of frames.- Raises:
AssertionError – If pad_mode is not one of
["constant", "reflect"].
Examples
>>> import jax.numpy as jnp >>> import librosax >>> # Single waveform: (n_samples,) -> (n_freq, n_frames) >>> y = jnp.zeros(22050) # 1 second at 22050 Hz >>> S = librosax.stft(y, n_fft=2048, hop_length=512) >>> S.shape (1025, 44) >>> # Batched waveforms: (batch, n_samples) -> (batch, n_freq, n_frames) >>> y_batch = jnp.zeros((4, 22050)) >>> S_batch = librosax.stft(y_batch, n_fft=2048, hop_length=512) >>> S_batch.shape (4, 1025, 44)