Skip to content

Error loading state dict for SwinIR, Missing key(s) in state_dict [...] #165

Open
@AwaaX

Description

@AwaaX

Hello there,
I try to load SwinIR but got errors ,

We use this model https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth

Renamed swinir_model.pth here

` def load_swinir(self):
try:
logger.info("Creating SwinIR model instance")

        # Import SwinIR
        try:
            from basicsr.archs.swinir_arch import SwinIR
        except ImportError as e:
            logger.error(f"First import attempt failed: {str(e)}")
            from basicsr.models.archs.swinir_arch import SwinIR
        
        logger.info("SwinIR class imported successfully")
        
        # Create model with correct configurations
        model = SwinIR(
            upscale=4,
            in_chans=3,
            img_size=64,
            window_size=8,
            img_range=1.0,
            depths=[6, 6, 6, 6, 6, 6],
            embed_dim=240,        # Fixed dimension
            num_heads=[8, 8, 8, 8, 8, 8],  # Fixed heads
            mlp_ratio=2.0,
            upsampler='pixelshuffel',  # Changed from nearest+conv
            resi_connection='3conv'     # Changed from 1conv
        )
        
        # Load pre-trained weights
        model_path = 'models/swinir_model.pth'
        logger.info(f"Loading SwinIR weights from {model_path}")
        loadnet = torch.load(model_path, map_location=self.device)
        
        # Convert weight keys to match model's expected format
        new_state_dict = {}
        for k, v in loadnet.items():
            if 'params' in k:
                continue
            
            # Handle conv layers
            if '.conv.' in k:
                parts = k.split('.')
                if parts[-1] == '0':
                    new_k = '.'.join(parts[:-1]) + '.weight'
                elif parts[-1] == '1':
                    new_k = '.'.join(parts[:-1]) + '.bias'
                else:
                    new_k = k
                new_state_dict[new_k] = v
            
            # Handle conv_after_body
            elif 'conv_after_body' in k:
                parts = k.split('.')
                if parts[-1] == '0':
                    new_k = 'conv_after_body.weight'
                elif parts[-1] == '1':
                    new_k = 'conv_after_body.bias'
                else:
                    new_k = k
                new_state_dict[new_k] = v
            
            else:
                new_state_dict[k] = v

        # Load state dict with detailed error reporting
        try:
            model.load_state_dict(new_state_dict, strict=True)
            logger.info("SwinIR weights loaded successfully")
        except Exception as e:
            logger.error(f"Error loading state dict: {str(e)}")
            logger.error("Expected keys:")
            logger.error(model.state_dict().keys())
            logger.error("Provided keys:")
            logger.error(new_state_dict.keys())
            raise
            
        model.eval()
        logger.info("SwinIR model initialized in eval mode")
        
        return model.to(self.device)
        
    except Exception as e:
        logger.error(f"Error loading SwinIR model: {str(e)}")
        logger.error(f"Current sys.path: {sys.path}")
        raise`

I got this errors
"Error loading state dict: Error(s) in loading state_dict for SwinIR"
"Missing key(s) in state_dict [...]"

Here is more logs details :

https://pastebin.com/KwqgYje8

We search for model architecture details, configurations used for the pre-trained weights but we are not sure where to find it

Ty very much if u can help

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions