Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added hidden_dim to config so LSTMDecoderCell can save properly #197

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

johntzwei
Copy link

@johntzwei johntzwei commented Jun 20, 2017

Hi,

I wasn't able to save the Seq2Seq model, and I wrote this small bit to try to fix it. I hope this is useful!

Let me know if there is anything else I can do!

@johntzwei johntzwei changed the title added hidden_dim to config so lstmdecoder cell can save properly Added hidden_dim to config so LSTMDecoderCell can save properly Jun 20, 2017
@ChristopherLu
Copy link

ChristopherLu commented Jul 3, 2017

Hi,
Can this modification make the load_model() work? For now, I can only use the model.load_weights() method.

@johntzwei
Copy link
Author

Yes, its been a while, but I think this should allow for saving and loading with keras.model.load_model(). Is this what you mean?

@ChristopherLu
Copy link

ChristopherLu commented Jul 4, 2017

@johntzwei Hi, it saves properly, but when I try model = load_model(model_path, custom_objects={'_OptionalInputPlaceHolder': _OptionalInputPlaceHolder, 'LSTMDecoderCell': LSTMDecoderCell, 'RecurrentSequential': RecurrentSequential, 'AttentionDecoderCell': AttentionDecoderCell}), it gives me error messages as follows:
File "build/bdist.linux-x86_64/egg/seq2seq/cells.py", line 63, in init
AttributeError: 'AttentionDecoderCell' object has no attribute 'output_dim'

Edit: it can use model_load for Seq2Seq() but not successful for AttentionSeq2Seq(). Could you help to modify the corresponding AttentionSeq2Seq files?

@dbklim
Copy link

dbklim commented Mar 13, 2019

@ChristopherLu hello, you would better add to AttentionDecoderCell of "build/bdist.linux-x86_64/egg/seq2seq/cells.py" method get_config():

def get_config(self):
    config = {
        'hidden_dim' : self.hidden_dim,
        'output_dim' : self.output_dim
    }
    base_config = super(AttentionDecoderCell, self).get_config()
    config.update(base_config)
    return config

It worked for me :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants