librosax.istft¶
- istft(stft_matrix: Array, hop_length: int = None, win_length: int = None, n_fft: int = None, window: str = 'hann', center: bool = True, length: int = None)[source]¶
Compute the Inverse Short-Time Fourier Transform (ISTFT).
This function reconstructs a waveform from an STFT matrix using JAX’s
scipy.signal.istftimplementation.- Parameters:
stft_matrix –
Complex STFT matrix. The second-to-last axis must be frequency (
n_fft // 2 + 1bins) and the last axis must be time frames. Supported shapes:(F, N)- single spectrogram(B, F, N)- batch of B spectrograms(B, C, F, N)- batch of B spectrograms with C channels
where F =
n_fft // 2 + 1and N = number of frames.hop_length – Number of samples between successive frames. Default is
win_length // 4.win_length – Window size. Default is
n_fft.n_fft – FFT size. Default is
(stft_matrix.shape[-2] - 1) * 2, inferred from the frequency axis of the input.window – Window function type. Default is
"hann". Also supports"sqrt_hann".center – If
True, assumes the waveform was padded so that frames were centered. Default isTrue.length – Target length for the reconstructed signal. If
None, the entire signal is returned.
- Returns:
Reconstructed time-domain signal. The last axis is time (samples). Output shapes correspond to inputs:
(F, N)→(T,)(B, F, N)→(B, T)(B, C, F, N)→(B, C, T)
where T is the reconstructed signal length.
- Raises:
AssertionError – If center is
Falsebecause the function is only tested forcenter=True.
Examples
>>> import jax.numpy as jnp >>> import librosax >>> # Single STFT: (n_freq, n_frames) -> (n_samples,) >>> S = jnp.zeros((1025, 44), dtype=jnp.complex64) >>> y = librosax.istft(S, hop_length=512) >>> y.shape (22016,) >>> # Batched STFT: (batch, n_freq, n_frames) -> (batch, n_samples) >>> S_batch = jnp.zeros((4, 1025, 44), dtype=jnp.complex64) >>> y_batch = librosax.istft(S_batch, hop_length=512) >>> y_batch.shape (4, 22016) >>> # With target length >>> y_trimmed = librosax.istft(S, hop_length=512, length=22050) >>> y_trimmed.shape (22050,)