Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error loading state dict for SwinIR, Missing key(s) in state_dict [...] #165

Open
AwaaX opened this issue Nov 10, 2024 · 2 comments
Open

Comments

@AwaaX
Copy link

AwaaX commented Nov 10, 2024

Hello there,
I try to load SwinIR but got errors ,

We use this model https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth

Renamed swinir_model.pth here

` def load_swinir(self):
try:
logger.info("Creating SwinIR model instance")

        # Import SwinIR
        try:
            from basicsr.archs.swinir_arch import SwinIR
        except ImportError as e:
            logger.error(f"First import attempt failed: {str(e)}")
            from basicsr.models.archs.swinir_arch import SwinIR
        
        logger.info("SwinIR class imported successfully")
        
        # Create model with correct configurations
        model = SwinIR(
            upscale=4,
            in_chans=3,
            img_size=64,
            window_size=8,
            img_range=1.0,
            depths=[6, 6, 6, 6, 6, 6],
            embed_dim=240,        # Fixed dimension
            num_heads=[8, 8, 8, 8, 8, 8],  # Fixed heads
            mlp_ratio=2.0,
            upsampler='pixelshuffel',  # Changed from nearest+conv
            resi_connection='3conv'     # Changed from 1conv
        )
        
        # Load pre-trained weights
        model_path = 'models/swinir_model.pth'
        logger.info(f"Loading SwinIR weights from {model_path}")
        loadnet = torch.load(model_path, map_location=self.device)
        
        # Convert weight keys to match model's expected format
        new_state_dict = {}
        for k, v in loadnet.items():
            if 'params' in k:
                continue
            
            # Handle conv layers
            if '.conv.' in k:
                parts = k.split('.')
                if parts[-1] == '0':
                    new_k = '.'.join(parts[:-1]) + '.weight'
                elif parts[-1] == '1':
                    new_k = '.'.join(parts[:-1]) + '.bias'
                else:
                    new_k = k
                new_state_dict[new_k] = v
            
            # Handle conv_after_body
            elif 'conv_after_body' in k:
                parts = k.split('.')
                if parts[-1] == '0':
                    new_k = 'conv_after_body.weight'
                elif parts[-1] == '1':
                    new_k = 'conv_after_body.bias'
                else:
                    new_k = k
                new_state_dict[new_k] = v
            
            else:
                new_state_dict[k] = v

        # Load state dict with detailed error reporting
        try:
            model.load_state_dict(new_state_dict, strict=True)
            logger.info("SwinIR weights loaded successfully")
        except Exception as e:
            logger.error(f"Error loading state dict: {str(e)}")
            logger.error("Expected keys:")
            logger.error(model.state_dict().keys())
            logger.error("Provided keys:")
            logger.error(new_state_dict.keys())
            raise
            
        model.eval()
        logger.info("SwinIR model initialized in eval mode")
        
        return model.to(self.device)
        
    except Exception as e:
        logger.error(f"Error loading SwinIR model: {str(e)}")
        logger.error(f"Current sys.path: {sys.path}")
        raise`

I got this errors
"Error loading state dict: Error(s) in loading state_dict for SwinIR"
"Missing key(s) in state_dict [...]"

Here is more logs details :

https://pastebin.com/KwqgYje8

We search for model architecture details, configurations used for the pre-trained weights but we are not sure where to find it

Ty very much if u can help

@itboy2009
Copy link

see arg: large_model

@AwaaX
Copy link
Author

AwaaX commented Nov 28, 2024

see arg: large_model

Any link ?

U mean

parser.add_argument('--large_model', action='store_true', help='use large model, only provided for real image sr')

from https://github.com/JingyunLiang/SwinIR/blob/main/main_test_swinir.py ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants