-
Notifications
You must be signed in to change notification settings - Fork 334
Loading TRAINED models with VISSL #577
Comments
Hello @davitpapikyan, Thanks a lot for using VISSL and thanks a lot for your questions :) I am not 100% I understand your use case correctly:
I would like to get some additional details so that I can help you appropriately. Thank you, |
Hi @QuentinDuval , I have appended a classification head to Torchvision ResNet50 trunk, have fine-tuned the classifier on MNIST. Now I have "<checkpoint_name.torch>" checkpoint of that model (produced by run_distributed_engines.py) and want to load it like this: model = load("checkpoint_name.torch")
predictions = model(images) What would yo recommend me to do? Thank yo very much. |
Hi @davitpapikyan, So there are several ways. One way is go through the
The other way is the programmatic way in which you do it through code:
The overrides contain a list of the configuration relevant to the creation of your model, here is an example:
Please tell me if that helps you :) |
Unfortunately with your approach (load_model) the accuracy of my classifier degrades from 98% to 10%. I checked your repo and found another way which preserves the performance: from vissl.models.base_ssl_model import BaseSSLMultiInputOutputModel
from classy_vision.generic.util import load_checkpoint
from vissl.utils.hydra_config import compose_hydra_configuration, convert_to_attrdict
config = compose_hydra_configuration(config)
_, config = convert_to_attrdict(config)
model = BaseSSLMultiInputOutputModel(config.MODEL, config.OPTIMIZER)
weights = load_checkpoint(checkpoint_path=config.MODEL.WEIGHTS_INIT.PARAMS_FILE)
vissl_state_dict = weights.get("classy_state_dict")
model.set_classy_state(vissl_state_dict["base_model"]) |
Oh interesting, it works on my end, what version of VISSL are you using ? I will check if there is a bug there. |
Dear @QuentinDuval, I'm using 0.1.6 version of VISSL. |
Dear @QuentinDuval, given a full-tuned model, how can one lead it? |
Here are the logs when loading the checkpoint:
@QuentinDuval Do you have any suggestion why the performance degrades? Seems that weights are being loaded correctly. |
@QuentinDuval Any news on this issue ? |
❓ How to load an already trained model on VISSL
I'm using Torchvision ResNet50, have attached a linear classifier on top of it and have already trained the model on MNIST using VISSL (accuracy ~98%). Now I want to load my model and have smth like this:
I tried using Loading a pre-trained model in inference mode but it doesn't help. Seems like my model's weights are randomly initialized and accuracy on the same dataset is ~10%.
Note that the tutorial above shows how to load a pre-trained model, my case is different - I have modified the trunk by adding a linear classifier on top of it.
The text was updated successfully, but these errors were encountered: