1 Forced Alignment with Wav2Vec2

  • Originally written by Moto Hira
  • Modified by Jalal Al-Tamimi (September 2025) to include saving outputs to Praat TextGrids Author: Moto Hira <moto@meta.com>__

==============================

This tutorial shows how to align transcript to speech with torchaudio, using CTC segmentation algorithm described in CTC-Segmentation of Large Corpora for German End-to-end Speech Recognition <https://arxiv.org/abs/2007.09127>__.

note:

This tutorial was originally written to illustrate a usecase for Wav2Vec2 pretrained model.

TorchAudio now has a set of APIs designed for forced alignment. The CTC forced alignment API tutorial <./ctc_forced_alignment_api_tutorial.html>__ illustrates the usage of :py:func:torchaudio.functional.forced_align, which is the core API.

If you are looking to align your corpus, we recommend to use :py:class:torchaudio.pipelines.Wav2Vec2FABundle, which combines :py:func:~torchaudio.functional.forced_align and other support functions with pre-trained model specifically trained for forced-alignment. Please refer to the Forced alignment for multilingual data <forced_alignment_for_multilingual_data_tutorial.html>__ which illustrates its usage.

2 install required modules

If you are using a system with CPU only (e.g., not CUDA enabled), then follow this link to install a CPU only version of torchaudio: https://pytorch.org/get-started/previous-versions/#v201. To do so, uncomment the second line below and run it. And restart the kernel after installation.

In addition, on windows, you might need to install ffmpeg separately. Please refer to the official ffmpeg installation guide: https://ffmpeg.org/download.html. After installing, make sure to have added ffmpeg to your system path.

# General Python packages
## !pip install IPython matplotlib textgrid pandas

## !pip install torch torchaudio torchvision 

# CPU specific Python packages
## !pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cpu

## After restarting the kernell, you need to run the code below to upgrade numpy
## !pip install numpy --upgrade
import torch
import torchaudio

print(torch.__version__)
## 2.7.0+cpu
print(torchaudio.__version__)
## 2.7.0+cpu
torch.random.manual_seed(0)
## <torch._C.Generator object at 0x00000182AAC97850>
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)
## cpu

3 Overview

The process of alignment looks like the following.

  1. Estimate the frame-wise label probability from audio waveform
  2. Generate the trellis matrix which represents the probability of labels aligned at time step.
  3. Find the most likely path from the trellis matrix.

In this example, we use torchaudio ’s Wav2Vec2 model for acoustic feature extraction.

4 Preparation

First we import the necessary packages, and fetch data that we work on.

from dataclasses import dataclass
import IPython
import matplotlib.pyplot as plt
from torchaudio.utils import download_asset

SPEECH_FILE = download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
# here is the address to download the audio file:
# https://download.pytorch.org/torchaudio/tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav

5 Generate frame-wise label probability

The first step is to generate the label class porbability of each audio frame. We can use a Wav2Vec2 model that is trained for ASR. Here we use :py:func:torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H.

torchaudio provides easy access to pretrained models with associated labels.

.. note::

In the subsequent sections, we will compute the probability in log-domain to avoid numerical instability. For this purpose, we normalize the emission with :py:func:torch.log_softmax.

bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)
labels = bundle.get_labels()
with torch.inference_mode():
    waveform, _ = torchaudio.load(SPEECH_FILE)
    emissions, _ = model(waveform.to(device))
    emissions = torch.log_softmax(emissions, dim=-1)
emission = emissions[0].cpu().detach()

print(labels)
## ('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')

5.1 Visualization

def plot():
    fig, ax = plt.subplots(figsize=(15, 10))
    img = ax.imshow(emission.T)
    ax.set_title("Frame-wise class probability")
    ax.set_xlabel("Time")
    ax.set_ylabel("Labels")
    fig.colorbar(img, ax=ax, shrink=0.6, location="bottom")
    fig.tight_layout()
    fig.savefig("pyplot.png")
    plt.close(fig)

plot()

6 Generate alignment probability (trellis)

From the emission matrix, next we generate the trellis which represents the probability of transcript labels occur at each time frame.

