forked from tensorflow/tfjs-core
-
Notifications
You must be signed in to change notification settings - Fork 0
/
net.ts
155 lines (133 loc) · 5.2 KB
/
net.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
/**
* @license
* Copyright 2017 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
import * as dl from 'deeplearn';
const GOOGLE_CLOUD_STORAGE_DIR =
'https://storage.googleapis.com/learnjs-data/checkpoint_zoo/transformnet/';
export class TransformNet implements dl.Model {
private variables: {[varName: string]: dl.Tensor};
private variableDictionary:
{[styleName: string]: {[varName: string]: dl.Tensor}};
private timesScalar: dl.Scalar;
private plusScalar: dl.Scalar;
private epsilonScalar: dl.Tensor;
constructor(private style: string) {
this.variableDictionary = {};
this.timesScalar = dl.scalar(150);
this.plusScalar = dl.scalar(255. / 2);
this.epsilonScalar = dl.scalar(1e-3);
}
setStyle(style: string) {
this.style = style;
}
/**
* Loads necessary variables for SqueezeNet. Resolves the promise when the
* variables have all been loaded.
*/
async load(): Promise<void> {
if (this.variableDictionary[this.style] == null) {
const checkpointLoader =
new dl.CheckpointLoader(GOOGLE_CLOUD_STORAGE_DIR + this.style + '/');
this.variableDictionary[this.style] =
await checkpointLoader.getAllVariables();
}
this.variables = this.variableDictionary[this.style];
}
/**
* Infer through TransformNet, assumes variables have been loaded.
* Original Tensorflow version of model can be found at
* https://github.com/lengstrom/fast-style-transfer
*
* @param preprocessedInput preprocessed input Array.
* @return dl.Tensor3D containing pixels of output img
*/
predict(preprocessedInput: dl.Tensor3D): dl.Tensor3D {
const img = dl.tidy(() => {
const conv1 = this.convLayer(preprocessedInput.toFloat(), 1, true, 0);
const conv2 = this.convLayer(conv1, 2, true, 3);
const conv3 = this.convLayer(conv2, 2, true, 6);
const resid1 = this.residualBlock(conv3, 9);
const resid2 = this.residualBlock(resid1, 15);
const resid3 = this.residualBlock(resid2, 21);
const resid4 = this.residualBlock(resid3, 27);
const resid5 = this.residualBlock(resid4, 33);
const convT1 = this.convTransposeLayer(resid5, 64, 2, 39);
const convT2 = this.convTransposeLayer(convT1, 32, 2, 42);
const convT3 = this.convLayer(convT2, 1, false, 45);
return convT3.tanh()
.mul(this.timesScalar)
.add(this.plusScalar)
.clipByValue(0, 255)
.div(dl.scalar(255)) as dl.Tensor3D;
});
return img;
}
private convLayer(
input: dl.Tensor3D, strides: number, relu: boolean,
varId: number): dl.Tensor3D {
const y = input.conv2d(
this.variables[this.varName(varId)] as dl.Tensor4D, [strides, strides],
'same');
const y2 = this.instanceNorm(y, varId + 1);
if (relu) {
return y2.relu();
}
return y2;
}
private convTransposeLayer(
input: dl.Tensor3D, numFilters: number, strides: number,
varId: number): dl.Tensor3D {
const [height, width, ]: [number, number, number] = input.shape;
const newRows = height * strides;
const newCols = width * strides;
const newShape: [number, number, number] = [newRows, newCols, numFilters];
const y = input.conv2dTranspose(
this.variables[this.varName(varId)] as dl.Tensor4D, newShape,
[strides, strides], 'same');
return this.instanceNorm(y, varId + 1).relu();
}
private residualBlock(input: dl.Tensor3D, varId: number): dl.Tensor3D {
const conv1 = this.convLayer(input, 1, true, varId);
const conv2 = this.convLayer(conv1, 1, false, varId + 3);
return conv2.addStrict(input);
}
private instanceNorm(input: dl.Tensor3D, varId: number): dl.Tensor3D {
const [height, width, inDepth]: [number, number, number] = input.shape;
const moments = dl.moments(input, [0, 1]);
const mu = moments.mean;
const sigmaSq = moments.variance as dl.Tensor3D;
const shift = this.variables[this.varName(varId)] as dl.Tensor1D;
const scale = this.variables[this.varName(varId + 1)] as dl.Tensor1D;
const epsilon = this.epsilonScalar;
const normalized = input.sub(mu).div(sigmaSq.add(epsilon).sqrt());
const shifted = scale.mul(normalized).add(shift);
return shifted.as3D(height, width, inDepth);
}
private varName(varId: number): string {
if (varId === 0) {
return 'Variable';
} else {
return 'Variable_' + varId.toString();
}
}
dispose() {
for (const styleName in this.variableDictionary) {
for (const varName in this.variableDictionary[styleName]) {
this.variableDictionary[styleName][varName].dispose();
}
}
}
}