Skip to content

Commit

Permalink
Update data_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
TsingZ0 committed Apr 21, 2024
1 parent 6e4e336 commit 8fab3e5
Showing 1 changed file with 23 additions and 60 deletions.
83 changes: 23 additions & 60 deletions system/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,68 +2,12 @@
import os
import torch

# IMAGE_SIZE = 28
# IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
# NUM_CHANNELS = 1

# IMAGE_SIZE_CIFAR = 32
# NUM_CHANNELS_CIFAR = 3


def batch_data(data, batch_size):
'''
data is a dict := {'x': [numpy array], 'y': [numpy array]} (on one client)
returns x, y, which are both numpy array of length: batch_size
'''
data_x = data['x']
data_y = data['y']

# randomly shuffle data
ran_state = np.random.get_state()
np.random.shuffle(data_x)
np.random.set_state(ran_state)
np.random.shuffle(data_y)

# loop through mini-batches
for i in range(0, len(data_x), batch_size):
batched_x = data_x[i:i+batch_size]
batched_y = data_y[i:i+batch_size]
yield (batched_x, batched_y)


def get_random_batch_sample(data_x, data_y, batch_size):
num_parts = len(data_x)//batch_size + 1
if(len(data_x) > batch_size):
batch_idx = np.random.choice(list(range(num_parts + 1)))
sample_index = batch_idx*batch_size
if(sample_index + batch_size > len(data_x)):
return (data_x[sample_index:], data_y[sample_index:])
else:
return (data_x[sample_index: sample_index+batch_size], data_y[sample_index: sample_index+batch_size])
else:
return (data_x, data_y)


def get_batch_sample(data, batch_size):
data_x = data['x']
data_y = data['y']

# np.random.seed(100)
ran_state = np.random.get_state()
np.random.shuffle(data_x)
np.random.set_state(ran_state)
np.random.shuffle(data_y)

batched_x = data_x[0:batch_size]
batched_y = data_y[0:batch_size]
return (batched_x, batched_y)


def read_data(dataset, idx, is_train=True):
if is_train:
train_data_dir = os.path.join('../dataset', dataset, 'train/')

train_file = train_data_dir + 'train' + str(idx) + '_.npz'
train_file = train_data_dir + str(idx) + '.npz'
with open(train_file, 'rb') as f:
train_data = np.load(f, allow_pickle=True)['data'].tolist()

Expand All @@ -72,16 +16,18 @@ def read_data(dataset, idx, is_train=True):
else:
test_data_dir = os.path.join('../dataset', dataset, 'test/')

test_file = test_data_dir + 'test' + str(idx) + '_.npz'
test_file = test_data_dir + str(idx) + '.npz'
with open(test_file, 'rb') as f:
test_data = np.load(f, allow_pickle=True)['data'].tolist()

return test_data


def read_client_data(dataset, idx, is_train=True):
if dataset[:2] == "ag" or dataset[:2] == "SS":
return read_client_data_text(dataset, idx)
if "News" in dataset:
return read_client_data_text(dataset, idx, is_train)
elif "Shakespeare" in dataset:
return read_client_data_Shakespeare(dataset, idx)

if is_train:
train_data = read_data(dataset, idx, is_train)
Expand Down Expand Up @@ -121,3 +67,20 @@ def read_client_data_text(dataset, idx, is_train=True):

test_data = [((x, lens), y) for x, lens, y in zip(X_test, X_test_lens, y_test)]
return test_data


def read_client_data_Shakespeare(dataset, idx, is_train=True):
if is_train:
train_data = read_data(dataset, idx, is_train)
X_train = torch.Tensor(train_data['x']).type(torch.int64)
y_train = torch.Tensor(train_data['y']).type(torch.int64)

train_data = [(x, y) for x, y in zip(X_train, y_train)]
return train_data
else:
test_data = read_data(dataset, idx, is_train)
X_test = torch.Tensor(test_data['x']).type(torch.int64)
y_test = torch.Tensor(test_data['y']).type(torch.int64)
test_data = [(x, y) for x, y in zip(X_test, y_test)]
return test_data

0 comments on commit 8fab3e5

Please sign in to comment.