Skip to content

Commit

Permalink
Start reorganizing and documenting
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjohnson committed Jan 24, 2017
1 parent 6f8e941 commit 1f45934
Show file tree
Hide file tree
Showing 8 changed files with 333 additions and 83 deletions.
74 changes: 74 additions & 0 deletions autograd/tf_two_layer_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import tensorflow as tf
import numpy as np

"""
A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.
This implementation uses basic TensorFlow operations to set up a computational
graph, then executes the graph many times to actually train the network.
One of the main differences between TensorFlow and PyTorch is that TensorFlow
uses static computational graphs while PyTorch uses dynamic computational
graphs.
In TensorFlow we first set up the computational graph, then execute the same
graph many times.
"""

# First we set up the computational graph:

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create placeholders for the input and target data; these will be filled
# with real data when we execute the graph.
x = tf.placeholder(tf.float32, shape=(None, D_in))
y = tf.placeholder(tf.float32, shape=(None, D_out))

# Create Variables for the weights and initialize them with random data.
# A TensorFlow Variable persists its value across executions of the graph.
w1 = tf.Variable(tf.random_normal((D_in, H)))
w2 = tf.Variable(tf.random_normal((H, D_out)))

# Forward pass: Compute the predicted y using operations on TensorFlow Tensors.
# Note that this code does not actually perform any numeric operations; it
# merely sets up the computational graph that we will later execute.
h = tf.matmul(x, w1)
h_relu = tf.maximum(h, tf.zeros(1))
y_pred = tf.matmul(h_relu, w2)

# Compute loss using operations on TensorFlow Tensors
loss = tf.reduce_sum((y - y_pred) ** 2.0)

# Compute gradient of the loss with respect to w1 and w2.
grad_w1, grad_w2 = tf.gradients(loss, [w1, w2])

# Update the weights using gradient descent. To actually update the weights
# we need to evaluate new_w1 and new_w2 when executing the graph. Note that
# in TensorFlow the the act of updating the value of the weights is part of
# the computational graph; in PyTorch this happens outside the computational
# graph.
learning_rate = 1e-6
new_w1 = w1.assign(w1 - learning_rate * grad_w1)
new_w2 = w2.assign(w2 - learning_rate * grad_w2)

# Now we have built our computational graph, so we enter a TensorFlow session to
# actually execute the graph.
with tf.Session() as sess:
# Run the graph once to initialize the Variables w1 and w2.
sess.run(tf.global_variables_initializer())

# Create numpy arrays holding the actual data for the inputs x and targets y
x_value = np.random.randn(N, D_in)
y_value = np.random.randn(N, D_out)
for _ in range(500):
# Execute the graph many times. Each time it executes we want to bind
# x_value to x and y_value to y, specified with the feed_dict argument.
# Each time we execute the graph we want to compute the values for loss,
# new_w1, and new_w2; the values of these Tensors are returned as numpy
# arrays.
loss_value, _, _ = sess.run([loss, new_w1, new_w2],
feed_dict={x: x_value, y: y_value})
print(loss_value)
69 changes: 69 additions & 0 deletions autograd/two_layer_net_autograd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
from torch.autograd import Variable


"""
A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.
This implementation computes the forward pass using operations on PyTorch
Variables, and uses PyTorch autograd to compute gradients.
A PyTorch Variable is a wrapper around a PyTorch Tensor, and represents a node
in a computational graph. If x is a Variable then x.data is a Tensor giving its
value, and x.grad is another Variable holding the gradient of x with respect to
some scalar value.
PyTorch Variables have the same API as PyTorch tensors: (almost) any operation
you can do on a Tensor you can also do on a Variable; the difference is that
autograd allows you to automatically compute gradients.
"""

dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs, and wrap them in Variables.
# Setting requires_grad=False indicates that we do not need to compute gradients
# with respect to these Variables during the backward pass.
x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)
y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)

# Create random Tensors for weights, and wrap them in Variables.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Variables during the backward pass.
w1 = Variable(torch.randn(D_in, H), requires_grad=True)
w2 = Variable(torch.randn(H, D_out), requires_grad=True)


learning_rate = 1e-6
for t in range(500):
# Forward pass: compute predicted y using operations on Variables; these
# are exactly the same operations we used to compute the forward pass using
# Tensors, but we do not need to keep references to intermediate values since
# we are not implementing the backward pass by hand.
y_pred = x.mm(w1).clamp(min=0).mm(w2)

# Compute and print loss using operations on Variables.
# Now loss is a Variable of shape (1,) and loss.data is a Tensor of shape
# (1,); loss.data[0] is a scalar value holding the loss.
loss = (y_pred - y).pow(2).sum()
print(t, loss.data[0])

# Manually zero the gradients before running the backward pass
w1.grad.data.zero_()
w2.grad.data.zero_()

# Use autograd to compute the backward pass. This call will compute the
# gradient of all loss with respect to all Variables with requires_grad=True.
# After this call w1.data and w2.data will be Variables holding the gradient
# of the loss with respect to w1 and w2 respectively.
loss.backward()

