-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_dnn_model_6.m
88 lines (80 loc) · 2.7 KB
/
train_dnn_model_6.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
function [modelFile, trainLoss] = train_dnn_model_6(sampleFile, trainParams)
% Train a DNN model for learning dynamics system behvior
% load samples and prepare training dataset
ds = load(sampleFile);
numSamples = length(ds.samples);
modelFile = "model\"+trainParams.type+"_"+num2str(numSamples)+".mat";
% generate training dataset
% Feature: 6-D initial state (x0) + the predict future time (t)
% Label: a predicted state x = [q1,q2,q1dot,q2dot,q1ddot,q2ddot]'
% Start from 1 sec to 4 sec with 0.5 sec step
initTimes = 1:trainParams.initTimeStep:5;
xTrain = [];
yTrain = [];
for i = 1:numSamples
data = load(ds.samples{i,1}).state;
t = data(1,:);
x = data(2:7,:);
for tInit = initTimes
initIdx = find(t > tInit, 1, 'first');
x0 = x(:,initIdx); % Initial state
t0 = t(initIdx); % Start time
for j = initIdx+1:length(t)
xTrain = [xTrain, [x0; t(j)-t0]];
yTrain = [yTrain, x(:,j)];
end
end
end
disp(num2str(length(xTrain)) + " samples are generated for training.");
xTrain = xTrain';
yTrain = yTrain';
% Create neural network
numStates = 6;
layers = [
featureInputLayer(numStates+1, "Name", "input")
];
numMiddle = floor(trainParams.numLayers/2);
for i = 1:numMiddle
layers = [
layers
fullyConnectedLayer(trainParams.numNeurons)
reluLayer
];
end
if trainParams.dropoutFactor > 0
layers = [
layers
dropoutLayer(trainParams.dropoutFactor)
];
end
for i = numMiddle+1:trainParams.numLayers
layers = [
layers
fullyConnectedLayer(trainParams.numNeurons)
reluLayer
];
end
layers = [
layers
fullyConnectedLayer(numStates, "Name", "output")
weightedLossLayer("mse")
];
lgraph = layerGraph(layers);
% plot(lgraph);
% analyzeNetwork(lgraph);
options = trainingOptions("adam", ...
InitialLearnRate = trainParams.initLearningRate, ...
LearnRateSchedule = "piecewise", ...
LearnRateDropFactor = trainParams.lrDropFactor, ...
LearnRateDropPeriod = trainParams.lrDropEpoch, ...
MaxEpochs = trainParams.numEpochs, ...
MiniBatchSize = trainParams.miniBatchSize, ...
Shuffle = "every-epoch", ...
Plots = "training-progress", ...
Verbose = 1);
% training with numeric array data
[net,info] = trainNetwork(xTrain,yTrain,lgraph,options);
trainLoss = info.TrainingLoss;
save(modelFile, 'net');
% disp(info)
end