forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransfer_learning.py
555 lines (431 loc) · 19.9 KB
/
transfer_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
"""
Title: Transfer learning & fine-tuning
Author: [fchollet](https://twitter.com/fchollet)
Date created: 2020/04/15
Last modified: 2023/06/25
Description: Complete guide to transfer learning & fine-tuning in Keras.
Accelerator: GPU
"""
"""
## Setup
"""
import numpy as np
import keras
from keras import layers
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
"""
## Introduction
**Transfer learning** consists of taking features learned on one problem, and
leveraging them on a new, similar problem. For instance, features from a model that has
learned to identify racoons may be useful to kick-start a model meant to identify
tanukis.
Transfer learning is usually done for tasks where your dataset has too little data to
train a full-scale model from scratch.
The most common incarnation of transfer learning in the context of deep learning is the
following workflow:
1. Take layers from a previously trained model.
2. Freeze them, so as to avoid destroying any of the information they contain during
future training rounds.
3. Add some new, trainable layers on top of the frozen layers. They will learn to turn
the old features into predictions on a new dataset.
4. Train the new layers on your dataset.
A last, optional step, is **fine-tuning**, which consists of unfreezing the entire
model you obtained above (or part of it), and re-training it on the new data with a
very low learning rate. This can potentially achieve meaningful improvements, by
incrementally adapting the pretrained features to the new data.
First, we will go over the Keras `trainable` API in detail, which underlies most
transfer learning & fine-tuning workflows.
Then, we'll demonstrate the typical workflow by taking a model pretrained on the
ImageNet dataset, and retraining it on the Kaggle "cats vs dogs" classification
dataset.
This is adapted from
[Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python)
and the 2016 blog post
["building powerful image classification models using very little data"](https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html).
"""
"""
## Freezing layers: understanding the `trainable` attribute
Layers & models have three weight attributes:
- `weights` is the list of all weights variables of the layer.
- `trainable_weights` is the list of those that are meant to be updated (via gradient
descent) to minimize the loss during training.
- `non_trainable_weights` is the list of those that aren't meant to be trained.
Typically they are updated by the model during the forward pass.
**Example: the `Dense` layer has 2 trainable weights (kernel & bias)**
"""
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
"""
In general, all weights are trainable weights. The only built-in layer that has
non-trainable weights is the `BatchNormalization` layer. It uses non-trainable weights
to keep track of the mean and variance of its inputs during training.
To learn how to use non-trainable weights in your own custom layers, see the
[guide to writing new layers from scratch](https://keras.io/guides/making_new_layers_and_models_via_subclassing/).
**Example: the `BatchNormalization` layer has 2 trainable weights and 2 non-trainable
weights**
"""
layer = keras.layers.BatchNormalization()
layer.build((None, 4)) # Create the weights
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
"""
Layers & models also feature a boolean attribute `trainable`. Its value can be changed.
Setting `layer.trainable` to `False` moves all the layer's weights from trainable to
non-trainable. This is called "freezing" the layer: the state of a frozen layer won't
be updated during training (either when training with `fit()` or when training with
any custom loop that relies on `trainable_weights` to apply gradient updates).
**Example: setting `trainable` to `False`**
"""
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights
layer.trainable = False # Freeze the layer
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
"""
When a trainable weight becomes non-trainable, its value is no longer updated during
training.
"""
# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])
# Freeze the first layer
layer1.trainable = False
# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()
# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
"""
Do not confuse the `layer.trainable` attribute with the argument `training` in
`layer.__call__()` (which controls whether the layer should run its forward pass in
inference mode or training mode). For more information, see the
[Keras FAQ](
https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute).
"""
"""
## Recursive setting of the `trainable` attribute
If you set `trainable = False` on a model or on any layer that has sublayers,
all children layers become non-trainable as well.
**Example:**
"""
inner_model = keras.Sequential(
[
keras.Input(shape=(3,)),
keras.layers.Dense(3, activation="relu"),
keras.layers.Dense(3, activation="relu"),
]
)
model = keras.Sequential(
[
keras.Input(shape=(3,)),
inner_model,
keras.layers.Dense(3, activation="sigmoid"),
]
)
model.trainable = False # Freeze the outer model
assert inner_model.trainable == False # All layers in `model` are now frozen
assert (
inner_model.layers[0].trainable == False
) # `trainable` is propagated recursively
"""
## The typical transfer-learning workflow
This leads us to how a typical transfer learning workflow can be implemented in Keras:
1. Instantiate a base model and load pre-trained weights into it.
2. Freeze all layers in the base model by setting `trainable = False`.
3. Create a new model on top of the output of one (or several) layers from the base
model.
4. Train your new model on your new dataset.
Note that an alternative, more lightweight workflow could also be:
1. Instantiate a base model and load pre-trained weights into it.
2. Run your new dataset through it and record the output of one (or several) layers
from the base model. This is called **feature extraction**.
3. Use that output as input data for a new, smaller model.
A key advantage of that second workflow is that you only run the base model once on
your data, rather than once per epoch of training. So it's a lot faster & cheaper.
An issue with that second workflow, though, is that it doesn't allow you to dynamically
modify the input data of your new model during training, which is required when doing
data augmentation, for instance. Transfer learning is typically used for tasks when
your new dataset has too little data to train a full-scale model from scratch, and in
such scenarios data augmentation is very important. So in what follows, we will focus
on the first workflow.
Here's what the first workflow looks like in Keras:
First, instantiate a base model with pre-trained weights.
```python
base_model = keras.applications.Xception(
weights='imagenet', # Load weights pre-trained on ImageNet.
input_shape=(150, 150, 3),
include_top=False) # Do not include the ImageNet classifier at the top.
```
Then, freeze the base model.
```python
base_model.trainable = False
```
Create a new model on top.
```python
inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
```
Train the model on new data.
```python
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)
```
"""
"""
## Fine-tuning
Once your model has converged on the new data, you can try to unfreeze all or part of
the base model and retrain the whole model end-to-end with a very low learning rate.
This is an optional last step that can potentially give you incremental improvements.
It could also potentially lead to quick overfitting -- keep that in mind.
It is critical to only do this step *after* the model with frozen layers has been
trained to convergence. If you mix randomly-initialized trainable layers with
trainable layers that hold pre-trained features, the randomly-initialized layers will
cause very large gradient updates during training, which will destroy your pre-trained
features.
It's also critical to use a very low learning rate at this stage, because
you are training a much larger model than in the first round of training, on a dataset
that is typically very small.
As a result, you are at risk of overfitting very quickly if you apply large weight
updates. Here, you only want to readapt the pretrained weights in an incremental way.
This is how to implement fine-tuning of the whole base model:
```python
# Unfreeze the base model
base_model.trainable = True
# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5), # Very low learning rate
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()])
# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)
```
**Important note about `compile()` and `trainable`**
Calling `compile()` on a model is meant to "freeze" the behavior of that model. This
implies that the `trainable`
attribute values at the time the model is compiled should be preserved throughout the
lifetime of that model,
until `compile` is called again. Hence, if you change any `trainable` value, make sure
to call `compile()` again on your
model for your changes to be taken into account.
**Important notes about `BatchNormalization` layer**
Many image models contain `BatchNormalization` layers. That layer is a special case on
every imaginable count. Here are a few things to keep in mind.
- `BatchNormalization` contains 2 non-trainable weights that get updated during
training. These are the variables tracking the mean and variance of the inputs.
- When you set `bn_layer.trainable = False`, the `BatchNormalization` layer will
run in inference mode, and will not update its mean & variance statistics. This is not
the case for other layers in general, as
[weight trainability & inference/training modes are two orthogonal concepts](
https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute).
But the two are tied in the case of the `BatchNormalization` layer.
- When you unfreeze a model that contains `BatchNormalization` layers in order to do
fine-tuning, you should keep the `BatchNormalization` layers in inference mode by
passing `training=False` when calling the base model.
Otherwise the updates applied to the non-trainable weights will suddenly destroy
what the model has learned.
You'll see this pattern in action in the end-to-end example at the end of this guide.
"""
"""
## An end-to-end example: fine-tuning an image classification model on a cats vs. dogs dataset
To solidify these concepts, let's walk you through a concrete end-to-end transfer
learning & fine-tuning example. We will load the Xception model, pre-trained on
ImageNet, and use it on the Kaggle "cats vs. dogs" classification dataset.
"""
"""
### Getting the data
First, let's fetch the cats vs. dogs dataset using TFDS. If you have your own dataset,
you'll probably want to use the utility
`keras.utils.image_dataset_from_directory` to generate similar labeled
dataset objects from a set of images on disk filed into class-specific folders.
Transfer learning is most useful when working with very small datasets. To keep our
dataset small, we will use 40% of the original training data (25,000 images) for
training, 10% for validation, and 10% for testing.
"""
tfds.disable_progress_bar()
train_ds, validation_ds, test_ds = tfds.load(
"cats_vs_dogs",
# Reserve 10% for validation and 10% for test
split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
as_supervised=True, # Include labels
)
print(f"Number of training samples: {train_ds.cardinality()}")
print(f"Number of validation samples: {validation_ds.cardinality()}")
print(f"Number of test samples: {test_ds.cardinality()}")
"""
These are the first 9 images in the training dataset -- as you can see, they're all
different sizes.
"""
plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image)
plt.title(int(label))
plt.axis("off")
"""
We can also see that label 1 is "dog" and label 0 is "cat".
"""
"""
### Standardizing the data
Our raw images have a variety of sizes. In addition, each pixel consists of 3 integer
values between 0 and 255 (RGB level values). This isn't a great fit for feeding a
neural network. We need to do 2 things:
- Standardize to a fixed image size. We pick 150x150.
- Normalize pixel values between -1 and 1. We'll do this using a `Normalization` layer as
part of the model itself.
In general, it's a good practice to develop models that take raw data as input, as
opposed to models that take already-preprocessed data. The reason being that, if your
model expects preprocessed data, any time you export your model to use it elsewhere
(in a web browser, in a mobile app), you'll need to reimplement the exact same
preprocessing pipeline. This gets very tricky very quickly. So we should do the least
possible amount of preprocessing before hitting the model.
Here, we'll do image resizing in the data pipeline (because a deep neural network can
only process contiguous batches of data), and we'll do the input value scaling as part
of the model, when we create it.
Let's resize images to 150x150:
"""
resize_fn = keras.layers.Resizing(150, 150)
train_ds = train_ds.map(lambda x, y: (resize_fn(x), y))
validation_ds = validation_ds.map(lambda x, y: (resize_fn(x), y))
test_ds = test_ds.map(lambda x, y: (resize_fn(x), y))
"""
### Using random data augmentation
When you don't have a large image dataset, it's a good practice to artificially
introduce sample diversity by applying random yet realistic transformations to
the training images, such as random horizontal flipping or small random rotations. This
helps expose the model to different aspects of the training data while slowing down
overfitting.
"""
augmentation_layers = [
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.1),
]
def data_augmentation(x):
for layer in augmentation_layers:
x = layer(x)
return x
train_ds = train_ds.map(lambda x, y: (data_augmentation(x), y))
"""
Let's batch the data and use prefetching to optimize loading speed.
"""
from tensorflow import data as tf_data
batch_size = 64
train_ds = train_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
validation_ds = (
validation_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
)
test_ds = test_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()
"""
Let's visualize what the first image of the first batch looks like after various random
transformations:
"""
for images, labels in train_ds.take(1):
plt.figure(figsize=(10, 10))
first_image = images[0]
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
augmented_image = data_augmentation(np.expand_dims(first_image, 0))
plt.imshow(np.array(augmented_image[0]).astype("int32"))
plt.title(int(labels[0]))
plt.axis("off")
"""
## Build a model
Now let's built a model that follows the blueprint we've explained earlier.
Note that:
- We add a `Rescaling` layer to scale input values (initially in the `[0, 255]`
range) to the `[-1, 1]` range.
- We add a `Dropout` layer before the classification layer, for regularization.
- We make sure to pass `training=False` when calling the base model, so that
it runs in inference mode, so that batchnorm statistics don't get updated
even after we unfreeze the base model for fine-tuning.
"""
base_model = keras.applications.Xception(
weights="imagenet", # Load weights pre-trained on ImageNet.
input_shape=(150, 150, 3),
include_top=False,
) # Do not include the ImageNet classifier at the top.
# Freeze the base_model
base_model.trainable = False
# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(inputs)
# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x) # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
model.summary(show_trainable=True)
"""
## Train the top layer
"""
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 2
print("Fitting the top layer of the model")
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
"""
## Do a round of fine-tuning of the entire model
Finally, let's unfreeze the base model and train the entire model end-to-end with a low
learning rate.
Importantly, although the base model becomes trainable, it is still running in
inference mode since we passed `training=False` when calling it when we built the
model. This means that the batch normalization layers inside won't update their batch
statistics. If they did, they would wreck havoc on the representations learned by the
model so far.
"""
# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary(show_trainable=True)
model.compile(
optimizer=keras.optimizers.Adam(1e-5), # Low learning rate
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 1
print("Fitting the end-to-end model")
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
"""
After 10 epochs, fine-tuning gains us a nice improvement here.
Let's evaluate the model on the test dataset:
"""
print("Test dataset evaluation")
model.evaluate(test_ds)