Skip to content

Commit

Permalink
four datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
[email protected] committed Feb 20, 2023
1 parent 2cd56c6 commit 17c311f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions federated/fed_digits.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def communication(args, server_model, models, client_weights):
parser.add_argument('--lr', type=float, default=1e-2, help='learning rate')
parser.add_argument('--batch', type = int, default= 32, help ='batch size')
parser.add_argument('--iters', type = int, default=100, help = 'iterations for communication')
parser.add_argument('--wk_iters', type = int, default=10, help = 'optimization iters in local worker between communication')
parser.add_argument('--wk_iters', type = int, default=50, help = 'optimization iters in local worker between communication')
parser.add_argument('--mode', type = str, default='fedbn', help='fedavg | fedprox | fedbn')
parser.add_argument('--mu', type=float, default=1e-2, help='The hyper parameter for fedprox')
parser.add_argument('--save_path', type = str, default='../checkpoint/digits', help='path to save the checkpoint')
Expand Down Expand Up @@ -264,7 +264,7 @@ def communication(args, server_model, models, client_weights):
test_loaders = test_loaders[1:]

# federated setting
client_num = len(datasets)-3
client_num = len(datasets)
client_weights = [1/client_num for i in range(client_num)]
models = [copy.deepcopy(server_model).to(device) for idx in range(client_num)]

Expand Down

0 comments on commit 17c311f

Please sign in to comment.