Open
Description
Hello there,
I try to load SwinIR but got errors ,
We use this model https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth
Renamed swinir_model.pth here
` def load_swinir(self):
try:
logger.info("Creating SwinIR model instance")
# Import SwinIR
try:
from basicsr.archs.swinir_arch import SwinIR
except ImportError as e:
logger.error(f"First import attempt failed: {str(e)}")
from basicsr.models.archs.swinir_arch import SwinIR
logger.info("SwinIR class imported successfully")
# Create model with correct configurations
model = SwinIR(
upscale=4,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=240, # Fixed dimension
num_heads=[8, 8, 8, 8, 8, 8], # Fixed heads
mlp_ratio=2.0,
upsampler='pixelshuffel', # Changed from nearest+conv
resi_connection='3conv' # Changed from 1conv
)
# Load pre-trained weights
model_path = 'models/swinir_model.pth'
logger.info(f"Loading SwinIR weights from {model_path}")
loadnet = torch.load(model_path, map_location=self.device)
# Convert weight keys to match model's expected format
new_state_dict = {}
for k, v in loadnet.items():
if 'params' in k:
continue
# Handle conv layers
if '.conv.' in k:
parts = k.split('.')
if parts[-1] == '0':
new_k = '.'.join(parts[:-1]) + '.weight'
elif parts[-1] == '1':
new_k = '.'.join(parts[:-1]) + '.bias'
else:
new_k = k
new_state_dict[new_k] = v
# Handle conv_after_body
elif 'conv_after_body' in k:
parts = k.split('.')
if parts[-1] == '0':
new_k = 'conv_after_body.weight'
elif parts[-1] == '1':
new_k = 'conv_after_body.bias'
else:
new_k = k
new_state_dict[new_k] = v
else:
new_state_dict[k] = v
# Load state dict with detailed error reporting
try:
model.load_state_dict(new_state_dict, strict=True)
logger.info("SwinIR weights loaded successfully")
except Exception as e:
logger.error(f"Error loading state dict: {str(e)}")
logger.error("Expected keys:")
logger.error(model.state_dict().keys())
logger.error("Provided keys:")
logger.error(new_state_dict.keys())
raise
model.eval()
logger.info("SwinIR model initialized in eval mode")
return model.to(self.device)
except Exception as e:
logger.error(f"Error loading SwinIR model: {str(e)}")
logger.error(f"Current sys.path: {sys.path}")
raise`
I got this errors
"Error loading state dict: Error(s) in loading state_dict for SwinIR"
"Missing key(s) in state_dict [...]"
Here is more logs details :
We search for model architecture details, configurations used for the pre-trained weights but we are not sure where to find it
Ty very much if u can help
Metadata
Metadata
Assignees
Labels
No labels