Skip to content

Commit

Permalink
fix: address some problems when using ts data in DSAD
Browse files Browse the repository at this point in the history
  • Loading branch information
xuhongzuo committed Dec 7, 2022
1 parent 670f8df commit a0e2ef5
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 5 additions & 3 deletions deepod/models/dsad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
One-class classification
this is partially adapted from
this is partially adapted from https://github.com/lukasruff/Deep-SAD-PyTorch (MIT license)
@Author: Hongzuo Xu <[email protected], [email protected]>
"""

Expand Down Expand Up @@ -89,9 +89,10 @@ def __init__(self, data_type='tabular', epochs=100, batch_size=64, lr=1e-3,
return

def training_prepare(self, X, y):
# By following the original paper,
# use -1 to denote known anomalies, and 1 to denote known inliers
known_anom_id = np.where(y == 1) if len(y.shape) == 2 \
else np.where(y == 1)[0]

y = np.zeros_like(y)
y[known_anom_id] = -1

Expand All @@ -112,7 +113,8 @@ def training_prepare(self, X, y):
n_features=self.n_features,
n_hidden=self.hidden_dims,
n_output=self.rep_dim,
activation=self.act
activation=self.act,
bias=self.bias
).to(self.device)
else:
raise NotImplementedError('Not supported network structures')
Expand Down
4 changes: 2 additions & 2 deletions deepod/test/test_dsad.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def setUp(self):
y_semi[known_anom_id] = 1

# # ts data
train_file = '../../data/omi-1/omi-1_train.csv'
test_file = '../../data/omi-1/omi-1_test.csv'
train_file = 'data/omi-1/omi-1_train.csv'
test_file = 'data/omi-1/omi-1_test.csv'
train_df = pd.read_csv(train_file, index_col=0)
test_df = pd.read_csv(test_file, index_col=0)
y_train, y_test = train_df['label'].values, test_df['label'].values
Expand Down

0 comments on commit a0e2ef5

Please sign in to comment.