librosax.layers.SpecAugmentation¶
- class SpecAugmentation(*args: Any, **kwargs: Any)[source]¶
A module that applies SpecAugment data augmentation to spectrograms.
SpecAugment is a data augmentation technique that applies both time and frequency masking to spectrograms for audio tasks. It randomly masks rectangular blocks along the time and frequency dimensions by setting values to zero.
- Variables:
time_drop_width – Maximum width (in time frames) of each time mask. Each mask will have a random width between 0 and this value. For example, if
time_drop_width=30and your spectrogram has 200 time frames, each mask can be 0-30 frames wide.time_stripes_num – Number of time masks to apply. Each mask is independently positioned and sized.
freq_drop_width – Maximum width (in frequency bins) of each frequency mask. Each mask will have a random width between 0 and this value. For example, if
freq_drop_width=20and your spectrogram has 128 frequency bins, each mask can be 0-20 bins wide.freq_stripes_num – Number of frequency masks to apply. Each mask is independently positioned and sized.
deterministic – If
True, no augmentation is applied. Default isFalse.rng_collection – The rng collection name to use when requesting an rng key. Default is
"dropout".rngs – Random number generator key for generating random masks.
Example
>>> import jax >>> from flax import nnx >>> from librosax.layers import SpecAugmentation >>> >>> # Create augmentation layer >>> spec_aug = SpecAugmentation( ... time_drop_width=64, # Max 64 time frames per mask ... time_stripes_num=2, # Apply 2 time masks ... freq_drop_width=16, # Max 16 freq bins per mask ... freq_stripes_num=2, # Apply 2 frequency masks ... rngs=nnx.Rngs(jax.random.key(0)) ... ) >>> >>> # Apply to spectrogram (batch_size, channels, time, freq) >>> spec = jnp.ones((4, 1, 200, 128)) >>> augmented = spec_aug(spec, deterministic=False)
Note
The masks are applied multiplicatively (by setting regions to 0), so overlapping masks will not create additional effect. Each batch item receives different random masks for better augmentation diversity.
- __init__(time_drop_width: int, time_stripes_num: int, freq_drop_width: int, freq_stripes_num: int, deterministic: bool = False, rng_collection: str = 'dropout', rngs: Rngs | RngStream | None = None)[source]¶
Methods
__init__(time_drop_width, time_stripes_num, ...)eval(**attributes)Sets the Module to evaluation mode.
iter_children()Iterates over all children
Module's of the current Module.iter_modules()Recursively iterates over all nested
Module's of the current Module, including the current Module.perturb(name, value[, variable_type])Add an zero-value variable ("perturbation") to the intermediate value.
set_attributes(*filters[, raise_if_not_found])Sets the attributes of nested Modules including the current Module.
sow(variable_type, name, value[, reduce_fn, ...])sow()can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.train(**attributes)Sets the Module to training mode.