Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
Derek-Wds authored and you-n-g committed Dec 9, 2020
1 parent 6ef339b commit a8ac56a
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions qlib/contrib/model/pytorch_gats_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,12 @@


class DailyBatchSampler(Sampler):

def __init__(self, data_source):

self.data_source = data_source
self.data = self.data_source.data.loc[self.data_source.get_index()]
self.daily_count = self.data.groupby(level=0).size().values # calculate number of samples in each batch
self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch
self.daily_count = self.data.groupby(level=0).size().values # calculate number of samples in each batch
self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch
self.daily_index[0] = 0

def __iter__(self):
Expand Down

0 comments on commit a8ac56a

Please sign in to comment.