The #2140 PR implements
a streaming model in SpeechBrain, based on a Conformer-Transducer architecture.
The "transducer" part refers to (part of) the loss used (RNN-T loss), but the
vast majority of changes needed for streaming were specific to the Conformer
model.
Two main parameters exist:
- Chunk size, which basically corresponds to how many new audio frames we need to throw at the model to get new predictions. With most models, a smaller chunk size will worsen the model's accuracy, but improve latency. Lower chunk sizes also tend to worsen the RTF (decoding speed).
- Left context size, which, here, corresponds to the number of left context frames at each transformer layer, which will be detailed later. When not using streaming, the left and right context are technically infinite. Larger contexts mean higher accuracy, but also higher memory and computational costs.
This model leverages several techniques for streaming, and is trained to be able to cope with different chunk and left context sizes chosen at runtime (with certain limits).
Both the chunk size and left context sizes are defined in terms of frames at transformer level. However, by knowing the entirety of the architecture, we deduce how many audio samples this corresponds to, which is required at inference time.
When we stream, we have to constrain our models in certain ways. The most
important is that we have to perfectly control and understand what outputs
depend on what inputs: We need to be able to infer continuously, for a
potentially long time, and predictions need to be constrained to a low-ish
latency (the prediction at time t
cannot depend on, say, t + 5s
).
Additionally, these constraints need to be enforced during training, and this
enforcement is not necessarily implemented the same between training and
inference, as the former usually depends on masking and the latter needs special
code to operate over individual chunks of data.
These constraints were generally verified by using the newly introduced
speechbrain.utils.streaming.infer_dependency_matrix
(and
plot_dependency_matrix
), which attempts to determine what inputs do any output
depend on by randomizing that input data.
The streaming approach has the benefit of having a fixed memory cost long term. This means that the model can be used to decode long files which would otherwise run out of memory or require workarounds with typical models.
At training, for each different batch, we select a random chunk size and a random left context size. This is so that the model can adapt to different situations.
The strategy we adopt is (with the default hparams):
- A 60% chance to use chunking for this batch with a chunk size randomly sampled
in the range
8..32
- Then, a 75% chance to limit the left context to a value randomly sampled in
the range
16..64
Some of the prior art describes obtaining better results with this approach than
if the model were trained normally for half the epochs and then exclusively with
chunking for the other half (e.g. https://arxiv.org/abs/2012.05481).
Training the model for the entire duration with a mix of both methods provides
the added benefit that the model can be used with an infinite chunk size, thus
behaving like a non-streaming model.
When training our model on the LibriSpeech dataset, we found that for this
particular architecture and setup, the error rate is very similar between the
non-streaming and the streaming model (in non-streaming mode for both).
The objective of chunked attention is to prevent an output at a given time point
from depending on data far into the future.
Typically, you would be able to achieve this by preventing a frame in the
transformer at time t
from depending on any frame at a time > t
, i.e. a
given output can only depend on past data, i.e. "left attention".
However, this generally tends to harm accuracy, so we prefer to give the ability
to restrain the attention to non-overlapping chunks (of e.g. 8 frames), where
any output frame can depend on any input frame from within the same chunk
(and left context, i.e. past data, as described later).
Figure 2 in https://arxiv.org/abs/2012.05481 demonstrates the difference. Note that this is not specific to the Conformer model but can sometimes be applied to other Transformer-based models.
A nice property of this approach is that this constraint persists across layers, meaning that any given output at the last layer will only ever depend on the input frames from within the same chunk.
For example, with a chunk size of 4 (and infinite left context), output frames
at indices (4,5,6,7)
at the last layer will all depend on input frames at
indices (0,1,2,3,4,5,6,7)
, but not (8,9,...)
, which we wouldn't know yet
in a streaming inference context.
Emphasis on the fact that output frame 4
can directly depend on 7
in the
above example, because it is within the same chunk (i//chunk_size == 1
), which
is the difference with pure left context (in which case 4
could only depend on
inputs (0,1,2,3,4)
).
At train time, chunked attention is enforced through masking.
For each layer of the conformer, we save (with the default hparams) a certain number of left frames (16 or 32 are reasonable values).
At train time, this is enforced by masking (at chunk level) any input to the
layer that is further to the left than the enforced limit.
At inference time, we save and reinject that left context as necessary.
This works because, due to the chunking, no past chunk can ever depend on a
future chunk. Additionally, since we do this at each layer, the earliest frame
of the left context of the current layer depends on the earliest frame of the
left context of the previous layer, and so on.
This effectively means that the "receptive field" of the model can be very wide
with not-so-high left context sizes, even though we only need to save that
fairly small context.
Note that this does not necessarily mean that the model makes meaningful use of
information that goes back many seconds.
Understanding of the Conformer paper may prove helpful here.
The Conformer model is composed of stacked conformer blocks, each of which is composed of feed-forward modules, but especially of a multi-head self-attention module followed by the convolution module. The streaming changes relevant to the self-attention were described above, and those relevant to the convolution module comes now.
Usually, in the convolution module, you would support streaming by using a
causal convolution, which shifts the output so that the n
-th output frame
would only depend on the input frames from n-kernel_size
to n
(instead of
n-((kernel_size-1)//2)
to n+((kernel_size-1)//2)
). This
generally comes at a significant accuracy cost, even for the non-streaming case.
The approach taken by the DCConv paper, described in this Amazon paper (https://assets.amazon.science/18/80/2126d1f5416aa7143505694ae013/dynamic-chunk-convolution-for-unified-streaming-and-non-streaming-conformer-asr.pdf), is to instead mask frames that come after the current chunk (see Fig. 2). This happens to make the convolution identical to normal in a non-streaming context, and performs pretty well in a streaming context.
The implementation for this is slightly trickier as we actually have to
functionally split the chunks that are fed to the convolution.
We could simply split the tensors into a list of chunks and perform calculations
this way, but this adds a lot of overhead, so we instead pack all those chunks
to the batch dimension and concatenate them later. Doing things this way makes
training roughly as fast as when not streaming.
We use a feature extraction module at the start of the model which combines a
filter bank (fbank) extractor, some normalization and a downsampling CNN, which
reduces the frame rate by a factor of //4
vs the frame rate at the output of
the fbank extractor.
Changes to feature extraction were not really necessary for the training of a streaming model in this case. We, however, need to understand the properties of this part of the model to know how many frames to feed a chunk exactly.
To do this, we can consider the feature extraction module as a filter, with a specific window size and stride.
With a kernel size of 3, a stride of 2, and two stacked layers, the downsampling CNN has an effective kernel size of 7 and a stride of 4.
The following function then calculates the same for feature extraction as a whole:
def get_filter_properties():
# FIXME: this should not be hardcoded
sample_rate = 16000 # Hz
frames_per_ms = sample_rate // 1000
# win_length (32 in hparams) and hop_length (defaults to 10) respectively,
# specified in milliseconds
fbank_win_size_ms = 32
fbank_stride_ms = 10
fbank_win_size_frames = fbank_win_size_ms * frames_per_ms
fbank_stride_frames = fbank_stride_ms * frames_per_ms
# the configuration of the downsampling CNN has the following properties.
# we express them in "fbank frames" as it consumes the fbanks
# this is determined from its architecture (window size, stride, layers)
# NOTE: ideally i feel like we should have a way to poll these properties
# from the model directly, or at least provide a mechanism to compute
# combined filter properties (see downcnn_*_frames below)
downcnn_win_size_fbanks = 7
downcnn_stride_fbanks = 4
# we can consider the fbank+featcnn combination as a filter and thus we can
# determine the properties of the whole filter
downcnn_win_size_frames = (
fbank_win_size_frames
+ (fbank_stride_frames * (downcnn_win_size_fbanks - 1))
)
downcnn_stride_frames = fbank_stride_frames * downcnn_stride_fbanks
return FilterProperties(
window_size_frames=downcnn_win_size_frames,
stride_frames=downcnn_stride_frames
)
The result is a window size of 92ms and a stride of 40ms.
In order to avoid a discrepancy between training and time and inference time, feature extraction needs some care (preserving left and some fixed right context). The code is available, and the logic is complicated and wasn't thoroughly checked for correctness and effect on performance.
I found from limited testing that it roughly seems to work even without doing this.
This is still TODO, but if I don't get around it, this might deserve a better look.
Currently, only greedy search (GS) is implemented, as it is the easiest to implement for streaming. Here, we only really need to do two things:
- Saving the decoder context (here, a LSTM's) to reuse for subsequent calls, which is fairly straightforward as it is fully unidirectional and only depends on the last hidden state
- Adapt the token decoding (token ID -> string) that is done by SentencePiece.
The reason why we need to do the latter is the way that SentencePiece tokenizers
actually deal with whitespaces. If a token begins with
the '▁'
symbol, then SentencePiece will emit a space only if it is not the
first token.
This makes sense generally, but as we're streaming, the "first" token might just
be in the middle of a sentence which SentencePiece cannot see. Thus, we work
around the issue by peeking at the first token and inserting a space if needed.
We let encoders optionally define a make_streaming_context
method if they
support streaming. A streaming context is a simple mutable dataclass that holds
streaming metadata (chunk size, etc.) and the tensors that need to be saved at
inference time, due to the techniques outlined previously.
This is done fairly generically so not all TransformerASR
encoders actually
need to support streaming contexts, but can do so if someone desires to
implement such a model for a particular setup.
The streaming encoding/forward/etc. methods take the context as a parameter and
will update it, and forward any sub-context required for the operation of e.g.
an encoder (like the ConformerEncoder
, which in turns holds
ConformerEncoderLayer
streaming contexts).
The same context object passed to these methods, which is initially blank when
initialized by make_streaming_context
, gets updated by them.
NOTE: This issue was resolved. It was caused by a mismatch in the conformer convolution streaming code path. This required a model retrain.
Using a chunk size of 16 and a left context of 32, in real streaming mode, we obtain 4.69% WER on test-clean which is way too high compared to emulation (i.e. masking) where we get 3.46% with half the chunk size.
Several things:
- Minor: The chunk size calculation or splitting logic appears to be very slightly incorrect. Sometimes, the expected chunk count and the calculated chunk count are off-by-one but this is rare. Additionally, calculating features for the entire audio and splitting those features instead of splitting the audio and calculating features yields a very slight WER difference (4.67% vs 4.69%). This might not need immediate attention.
- Maybe something is acting unexpectedly inside the RelMHAXL, especially regarding the use of positional embeddings. To my understanding, theoretically, this should not be the case.
- My assumption is that the MHAXL doesn't need any sort of change to account for the offset within the stream in streaming code and the code is written as such.
- The left context seems to be inserted correctly for the MHAXL and the DCConv, and for the latter, padding seems to match.
- Some of the normalization inside the model might behave differently, but it seems surprising that this would have such a dramatic effect.
- Greedy decoding yields identical results between the streaming mode and the non-streaming mode, so it is not the culprit.
I believe that for debugging this issue one would need to really carefully observe what difference there is between emulating streaming and actually doing streaming at each layer in order to find where the discrepancy occurs.
Hopefully that isn't the case as it would require a model retrain, but maybe the masking (i.e. streaming "emulation") is wrong as well. Not very easy to debug, however.
NOTE: The RTF figures here are not final since there were some adjustements.
Using a single CPU thread on an AMD Ryzen 9 3900X
(Zen 2), a chunk size of 8
(~700ms latency, TODO check) and left context of 32, a 6x speed (~0.16 RTF)
could be achieved (using a single stream, including decoding).
This RTF is currently not very good compared to certain other options and could
be a future area of improvement. Note that this is on a full fp32 model running
in eager PyTorch mode. Also note that this was tested on a single test audio,
more accurate RTF measurements should be done on test-clean
/test-other
.
Please take these figures with a grain of salt.
At the time of writing, a proper inference server to serve transcription over the network has not be written.
Anecdotally, we found that using more PyTorch threads did not result in better performance on that system for CPU decoding, even when increasing the number of batches (which is more complicated to set up).
GPU decoding was only briefly explored, but batching was hugely more beneficial
there. It is unclear how much of that performance can be achieved without
batching and exclusively through multithreading/multiprocessing.
On a 31s test audio, on a 2080 Ti, using a chunk size of 8 and left context of
16:
- batch size = 256, autocast on: RTF=0.0011
- batch size = 256, autocast off: RTF=0.0014
- batch size = 128, autocast on: RTF=0.0012
- batch size = 128, autocast off: RTF=0.0015
- batch size = 64, autocast on: RTF=0.0016
- batch size = 64, autocast off: RTF=0.0017
- batch size = 64, chunk size=+inf: RTF=
OutOfMemoryError
- batch size = 16, chunk size=+inf, autocast on: RTF=0.0015
You can archieve even higher RTF if you can allow higher latency and you only
care about batched transcription: with batch size = 256, autocast and a chunk
size of 32 (left context=16), a RTF of 0.00086
is achievable on that GPU.
You could also explore running a thread pool of two with half the batch size, with both threads consuming the stream of inferences. This might result into better performance while the other thread is busy CPU-side, for example.
For an inference server on CPU, we would recommend disabling PyTorch's
parallelism using torch.set_num_threads
and then using a thread or process pool to concurrently
process requests and not bother with batching, as processing each request this
way will scale much better than relying on PyTorch's own parallelism.
Performance using a thread pool should globally be acceptable despite the
GIL, because PyTorch's
functions will generally release it.
A process pool might scale better on a high number of cores, but this option was
not explored.
Currently, batching is tricky as you may want to interrupt or reset streams. This is not yet well supported.