forked from microsoft/qlib
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtabnet.py
85 lines (75 loc) · 2.3 KB
/
tabnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import pandas as pd
from pytorch_tabnet.tab_model import TabNetRegressor
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
class TabNetModel(Model):
"""TabNetModel Model"""
def __init__(
self,
n_d,
n_a,
n_steps,
gamma,
n_independent,
n_shared,
seed,
momentum,
lambda_sparse,
optimizer_params,
**kwargs
):
self.model = None
self.n_d = n_d
self.n_a = n_a
self.n_steps = n_steps
self.gamma = gamma
self.n_independent = n_independent
self.n_shared = n_shared
self.seed = seed
self.momentum = momentum
self.lambda_sparse = lambda_sparse
self.optimizer_params = optimizer_params
def fit(
self,
dataset: DatasetH,
n_d=8,
n_a=8,
n_steps=3,
gamma=1.3,
n_independent=2,
n_shared=2,
seed=0,
momentum=0.02,
lambda_sparse=1e-3,
optimizer_params={"lr": 2e-3},
**kwargs
):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"].values, df_train["label"].values * 100
x_valid, y_valid = df_valid["feature"].values, df_valid["label"].values * 100
self.model = TabNetRegressor(
n_d=self.n_d,
n_a=self.n_a,
n_steps=self.n_steps,
gamma=self.gamma,
n_independent=self.n_independent,
n_shared=self.n_shared,
seed=self.seed,
momentum=self.momentum,
lambda_sparse=self.lambda_sparse,
optimizer_params=self.optimizer_params,
**kwargs
)
self.model.fit(x_train, y_train, eval_set=[(x_valid, y_valid)])
def predict(self, dataset):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
test_pred = self.model.predict(x_test.values)
return pd.Series(test_pred.reshape([-1]), index=x_test.index)