upstream#

(s3prl.nn.upstream)

S3PRL Upstream Collection and some utilities

Authors:
  • Leo 2022

S3PRLUpstream#

class s3prl.nn.upstream.S3PRLUpstream(name: str, path_or_url: Optional[str] = None, refresh: bool = False, normalize: bool = False, extra_conf: Optional[dict] = None, randomize: bool = False)[source][source]#

Bases: Module

This is an easy interface for using all the models in S3PRL. See S3PRL Upstream Collection for the example usage and all the supported models.

Parameters:
  • name (str) – can be “apc”, “hubert”, “wav2vec2”. See available_names for all the supported names

  • path_or_url (str) – The source of the checkpoint. Might be a local path or a URL

  • refresh (bool) – (default, False) If false, only downlaod checkpoint if not yet downloaded before. If true, force to re-download the checkpoint.

  • extra_conf (dict) – (default, None) The extra arguments for each specific upstream, the available options are shown in each upstream section

  • randomize (bool) – (default, False) If True, randomize the upstream model

Note

When using S3PRLUpstream with refresh=True and multiprocessing (e.g. DDP), the checkpoint will only be downloaded once, and the other processes will simply re-use the newly downloaded checkpoint, instead of re-downloading on every processes, which can be very time/bandwidth consuming.

Example:

>>> import torch
>>> from s3prl.nn import S3PRLUpstream
...
>>> model = S3PRLUpstream("hubert")
>>> model.eval()
...
>>> with torch.no_grad():
...     wavs = torch.randn(2, 16000 * 2)
...     wavs_len = torch.LongTensor([16000 * 1, 16000 * 2])
...     all_hs, all_hs_len = model(wavs, wavs_len)
...
>>> for hs, hs_len in zip(all_hs, all_hs_len):
...     assert isinstance(hs, torch.FloatTensor)
...     assert isinstance(hs_len, torch.LongTensor)
...
...     batch_size, max_seq_len, hidden_size = hs.shape
...     assert hs_len.dim() == 1
classmethod available_names(only_registered_ckpt: bool = False) List[str][source][source]#

All the available names supported by this S3PRLUpstream

Parameters:

only_registered_ckpt (bool) – ignore entry names which require to give path_or_url. That is, the entry names without the registered checkpoint sources. These names end with _local (for local path), _url (for URL) or _custom (auto-determine path or URL)

property num_layers: int[source]#

Number of hidden sizes. All the upstream have a deterministic number of layers. That is, layer drop is turned off by default.

property downsample_rates: List[int][source]#

Downsampling rate from 16000 Hz audio of each layer. Usually, all layers have the same downsampling rate, but might not be the case for some advanced upstreams.

property hidden_sizes: List[int][source]#

The hidden size of each layer

forward(wavs: FloatTensor, wavs_len: LongTensor)[source][source]#
Parameters:
  • wavs (torch.FloatTensor) – (batch_size, seqlen) or (batch_size, seqlen, 1)

  • wavs_len (torch.LongTensor) – (batch_size, )

Returns:

List[torch.FloatTensor], List[torch.LongTensor]

  1. all the layers of hidden states: List[ (batch_size, max_seq_len, hidden_size) ]

  2. the valid length for each hidden states: List[ (batch_size, ) ]

call_super_init: bool = False[source]#
dump_patches: bool = False[source]#
training: bool[source]#

Featurizer#

class s3prl.nn.upstream.Featurizer(upstream: S3PRLUpstream, layer_selections: Optional[List[int]] = None, normalize: bool = False)[source][source]#

Bases: Module

Featurizer take the S3PRLUpstream’s multiple layer of hidden_states and reduce (standardize) them into a single hidden_states, to connect with downstream NNs.

This basic Featurizer expects all the layers to have same stride and hidden_size When the input upstream only have a single layer of hidden states, use that directly. If multiple layers are presented, add a trainable weighted-sum on top of those layers.

Parameters:
  • upstream (S3PRLUpstream) – the upstream to extract features, this upstream is used only for initialization and will not be kept in this Featurizer object

  • layer_selections (List[int]) – To select a subset of hidden states from the given upstream by layer ids (0-index) If None (default), than all the layer of hidden states are selected

  • normalize (bool) – Whether to apply layer norm on all the hidden states before weighted-sum This can help convergence in some cases, but not used in SUPERB to ensure the fidelity of each upstream’s extracted representation.

Example:

>>> import torch
>>> from s3prl.nn import S3PRLUpstream, Featurizer
...
>>> model = S3PRLUpstream("hubert")
>>> model.eval()
...
>>> with torch.no_grad():
...     wavs = torch.randn(2, 16000 * 2)
...     wavs_len = torch.LongTensor([16000 * 1, 16000 * 2])
...     all_hs, all_hs_len = model(wavs, wavs_len)
...
>>> featurizer = Featurizer(model)
>>> hs, hs_len = featurizer(all_hs, all_hs_len)
...
>>> assert isinstance(hs, torch.FloatTensor)
>>> assert isinstance(hs_len, torch.LongTensor)
>>> batch_size, max_seq_len, hidden_size = hs.shape
>>> assert hs_len.dim() == 1
property output_size: int[source]#

The hidden size of the final weighted-sum output

property downsample_rate: int[source]#

The downsample rate (from 16k Hz waveform) of the final weighted-sum output

forward(all_hs: List[FloatTensor], all_lens: List[LongTensor])[source][source]#
Parameters:
  • all_hs (List[torch.FloatTensor]) – List[ (batch_size, seq_len, hidden_size) ]

  • all_lens (List[torch.LongTensor]) – List[ (batch_size, ) ]

Returns:

torch.FloatTensor, torch.LongTensor

  1. The weighted-sum result, (batch_size, seq_len, hidden_size)

  2. the valid length of the result, (batch_size, )

call_super_init: bool = False[source]#
dump_patches: bool = False[source]#
training: bool[source]#

UpstreamDownstreamModel#

class s3prl.nn.upstream.UpstreamDownstreamModel(upstream: S3PRLUpstream, featurizer: Featurizer, downstream, upstream_trainable: bool = False)[source][source]#

Bases: Module

property input_size[source]#
property downsample_rate[source]#
call_super_init: bool = False[source]#
dump_patches: bool = False[source]#
property output_size[source]#
training: bool[source]#
forward(wav, wav_len, *args, **kwargs)[source][source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.