forked from ContinualAI/avalanche
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathjoint_training.py
76 lines (64 loc) · 2.54 KB
/
joint_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
################################################################################
# Copyright (c) 2021 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 20-11-2020 #
# Author(s): Vincenzo Lomonaco #
# E-mail: [email protected] #
# Website: avalanche.continualai.org #
################################################################################
"""
This is a simple example to show how a simple "offline" upper bound can be
computed. This is useful to see what's the maximum accuracy a model can get
without the hindering of learning continually. This is often referred to as
"cumulative", "joint-training" or "offline" upper bound.
"""
import argparse
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from avalanche.benchmarks.classic import PermutedMNIST
from avalanche.models import SimpleMLP
from avalanche.training.supervised import JointTraining
def main(args):
# Config
device = torch.device(
f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu"
)
# model
model = SimpleMLP(num_classes=10)
# CL Benchmark Creation
perm_mnist = PermutedMNIST(n_experiences=5)
train_stream = perm_mnist.train_stream
test_stream = perm_mnist.test_stream
# Prepare for training & testing
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = CrossEntropyLoss()
# Joint training strategy
joint_train = JointTraining(
model,
optimizer,
criterion,
train_mb_size=32,
train_epochs=1,
eval_mb_size=32,
device=device,
)
# train and test loop
results = []
print("Starting training.")
# Differently from other avalanche strategies, you NEED to call train
# on the entire stream.
joint_train.train(train_stream)
results.append(joint_train.eval(test_stream))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--cuda",
type=int,
default=0,
help="Select zero-indexed cuda device. -1 to use CPU.",
)
args = parser.parse_args()
main(args)