-
Notifications
You must be signed in to change notification settings - Fork 744
Open
Description
1.Describe the current behavior / 问题描述 (Mandatory / 必填)
vgg16模型,cifar10数据集,Batch_size为1和64时,相同数据在同一模型上输出不一致
2.Environment / 环境信息 (Mandatory / 必填)
- Hardware Environment / 硬件环境(Mandatory / 必填):
后端类型 | 硬件具体类别 |
---|---|
GPU | NVIDIA GeForce RTX 3090 * 8 |
CPU | Intel(R) Xeon(R) Gold 6226R CPU @ 2.90GHz |
RAM | 96GB |
- Software Environment / 软件环境 (Mandatory / 必填):
Software | Version(根据实际修改,必填) |
---|---|
MindSpore | MindSpore 2.0.0 |
MindFormers | 1.1.0 |
MindIE | |
CANN | |
Python | Python 3.9.19 |
OS platform | Ubuntu 18.04.2 LTS |
GCC/Compiler version |
3.Related testcase / 关联用例 (Mandatory / 必填)
4.Steps to reproduce the issue / 重现步骤 (Mandatory / 必填)
import mindspore.dataset as ds
import mindspore
from mindspore import Tensor
import torch
import os
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
from mindspore import context
import model_util as model_util
# mindspore.set_context(mode=mindspore.GRAPH_MODE, device_target="GPU", device_id=0)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
torch.backends.cudnn.benchmark = True # cuDNN 衡量库里面的多个卷积算法速度,然后选择其中最快的那个卷积算法
# vgg16 deeplabv3 deeplabv3plus openpose patchcore unet resnet50 textcnn ssimae
model_name = "vgg16"
# cifar10 Pascal_VOC_deeplab openposecoco2017 ischanllge patchcoreMVTecAD
dataset_name = "cifar10"
batch_size = 1
device_target = "GPU"
device = "cuda"
device_id = 0
DEVICE = torch.device(device)
attack = "wanet"
data_size = {
'cifar10': (3, 224, 224),
}
def get_model_and_load_weight():
pt_model_filepath = f'/data2/CKPTS/' + model_name + '/' + model_name + '.pth'
ms_model_filepath = f'/data2/CKPTS/' + model_name + '/' + model_name + '.ckpt'
f = open('./config.txt', 'w')
config_path = r'../common/config/' + model_name + '.yaml'
f.write(config_path.lower())
f.close()
input_size = (batch_size,) + data_size[dataset_name]
# 获取模型并加载权重
ms_model, pt_model = model_util.get_model(model_name, device_target, device_id, input_size)
# model = torch.nn.DataParallel(model)
loaded = torch.load(pt_model_filepath, map_location=device)
if isinstance(loaded, dict):
state_dict = loaded
else:
state_dict = loaded.state_dict()
# state_dict = torch.load(pt_model_filepath, map_location=device)
pt_model.load_state_dict(state_dict=state_dict, strict=False)
# model = torch.load(model_filepath,map_location=device)
pt_model.eval()
pt_model.to(DEVICE)
mindspore.load_checkpoint(ms_model_filepath, ms_model)
ms_model.set_train(False) # 设置为推理模式
return ms_model, pt_model
def collect_diff_data(poison=False):
ms_model, pt_model = get_model_and_load_weight()
print("PT Model loaded successfully")
print("MS Model loaded successfully")
data_path = "./" + attack + ".npz"
npz_data = np.load(data_path)
x_data = npz_data['x'] # 图像数据
y_data = npz_data['y'] # 标签数据
y_data = y_data.squeeze()
print(f"x_data shape: {x_data.shape}") # (10000, 3, 224, 224)
print(f"y_data shape: {y_data.shape}") # (10000,)
# PyTorch数据加载
inputs_torch = torch.tensor(x_data, dtype=torch.float32).to(DEVICE)
labels_torch = torch.tensor(y_data, dtype=torch.long).to(DEVICE)
dataset_torch = TensorDataset(inputs_torch, labels_torch)
dataloader_torch = DataLoader(dataset_torch, batch_size=batch_size, shuffle=False)
# MindSpore数据加载
dataset_mindspore = ds.NumpySlicesDataset({"inputs": x_data, "labels": y_data}, shuffle=False)
dataloader_mindspore = dataset_mindspore.batch(batch_size)
# 定义存储不一致数据的数组
mismatch_data = []
mismatch_pytorch_results = []
mismatch_mindspore_results = []
mismatch_labels = []
# 定义一个列表来存储不一致样本的全局索引
mismatch_indices = []
# 初始化全局索引
global_index = 0
count = 0
# 批量推导和比较
for (batch_torch, batch_mindspore) in zip(dataloader_torch, dataloader_mindspore.create_dict_iterator()):
# PyTorch推导
inputs_batch_torch, labels_batch_torch = batch_torch
with torch.no_grad():
outputs_torch = pt_model(inputs_batch_torch).argmax(dim=1)
# MindSpore推导
inputs_batch_mindspore = Tensor(batch_mindspore["inputs"], mindspore.float32)
labels_batch_mindspore = batch_mindspore["labels"]
outputs_mindspore = ms_model(inputs_batch_mindspore).asnumpy().argmax(axis=1)
# 比较结果并存储
batch_indices = np.where(outputs_torch.cpu().numpy() != outputs_mindspore)[0]
if len(batch_indices) == 0 and count == 15:
print("outputs_torch:")
print(pt_model(inputs_batch_torch)[0])
print("outputs_mindspore")
print(ms_model(inputs_batch_mindspore).asnumpy()[0])
if len(batch_indices) > 0:
print("outputs_torch:")
print(pt_model(inputs_batch_torch)[batch_indices])
print("outputs_mindspore")
print(ms_model(inputs_batch_mindspore).asnumpy()[batch_indices])
# 保存不一致样本的全局索引
mismatch_indices.extend(batch_indices + global_index)
mismatch_data.extend(x_data[batch_indices + global_index])
mismatch_pytorch_results.extend(outputs_torch.cpu().to('cpu').numpy()[batch_indices]) # 确保输出在 CPU 上
mismatch_mindspore_results.extend(outputs_mindspore[batch_indices])
mismatch_labels.extend(y_data[batch_indices + global_index])
# 更新全局索引
global_index += len(batch_torch[0])
count += 1
# 输出不一致样本的全局索引
print(f"Indices of mismatched samples: {mismatch_indices}")
print(f"Number of mismatched samples: {len(mismatch_data)}")
# 输出不一致样本
for i in range(len(mismatch_data)):
print(f"Sample Index: {i}")
# print(f"Input Data: {mismatch_data[i]}")
print(f"PyTorch Result: {mismatch_pytorch_results[i]}, MindSpore Result: {mismatch_mindspore_results[i]}")
print(f"True Label: {mismatch_labels[i]}")
if __name__ == "__main__":
collect_diff_data()
wanet.npz中保存了64条数据,包括(3,224,224)的图像和对应的标签,batch_size为64和1时,索引为15的数据在被mindspore.dataset加载后其输出不一致
当batch_size为1时,mindspore和PyTorch框架下的分类一致,都分类到3
当batch_size为64时,mindspore和PyTorch框架下的分类不一致,PyTorch分类到3,mindspore分类到5
Metadata
Metadata
Assignees
Labels
No labels