Open
Description
Hello! Thanks so much for easyfsl, it's fantastic. I am testing my Prototypical Network (trained on mini imagenet) with SupportSetFolder. When I test it on the folder I attached called dataset1 (containing photos from internet), I get very accurate results. But when I test it on the folder I attached called dataset2 (containing photos I took), I get very inaccurate results. If you could help me figure out why this is, I'd appreciate it so much. Thanks again.
pip install easyfsl
import torch
import os
import csv
from pathlib import Path
import pandas as pd
from skimage import io
from typing import List, Tuple
from PIL import Image
from typing import Optional
from torch import nn, optim, Tensor
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder, DatasetFolder
from torchvision.models import resnet18
from tqdm import tqdm
from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average
from easyfsl.methods.utils import compute_prototypes
from easyfsl.datasets import FewShotDataset, WrapFewShotDataset, SupportSetFolder
from easyfsl.methods import FewShotClassifier
class PrototypicalNetworks(FewShotClassifier):
def __init__(
self,
backbone: Optional[nn.Module] = None,
):
"""
Initialize the Prototypical Networks Few-Shot Classifier
Args:
backbone: the feature extractor used by the method. Must output a tensor of the
appropriate shape (depending on the method).
If None is passed, the backbone will be initialized as nn.Identity().
"""
super().__init__(backbone=backbone)
def forward(
self,
support_images: torch.Tensor, # Support images
support_labels: torch.Tensor, # Support labels
query_images: torch.Tensor, # Query images
) -> torch.Tensor:
"""
Predict classification labels.
Args:
support_images: images of the support set of shape (n_support, **image_shape)
support_labels: labels of support set images of shape (n_support, )
query_images: images of the query set of shape (n_query, **image_shape)
Returns:
a prediction of classification scores for query images of shape (n_query, n_classes)
"""
# Compute features for support and query images
z_support = self.compute_features(support_images)
z_query = self.compute_features(query_images)
# Compute prototypes from support set
self.compute_prototypes_and_store_support_set(support_images, support_labels)
logits = self.l2_distance_to_prototypes(z_query)
return self.softmax_if_specified(logits)
@staticmethod
def is_transductive() -> bool:
return True # or False depending on your implementation
# Initialize the backbone (pretrained ResNet18 with the fully connected layer replaced by a Flatten layer)
convolutional_network = resnet18(pretrained=True)
convolutional_network.fc = nn.Flatten()
# print(convolutional_network)
# Create the Prototypical Networks model using resnet18 as the feature extractor CNN
model_path = '/content/MIN_model.pth'
model = PrototypicalNetworks(convolutional_network).cuda()
model.load_state_dict(torch.load(model_path))
device = "cuda"
# Define transformations to be applied to images
transform=transforms.Compose(
[
transforms.Resize([348, 348]),
transforms.CenterCrop(348),
transforms.ToTensor(),
]
)
support_set = SupportSetFolder(root='/content/dataset2/support_set', transform=transform, device=device)
# transform_tensor = transforms.Compose([transforms.ToTensor()])
query_image_path = '/content/dataset2/query_set/chips1.jpg'
query_image_PIL = Image.open(query_image_path)
query_images = transform(query_image_PIL).float()
query_images = query_images.unsqueeze(0)
with torch.no_grad():
model.eval()
model.process_support_set(support_set.get_images(), support_set.get_labels())
class_names = support_set.classes
print(f"Class names: {class_names}")
predicted_labels = model(support_set.get_images().cuda(), support_set.get_labels().cuda(), query_images.to(device).cuda()).argmax(dim=1)
# print(f"Predicted labels: {predicted_labels}")
predicted_classes = [ support_set.classes[label] for label in predicted_labels]
print(f"Predicted classes: {predicted_classes}")
Link to download MIN_model.pth: https://drive.google.com/file/d/1q6sfNYcYSTUJzEiHq1T-nJ5R31EZ8dio/view?usp=sharing
dataset1.zip
dataset1.zip
dataset2.zip
dataset2.zip