Trellis is 2D matrix with time axis and label axis. The label axis represents the transcript that we are aligning. In the following, we use :math:t to denote the index in time axis and :math:j to denote the index in label axis. :math:c_j represents the label at label index :math:j.

To generate, the probability of time step :math:t+1, we look at the trellis from time step :math:t and emission at time step :math:t+1. There are two path to reach to time step :math:t+1 with label :math:c_{j+1}. The first one is the case where the label was :math:c_{j+1} at :math:t and there was no label change from :math:t to :math:t+1. The other case is where the label was :math:c_j at :math:t and it transitioned to the next label :math:c_{j+1} at :math:t+1.

The following diagram illustrates this transition.

https://download.pytorch.org/torchaudio/tutorial-assets/ctc-forward.png https://download.pytorch.org/torchaudio/tutorial-assets/ctc-forward.png

Since we are looking for the most likely transitions, we take the more likely path for the value of :math:k_{(t+1, j+1)}, that is

:math:k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1, repeat) )

where :math:k represents is trellis matrix, and :math:p(t, c_j) represents the probability of label :math:c_j at time step :math:t. :math:repeat represents the blank token from CTC formulation. (For the detail of CTC algorithm, please refer to the Sequence Modeling with CTC [distill.pub <https://distill.pub/2017/ctc/>__])

We enclose the transcript with space tokens, which represent SOS and EOS.

transcript = "|I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"
dictionary = {c: i for i, c in enumerate(labels)}

tokens = [dictionary[c] for c in transcript]
print(list(zip(transcript, tokens)))
## [('|', 1), ('I', 7), ('|', 1), ('H', 8), ('A', 4), ('D', 11), ('|', 1), ('T', 3), ('H', 8), ('A', 4), ('T', 3), ('|', 1), ('C', 16), ('U', 13), ('R', 10), ('I', 7), ('O', 5), ('S', 9), ('I', 7), ('T', 3), ('Y', 19), ('|', 1), ('B', 21), ('E', 2), ('S', 9), ('I', 7), ('D', 11), ('E', 2), ('|', 1), ('M', 14), ('E', 2), ('|', 1), ('A', 4), ('T', 3), ('|', 1), ('T', 3), ('H', 8), ('I', 7), ('S', 9), ('|', 1), ('M', 14), ('O', 5), ('M', 14), ('E', 2), ('N', 6), ('T', 3), ('|', 1)]


def get_trellis(emission, tokens, blank_id=0):
    num_frame = emission.size(0)
    num_tokens = len(tokens)

    trellis = torch.zeros((num_frame, num_tokens))
    trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
    trellis[0, 1:] = -float("inf")
    trellis[-num_tokens + 1 :, 0] = float("inf")

    for t in range(num_frame - 1):
      trellis[t + 1, 1:] = torch.maximum(
        #Score for staying at the same token
        trellis[t, 1:] + emission[t, blank_id],
        #Score for changing to the next token
        trellis[t, :-1] + emission[t, tokens[1:]],
        )
    return trellis


trellis = get_trellis(emission, tokens)

6.1 Visualization

def plot():
    fig, ax = plt.subplots(figsize=(15, 20))
    img = ax.imshow(trellis.T, origin="lower")
    ax.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
    ax.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
    fig.colorbar(img, ax=ax, shrink=0.6, location="bottom")
    fig.tight_layout()
    fig.savefig("pyplot2.png")
    plt.close(fig)

plot()

In the above visualization, we can see that there is a trace of high probability crossing the matrix diagonally.

6.2 Find the most likely path (backtracking)

Once the trellis is generated, we will traverse it following the elements with high probability.

We will start from the last label index with the time step of highest probability, then, we traverse back in time, picking stay (:math:c_j \rightarrow c_j) or transition (:math:c_j \rightarrow c_{j+1}), based on the post-transition probability :math:k_{t, j} p(t+1, c_{j+1}) or :math:k_{t, j+1} p(t+1, repeat).

Transition is done once the label reaches the beginning.

The trellis matrix is used for path-finding, but for the final probability of each segment, we take the frame-wise probability from emission matrix.

@dataclass
class Point:
    token_index: int
    time_index: int
    score: float
    
def backtrack(trellis, emission, tokens, blank_id=0):
    t, j = trellis.size(0) - 1, trellis.size(1) - 1

    path = [Point(j, t, emission[t, blank_id].exp().item())]
    while j > 0:
        #Should not happen but just in case
        assert t > 0

        #1. Figure out if the current position was stay or change
        #Frame-wise score of stay vs change
        p_stay = emission[t - 1, blank_id]
        p_change = emission[t - 1, tokens[j]]

        #Context-aware score for stay vs change
        stayed = trellis[t - 1, j] + p_stay
        changed = trellis[t - 1, j - 1] + p_change

        #Update position
        t -= 1
        if changed > stayed:
            j -= 1

        #Store the path with frame-wise probability.
        prob = (p_change if changed > stayed else p_stay).exp().item()
        path.append(Point(j, t, prob))

    #Now j == 0, which means, it reached the SoS.
    #Fill up the rest for the sake of visualization
    while t > 0:
        prob = emission[t - 1, blank_id].exp().item()
        path.append(Point(j, t - 1, prob))
        t -= 1

    return path[::-1]


path = backtrack(trellis, emission, tokens)
for p in path:
    print(p)
## Point(token_index=0, time_index=0, score=0.9999998807907104)
## Point(token_index=0, time_index=1, score=0.9999998807907104)
## Point(token_index=0, time_index=2, score=0.9999998807907104)
## Point(token_index=0, time_index=3, score=0.9999998807907104)
## Point(token_index=0, time_index=4, score=0.9999995231628418)
## Point(token_index=0, time_index=5, score=0.9999995231628418)
## Point(token_index=0, time_index=6, score=0.9999995231628418)
## Point(token_index=0, time_index=7, score=0.9999995231628418)
## Point(token_index=0, time_index=8, score=1.0)
## Point(token_index=0, time_index=9, score=0.9999996423721313)
## Point(token_index=0, time_index=10, score=0.9999995231628418)
## Point(token_index=0, time_index=11, score=1.0)
## Point(token_index=0, time_index=12, score=0.9999995231628418)
## Point(token_index=0, time_index=13, score=0.9999997615814209)
## Point(token_index=0, time_index=14, score=0.9999995231628418)
## Point(token_index=0, time_index=15, score=0.9999995231628418)
## Point(token_index=0, time_index=16, score=0.9999997615814209)
## Point(token_index=0, time_index=17, score=0.9999995231628418)
## Point(token_index=0, time_index=18, score=1.0)
## Point(token_index=0, time_index=19, score=0.9999996423721313)
## Point(token_index=0, time_index=20, score=0.9999995231628418)
## Point(token_index=0, time_index=21, score=0.9999997615814209)
## Point(token_index=0, time_index=22, score=0.9999995231628418)
## Point(token_index=0, time_index=23, score=0.9999998807907104)
## Point(token_index=0, time_index=24, score=1.0)
## Point(token_index=0, time_index=25, score=1.0)
## Point(token_index=0, time_index=26, score=1.0)
## Point(token_index=0, time_index=27, score=1.0)
## Point(token_index=0, time_index=28, score=0.9999984502792358)
## Point(token_index=0, time_index=29, score=0.9999943971633911)
## Point(token_index=0, time_index=30, score=0.9999842643737793)
## Point(token_index=1, time_index=31, score=0.9847091436386108)
## Point(token_index=1, time_index=32, score=0.9999707937240601)
## Point(token_index=1, time_index=33, score=0.15399916470050812)
## Point(token_index=1, time_index=34, score=0.9999173879623413)
## Point(token_index=2, time_index=35, score=0.6080849766731262)
## Point(token_index=2, time_index=36, score=0.9997718930244446)
## Point(token_index=3, time_index=37, score=0.999713122844696)
## Point(token_index=3, time_index=38, score=0.9999357461929321)
## Point(token_index=4, time_index=39, score=0.9861576557159424)
## Point(token_index=4, time_index=40, score=0.9238603711128235)
## Point(token_index=5, time_index=41, score=0.9257326722145081)
## Point(token_index=5, time_index=42, score=0.015660803765058517)
## Point(token_index=5, time_index=43, score=0.9998378753662109)
## Point(token_index=6, time_index=44, score=0.9988442659378052)
## Point(token_index=7, time_index=45, score=0.10145310312509537)
## Point(token_index=7, time_index=46, score=0.9999426603317261)
## Point(token_index=8, time_index=47, score=0.9999946355819702)
## Point(token_index=8, time_index=48, score=0.997960090637207)
## Point(token_index=9, time_index=49, score=0.03603268042206764)
## Point(token_index=9, time_index=50, score=0.06164022535085678)
## Point(token_index=9, time_index=51, score=4.3323558202246204e-05)
## Point(token_index=10, time_index=52, score=0.9999802112579346)
## Point(token_index=11, time_index=53, score=0.9967091083526611)
## Point(token_index=11, time_index=54, score=0.9999257326126099)
## Point(token_index=11, time_index=55, score=0.9999982118606567)
## Point(token_index=12, time_index=56, score=0.9990690350532532)
## Point(token_index=12, time_index=57, score=0.9999996423721313)
## Point(token_index=12, time_index=58, score=0.9999996423721313)
## Point(token_index=12, time_index=59, score=0.8457302451133728)
## Point(token_index=12, time_index=60, score=0.9999996423721313)
## Point(token_index=13, time_index=61, score=0.9996013045310974)
## Point(token_index=13, time_index=62, score=0.999998927116394)
## Point(token_index=14, time_index=63, score=0.0035265316255390644)
## Point(token_index=14, time_index=64, score=1.0)
## Point(token_index=14, time_index=65, score=1.0)
## Point(token_index=14, time_index=66, score=0.9999914169311523)
## Point(token_index=15, time_index=67, score=0.9971597194671631)
## Point(token_index=15, time_index=68, score=0.9999991655349731)
## Point(token_index=15, time_index=69, score=0.9999992847442627)
## Point(token_index=15, time_index=70, score=0.9999998807907104)
## Point(token_index=15, time_index=71, score=0.9999998807907104)
## Point(token_index=15, time_index=72, score=0.9999881982803345)
## Point(token_index=15, time_index=73, score=0.011429330334067345)
## Point(token_index=15, time_index=74, score=0.9999977350234985)
## Point(token_index=16, time_index=75, score=0.999613344669342)
## Point(token_index=16, time_index=76, score=0.999998927116394)
## Point(token_index=16, time_index=77, score=0.972746729850769)
## Point(token_index=16, time_index=78, score=0.9999988079071045)
## Point(token_index=17, time_index=79, score=0.9949318766593933)
## Point(token_index=17, time_index=80, score=0.999998927116394)
## Point(token_index=17, time_index=81, score=0.9999121427536011)
## Point(token_index=17, time_index=82, score=0.9999774694442749)
## Point(token_index=18, time_index=83, score=0.6577526926994324)
## Point(token_index=18, time_index=84, score=0.9984303116798401)
## Point(token_index=18, time_index=85, score=0.9999874830245972)
## Point(token_index=19, time_index=86, score=0.9993744492530823)
## Point(token_index=19, time_index=87, score=0.9999988079071045)
## Point(token_index=19, time_index=88, score=0.10427592694759369)
## Point(token_index=19, time_index=89, score=0.9999967813491821)
## Point(token_index=20, time_index=90, score=0.3978538513183594)
## Point(token_index=20, time_index=91, score=0.9999933242797852)
## Point(token_index=21, time_index=92, score=1.698439064057311e-06)
## Point(token_index=21, time_index=93, score=0.9861314296722412)
## Point(token_index=21, time_index=94, score=0.9999960660934448)
## Point(token_index=22, time_index=95, score=0.9992737174034119)
## Point(token_index=22, time_index=96, score=0.9993410706520081)
## Point(token_index=22, time_index=97, score=0.9999983310699463)
## Point(token_index=23, time_index=98, score=0.9999971389770508)
## Point(token_index=23, time_index=99, score=0.9999997615814209)
## Point(token_index=23, time_index=100, score=0.9999995231628418)
## Point(token_index=23, time_index=101, score=0.9999732971191406)
## Point(token_index=24, time_index=102, score=0.998322069644928)
## Point(token_index=24, time_index=103, score=0.9999991655349731)
## Point(token_index=24, time_index=104, score=0.9999997615814209)
## Point(token_index=24, time_index=105, score=1.0)
## Point(token_index=24, time_index=106, score=1.0)
## Point(token_index=24, time_index=107, score=0.9998630285263062)
## Point(token_index=24, time_index=108, score=0.9999980926513672)
## Point(token_index=25, time_index=109, score=0.9988586902618408)
## Point(token_index=25, time_index=110, score=0.9999797344207764)
## Point(token_index=26, time_index=111, score=0.8573068976402283)
## Point(token_index=26, time_index=112, score=0.999984860420227)
## Point(token_index=27, time_index=113, score=0.9870262742042542)
## Point(token_index=27, time_index=114, score=1.9047982277697884e-05)
## Point(token_index=27, time_index=115, score=0.9999794960021973)
## Point(token_index=28, time_index=116, score=0.9998255372047424)
## Point(token_index=28, time_index=117, score=0.9999990463256836)
## Point(token_index=29, time_index=118, score=0.9999734163284302)
## Point(token_index=29, time_index=119, score=0.0009004927123896778)
## Point(token_index=29, time_index=120, score=0.9993483424186707)
## Point(token_index=30, time_index=121, score=0.9975456595420837)
## Point(token_index=30, time_index=122, score=0.0003051277599297464)
## Point(token_index=30, time_index=123, score=0.9999344348907471)
## Point(token_index=31, time_index=124, score=6.079461854824331e-06)
## Point(token_index=31, time_index=125, score=0.9833155870437622)
## Point(token_index=32, time_index=126, score=0.9974578022956848)
## Point(token_index=33, time_index=127, score=0.0008234027773141861)
## Point(token_index=33, time_index=128, score=0.9965150356292725)
## Point(token_index=34, time_index=129, score=0.017463643103837967)
## Point(token_index=34, time_index=130, score=0.9989169836044312)
## Point(token_index=35, time_index=131, score=0.9999697208404541)
## Point(token_index=36, time_index=132, score=0.9999842643737793)
## Point(token_index=36, time_index=133, score=0.9997640252113342)
## Point(token_index=37, time_index=134, score=0.5096980929374695)
## Point(token_index=37, time_index=135, score=0.9998302459716797)
## Point(token_index=38, time_index=136, score=0.08524630963802338)
## Point(token_index=38, time_index=137, score=0.004073922522366047)
## Point(token_index=38, time_index=138, score=0.9999815225601196)
## Point(token_index=39, time_index=139, score=0.01204877533018589)
## Point(token_index=39, time_index=140, score=0.9999979734420776)
## Point(token_index=39, time_index=141, score=0.0005776838515885174)
## Point(token_index=39, time_index=142, score=0.9999066591262817)
## Point(token_index=40, time_index=143, score=0.9999960660934448)
## Point(token_index=40, time_index=144, score=0.9999980926513672)
## Point(token_index=40, time_index=145, score=0.9999915361404419)
## Point(token_index=41, time_index=146, score=0.9971170425415039)
## Point(token_index=41, time_index=147, score=0.9981802701950073)
## Point(token_index=41, time_index=148, score=0.9999310970306396)
## Point(token_index=42, time_index=149, score=0.9879518151283264)
## Point(token_index=42, time_index=150, score=0.9997628331184387)
## Point(token_index=42, time_index=151, score=0.9999533891677856)
## Point(token_index=43, time_index=152, score=0.9999715089797974)
## Point(token_index=44, time_index=153, score=0.31862306594848633)
## Point(token_index=44, time_index=154, score=0.999782145023346)
## Point(token_index=45, time_index=155, score=0.016034986823797226)
## Point(token_index=45, time_index=156, score=0.999901294708252)
## Point(token_index=46, time_index=157, score=0.46723461151123047)
## Point(token_index=46, time_index=158, score=0.9999995231628418)
## Point(token_index=46, time_index=159, score=0.9999996423721313)
## Point(token_index=46, time_index=160, score=0.9999996423721313)
## Point(token_index=46, time_index=161, score=0.9999996423721313)
## Point(token_index=46, time_index=162, score=0.9999996423721313)
## Point(token_index=46, time_index=163, score=0.9999996423721313)
## Point(token_index=46, time_index=164, score=0.9999995231628418)
## Point(token_index=46, time_index=165, score=0.9999995231628418)
## Point(token_index=46, time_index=166, score=0.9999995231628418)
## Point(token_index=46, time_index=167, score=0.9999995231628418)
## Point(token_index=46, time_index=168, score=0.9999995231628418)

6.3 Visualization

def plot_trellis_with_path(trellis, path):
    #To plot trellis with path, we take advantage of 'nan' value
    trellis_with_path = trellis.clone()
    for _, p in enumerate(path):
        trellis_with_path[p.time_index, p.token_index] = float("nan")
    plt.imshow(trellis_with_path.T, origin="lower")
    plt.title("The path found by backtracking")
    plt.tight_layout()

plot_trellis_with_path(trellis, path)

Looking good.

6.4 Segment the path

Now this path contains repetations for the same labels, so let’s merge them to make it close to the original transcript.

When merging the multiple path points, we simply take the average probability for the merged segments.

6.5 Merge the labels

@dataclass
class Segment:
    label: str
    start: int
    end: int
    score: float

    def __repr__(self):
        return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"

    @property
    def length(self):
        return self.end - self.start


def merge_repeats(path):
    i1, i2 = 0, 0
    segments = []
    while i1 < len(path):
        while i2 < len(path) and path[i1].token_index == path[i2].token_index:
            i2 += 1
        score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
        segments.append(
            Segment(
                transcript[path[i1].token_index],
                path[i1].time_index,
                path[i2 - 1].time_index + 1,
                score,
            )
        )
        i1 = i2
    return segments


segments = merge_repeats(path)
for seg in segments:
    print(seg)
## |    (1.00): [    0,    31)
## I    (0.78): [   31,    35)
## |    (0.80): [   35,    37)
## H    (1.00): [   37,    39)
## A    (0.96): [   39,    41)
## D    (0.65): [   41,    44)
## |    (1.00): [   44,    45)
## T    (0.55): [   45,    47)
## H    (1.00): [   47,    49)
## A    (0.03): [   49,    52)
## T    (1.00): [   52,    53)
## |    (1.00): [   53,    56)
## C    (0.97): [   56,    61)
## U    (1.00): [   61,    63)
## R    (0.75): [   63,    67)
## I    (0.88): [   67,    75)
## O    (0.99): [   75,    79)
## S    (1.00): [   79,    83)
## I    (0.89): [   83,    86)
## T    (0.78): [   86,    90)
## Y    (0.70): [   90,    92)
## |    (0.66): [   92,    95)
## B    (1.00): [   95,    98)
## E    (1.00): [   98,   102)
## S    (1.00): [  102,   109)
## I    (1.00): [  109,   111)
## D    (0.93): [  111,   113)
## E    (0.66): [  113,   116)
## |    (1.00): [  116,   118)
## M    (0.67): [  118,   121)
## E    (0.67): [  121,   124)
## |    (0.49): [  124,   126)
## A    (1.00): [  126,   127)
## T    (0.50): [  127,   129)
## |    (0.51): [  129,   131)
## T    (1.00): [  131,   132)
## H    (1.00): [  132,   134)
## I    (0.75): [  134,   136)
## S    (0.36): [  136,   139)
## |    (0.50): [  139,   143)
## M    (1.00): [  143,   146)
## O    (1.00): [  146,   149)
## M    (1.00): [  149,   152)
## E    (1.00): [  152,   153)
## N    (0.66): [  153,   155)
## T    (0.51): [  155,   157)
## |    (0.96): [  157,   169)

6.6 Visualization

def plot_trellis_with_segments(trellis, segments, transcript):
    # To plot trellis with path, we take advantage of 'nan' value
    trellis_with_path = trellis.clone()
    for i, seg in enumerate(segments):
        if seg.label != "|":
            trellis_with_path[seg.start : seg.end, i] = float("nan")

    fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True)
    ax1.set_title("Path, label and probability for each label")
    ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")

    for i, seg in enumerate(segments):
        if seg.label != "|":
            ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
            ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")

    ax2.set_title("Label probability with and without repetation")
    xs, hs, ws = [], [], []
    for seg in segments:
        if seg.label != "|":
            xs.append((seg.end + seg.start) / 2 + 0.4)
            hs.append(seg.score)
            ws.append(seg.end - seg.start)
            ax2.annotate(seg.label, (seg.start + 0.8, -0.07))
    ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")

    xs, hs = [], []
    for p in path:
        label = transcript[p.token_index]
        if label != "|":
            xs.append(p.time_index + 1)
            hs.append(p.score)

    ax2.bar(xs, hs, width=0.5, alpha=0.5)
    ax2.axhline(0, color="black")
    ax2.grid(True, axis="y")
    ax2.set_ylim(-0.1, 1.1)
    fig.tight_layout()


plot_trellis_with_segments(trellis, segments, transcript)

Looks good.

6.7 Merge the segments into words

Now let’s merge the words. The Wav2Vec2 model uses '|' as the word boundary, so we merge the segments before each occurence of '|'.

Then, finally, we segment the original audio into segmented audio and listen to them to see if the segmentation is correct.

7 Merge words

def merge_words(segments, separator="|"):
    words = []
    i1, i2 = 0, 0
    while i1 < len(segments):
        if i2 >= len(segments) or segments[i2].label == separator:
            if i1 != i2:
                segs = segments[i1:i2]
                word = "".join([seg.label for seg in segs])
                score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
                words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
            i1 = i2 + 1
            i2 = i1
        else:
            i2 += 1
    return words


word_segments = merge_words(segments)
for word in word_segments:
    print(word)
## I    (0.78): [   31,    35)
## HAD  (0.84): [   37,    44)
## THAT (0.52): [   45,    53)
## CURIOSITY    (0.89): [   56,    92)
## BESIDE   (0.94): [   95,   116)
## ME   (0.67): [  118,   124)
## AT   (0.66): [  126,   129)
## THIS (0.70): [  131,   139)
## MOMENT   (0.88): [  143,   157)

7.1 Visualization

def plot_alignments(trellis, segments, word_segments, waveform, sample_rate=bundle.sample_rate):
    trellis_with_path = trellis.clone()
    for i, seg in enumerate(segments):
        if seg.label != "|":
            trellis_with_path[seg.start : seg.end, i] = float("nan")

    fig, [ax1, ax2] = plt.subplots(2, 1)

    ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
    ax1.set_facecolor("lightgray")
    ax1.set_xticks([])
    ax1.set_yticks([])

    for word in word_segments:
        ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none")

    for i, seg in enumerate(segments):
        if seg.label != "|":
            ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
            ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")

    # The original waveform
    ratio = waveform.size(0) / sample_rate / trellis.size(0)
    ax2.specgram(waveform, Fs=sample_rate)
    for word in word_segments:
        x0 = ratio * word.start
        x1 = ratio * word.end
        ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/")
        ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False)

    for seg in segments:
        if seg.label != "|":
            ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False)
    ax2.set_xlabel("time [second]")
    ax2.set_yticks([])
    fig.tight_layout()


