forked from WenjieDu/PyPOTS
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request WenjieDu#442 from WenjieDu/dev
Add `inverse_sliding_window()` and enable TimesNet to work with len>5000 samples
- Loading branch information
Showing
9 changed files
with
109 additions
and
247 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,11 +5,14 @@ | |
# Created by Wenjie Du <[email protected]> | ||
# License: BSD-3-Clause | ||
|
||
import math | ||
from typing import Union | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from ..utils.logging import logger | ||
|
||
|
||
def turn_data_into_specified_dtype( | ||
data: Union[np.ndarray, torch.Tensor, list], | ||
|
@@ -194,8 +197,13 @@ def sliding_window(time_series, window_len, sliding_len=None): | |
start_indices = np.asarray(range(total_len // sliding_len)) * sliding_len | ||
|
||
# remove the last one if left length is not enough | ||
if total_len - start_indices[-1] * sliding_len < window_len: | ||
start_indices = start_indices[:-1] | ||
if total_len - start_indices[-1] < window_len: | ||
to_drop = math.ceil(window_len / sliding_len) | ||
left_len = total_len - start_indices[-1] | ||
start_indices = start_indices[:-to_drop] | ||
logger.warning( | ||
f"The last {to_drop} samples are dropped due to the left length {left_len} is not enough." | ||
) | ||
|
||
sample_collector = [] | ||
for idx in start_indices: | ||
|
@@ -204,3 +212,51 @@ def sliding_window(time_series, window_len, sliding_len=None): | |
samples = np.asarray(sample_collector).astype("float32") | ||
|
||
return samples | ||
|
||
|
||
def inverse_sliding_window(X, sliding_len): | ||
"""Restore the original time-series data from the generated sliding window samples. | ||
Note that this is the inverse operation of the `sliding_window` function, but there is no guarantee that | ||
the restored data is the same as the original data considering that | ||
1. the sliding length may be larger than the window size and there will be gaps between restored data; | ||
2. if values in the samples get changed, the overlap part may not be the same as the original data after averaging; | ||
3. some incomplete samples at the tail may be dropped during the sliding window operation, hence the restored data | ||
may be shorter than the original data. | ||
Parameters | ||
---------- | ||
X : | ||
The generated time-series samples with sliding window method, shape of [n_samples, n_steps, n_features], | ||
where n_steps is the window size of the used sliding window method. | ||
sliding_len : | ||
The sliding length of the window for each moving step in the sliding window method used to generate X. | ||
Returns | ||
------- | ||
restored_data : | ||
The restored time-series data with shape of [total_length, n_features]. | ||
""" | ||
assert len(X.shape) == 3, f"X should be a 3D array, but got {X.shape}" | ||
n_samples, window_size, n_features = X.shape | ||
|
||
if sliding_len >= window_size: | ||
if sliding_len > window_size: | ||
logger.warning( | ||
f"sliding_len {sliding_len} is larger than the window size {window_size}, " | ||
f"hence there will be gaps between restored data." | ||
) | ||
restored_data = X.reshape(n_samples * window_size, n_features) | ||
else: | ||
collector = [X[0][:sliding_len]] | ||
overlap = X[0][sliding_len:] | ||
for x in X[1:]: | ||
overlap_avg = (overlap + x[:-sliding_len]) / 2 | ||
collector.append(overlap_avg[:sliding_len]) | ||
overlap = np.concatenate( | ||
[overlap_avg[sliding_len:], x[-sliding_len:]], axis=0 | ||
) | ||
collector.append(overlap) | ||
restored_data = np.concatenate(collector, axis=0) | ||
return restored_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.