# Update weights using gradient descent: w1.grad and w2.grad are Variables
# and w1.grad.data and w2.grad.data are Tensors.
w1.data -= learning_rate * w1.grad.data
w2.data -= learning_rate * w2.grad.data
87 changes: 87 additions & 0 deletions autograd/two_layer_net_custom_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
from torch.autograd import Variable



"""
A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.
This implementation computes the forward pass using operations on PyTorch
Variables, and uses PyTorch autograd to compute gradients.
In this implementation we implement our own custom autograd function to perform
the ReLU function.
"""


class MyReLU(torch.autograd.Function):
"""
We can implement our own custom autograd Functions by subclassing
torch.autograd.Function and implementing the forward and backward passes
which operate on Tensors.
"""

def forward(self, input):
"""
In the forward pass we receive a Tensor containing the input and return a
Tensor containing the output. You can save cache arbitrary Tensors for use
in the backward pass using the save_for_backward method.
"""
self.save_for_backward(input)
return input.clamp(min=0)

def backward(self, grad_output):
"""
In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss
with respect to the input.
"""
input, = self.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input


dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs, and wrap them in Variables.
# Setting requires_grad=False indicates that we do not need to compute gradients
# with respect to these Variables during the backward pass.
x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)
y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)

# Create random Tensors for weights, and wrap them in Variables.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Variables during the backward pass.
w1 = Variable(torch.randn(D_in, H), requires_grad=True)
w2 = Variable(torch.randn(H, D_out), requires_grad=True)

learning_rate = 1e-6
for t in range(500):
# Construct an instance of our MyReLU class to use in our network
relu = MyReLU()

# Forward pass: compute predicted y using operations on Variables; we compute
# ReLU using our custom autograd operation.
y_pred = relu(x.mm(w1)).mm(w2)

# Compute and print loss
loss = (y_pred - y).pow(2).sum()
print(t, loss.data[0])

# Manually zero the gradients before running the backward pass
w1.grad.data.zero_()
w2.grad.data.zero_()

# Use autograd to compute the backward pass.
loss.backward()

# Update weights using gradient descent
w1.data -= learning_rate * w1.grad.data
w2.data -= learning_rate * w2.grad.data
48 changes: 48 additions & 0 deletions tensor/two_layer_net_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np

"""
A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x using Euclidean error.
This implementation uses numpy to manually compute the forward pass, loss, and
backward pass.
A numpy array is a generic n-dimensional array; it does not know anything about
deep learning or gradients or computational graphs, and is just a way to perform
generic numeric computations.
"""

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

# Randomly initialize weights
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
# Forward pass: compute predicted y
h = x.dot(w1)
h_relu = np.maximum(h, 0)
y_pred = h_relu.dot(w2)

# Compute and print loss
loss = np.square(y_pred - y).sum()
print(t, loss)

# Backprop to compute gradients of w1 and w2 with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_w2 = h_relu.T.dot(grad_y_pred)
grad_h_relu = grad_y_pred.dot(w2.T)
grad_h = grad_h_relu.copy()
grad_h[h < 0] = 0
grad_w1 = x.T.dot(grad_h)

# Update weights
w1 -= learning_rate * grad_w1
w2 -= learning_rate * grad_w2
55 changes: 55 additions & 0 deletions tensor/two_layer_net_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch

"""
A fully-connected ReLU network with one hidden layer and no biases, trained to
predict y from x by minimizing squared Euclidean distance.
This implementation uses PyTorch tensors to manually compute the forward pass,
loss, and backward pass.
A PyTorch Tensor is basically the same as a numpy array: it does not know
anything about deep learning or computational graphs or gradients, and is just
a generic n-dimensional array to be used for arbitrary numeric computation.
The biggest difference between a numpy array and a PyTorch Tensor is that
a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU,
just cast the Tensor to a cuda datatype.
"""

dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in).type(dtype)
y = torch.randn(N, D_out).type(dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H).type(dtype)
w2 = torch.randn(H, D_out).type(dtype)

learning_rate = 1e-6
for t in range(500):
# Forward pass: compute predicted y
h = x.mm(w1)
h_relu = h.clamp(min=0)
y_pred = h_relu.mm(w2)

# Compute and print loss
loss = (y_pred - y).pow(2).sum()
print(t, loss)

# Backprop to compute gradients of w1 and w2 with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_w2 = h_relu.t().mm(grad_y_pred)
grad_h_relu = grad_y_pred.mm(w2.t())
grad_h = grad_h_relu.clone()
grad_h[h < 0] = 0
grad_w1 = x.t().mm(grad_h)

# Update weights using gradient descent
w1 -= learning_rate * grad_w1
w2 -= learning_rate * grad_w2
32 changes: 0 additions & 32 deletions tf_two_layer_net.py

This file was deleted.

23 changes: 0 additions & 23 deletions two_layer_net_autograd.py

This file was deleted.

Loading

0 comments on commit 1f45934

Please sign in to comment.