Skip to content

Commit

Permalink
Update cosmo2.py
Browse files Browse the repository at this point in the history
  • Loading branch information
AiDeveloper21 authored Sep 24, 2024
1 parent ff73dad commit 10c83fa
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions cosmo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,14 +335,14 @@ def train_lstm_model(new_dataset_path, weights_path, tokenizer_path='tokenizer.p
model.load_weights(weights_path)
logger.info("LSTM model weights loaded.")
if retrain:
model.fit(input_padded, response_labels, epochs=5, batch_size=32, validation_split=0.2)
model.fit(input_padded, response_labels, epochs=100, batch_size=32, validation_split=0.2)
model.save_weights(weights_path)
logger.info("LSTM model retrained and weights updated.")
except Exception as e:
logger.error(f"Failed to load weights: {e}")
sys.exit(1)
else:
model.fit(input_padded, response_labels, epochs=10, batch_size=32, validation_split=0.2)
model.fit(input_padded, response_labels, epochs=100, batch_size=32, validation_split=0.2)
model.save_weights(weights_path)
logger.info("LSTM model trained and weights saved.")

Expand Down

0 comments on commit 10c83fa

Please sign in to comment.