forked from ContinualAI/avalanche
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathonline_replay.py
140 lines (122 loc) · 4.88 KB
/
online_replay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
################################################################################
# Copyright (c) 2021 ContinualAI. #
# Copyrights licensed under the MIT License. #
# See the accompanying LICENSE file for terms. #
# #
# Date: 12-10-2020 #
# Author(s): Vincenzo Lomonaco, Hamed Hemati #
# E-mail: [email protected] #
# Website: avalanche.continualai.org #
################################################################################
"""
This is a simple example on how to use the Replay strategy in an online benchmark.
"""
import argparse
import torch
from torch.nn import CrossEntropyLoss
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, RandomCrop
import torch.optim.lr_scheduler
from avalanche.benchmarks import nc_benchmark
from avalanche.benchmarks.datasets.dataset_utils import default_dataset_location
from avalanche.benchmarks.scenarios.supervised import class_incremental_benchmark
from avalanche.models import SimpleMLP
from avalanche.training.supervised.strategy_wrappers import Naive
from avalanche.training.supervised.strategy_wrappers_online import OnlineNaive
from avalanche.training.plugins import ReplayPlugin
from avalanche.training.storage_policy import ReservoirSamplingBuffer
from avalanche.benchmarks.scenarios.online import OnlineCLScenario, split_online_stream
from avalanche.evaluation.metrics import (
forgetting_metrics,
accuracy_metrics,
loss_metrics,
)
from avalanche.logging import InteractiveLogger
from avalanche.training.plugins import EvaluationPlugin
def main(args):
# --- CONFIG
device = torch.device(
f"cuda:{args.cuda}" if torch.cuda.is_available() and args.cuda >= 0 else "cpu"
)
# ---------
# --- TRANSFORMATIONS
train_transform = transforms.Compose(
[
RandomCrop(28, padding=4),
ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
test_transform = transforms.Compose(
[ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
# ---------
# --- SCENARIO CREATION
mnist_train = MNIST(
root=default_dataset_location("mnist"),
train=True,
download=True,
transform=train_transform,
)
mnist_test = MNIST(
root=default_dataset_location("mnist"),
train=False,
download=True,
transform=test_transform,
)
benchmark = class_incremental_benchmark(
{"train": mnist_train, "test": mnist_test}, num_experiences=5, seed=1234
)
# ---------
# MODEL CREATION
model = SimpleMLP(num_classes=10)
# choose some metrics and evaluation method
interactive_logger = InteractiveLogger()
eval_plugin = EvaluationPlugin(
accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
forgetting_metrics(experience=True),
loggers=[interactive_logger],
)
# CREATE THE STRATEGY INSTANCE (ONLINE-REPLAY)
storage_policy = ReservoirSamplingBuffer(max_size=100)
replay_plugin = ReplayPlugin(
mem_size=100, batch_size=1, storage_policy=storage_policy
)
cl_strategy = Naive(
model,
torch.optim.Adam(model.parameters(), lr=0.1),
CrossEntropyLoss(),
train_epochs=1, # in online settings, epochs correpond to a single iteration
train_mb_size=10,
eval_mb_size=32,
device=device,
evaluator=eval_plugin,
plugins=[replay_plugin],
)
# TRAINING LOOP
print("Starting experiment...")
results = []
# you can split the whole stream like this:
# ocl_stream = split_online_stream(benchmark.train_stream, experience_size=32)
# but we split each experience separately because we want to call .eval()
# after each experience
for i, exp in enumerate(benchmark.train_stream):
# split experience into an online stream
ocl_stream = split_online_stream([exp], experience_size=32)
# Train on the online train stream of the scenario
cl_strategy.train(ocl_stream)
# It is easier to evaluate on the original (non-online) streams
cl_strategy.eval(benchmark.train_stream)
results.append(cl_strategy.eval(benchmark.test_stream))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--cuda",
type=int,
default=0,
help="Select zero-indexed cuda device. -1 to use CPU.",
)
args = parser.parse_args()
main(args)