forked from ContinualAI/avalanche
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathar1.py
74 lines (61 loc) · 2.55 KB
/
ar1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
################################################################################
# Copyright (c) 2021 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 08-02-2021 #
# Author(s): Lorenzo Pellegrini #
# E-mail: [email protected] #
# Website: avalanche.continualai.org #
################################################################################
"""
This is a simple example on how to use the AR1 strategy.
"""
import argparse
import torch
from torch.nn import CrossEntropyLoss
from torchvision import transforms
from torchvision.transforms import ToTensor, Resize
from avalanche.benchmarks import SplitCIFAR10
from avalanche.training.supervised.ar1 import AR1
def main(args):
# Device config
device = torch.device(
f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu"
)
# ---------
# --- TRANSFORMATIONS
train_transform = transforms.Compose(
[Resize(224), ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
test_transform = transforms.Compose(
[Resize(224), ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
# ---------
# --- BENCHMARK CREATION
benchmark = SplitCIFAR10(
5, train_transform=train_transform, eval_transform=test_transform
)
# ---------
# CREATE THE STRATEGY INSTANCE
cl_strategy = AR1(criterion=CrossEntropyLoss(), device=device)
# TRAINING LOOP
print("Starting experiment...")
results = []
for experience in benchmark.train_stream:
print("Start of experience: ", experience.current_experience)
print("Current Classes: ", experience.classes_in_this_experience)
cl_strategy.train(experience, num_workers=0)
print("Training completed")
print("Computing accuracy on the whole test set")
results.append(cl_strategy.eval(benchmark.test_stream, num_workers=0))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--cuda",
type=int,
default=0,
help="Select zero-indexed cuda device. -1 to use CPU.",
)
args = parser.parse_args()
main(args)