Skip to content

Commit

Permalink
Merge pull request tensorflow#55 from tensorflow/update_rule
Browse files Browse the repository at this point in the history
Improve SGD with L1 regularization: Set weight to 0 once it crosses 0
  • Loading branch information
dsmilkov authored Mar 9, 2017
2 parents 67cf64f + bd89e3a commit 6ba43c8
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions src/nn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ export class Activations {
export class RegularizationFunction {
public static L1: RegularizationFunction = {
output: w => Math.abs(w),
der: w => w < 0 ? -1 : 1
der: w => w < 0 ? -1 : (w > 0 ? 1 : 0)
};
public static L2: RegularizationFunction = {
output: w => 0.5 * w * w,
Expand All @@ -159,6 +159,7 @@ export class Link {
source: Node;
dest: Node;
weight = Math.random() - 0.5;
isDead = false;
/** Error derivative with respect to this weight. */
errorDer = 0;
/** Accumulated error derivative since the last update. */
Expand Down Expand Up @@ -303,6 +304,9 @@ export function backProp(network: Node[][], target: number,
let node = currentLayer[i];
for (let j = 0; j < node.inputLinks.length; j++) {
let link = node.inputLinks[j];
if (link.isDead) {
continue;
}
link.errorDer = node.inputDer * link.source.output;
link.accErrorDer += link.errorDer;
link.numAccumulatedDers++;
Expand Down Expand Up @@ -343,11 +347,26 @@ export function updateWeights(network: Node[][], learningRate: number,
// Update the weights coming into this node.
for (let j = 0; j < node.inputLinks.length; j++) {
let link = node.inputLinks[j];
if (link.isDead) {
continue;
}
let regulDer = link.regularization ?
link.regularization.der(link.weight) : 0;
if (link.numAccumulatedDers > 0) {
link.weight -= (learningRate / link.numAccumulatedDers) *
(link.accErrorDer + regularizationRate * regulDer);
// Update the weight based on dE/dw.
link.weight = link.weight -
(learningRate / link.numAccumulatedDers) * link.accErrorDer;
// Further update the weight based on regularization.
let newLinkWeight = link.weight -
(learningRate * regularizationRate) * regulDer;
if (link.regularization === RegularizationFunction.L1 &&
link.weight * newLinkWeight < 0) {
// The weight crossed 0 due to the regularization term. Set it to 0.
link.weight = 0;
link.isDead = true;
} else {
link.weight = newLinkWeight;
}
link.accErrorDer = 0;
link.numAccumulatedDers = 0;
}
Expand Down

0 comments on commit 6ba43c8

Please sign in to comment.