plot_alignments(
    trellis,
    segments,
    word_segments,
    waveform[0],
)

8 Audio Samples

8.1 Generate the audio for full sample

print(transcript)
## |I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|
IPython.display.Audio(SPEECH_FILE)
## <IPython.lib.display.Audio object>

8.2 Generate the audio for each segment

def display_segment(i):
    ratio = waveform.size(1) / trellis.size(0)
    word = word_segments[i]
    x0 = int(ratio * word.start)
    x1 = int(ratio * word.end)
    print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec")
    segment = waveform[:, x0:x1]
    return IPython.display.Audio(segment.numpy(), rate=bundle.sample_rate)
  
display_segment(0)
## I (0.78): 0.624 - 0.704 sec
## <IPython.lib.display.Audio object>
display_segment(1)
## HAD (0.84): 0.744 - 0.885 sec
## <IPython.lib.display.Audio object>
display_segment(2)
## THAT (0.52): 0.905 - 1.066 sec
## <IPython.lib.display.Audio object>
display_segment(3)
## CURIOSITY (0.89): 1.127 - 1.851 sec
## <IPython.lib.display.Audio object>
display_segment(4)
## BESIDE (0.94): 1.911 - 2.334 sec
## <IPython.lib.display.Audio object>
display_segment(5)
## ME (0.67): 2.374 - 2.495 sec
## <IPython.lib.display.Audio object>
display_segment(6)
## AT (0.66): 2.535 - 2.595 sec
## <IPython.lib.display.Audio object>
display_segment(7)
## THIS (0.70): 2.635 - 2.796 sec
## <IPython.lib.display.Audio object>
display_segment(8)
## MOMENT (0.88): 2.877 - 3.159 sec
## <IPython.lib.display.Audio object>

9 Saving alignment results to csv and TextGrid format

After obtaining forced alignment with wav2vec2, we use the code below to transform the results into a Praat TextGrid. We first save the segments and word segments into a dataframe, then we transform them into a TextGrid format. We do this by separating the segments into two levels: word level and segments level. It is of course possible to combine both into one step, without even saving the results into a csv file. However, the code below is intended for exporting the segments from python into a csv file in case you are planning on using it with another software or for further processing.

9.1 Word level

9.1.1 Generate a dataframe

import pandas as pd
word_segments = pd.DataFrame(word_segments)
word_segments.to_csv("word_segments.csv")

9.1.2 Transform to TextGrid

import csv
import textgrid # install with pip install textgrid
# Load the CSV data
with open("word_segments.csv",
          "r", encoding="utf-8") as f:
    reader = csv.DictReader(f,
                            delimiter="," 
                            )
    data = [row for row in reader]

# Create a TextGrid object
tg = textgrid.TextGrid()

# Create IntervalTier objects
transcript_tier = textgrid.IntervalTier(name="label")
# Populate the interval tiers
for row in data:
  if "label" != "|":
    start_time = (float(row["start"])+(float(row["start"])-0.7))/100
    end_time = (float(row["end"])+(float(row["end"])-0.7))/100
    transcript_tier.add(start_time, end_time, row["label"])

# Add the interval tiers to the TextGrid
tg.append(transcript_tier)

# Write the TextGrid to a file
with open("words.TextGrid", "w", encoding="utf-8") as f:
    tg.write(f)   
    

9.2 segments level

9.2.1 Generate a dataframe

import pandas as pd
segments = pd.DataFrame(segments)
segments2 = segments[segments.label != "|"]
segments2.to_csv("segments.csv")

9.2.2 Transform to TextGrid

import csv
import textgrid # install with pip install textgrid
# Load the CSV data
with open("segments.csv",
          "r", encoding="utf-8") as f:
    reader = csv.DictReader(f,
                            delimiter="," 
                            )
    data = [row for row in reader]

# Create a TextGrid object
tg = textgrid.TextGrid()

# Create IntervalTier objects
transcript_tier = textgrid.IntervalTier(name="label")
# Populate the interval tiers
for row in data:
  start_time = (float(row["start"])+(float(row["start"])-0.7))/100
  end_time = (float(row["end"])+(float(row["end"])-0.7))/100
  transcript_tier.add(start_time, end_time, row["label"])

# Add the interval tiers to the TextGrid
tg.append(transcript_tier)

# Write the TextGrid to a file
with open("segments.TextGrid", "w", encoding="utf-8") as f:
    tg.write(f)   
    

9.3 Merging the two textgrids

import textgrid # install with pip intall textgrid
# Load the TextGrid files
tg_words = textgrid.TextGrid.fromFile("words.TextGrid")
tg_segments = textgrid.TextGrid.fromFile("segments.TextGrid")
# Merge the two TextGrids
tg_merged = textgrid.TextGrid()
# Add the tiers from both TextGrids
for tier in tg_words.tiers:
    tg_merged.append(tier)
for tier in tg_segments.tiers:
    tg_merged.append(tier)
# Write the merged TextGrid to a file
with open("merged.TextGrid", "w", encoding="utf-8") as f:
    tg_merged.write(f)
    

10 Conclusion

In this tutorial, we looked how to use torchaudio’s Wav2Vec2 model to perform CTC segmentation for forced alignment.