forked from cs231n/cs231n.github.io
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
622 additions
and
38 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,314 @@ | ||
|
||
<html> | ||
<head> | ||
<title>Convolution demo</title> | ||
|
||
<script type="text/javascript" src="external/d3.min.js"></script> | ||
<script type="text/javascript" src="utils.js"></script> | ||
|
||
<style type="text/css"> | ||
body { | ||
margin: 0; | ||
padding: 0; | ||
} | ||
</style> | ||
|
||
<script type="text/javascript"> | ||
|
||
var W1 = 7; | ||
var H1 = 7; | ||
var D1 = 3; | ||
|
||
var K = 2; | ||
var F = 3; | ||
var S = 2; // stride | ||
|
||
var cs = 25; // cell size | ||
|
||
var X = new U.Vol(W1, H1, D1); // input volume | ||
for(var q=0;q<X.w.length;q++) { | ||
X.w[q] = Math.floor(Math.random()*3); | ||
// 0 pad with P = 1 | ||
for(var d=0;d<X.depth;d++) { | ||
for(var x=0;x<X.sx;x++) { | ||
for(var y=0;y<X.sy;y++) { | ||
if(x === 0 || x === (X.sx - 1) || y === 0 || y === (X.sy - 1)) { | ||
X.set(x,y,d,0); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
var Ws = []; | ||
var bs = []; | ||
for(var k=0;k<K;k++) { | ||
var W = new U.Vol(F, F, D1); | ||
for(var q=0;q<W.w.length;q++) { | ||
W.w[q] = Math.floor(Math.random()*3) - 1; | ||
} | ||
Ws.push(W); | ||
var b = new U.Vol(1,1,1); | ||
b.w[0] = 1 - k; | ||
bs.push(b); | ||
} | ||
|
||
var conv_forward = function(V, Ws, bs, stride) { | ||
// optimized code by @mdda that achieves 2x speedup over previous version | ||
var out_sy = ((V.sy-W.sy)/stride +1); | ||
var out_sx = ((V.sx-W.sx)/stride +1); | ||
var A = new U.Vol(out_sx |0, out_sy |0, Ws.length |0, 0.0); | ||
|
||
var V_sx = V.sx |0; | ||
var V_sy = V.sy |0; | ||
var xy_stride = stride |0; | ||
|
||
for(var d=0;d<Ws.length;d++) { | ||
var f = Ws[d]; | ||
var x = 0; | ||
var y = 0; | ||
for(var ay=0; ay<out_sy; y+=xy_stride,ay++) { // xy_stride | ||
x = 0; | ||
for(var ax=0; ax<out_sx; x+=xy_stride,ax++) { // xy_stride | ||
|
||
// convolve centered at this particular location | ||
var a = 0.0; | ||
for(var fy=0;fy<f.sy;fy++) { | ||
var oy = y+fy; // coordinates in the original input array coordinates | ||
for(var fx=0;fx<f.sx;fx++) { | ||
var ox = x+fx; | ||
if(oy>=0 && oy<V_sy && ox>=0 && ox<V_sx) { | ||
for(var fd=0;fd<f.depth;fd++) { | ||
// avoid function call overhead (x2) for efficiency, compromise modularity :( | ||
a += f.w[((f.sx * fy)+fx)*f.depth+fd] * V.w[((V_sx * oy)+ox)*V.depth+fd]; | ||
} | ||
} | ||
} | ||
} | ||
a += bs[d].w[0]; | ||
A.set(ax, ay, d, a); | ||
} | ||
} | ||
} | ||
return A; | ||
} | ||
|
||
function renderVol(svg, V, xoff, yoff, col, title, vid) { | ||
|
||
var pad = 3; | ||
var dpad = 10; | ||
|
||
var gyoff = 7; | ||
|
||
var txt = title + ' (' + V.sx + 'x' + V.sy + 'x' + V.depth + ')'; | ||
// 1 padding exception | ||
if(vid === 'x') { txt = title + ' (' + (V.sx-2) + 'x' + (V.sy-2) + 'x' + V.depth + ')'; } | ||
|
||
svg.append('text') | ||
.attr('x', xoff) | ||
.attr('y', yoff) | ||
.attr('font-size', 16) | ||
.attr('fill', 'black') | ||
.text(txt); | ||
|
||
for(var d = 0; d < V.depth; d++) { | ||
for(var x = 0; x < V.sx; x++) { | ||
for(var y = 0; y < V.sy; y++) { | ||
|
||
var xcoord = xoff + x * (cs + pad); | ||
var ycoord = yoff + y * (cs + pad) + d * (V.sy * (cs + pad) + dpad) + gyoff; | ||
|
||
svg.append('rect') | ||
.attr('x', xcoord) | ||
.attr('y', ycoord) | ||
.attr('height', cs) | ||
.attr('width', cs) | ||
.attr('fill', col) | ||
.attr('stroke', 'none') | ||
.attr('stroke-width', '2') | ||
.attr('id', vid+'_'+x+'_'+y+'_'+d) | ||
.attr('class', vid); | ||
|
||
svg.append('text') | ||
.attr('x', xcoord + 5) | ||
.attr('y', ycoord + 15) | ||
.attr('font-size', 16) | ||
.attr('fill', 'black') | ||
.text(V.get(x,y,d).toFixed(0)); | ||
|
||
} | ||
} | ||
} | ||
} | ||
|
||
function draw() { | ||
var d3elt = d3.select('#draw'); | ||
svg = d3elt.append('svg').attr('width', '100%').attr('height', '100%') | ||
.append('g').attr('transform', 'scale(1)'); | ||
|
||
var yoff = 20; | ||
// render input volume | ||
renderVol(svg, X, 10, yoff, '#DDF', 'Input Volume X', 'x'); | ||
|
||
for(var i=0;i<Ws.length;i++) { | ||
// render weights | ||
renderVol(svg, Ws[i], 270 + i*170, yoff, '#FDD', 'Weight W'+i, 'w'+i); | ||
// render biases | ||
renderVol(svg, bs[i], 270 + i*170, 350 + yoff, '#FDD', 'Bias b'+i, 'b'+i); | ||
} | ||
|
||
// render output | ||
renderVol(svg, O, 600, yoff, '#DFD', 'Output Volume', 'o'); | ||
|
||
// render controls | ||
|
||
svg.append('text') | ||
.attr('x', 520) | ||
.attr('y', 470) | ||
.attr('font-size', 16) | ||
.attr('fill', 'black') | ||
.text('toggle movement'); | ||
svg.append('rect') | ||
.attr('x', 500) | ||
.attr('y', 450) | ||
.attr('height', 30) | ||
.attr('width', 150) | ||
.attr('fill', "rgba(200, 200, 200, 0.1)") | ||
.attr('stroke', 'black') | ||
.attr('stroke-width', '2') | ||
.attr('style', 'cursor:pointer;') | ||
.on('click', function() { | ||
// toggle | ||
if(iid === -1) { | ||
iid = setInterval(focusCell, 1000); | ||
} else { | ||
clearInterval(iid); | ||
iid = -1; | ||
} | ||
}); | ||
} | ||
|
||
var fxg = 0; | ||
var fyg = 0; | ||
var fdg = 0; | ||
function focusCell() { | ||
|
||
// first unfocus all | ||
for(var i=0;i<Ws.length;i++) { | ||
d3.selectAll('.w'+i).attr('stroke', 'none'); | ||
d3.selectAll('.b'+i).attr('stroke', 'none'); | ||
} | ||
d3.selectAll('.x').attr('stroke', 'none'); | ||
d3.selectAll('.o').attr('stroke', 'none'); | ||
|
||
var fx = fxg; | ||
var fy = fyg; | ||
var fd = fdg; | ||
|
||
// highlight the output cell | ||
var csel = d3.select('#o'+'_'+fx+'_'+fy+'_'+fd); | ||
csel.attr('stroke', '#0A0'); | ||
|
||
// highlight the weights | ||
d3.selectAll('.w'+fd).attr('stroke', '#A00'); | ||
// highlight the bias | ||
d3.selectAll('.b'+fd).attr('stroke', '#A00'); | ||
|
||
d3.selectAll('.ll').remove(); | ||
|
||
// highlight the input cell | ||
for(var d=0;d<D1;d++) { | ||
for(var x=0;x<F;x++) { | ||
for(var y=0;y<F;y++) { | ||
var ix = fx * S + x; | ||
var iy = fy * S + y; | ||
var id = d; | ||
var csel = d3.select('#x'+'_'+ix+'_'+iy+'_'+id); | ||
csel.attr('stroke', '#00A'); | ||
|
||
// connect with line | ||
if(x === 0 && y === 0) { | ||
var wsel = d3.select('#w'+fd+'_'+x+'_'+y+'_'+d); | ||
svg.append('line') | ||
.attr('x1', csel.attr('x')) | ||
.attr('y1', csel.attr('y')) | ||
.attr('x2', wsel.attr('x')) | ||
.attr('y2', wsel.attr('y')) | ||
.attr('stroke', 'black') | ||
.attr('stroke-width', '1') | ||
.attr('class', 'll'); | ||
} | ||
if(x === 0 && y === (F-1)) { | ||
var wsel = d3.select('#w'+fd+'_'+x+'_'+y+'_'+d); | ||
svg.append('line') | ||
.attr('x1', csel.attr('x')) | ||
.attr('y1', parseFloat(csel.attr('y')) + cs) | ||
.attr('x2', wsel.attr('x')) | ||
.attr('y2', parseFloat(wsel.attr('y')) + cs) | ||
.attr('stroke', 'black') | ||
.attr('stroke-width', '1') | ||
.attr('class', 'll'); | ||
} | ||
if(x === (F-1) && y === 0) { | ||
var wsel = d3.select('#w'+fd+'_'+x+'_'+y+'_'+d); | ||
svg.append('line') | ||
.attr('x1', parseFloat(csel.attr('x')) + cs) | ||
.attr('y1', csel.attr('y')) | ||
.attr('x2', parseFloat(wsel.attr('x')) + cs) | ||
.attr('y2', wsel.attr('y')) | ||
.attr('stroke', 'black') | ||
.attr('stroke-width', '1') | ||
.attr('class', 'll'); | ||
} | ||
if(x === (F-1) && y === (F-1)) { | ||
var wsel = d3.select('#w'+fd+'_'+x+'_'+y+'_'+d); | ||
svg.append('line') | ||
.attr('x1', parseFloat(csel.attr('x')) + cs) | ||
.attr('y1', parseFloat(csel.attr('y')) + cs) | ||
.attr('x2', parseFloat(wsel.attr('x')) + cs) | ||
.attr('y2', parseFloat(wsel.attr('y')) + cs) | ||
.attr('stroke', 'black') | ||
.attr('stroke-width', '1') | ||
.attr('class', 'll'); | ||
} | ||
|
||
} | ||
} | ||
} | ||
|
||
// output focus cycle | ||
fxg++; | ||
if(fxg >= O.sx) { | ||
fxg = 0; | ||
fyg++; | ||
if(fyg >=O.sy) { | ||
fyg = 0; | ||
fdg++; | ||
if(fdg >= O.depth) { | ||
fdg = 0; | ||
} | ||
} | ||
} | ||
|
||
} | ||
|
||
iid = -1; | ||
function start() { | ||
O = conv_forward(X, Ws, bs, S); | ||
draw(); | ||
iid = setInterval(focusCell, 1000); | ||
} | ||
|
||
</script> | ||
|
||
|
||
|
||
</head> | ||
|
||
<body onload="start()"> | ||
|
||
<div id="draw"> | ||
</div> | ||
|
||
</body> | ||
</html> |
Oops, something went wrong.