Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pruning problem #6

Open
LKAMING97 opened this issue May 23, 2023 · 3 comments
Open

Pruning problem #6

LKAMING97 opened this issue May 23, 2023 · 3 comments

Comments

@LKAMING97
Copy link

Hello,LightXECG is a good job ,I am very interested in the code of your pruning part. I want to know how to implement it, whether the code can be made public, and if not, can you provide some corresponding reference code?

@lhkhiem28
Copy link
Owner

Hi @LKAMING97, thanks for your interest in my work
please refer to this code for pruning, note that it should be synchronized with the main repo
https://gist.github.com/lhkhiem28/96494db44ca9278b5cf226190061f7b0

and here are some utility functions

import torch
import torch.nn.utils.prune as prune

def get_pruned_parameters(model, pruned=False):
    total_parameters, total_zero_parameters = 0, 0

    pruned_parameters = []
    for _, module in model.named_modules():
        if isinstance(module, torch.nn.Conv1d) or isinstance(module, torch.nn.Linear):
            if not pruned:
                for _, parameter in module.named_parameters():
                    total_parameters += parameter.nelement()
                    total_zero_parameters += torch.sum(parameter == 0).item()
                    pruned_parameters.append((module, "weight"))
            else:
                for _, parameter in module.named_buffers():
                    total_parameters += parameter.nelement()
                    total_zero_parameters += torch.sum(parameter == 0).item()
                    pruned_parameters.append((module, "weight"))

    return pruned_parameters

def remove_pruned_parameters(model):

    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv1d):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass
        elif isinstance(module, torch.nn.Linear):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass
    # for module_name, module in model.named_buffers():
    #     if isinstance(module, torch.nn.Conv1d):
    #         try:
    #             prune.remove(module, "weight")
    #         except:
    #             pass
    #         try:
    #             prune.remove(module, "bias")
    #         except:
    #             pass
    #     elif isinstance(module, torch.nn.Linear):
    #         try:
    #             prune.remove(module, "weight")
    #         except:
    #             pass
    #         try:
    #             prune.remove(module, "bias")
    #         except:
    #             pass

    return model

@LKAMING97
Copy link
Author

I would like to ask how the lead-wise grad cam is calculated. Using it, the output of GradCAM in Captum is directly 1 in the channel dimension. How do you deal with it?

@XLHFSJ
Copy link

XLHFSJ commented Aug 30, 2023

Hello, I read your LightX3ECG paper and was interested in your visualization section or Grad-CAM section, can you publish the code, or can you provide some reference code for it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants