Skip to content

Have you ever wondered how to create an AI model that returns an embedding instead of a classification? This notebook is for you ;)

Notifications You must be signed in to change notification settings

Amable-Valdes/AI-classification-model-to-embedding-model

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 

Repository files navigation

How to generate an embedding from a classification deep learning model

This Jupyter notebook will serve as a tutorial for anyone who wants to convert their deep learning model into a feature extractor.

In this example I will:

    1. Use the Wav2vec model to classify one audio.
    1. Use the same model to generate an embedding.

This is extrapolable to any other model, because this is a math method, not an implementation of a library.

First of all, make sure you have the following libraries installed to run this notebook:

import torch
import librosa
from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification

We can start with the code then.

Simple classification with wav2vec

First of all, we are going to load the wav2vec2 model with a classification head (the emotion recognition classificator)

model_to_use = "superb/wav2vec2-base-superb-er"
featureExtractor = AutoFeatureExtractor.from_pretrained(model_to_use)
wav2vec_model = Wav2Vec2ForSequenceClassification.from_pretrained(model_to_use)
/home/amable/Desktop/embeddings/env_audio/lib/python3.12/site-packages/transformers/configuration_utils.py:334: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.
  warnings.warn(

Ok, I have an audio that i would like to classify. Let's classify it:

# 3. Load and preprocess an audio file
audio, sr = librosa.load("audio_example.wav", sr=16000)  # wav2vec2 expects 16kHz
inputs = featureExtractor(audio, sampling_rate=16000, return_tensors="pt", padding=True)

# 4. Pass the audio through the embedding model
with torch.no_grad():
    outputs = wav2vec_model(**inputs)
    
# The model returns logits (before softmax)
logits = outputs.logits

# Predicted class
predicted_class_id = torch.argmax(logits, dim=-1).item()
predicted_label = wav2vec_model.config.id2label[predicted_class_id]

print("Logits:", logits)
print("Predicted class ID:", predicted_class_id)
print("Predicted label:", predicted_label)
Logits: tensor([[ 0.9470, -0.2392,  1.8484, -3.6136]])
Predicted class ID: 2
Predicted label: ang

It seems all normal. We have done a classification of an audio. That's very good.

Now, can we see the model architecture?

print("Original model with classification head:")
print(wav2vec_model)
Original model with classification head:
Wav2Vec2ForSequenceClassification(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2GroupNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
          (activation): GELUActivation()
          (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
        )
        (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (projection): Linear(in_features=512, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): Wav2Vec2Encoder(
      (pos_conv_embed): Wav2Vec2PositionalConvEmbedding(
        (conv): ParametrizedConv1d(
          768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16
          (parametrizations): ModuleDict(
            (weight): ParametrizationList(
              (0): _WeightNorm()
            )
          )
        )
        (padding): Wav2Vec2SamePadLayer()
        (activation): GELUActivation()
      )
      (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (layers): ModuleList(
        (0-11): 12 x Wav2Vec2EncoderLayer(
          (attention): Wav2Vec2Attention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (dropout): Dropout(p=0.1, inplace=False)
          (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (feed_forward): Wav2Vec2FeedForward(
            (intermediate_dropout): Dropout(p=0.0, inplace=False)
            (intermediate_dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
            (output_dense): Linear(in_features=3072, out_features=768, bias=True)
            (output_dropout): Dropout(p=0.1, inplace=False)
          )
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (projector): Linear(in_features=768, out_features=256, bias=True)
  (classifier): Linear(in_features=256, out_features=4, bias=True)
)

You see the last 2 layers?

(projector): Linear(in_features=768, out_features=256, bias=True)

(classifier): Linear(in_features=256, out_features=4, bias=True)

The layer projector will convert the final layer from 768 components to 256 componentes. The same way, we can see classifier is making the final conversion of 256 components to 4 components. The 4 components that compose the response of this model.

Have you ever thought what would happen if we strip off the classification head from the model?

Classification model to Embedding model

We can select only the encoder with wav2vec_model.wav2vec. Torch has other functionalities to remove layers, but this is the easiest way.

# Strip off the classification head
embedding_model = wav2vec_model.wav2vec2

print("Model after removing classification head (embedding extractor):")
print(embedding_model)
Model after removing classification head (embedding extractor):
Wav2Vec2Model(
  (feature_extractor): Wav2Vec2FeatureEncoder(
    (conv_layers): ModuleList(
      (0): Wav2Vec2GroupNormConvLayer(
        (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (activation): GELUActivation()
        (layer_norm): GroupNorm(512, 512, eps=1e-05, affine=True)
      )
      (1-4): 4 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
      (5-6): 2 x Wav2Vec2NoLayerNormConvLayer(
        (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
        (activation): GELUActivation()
      )
    )
  )
  (feature_projection): Wav2Vec2FeatureProjection(
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (projection): Linear(in_features=512, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): Wav2Vec2Encoder(
    (pos_conv_embed): Wav2Vec2PositionalConvEmbedding(
      (conv): ParametrizedConv1d(
        768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16
        (parametrizations): ModuleDict(
          (weight): ParametrizationList(
            (0): _WeightNorm()
          )
        )
      )
      (padding): Wav2Vec2SamePadLayer()
      (activation): GELUActivation()
    )
    (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x Wav2Vec2EncoderLayer(
        (attention): Wav2Vec2Attention(
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.1, inplace=False)
        (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (feed_forward): Wav2Vec2FeedForward(
          (intermediate_dropout): Dropout(p=0.0, inplace=False)
          (intermediate_dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
          (output_dense): Linear(in_features=3072, out_features=768, bias=True)
          (output_dropout): Dropout(p=0.1, inplace=False)
        )
        (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
)

Now we don't have the final layers. Perfect. Let's then pass again the audio to see what happens:

# Pass the audio through the embedding model
with torch.no_grad():
    outputs = embedding_model(**inputs)

# The raw embeddings (before pooling) are in last_hidden_state
hidden_states = outputs.last_hidden_state  # (batch_size, seq_len, hidden_dim)

# Get a fixed-size embedding (x-vector) by mean pooling across time
x_vector = hidden_states.mean(dim=1)  # (batch_size, hidden_dim)

print("X-vector embedding shape:", x_vector.shape)
print("X-vector embedding:", x_vector)
X-vector embedding shape: torch.Size([1, 768])
X-vector embedding: tensor([[ 8.6946e-02,  2.0988e-01,  1.4081e-01,  1.6515e-01,  3.5157e-01,
         -1.2929e-01,  1.0128e-01, -1.7905e-01, -9.6832e-02, -1.5092e-01,
         -2.4084e-01,  5.0127e-01,  1.8352e-01,  3.3696e-01,  2.4817e-01,
         -1.3278e-01,  2.7756e-01, -4.0567e-02, -9.2052e-02, -6.6830e-01,
         -5.1633e-02, -9.8745e-02,  3.2247e-01,  4.2257e-01,  3.3957e-01,
         -1.9503e-01,  2.2896e-01, -1.8060e-01, -2.7804e-02,  1.5503e-01,
         -3.5377e-01, -1.2443e-01, -4.1226e-01, -1.2335e-02, -1.0014e-01,
          7.5494e-02, -1.9365e-01,  1.3100e-01,  9.5487e-02, -1.1347e-01,
         -2.3310e-01, -8.7284e-02, -1.3224e-01, -2.0169e-01, -6.8098e-01,
         -1.9927e-01, -3.1116e-01,  1.3042e-01,  2.6950e-02, -2.0584e-01,
         -8.4703e-02,  2.8150e-01,  4.4262e-01,  1.0848e-02,  2.3769e-01,
         -1.0415e-01,  2.0475e-01, -2.2999e-02,  3.6658e-02, -6.8065e-02,
         -1.5170e-01,  8.9711e-02, -8.0588e-02,  1.2532e-02, -3.7965e-01,
          1.4530e-02, -1.6367e-01, -1.7028e-01, -7.7885e-02,  1.7602e-02,
          8.2119e-02, -1.5848e-01,  2.0984e-01, -3.3667e-02,  8.7700e-02,
          4.1502e-01,  4.2893e-01, -8.5900e-02, -2.0612e-01, -1.7328e-01,
          1.8840e-01, -3.8356e-02,  8.8701e-02, -2.8277e-01, -1.4308e-02,
          3.6897e-01, -4.2337e-02, -5.3423e-02, -2.5247e-01, -4.4887e-02,
         -4.9863e-01, -4.5243e-02, -1.7331e-02, -7.2172e-02, -2.1309e-01,
          2.1192e-03,  4.2654e-01, -2.0050e-01, -7.6087e-02,  3.4827e-03,
         -1.4142e-01, -5.0394e-02, -1.2198e-01, -1.4682e-01,  6.4946e-02,
          2.6498e-01,  1.4300e-01, -3.3160e-02, -1.4586e-01,  2.8439e-01,
         -1.7802e-01,  2.9644e-01,  2.9501e-01,  4.4525e-02, -1.4751e-01,
          1.1933e-01, -5.8670e-02, -2.8313e-01,  8.5155e-02, -2.0909e-01,
         -4.1474e-01,  6.9079e-02,  1.6961e-01, -6.5750e-02, -1.2792e-01,
          1.5040e-01,  3.5729e-01, -9.7014e-02,  4.4670e-01, -1.2755e-01,
          2.3765e-01,  2.0902e-01, -2.0602e-01, -1.6698e-01, -2.2022e-01,
          4.6073e-01, -5.9623e-02,  4.3860e-02,  2.8811e-01, -1.2830e-02,
         -1.1426e-04, -2.4598e-01,  7.9551e-02, -5.6403e-02, -8.0841e-02,
          1.7812e-02, -4.8133e-01,  3.0691e-01, -2.8828e-03,  2.7535e-01,
          1.2745e-01,  3.4564e-01, -2.0960e-02, -4.1170e-01,  4.2027e-02,
         -2.7241e-02, -3.8728e-02, -3.3342e-02, -1.8065e-01,  5.5978e-02,
          7.7837e-02,  3.1362e-02,  1.1818e-02, -5.3871e-02, -1.6361e-01,
         -1.9485e-01,  3.7057e-01,  4.5501e-01,  1.5337e-01, -3.0150e-01,
          1.0495e-01, -8.2195e-02, -1.8760e-01,  7.1806e-02, -7.1468e-02,
          3.4354e-01, -2.9288e-01, -1.2274e-01,  7.8707e-02, -6.8435e-02,
         -3.3413e-01, -7.5221e-02,  1.3046e-02,  7.0665e-03, -2.2640e-01,
         -1.5896e-02, -8.4640e-02,  3.2005e-01, -1.1023e-02, -5.7890e-02,
          3.5550e-01,  3.9788e-02, -9.6779e-02,  6.1224e-02, -1.0636e-01,
          1.8272e-01,  1.7317e-01,  2.0837e-01, -2.1222e-01,  1.6204e-01,
          6.7171e-02, -9.7062e-02, -5.2261e-01, -2.3529e-01,  8.9916e-02,
         -4.1063e-01, -8.2685e-03,  3.5603e-02, -1.2106e-01,  1.4215e-01,
          2.7054e-01, -2.9181e-01, -1.4289e-01, -1.7261e-01, -2.9020e-01,
          5.6692e-02, -3.2900e-01, -2.6791e-01, -9.8399e-02,  9.9202e-03,
          1.6404e-01, -2.9756e-01,  2.8839e-02,  8.5542e-03,  2.2670e-01,
          2.8387e-01,  1.9703e-01,  1.8825e-01,  1.2074e-01, -1.9759e-02,
         -9.0286e-02,  3.8643e-01,  3.9365e-02,  8.2036e-02,  1.4742e-01,
          3.9679e-02,  3.0909e-02,  2.7721e-02,  8.6717e-02, -3.1207e-01,
         -5.2392e-01, -1.5243e-02,  1.2067e-01, -1.9339e-02,  3.8234e-01,
          7.7681e-03, -6.5246e-02,  5.6399e-02,  3.0310e-01,  3.6010e-01,
          8.9212e-02, -1.1380e-01,  9.3853e-02,  1.1759e-01, -1.9539e-02,
         -7.5268e-02,  3.0517e-01,  6.7832e-02,  3.1257e-01,  2.0795e-03,
         -1.6781e-01, -2.0624e-01, -4.9047e-01,  8.9021e-02, -4.8869e-01,
          2.7109e-01,  9.7114e-02,  1.3620e-01,  2.2204e-02,  6.5262e-02,
          4.2267e-02,  2.8453e-01,  8.8643e-02,  1.0069e-01,  7.9461e-02,
          4.8795e-01,  6.0966e-02, -2.5339e-01,  2.6790e-01, -1.9822e-02,
         -9.4969e-02,  6.3960e-02, -8.7915e-02, -2.3109e-02,  3.8467e-01,
          6.2535e-02, -3.0470e-01,  6.1988e-02,  1.7438e-02, -1.2529e-02,
         -2.4090e-01,  5.1711e-01, -2.8790e-01,  1.4959e-01, -1.8237e-01,
         -7.2918e-02,  2.2302e-02, -1.3088e-01, -2.0976e-01, -1.7910e-01,
         -5.5897e-02, -1.4870e-01, -2.8453e-02,  1.2165e-01, -7.6795e-03,
         -2.5328e-01,  4.6149e-01,  2.2511e-01,  1.1284e-01,  4.1193e-02,
          5.4565e-01,  7.8619e-01, -4.3705e-01,  2.6051e-01, -9.1309e-02,
          4.5976e-03,  3.1618e-02, -1.6335e-01, -2.0930e-01, -8.6246e-02,
         -2.7809e-01,  3.4988e-02,  3.8848e-01,  3.8059e-01, -6.6192e-02,
         -2.0933e-01,  4.3955e-01, -4.2346e-01, -2.3746e-03,  1.2398e-01,
          1.9784e-01,  3.1139e-01,  3.1861e-02,  6.9496e-02, -1.4847e-01,
         -1.0827e-01, -2.7199e-01,  2.5093e-02, -5.2472e-02, -1.5192e-01,
          3.1719e-01,  7.3781e-02,  1.6941e-01, -2.8926e-01, -4.0540e-01,
          4.0647e-01, -5.5520e-02, -3.8259e-01,  2.2797e-01, -2.1741e-01,
          1.2530e-01, -4.7871e-02, -3.3642e-02, -4.4706e-02,  2.0079e-01,
          3.5729e-01, -6.6386e-04,  4.9221e-02,  1.3717e-01, -6.5993e-02,
          1.7775e-01, -1.8220e-01, -6.1361e-02, -2.5814e-01,  2.8462e-02,
         -3.2495e-01,  5.1516e-03, -1.6804e-01, -7.2339e-02, -1.0925e-01,
         -1.1470e-01, -5.7446e-01,  3.9203e-02,  8.5304e-02, -1.8025e-01,
         -6.0548e-02,  2.9438e-01, -2.9714e-01,  2.4058e-01, -6.3066e-02,
         -5.8853e-02, -1.8324e-02, -1.9370e-02,  4.8482e-01, -7.1397e-01,
          2.4785e-01, -5.7786e-01,  1.8072e-01, -2.2720e-01,  1.6635e-01,
         -6.2145e-03, -2.5553e-02, -3.8729e-01, -4.0088e-01,  3.8122e-01,
          1.9873e-01, -2.1970e-01, -7.3537e-02,  5.8734e-01,  6.7830e-03,
         -1.3014e-01,  1.8157e-01,  1.3034e-02, -2.6884e-01, -1.6194e-01,
          1.3162e-01,  3.0398e-01,  1.0347e-01, -3.3298e-02, -3.6055e-01,
          1.8854e-02,  4.4354e-01, -1.1624e-01, -5.8954e-02,  8.4650e-02,
         -1.5668e-01,  1.5441e-01, -1.1988e-01,  4.3403e-01,  3.9199e-02,
          1.3342e-01, -3.0168e-01, -1.5621e-01, -2.1547e-01, -2.3980e-01,
          3.4311e-01,  5.8059e-03,  2.2389e-02,  1.0664e-01, -1.5712e-01,
          1.2500e-02,  2.2568e-01, -1.2687e-02,  8.2794e-02, -1.0445e-01,
         -2.1359e-03,  6.1922e-02,  1.3416e-01,  1.4410e-01, -6.2430e-01,
         -2.5134e-01,  1.8748e-01,  1.0164e-01, -2.7316e-01, -3.7829e-02,
         -2.0521e-01, -1.3021e-01,  1.7669e-02, -2.7092e-01, -3.7699e-02,
         -2.6473e-01, -9.9748e-02, -1.6578e-01, -3.3541e-01, -9.5328e-02,
          4.2054e-02,  1.1569e-01,  1.0725e-01,  3.1933e-01, -2.9442e-03,
         -3.5642e-01, -3.4760e-02, -1.4611e-01, -3.0868e-01,  5.2004e-02,
          1.4428e-01,  2.0985e-01,  1.3911e-01, -4.9911e-02, -8.5841e-02,
         -8.1748e-02, -1.4590e-01,  1.2836e-01,  1.0299e-01,  3.9377e-02,
          2.5254e-02,  2.2689e-02,  2.0892e-01, -3.9834e-01, -5.4886e-02,
         -9.0952e-01,  9.6967e-03, -6.9458e-02, -2.9733e-02, -7.4314e-02,
         -1.4006e-01,  2.2001e-01, -1.7268e-01,  1.3299e-01,  2.6686e-01,
         -7.8849e-02,  5.4309e-02,  1.0157e-01, -1.0998e-01, -2.3125e-01,
          2.8301e-01,  5.8042e-02, -2.5059e-01,  3.7479e-02,  1.0202e-01,
          1.0634e-02,  1.9847e-01,  1.4989e-01, -1.0898e-01, -2.3768e-01,
         -1.7987e-01,  1.2295e-01,  3.3561e-01,  2.6437e-01, -1.8406e-01,
          7.3174e-02,  1.8835e-01,  1.8522e-01, -4.1971e-02, -6.7028e-02,
          7.3690e-02,  6.3314e-01,  2.0348e-01,  3.8468e-01,  1.9362e-01,
         -1.6777e-01,  1.3935e-01,  2.0597e-01,  1.3616e-01, -3.8034e-01,
         -2.0124e-01,  8.1477e-02, -2.8500e-01, -2.4414e-01,  5.3451e-01,
         -1.0349e-01,  8.2090e-02, -4.4940e-02, -6.9434e-02,  3.0520e-01,
          9.5877e-02,  3.5525e-01,  7.4328e-02,  3.3169e-01,  2.2752e-01,
          2.4501e-02, -2.7931e-01, -1.8020e-01, -1.9281e-01,  1.2669e-01,
         -2.1739e-01,  3.0783e-02,  4.0950e-01,  5.7406e-02, -1.1331e-02,
          6.3312e-01,  2.6685e-01,  4.7974e-02, -7.8970e-02, -6.9386e-02,
          3.7684e-01,  2.0445e-01, -4.7488e-02,  1.0344e-01,  1.5270e-01,
          2.9465e-03,  7.6271e-02,  5.9321e-02, -1.2665e-01, -1.1533e-01,
         -7.4709e-02, -8.9400e-02, -2.8587e-02,  1.7083e-01,  1.3932e-01,
          2.6748e-01, -5.2371e-01,  2.5220e-01,  4.8864e-03,  4.6810e-01,
         -1.8537e-01,  6.2583e-01, -2.8371e-01,  2.6778e-03,  2.2953e-01,
         -8.1751e-02, -4.2385e-01,  2.0033e-01,  8.0169e-03,  1.8887e-01,
          7.6481e-03, -7.4409e-02, -1.2445e-01,  6.8099e-02,  7.8805e-02,
         -8.9861e-02,  2.9578e-01, -9.7522e-02,  2.4435e-01, -4.5712e-02,
          4.2367e-01, -2.8260e-01, -1.1835e-01, -1.3279e-01,  2.8500e-01,
         -1.0637e-01, -4.1648e-01,  1.4115e-01,  2.5824e-02, -1.9656e-01,
          1.6353e-01, -1.4814e-01,  2.4213e-01,  1.2623e-02, -2.1931e-01,
          6.4695e-02, -3.6126e-01, -1.1576e-01, -2.5397e-01,  3.4354e-01,
         -2.3334e-01, -2.1040e-02,  2.5477e-02,  4.9265e-01,  2.9640e-01,
         -1.5071e-01, -7.0372e-02,  2.6653e-01,  4.9655e-02,  4.4080e-02,
         -9.3498e-02, -1.2610e-01, -3.4418e-01, -8.4982e-03,  1.9030e-01,
         -1.8883e-01, -2.0933e-01,  1.5472e-01,  1.7016e-01, -1.0562e-01,
          2.2627e-01,  2.7251e-01, -4.7672e-01, -6.8294e-01, -8.4626e-02,
         -2.7024e-01, -1.0735e-01, -4.2402e-02,  5.7343e-02,  3.8656e-02,
         -5.6007e-02, -3.6636e-01, -2.3866e-01, -3.2045e-01,  5.3438e-02,
          2.8592e-01, -1.5175e-01, -1.9146e-02,  9.0868e-02,  1.9629e-01,
         -1.0051e-02, -1.8033e-01, -2.0196e-01, -2.9602e-01, -4.2191e-02,
          3.8834e-01,  1.9399e-01, -1.7966e-01,  3.9846e-02, -1.7530e-01,
         -2.9701e-01,  6.7510e-02, -8.6376e-02, -1.8580e-01, -2.6900e-02,
         -2.8771e-01, -2.2083e-01, -1.5802e-01,  8.0829e-02, -1.1645e-02,
          1.7821e-01,  3.3995e-01,  6.6445e-02,  2.5586e-01, -9.1146e-02,
          3.0827e-01, -5.5081e-02,  2.2955e-01,  3.9893e-01,  1.6165e-01,
          1.5546e-01, -1.0180e-01,  5.3702e-01, -3.2218e-02,  5.0897e-01,
         -6.7983e-03,  4.2586e-01, -1.5410e-01, -3.9417e-01, -5.9003e-02,
          1.0391e-01,  1.0698e-01, -1.6199e-01,  1.5575e-01,  4.4614e-01,
         -3.4468e-01,  1.6429e-02,  1.3621e-01, -2.0787e-01, -2.0316e-01,
          1.1573e-01,  7.5611e-02,  9.0927e-02, -1.5930e-01, -9.1533e-03,
          4.1110e-02, -8.5310e-02,  6.5440e-02, -1.6863e-01, -4.4422e-02,
          3.0024e-01,  3.1807e-01,  1.0105e-01,  4.6243e-02,  1.4720e-01,
         -9.6682e-02, -2.4732e-01, -5.9085e-02,  2.8447e-01, -2.7828e-02,
         -1.1391e-01, -1.2189e-01, -4.0296e-01, -1.8533e-01, -1.4152e-01,
         -4.7412e-02,  4.5563e-02, -1.3914e-01,  5.6071e-02,  5.4437e-02,
          2.3609e-01,  1.9397e-01,  3.9884e-03, -8.4392e-02,  2.5361e-01,
         -4.2803e-01, -2.5822e-01, -4.7592e-01, -1.8651e-01, -1.3029e-01,
         -4.0664e-01, -2.7515e-01, -2.7198e-01, -2.9238e-01, -2.2215e-01,
         -5.2016e-01,  3.0412e-01, -5.9065e-02,  1.7501e-02,  2.0620e-01,
          1.6342e-01,  7.4594e-03, -1.6797e-01,  3.5212e-02,  1.4836e-01,
         -4.4118e-02,  7.7187e-02, -2.6953e-01, -2.6416e-02, -2.3423e-01,
          2.1405e-01,  4.6495e-01, -2.6026e-01]])

And, this way we have an x-vector!

These x-vectors are audio embeddings generated from a Deep Learning model that has been trained with other kinds of data. These are the features (768 features) generated by the model after it has been trained with billions of hours of audio of other tasks.

Now think about the possibilities! We only have one vector of a single audio file, but think about what we could achieve if we had a large number of vectors.

  • This turns the classification problem into a clustering problem, allowing the use of other techniques.
  • The model now behaves as a feature extractor. Features that can be added to the sample along with other features to improve the performance of other classification models.
  • It is not necessary to train the model for other classification tasks (although it could be done and would certainly improve the results with other models on these vectors).
  • It allows for zero-shot and one-shot inference, as it is very easy to add new vectors to the generated vector space.

A technique worth discussing and sharing.

Thank you very much for reading!

Amable

About

Have you ever wondered how to create an AI model that returns an embedding instead of a classification? This notebook is for you ;)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published