Skip to content

Commit

Permalink
intialized
Browse files Browse the repository at this point in the history
  • Loading branch information
[email protected] committed Feb 18, 2023
1 parent ca481b4 commit 01dc6f6
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions federated/fed_digits.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torchvision
import torchvision.transforms as transforms
from utils import data_utils
import wandb

def prepare_data(args):
# Prepare data
Expand Down Expand Up @@ -190,11 +191,6 @@ def communication(args, server_model, models, client_weights):


if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed= 1
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

print('Device:', device)
parser = argparse.ArgumentParser()
Expand All @@ -209,8 +205,28 @@ def communication(args, server_model, models, client_weights):
parser.add_argument('--mu', type=float, default=1e-2, help='The hyper parameter for fedprox')
parser.add_argument('--save_path', type = str, default='../checkpoint/digits', help='path to save the checkpoint')
parser.add_argument('--resume', action='store_true', help ='resume training from the save path checkpoint')
parser.add_argument('--project_name', type=str, default='fed_digit_2', help='name of wandb project')
parser.add_argument('--cuda_num', type=int, default=0, help='cuda num')
parser.add_argument('--runid', default=None, type=str)
args = parser.parse_args()

device = torch.device('cuda:' + str(args.cuda_num) if torch.cuda.is_available() else 'cpu')
seed = 1
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

log = args.log
if log:
# log_path = args.save_path + SAVE_PATH + '_log'
os.environ["WANDB_API_KEY"] = 'f87c7a64e4a4c89c4f1afc42620ac211ceb0f926'
if args.runid != None and args.resume:
wandb.init(project=args.project_name, entity="sanaayr", id=args.runid, resume="must", config=args)
else:
wandb.init(project=args.project_name, entity="sanaayr", config=args)
wandb.run.name = NAME
wandb.run.save()

exp_folder = 'federated_digits'

args.save_path = os.path.join(args.save_path, exp_folder)
Expand Down

0 comments on commit 01dc6f6

Please sign in to comment.