TorchSuite is a versatile, feature-rich, and PyTorch-based library designed to simplify and accelerate deep learning development for multimedia applications — including audio, image, and video processing — and medical imaging segmentation.
It provides essential modules and utilities for training, inference, data handling, and model optimization, making it a valuable tool for researchers, engineers, and practitioners.
Key highlights of the framework include:
- Simplified model training and evaluation
- CPU and GPU acceleration
- Support for knowledge distillation for classification and regression
- Easily customizable architectures: ViTs, U-Net, ConvNeXt, R-CNNs, YOLO, Wav2Vec2
- Extensive set of loss functions and schedulers
- Intuitive API for rapid experimentation
TorchSuite is currently optimized for:
- Image and audio classification
- Image and audio regression
- Object detection
- Image segmentation, including medical imaging
Additional capabilities:
- Augmentation scheduling for training stabilization
- Automatic Mixed Precision (AMP) for faster training
- Gradient accumulation for memory-efficient learning
- In-notebook live monitoring
- Training resume and checkpointing
Future versions will include support for other data types, such as video and text.
- CPU and GPU-accelerated computation: optimized for both CPU and GPU, enabling flexible model training and inference.
- Training and inference engines:
classification.py,regression.py,obj_detection.py,segmentation.pyfor training and evaluation workflows. - Flexible data loading: loaders for image, audio, detection, and segmentation (
*_dataloaders.py) to streamline dataset preparation and augmentation. - Utility functions: helper scripts (
*_utils.py) for productivity and experiment reproducibility. - Vision Transformer (ViT) support:
vision_transformer.pyfor implementing ViTs with PyTorch. - ConvNeXT support:
convnext.py, the original implementation by Meta Platforms, Inc. - Wav2Vec2 support:
wav2vec2.pyincludes a Transformer-based acoustic model for audio classification. - Region-based CNN (R-CNN) support:
faster_rcnn.pyto create deep learning networks for object detection. - U-Net support:
unet.pyfor a flexible design of U-Net models tailored to image segmentation. Pretrained U-Net modeling is taken from mberkay0. The library includes the canonical vanilla U-Net architecture as well. - Learning rate scheduling:
schedulers.pyprovides adaptive learning rate strategies. Some classes have been taken from kamrulhasanrony. - Custom loss functions:
loss_functions.pyincluding various loss formulations for different tasks. - Code examples: a series of notebooks demonstrating Python code for training deep learning models.
To install and set up this project, follow these steps:
- Clone the repository:
git clone https://github.com/sergio-sanz-rodriguez/TorchSuite- Navigate into the project directory:
cd TorchSuite- Create a virtual environment with GPU suppport:
conda create --name torchsuite_gpu python=3.11.10
conda activate torchsuite_gpu
(or using venv)
python -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate- Install additional dependencies:
pip install -r requirements.txt- Install PyTorch with GPU support (modify CUDA version as needed):
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124- Verify installation
import torch
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")The required libraries are listed in requirements.txt (pipreqs /path/to/your/project --force --ignore <comma_separated_folders_or_files>):
gdown==5.2.0
gradio==5.49.1
ipython==8.12.3
matplotlib==3.10.7
numpy==2.3.4
opencv_python==4.12.0.88
opencv_python_headless==4.9.0.80
pandas==2.3.3
Pillow==12.0.0
Requests==2.32.5
scikit_learn==1.7.2
scipy==1.16.2
seaborn==0.13.2
timm==1.0.20
torch==2.5.1+cu121
torch==2.5.1
torchaudio==2.5.1+cu121
torchaudio==2.5.1
torchvision==0.20.1+cu121
torchvision==0.20.1
tqdm==4.66.6
transformers==4.48.3
ultralytics==8.3.174The following notebooks demonstrate how to implement and train deep learning models using the modules described above:
image_classification.ipynb: transformer-based image classification.image_distillation.ipynb: model distillation for image tasks.audio_waveform_classification.ipynbandaudio_spectrogram_classification.ipynb: waveform- and spectrogram-based audio classification.image_regression.ipynb: ConvNeXt-Large model predicting image quality of filtered images.object_detection_rcnn_custom.ipynbandobject_detection_rcnn_standard.ipynb: custom and standard R-CNN object detection and segmentation workflows.object_detection_yolo.ipynb: YOLO-based object detection.image_segmentation_organs.ipynbandimage_segmentation_vehicles.ipynb: U-Net-based segmentation notebooks for medical imaging and vehicle detection use cases.demos/deep_count: Gradio-based human detection app using YOLO.
Data augmentation is of paramount importance to ensure the model's generalization. TrivialAugmentWide() is an efficient method that applies diverse image transformations with a single command. This method should be applied during preprocessing of the image dataset to adjust its format (e.g., image resolution, torch.tensor format, color normalization, etc.) to match the network's requirements.
# Specify transformations
from torchvision.transforms import v2
transform_train = v2.Compose([
v2.TrivialAugmentWide(), # Data augmentation
v2.Resize(256),
v2.RandomCrop((224, 224)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])Similar to image recognition tasks, generalization in audio recognition can be improved by applying augmentation techniques to the audio signal. Whether pattern recognition is performed in the time domain (waveform) or the frequency domain (spectrograms), the following transformations can be applied:
- Time domain
import torch
import torchaudio
import librosa
# Load audio file
waveform, sample_rate = torchaudio.load("wave.wav", normalize=True)
# Apply pitch shifting
waveform = librosa.effects.pitch_shift(
waveform.numpy(), # Audio waveform
sr=sample_rate, # Sample rate of the waveform
n_steps=2 # Number of semitones to shift the pitch
)
# Add random noise
waveform = waveform + torch.randn_like(waveform) * 0.005
- Frequency domain
import torchaudio
# Compute spectrogram to get an image with shape [1, 384, 381]
stride = round(len(waveform) / (384 - 1))
transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, # Sample rate of the waveform
n_fft=1024, # Number of FFT bins
win_length=1024, # Window length of the waveform for FFT analysis
hop_length=stride, # Stride in samples between two analysis windows
n_mels=384, # Number of mel filter banks
power=2 # Exponent for the magnitude spectrogram
)
spectrogram = transform(waveform)
# Define 10% masking in the frequency domain
freq_mask_param = int(0.10 * 384)
spectrogram = torchaudio.transforms.FrequencyMasking(
freq_mask_param=freq_mask_param)(spectrogram)
# Define 10% masking in the time domain
time_mask_param = int(0.10 * 381)
spectrogram = torchaudio.transforms.TimeMasking(
time_mask_param=time_mask_param)(spectrogram)The AdamW optimizer has been shown to improve generalization. Additionally, CrossEntropyLoss is the most commonly used loss function in classification tasks, where a certain level of label smoothing (e.g., 0.1) can further enhance generalization. Adding a scheduler for learning rate regulation is also a good practice to optimize parameter updates. An initial learning rate between 1e-4 and 1e-5 and a final learning rate up to 1e-7 are recommended. Optionally, the custom FixedLRSchedulerWrapper scheduler can be used to maintain a fixed learning rate in the final epochs, helping stabilize the model parameters.
# Create AdamW optimizer
optimizer = torch.optim.AdamW(
params=model.parameters(),
lr=LR,
betas=(0.9, 0.999),
weight_decay=0.01
)
# Create loss function
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
# Set scheduler: from epoch #1 to #10 use CosinAnnealingRL, from epoch #11 to #20 a fixed learning rate
scheduler = FixedLRSchedulerWrapper(
scheduler=CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6),
fixed_lr=1e-6,
fixed_epoch=10)Distillation is a technique where a smaller, lightweight model (the "student") is trained to mimic the behavior of a larger, pre-trained model (the "teacher"). This approach can significantly reduce complexity and speed up inference while maintaining comparable accuracy.
A custom cross-entropy-based distillation loss function has been created. This loss function consists of a weighted combination of two components:
- Soft Loss (KL divergence): Encourages the student model to match the teacher model’s probability distribution, allowing it to learn fine-grained relationships between classes.
- Hard Loss (cross-entropy): Ensures the student model learns directly from the ground truth labels for correct classification.
A good starting point for configuring this loss function is:
# Create loss function
loss_fn = DistillationLoss(alpha=0.4, temperature=2, label_smoothing=0.1)where alpha controls the weighting between soft and hard losses, temperature smooths the teacher’s probability distribution, making it easier for the student to learn from, and label_smoothing prevents overconfidence by redistributing a small portion of the probability mass to all classes.
If you want to contribute to this project, contact me via email at sergio.sanz.rodriguez@gmail.com.
This project is licensed under the MIT License - see the LICENSE file for details.
For any questions, feel free to contact me via email at sergio.sanz.rodriguez@gmail.com or connect with me on LinkedIn: linkedin.com/in/sergio-sanz-rodriguez/.
