upstream#
(s3prl.nn.upstream)
S3PRL Upstream Collection and some utilities
- Authors:
Leo 2022
S3PRLUpstream#
- class s3prl.nn.upstream.S3PRLUpstream(name: str, path_or_url: Optional[str] = None, refresh: bool = False, normalize: bool = False, extra_conf: Optional[dict] = None, randomize: bool = False)[source][source]#
Bases:
Module
This is an easy interface for using all the models in S3PRL. See S3PRL Upstream Collection for the example usage and all the supported models.
- Parameters:
name (str) – can be “apc”, “hubert”, “wav2vec2”. See
available_names
for all the supported namespath_or_url (str) – The source of the checkpoint. Might be a local path or a URL
refresh (bool) – (default, False) If false, only downlaod checkpoint if not yet downloaded before. If true, force to re-download the checkpoint.
extra_conf (dict) – (default, None) The extra arguments for each specific upstream, the available options are shown in each upstream section
randomize (bool) – (default, False) If True, randomize the upstream model
Note
When using S3PRLUpstream with
refresh=True
and multiprocessing (e.g. DDP), the checkpoint will only be downloaded once, and the other processes will simply re-use the newly downloaded checkpoint, instead of re-downloading on every processes, which can be very time/bandwidth consuming.Example:
>>> import torch >>> from s3prl.nn import S3PRLUpstream ... >>> model = S3PRLUpstream("hubert") >>> model.eval() ... >>> with torch.no_grad(): ... wavs = torch.randn(2, 16000 * 2) ... wavs_len = torch.LongTensor([16000 * 1, 16000 * 2]) ... all_hs, all_hs_len = model(wavs, wavs_len) ... >>> for hs, hs_len in zip(all_hs, all_hs_len): ... assert isinstance(hs, torch.FloatTensor) ... assert isinstance(hs_len, torch.LongTensor) ... ... batch_size, max_seq_len, hidden_size = hs.shape ... assert hs_len.dim() == 1
- classmethod available_names(only_registered_ckpt: bool = False) List[str] [source][source]#
All the available names supported by this S3PRLUpstream
- Parameters:
only_registered_ckpt (bool) – ignore entry names which require to give path_or_url. That is, the entry names without the registered checkpoint sources. These names end with
_local
(for local path),_url
(for URL) or_custom
(auto-determine path or URL)
- property num_layers: int[source]#
Number of hidden sizes. All the upstream have a deterministic number of layers. That is, layer drop is turned off by default.
- property downsample_rates: List[int][source]#
Downsampling rate from 16000 Hz audio of each layer. Usually, all layers have the same downsampling rate, but might not be the case for some advanced upstreams.
The hidden size of each layer
- forward(wavs: FloatTensor, wavs_len: LongTensor)[source][source]#
- Parameters:
wavs (torch.FloatTensor) – (batch_size, seqlen) or (batch_size, seqlen, 1)
wavs_len (torch.LongTensor) – (batch_size, )
- Returns:
List[torch.FloatTensor], List[torch.LongTensor]
all the layers of hidden states: List[ (batch_size, max_seq_len, hidden_size) ]
the valid length for each hidden states: List[ (batch_size, ) ]
Featurizer#
- class s3prl.nn.upstream.Featurizer(upstream: S3PRLUpstream, layer_selections: Optional[List[int]] = None, normalize: bool = False)[source][source]#
Bases:
Module
Featurizer take the
S3PRLUpstream
’s multiple layer of hidden_states and reduce (standardize) them into a single hidden_states, to connect with downstream NNs.This basic Featurizer expects all the layers to have same stride and hidden_size When the input upstream only have a single layer of hidden states, use that directly. If multiple layers are presented, add a trainable weighted-sum on top of those layers.
- Parameters:
upstream (
S3PRLUpstream
) – the upstream to extract features, this upstream is used only for initialization and will not be kept in this Featurizer objectlayer_selections (List[int]) – To select a subset of hidden states from the given upstream by layer ids (0-index) If None (default), than all the layer of hidden states are selected
normalize (bool) – Whether to apply layer norm on all the hidden states before weighted-sum This can help convergence in some cases, but not used in SUPERB to ensure the fidelity of each upstream’s extracted representation.
Example:
>>> import torch >>> from s3prl.nn import S3PRLUpstream, Featurizer ... >>> model = S3PRLUpstream("hubert") >>> model.eval() ... >>> with torch.no_grad(): ... wavs = torch.randn(2, 16000 * 2) ... wavs_len = torch.LongTensor([16000 * 1, 16000 * 2]) ... all_hs, all_hs_len = model(wavs, wavs_len) ... >>> featurizer = Featurizer(model) >>> hs, hs_len = featurizer(all_hs, all_hs_len) ... >>> assert isinstance(hs, torch.FloatTensor) >>> assert isinstance(hs_len, torch.LongTensor) >>> batch_size, max_seq_len, hidden_size = hs.shape >>> assert hs_len.dim() == 1
- property downsample_rate: int[source]#
The downsample rate (from 16k Hz waveform) of the final weighted-sum output
- forward(all_hs: List[FloatTensor], all_lens: List[LongTensor])[source][source]#
- Parameters:
all_hs (List[torch.FloatTensor]) – List[ (batch_size, seq_len, hidden_size) ]
all_lens (List[torch.LongTensor]) – List[ (batch_size, ) ]
- Returns:
torch.FloatTensor, torch.LongTensor
The weighted-sum result, (batch_size, seq_len, hidden_size)
the valid length of the result, (batch_size, )
UpstreamDownstreamModel#
- class s3prl.nn.upstream.UpstreamDownstreamModel(upstream: S3PRLUpstream, featurizer: Featurizer, downstream, upstream_trainable: bool = False)[source][source]#
Bases:
Module
- forward(wav, wav_len, *args, **kwargs)[source][source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.