Skip to content

Commit

Permalink
Update models.
Browse files Browse the repository at this point in the history
  • Loading branch information
lwwang1995 authored and you-n-g committed Dec 9, 2020
1 parent 752f17e commit ec0d783
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 206 deletions.
Binary file modified examples/benchmarks/GRU/csi300_gru_ts.pkl
Binary file not shown.
Binary file modified examples/benchmarks/LSTM/csi300_lstm_ts.pkl
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ task:
num_layers: 2
dropout: 0.0
n_epochs: 200
lr: 1e-3
lr: 1e-2
early_stop: 10
batch_size: 800
metric: loss
Expand Down
7 changes: 6 additions & 1 deletion qlib/contrib/model/pytorch_gats_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from ...model.base import Model
from ...data.dataset import DatasetH
Expand Down Expand Up @@ -62,7 +63,9 @@ def __init__(
model_path=None,
optimizer="adam",
GPU="0",
n_jobs=10,
seed=None,
batch_size=800,
**kwargs
):
# Set logger.
Expand All @@ -84,8 +87,10 @@ def __init__(
self.with_pretrain = with_pretrain
self.model_path = model_path
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
self.n_jobs = n_jobs
self.use_gpu = torch.cuda.is_available()
self.seed = seed
self.batch_size = batch_size

self.logger.info(
"GATs parameters setting:"
Expand Down Expand Up @@ -218,7 +223,7 @@ def test_epoch(self, data_loader):

def fit(
self,
dataset: DatasetH,
dataset,
evals_result=dict(),
verbose=True,
save_path=None,
Expand Down
Loading

0 comments on commit ec0d783

Please sign in to comment.