Skip to content

Batch_size不一致导致相同数据在同一模型上输出不一致 #328

@Let-the-bullet-fly

Description

@Let-the-bullet-fly

1.Describe the current behavior / 问题描述 (Mandatory / 必填)

vgg16模型,cifar10数据集,Batch_size为1和64时,相同数据在同一模型上输出不一致

2.Environment / 环境信息 (Mandatory / 必填)

  • Hardware Environment / 硬件环境(Mandatory / 必填):
后端类型 硬件具体类别
GPU NVIDIA GeForce RTX 3090 * 8
CPU Intel(R) Xeon(R) Gold 6226R CPU @ 2.90GHz
RAM 96GB
  • Software Environment / 软件环境 (Mandatory / 必填):
Software Version(根据实际修改,必填)
MindSpore MindSpore 2.0.0
MindFormers 1.1.0
MindIE
CANN
Python Python 3.9.19
OS platform Ubuntu 18.04.2 LTS
GCC/Compiler version

3.Related testcase / 关联用例 (Mandatory / 必填)

4.Steps to reproduce the issue / 重现步骤 (Mandatory / 必填)

import mindspore.dataset as ds
import mindspore
from mindspore import Tensor
import torch
import os
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
from mindspore import context
import model_util as model_util

# mindspore.set_context(mode=mindspore.GRAPH_MODE, device_target="GPU", device_id=0)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
torch.backends.cudnn.benchmark = True  # cuDNN 衡量库里面的多个卷积算法速度,然后选择其中最快的那个卷积算法

#  vgg16 deeplabv3 deeplabv3plus openpose patchcore unet resnet50 textcnn ssimae
model_name = "vgg16"

# cifar10 Pascal_VOC_deeplab openposecoco2017 ischanllge patchcoreMVTecAD
dataset_name = "cifar10"

batch_size = 1

device_target = "GPU"
device = "cuda"
device_id = 0
DEVICE = torch.device(device)
attack = "wanet"

data_size = {
    'cifar10': (3, 224, 224),
}


def get_model_and_load_weight():
    pt_model_filepath = f'/data2/CKPTS/' + model_name + '/' + model_name + '.pth'
    ms_model_filepath = f'/data2/CKPTS/' + model_name + '/' + model_name + '.ckpt'
    f = open('./config.txt', 'w')
    config_path = r'../common/config/' + model_name + '.yaml'
    f.write(config_path.lower())
    f.close()

    input_size = (batch_size,) + data_size[dataset_name]

    # 获取模型并加载权重
    ms_model, pt_model = model_util.get_model(model_name, device_target, device_id, input_size)

    # model = torch.nn.DataParallel(model)
    loaded = torch.load(pt_model_filepath, map_location=device)
    if isinstance(loaded, dict):
        state_dict = loaded
    else:
        state_dict = loaded.state_dict()
    # state_dict = torch.load(pt_model_filepath, map_location=device)
    pt_model.load_state_dict(state_dict=state_dict, strict=False)
    # model = torch.load(model_filepath,map_location=device)
    pt_model.eval()
    pt_model.to(DEVICE)

    mindspore.load_checkpoint(ms_model_filepath, ms_model)
    ms_model.set_train(False)  # 设置为推理模式

    return ms_model, pt_model


def collect_diff_data(poison=False):
    ms_model, pt_model = get_model_and_load_weight()
    print("PT Model loaded successfully")
    print("MS Model loaded successfully")

    data_path = "./" + attack + ".npz"

    npz_data = np.load(data_path)
    x_data = npz_data['x']  # 图像数据
    y_data = npz_data['y']  # 标签数据
    y_data = y_data.squeeze()
    print(f"x_data shape: {x_data.shape}")  # (10000, 3, 224, 224)
    print(f"y_data shape: {y_data.shape}")  # (10000,)

    # PyTorch数据加载
    inputs_torch = torch.tensor(x_data, dtype=torch.float32).to(DEVICE)
    labels_torch = torch.tensor(y_data, dtype=torch.long).to(DEVICE)
    dataset_torch = TensorDataset(inputs_torch, labels_torch)
    dataloader_torch = DataLoader(dataset_torch, batch_size=batch_size, shuffle=False)

    # MindSpore数据加载
    dataset_mindspore = ds.NumpySlicesDataset({"inputs": x_data, "labels": y_data}, shuffle=False)
    dataloader_mindspore = dataset_mindspore.batch(batch_size)

    # 定义存储不一致数据的数组
    mismatch_data = []
    mismatch_pytorch_results = []
    mismatch_mindspore_results = []
    mismatch_labels = []
    # 定义一个列表来存储不一致样本的全局索引
    mismatch_indices = []

    # 初始化全局索引
    global_index = 0

    count = 0

    # 批量推导和比较
    for (batch_torch, batch_mindspore) in zip(dataloader_torch, dataloader_mindspore.create_dict_iterator()):
        # PyTorch推导
        inputs_batch_torch, labels_batch_torch = batch_torch
        with torch.no_grad():
            outputs_torch = pt_model(inputs_batch_torch).argmax(dim=1)

        # MindSpore推导
        inputs_batch_mindspore = Tensor(batch_mindspore["inputs"], mindspore.float32)
        labels_batch_mindspore = batch_mindspore["labels"]
        outputs_mindspore = ms_model(inputs_batch_mindspore).asnumpy().argmax(axis=1)

        # 比较结果并存储
        batch_indices = np.where(outputs_torch.cpu().numpy() != outputs_mindspore)[0]

        if len(batch_indices) == 0 and count == 15:
            print("outputs_torch:")
            print(pt_model(inputs_batch_torch)[0])
            print("outputs_mindspore")
            print(ms_model(inputs_batch_mindspore).asnumpy()[0])

        if len(batch_indices) > 0:
            print("outputs_torch:")
            print(pt_model(inputs_batch_torch)[batch_indices])
            print("outputs_mindspore")
            print(ms_model(inputs_batch_mindspore).asnumpy()[batch_indices])
            # 保存不一致样本的全局索引
            mismatch_indices.extend(batch_indices + global_index)
            mismatch_data.extend(x_data[batch_indices + global_index])
            mismatch_pytorch_results.extend(outputs_torch.cpu().to('cpu').numpy()[batch_indices])  # 确保输出在 CPU 上
            mismatch_mindspore_results.extend(outputs_mindspore[batch_indices])
            mismatch_labels.extend(y_data[batch_indices + global_index])

        # 更新全局索引
        global_index += len(batch_torch[0])
        count += 1

    # 输出不一致样本的全局索引
    print(f"Indices of mismatched samples: {mismatch_indices}")
    print(f"Number of mismatched samples: {len(mismatch_data)}")

    # 输出不一致样本
    for i in range(len(mismatch_data)):
        print(f"Sample Index: {i}")
        # print(f"Input Data: {mismatch_data[i]}")
        print(f"PyTorch Result: {mismatch_pytorch_results[i]}, MindSpore Result: {mismatch_mindspore_results[i]}")
        print(f"True Label: {mismatch_labels[i]}")




if __name__ == "__main__":
    collect_diff_data()

wanet.npz中保存了64条数据,包括(3,224,224)的图像和对应的标签,batch_size为64和1时,索引为15的数据在被mindspore.dataset加载后其输出不一致
输入图片说明
当batch_size为1时,mindspore和PyTorch框架下的分类一致,都分类到3
输入图片说明
当batch_size为64时,mindspore和PyTorch框架下的分类不一致,PyTorch分类到3,mindspore分类到5

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions