Skip to content

Commit a6996a8

Browse files
committed
part 2, multivariate linear regression with gradient descent
1 parent 1897007 commit a6996a8

File tree

3 files changed

+69
-19
lines changed

3 files changed

+69
-19
lines changed

src/csvToMatrix.js

Lines changed: 0 additions & 19 deletions
This file was deleted.

src/index.js

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {
77
getMeanByVector,
88
getStdByVector,
99
setVector,
10+
pushVector,
1011
} from './util';
1112

1213
csvToMatrix('./src/data.csv', init);
@@ -19,6 +20,7 @@ function init(matrix) {
1920
let X = getSubset(matrix, ':, 1:2');
2021
let y = getSubset(matrix, ':, 3');
2122
let m = getDimension(y, 1);
23+
let n = getDimension(y, 2);
2224

2325
// Part 1: Feature Normalization
2426
console.log('Part 1: Feature Normalization ...\n');
@@ -31,6 +33,20 @@ function init(matrix) {
3133
console.log('\n');
3234
console.log('Std: ', sigma);
3335
console.log('\n');
36+
37+
// Part 2: Gradient Descent
38+
console.log('Part 2: Gradient Descent ...\n');
39+
40+
// Add Intercept Term
41+
XNorm = pushVector(XNorm, 0, math.ones([m, 1]).valueOf());
42+
43+
const ALPHA = 0.01;
44+
const ITERATIONS = 400;
45+
46+
let theta = math.zeros(3, 1).valueOf();
47+
theta = gradientDescentMulti(XNorm, y, theta, ALPHA, ITERATIONS);
48+
49+
console.log(theta);
3450
}
3551

3652
function featureNormalize(X) {
@@ -51,3 +67,35 @@ function featureNormalize(X) {
5167

5268
return { XNorm: X, mu, sigma };
5369
}
70+
71+
function gradientDescentMulti(X, y, theta, ALPHA, ITERATIONS) {
72+
const m = getDimension(y, 1);
73+
74+
for (let i = 0; i < ITERATIONS; i++) {
75+
// Octave:
76+
// theta = theta - ALPHA / m * ((X * theta - y)' * X)';
77+
78+
theta = math.subtract(
79+
theta,
80+
math.multiply(
81+
(ALPHA / m),
82+
math.transpose(
83+
math.multiply(
84+
math.transpose(
85+
math.subtract(
86+
math.multiply(
87+
X,
88+
theta
89+
),
90+
y
91+
)
92+
),
93+
X
94+
)
95+
)
96+
)
97+
);
98+
}
99+
100+
return theta;
101+
}

src/util.js

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,27 @@ import math from 'mathjs';
33
export const getSubset = (matrix, selector) =>
44
math.eval(`matrix[${selector}]`, { matrix });
55

6+
export const pushVector = (matrix, index, vector) => {
7+
const extendedMatrix = math
8+
.ones([
9+
getDimension(matrix, 1),
10+
getDimension(matrix, 2) + 1
11+
])
12+
.valueOf();
13+
14+
return extendedMatrix.map((row, rowKey) => row.map((column, columnKey) => {
15+
if (index === columnKey) {
16+
return vector[rowKey][0];
17+
}
18+
if (columnKey < index) {
19+
return matrix[rowKey][columnKey];
20+
}
21+
if (columnKey > index) {
22+
return matrix[rowKey][columnKey - 1];
23+
}
24+
}));
25+
};
26+
627
export const setVector = (matrix, index, vector) =>
728
matrix.map((row, rowKey) => row.map((column, columnKey) => index === columnKey ? vector[rowKey][0] : column));
829

0 commit comments

Comments
 (0)