-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
92 changed files
with
8,428 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# CS229 Problem Set Instructions | ||
|
||
|
||
## Setup for Written Parts | ||
|
||
1. We have provided a LaTeX template in the `tex/` directory to make it easy to typeset your homework solutions. | ||
2. Every problem has its own directory (*e.g.,* `tex/featuremaps/` for Problem 1). | ||
3. Every subproblem has two files within the parent problem’s directory: | ||
- The problem statement, *e.g.* `tex/featuremaps/01-degree-3-math.tex` for Problem 1(a)). You do not need to modify this. | ||
- Your solution, *e.g.* `tex/featuremaps/01-degree-3-math-sol.tex` for your solution to Problem 1(a). You will need to modify these files (and the source files in `src` for coding parts). | ||
4. You can use the given `Makefile` to typeset your solution, or use an editor with built-in typesetting such as TeXShop (comes free with the standard [LaTeX distribution](https://www.latex-project.org/get/)) or [Texpad](https://www.texpad.com/) (separate download, not free). | ||
|
||
|
||
## Setup for Coding Parts | ||
|
||
1. Install [Miniconda](https://docs.conda.io/en/latest/miniconda.html) | ||
- Conda is a package manager that sandboxes your project’s dependencies in a virtual environment | ||
- Miniconda contains Conda and its dependencies with no extra packages by default (as opposed to Anaconda, which installs some extra packages) | ||
2. Extract the zip file and run `conda env create -f environment.yml` from inside the extracted directory. | ||
- This creates a Conda environment called `cs229` | ||
3. Run `source activate cs229` | ||
- This activates the `cs229` environment | ||
- Do this each time you want to write/test your code | ||
4. (Optional) If you use PyCharm: | ||
- Open the `src` directory in PyCharm | ||
- Go to `PyCharm` > `Preferences` > `Project` > `Project interpreter` | ||
- Click the gear in the top-right corner, then `Add` | ||
- Select `Conda environment` > `Existing environment` > Button on the right with `…` | ||
- Select `/Users/YOUR_USERNAME/miniconda3/envs/cs229/bin/python` | ||
- Select `OK` then `Apply` | ||
5. Notice some coding problems come with `util.py` file. In it you have access to methods that do the following tasks: | ||
- Load a dataset in the CSV format provided in the problem | ||
- Add an intercept to a dataset (*i.e.,* add a new column of 1s to the design matrix) | ||
- Plot a dataset and a linear decision boundary. Some plots might require modified plotting code, but you can use this as a starting point. | ||
7. Notice that start codes are provided in each problem directory (e.g. `gda.py`, `posonly.py`) | ||
- Within each starter file, there are highlighted regions of the code with the comments ** START CODE HERE ** and ** END CODE HERE **. You are strongly suggested to make your changes only within this region. You can add helper functions within this region as well. | ||
8. Once you are done with all the code changes, run `make_zip.py` to create a `submission.zip`. | ||
- You must upload this `submission.zip` to Gradescope. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
name: cs229 | ||
channels: | ||
- defaults | ||
dependencies: | ||
- matplotlib=2.2.2 | ||
- numpy=1.15.0 | ||
- pip=10.0.1 | ||
- python=3.6.6 | ||
- scipy | ||
- pillow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import util | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
np.seterr(all='raise') | ||
|
||
|
||
factor = 2.0 | ||
|
||
class LinearModel(object): | ||
"""Base class for linear models.""" | ||
|
||
def __init__(self, theta=None): | ||
""" | ||
Args: | ||
theta: Weights vector for the model. | ||
""" | ||
self.theta = theta | ||
|
||
def fit(self, X, y): | ||
"""Run solver to fit linear model. You have to update the value of | ||
self.theta using the normal equations. | ||
Args: | ||
X: Training example inputs. Shape (n_examples, dim). | ||
y: Training example labels. Shape (n_examples,). | ||
""" | ||
# *** START CODE HERE *** | ||
# *** END CODE HERE *** | ||
|
||
def create_poly(self, k, X): | ||
""" | ||
Generates a polynomial feature map using the data x. | ||
The polynomial map should have powers from 0 to k | ||
Output should be a numpy array whose shape is (n_examples, k+1) | ||
Args: | ||
X: Training example inputs. Shape (n_examples, 2). | ||
""" | ||
# *** START CODE HERE *** | ||
# *** END CODE HERE *** | ||
|
||
def create_sin(self, k, X): | ||
""" | ||
Generates a sin with polynomial featuremap to the data x. | ||
Output should be a numpy array whose shape is (n_examples, k+2) | ||
Args: | ||
X: Training example inputs. Shape (n_examples, 2). | ||
""" | ||
# *** START CODE HERE *** | ||
# *** END CODE HERE *** | ||
|
||
def predict(self, X): | ||
""" | ||
Make a prediction given new inputs x. | ||
Returns the numpy array of the predictions. | ||
Args: | ||
X: Inputs of shape (n_examples, dim). | ||
Returns: | ||
Outputs of shape (n_examples,). | ||
""" | ||
# *** START CODE HERE *** | ||
# *** END CODE HERE *** | ||
|
||
|
||
def run_exp(train_path, sine=False, ks=[1, 2, 3, 5, 10, 20], filename='plot.png'): | ||
train_x,train_y=util.load_dataset(train_path,add_intercept=True) | ||
plot_x = np.ones([1000, 2]) | ||
plot_x[:, 1] = np.linspace(-factor*np.pi, factor*np.pi, 1000) | ||
plt.figure() | ||
plt.scatter(train_x[:, 1], train_y) | ||
|
||
for k in ks: | ||
''' | ||
Our objective is to train models and perform predictions on plot_x data | ||
''' | ||
# *** START CODE HERE *** | ||
# *** END CODE HERE *** | ||
''' | ||
Here plot_y are the predictions of the linear model on the plot_x data | ||
''' | ||
plt.ylim(-2, 2) | ||
plt.plot(plot_x[:, 1], plot_y, label='k=%d' % k) | ||
|
||
plt.legend() | ||
plt.savefig(filename) | ||
plt.clf() | ||
|
||
|
||
def main(train_path, small_path, eval_path): | ||
''' | ||
Run all expetriments | ||
''' | ||
# *** START CODE HERE *** | ||
# *** END CODE HERE *** | ||
|
||
if __name__ == '__main__': | ||
main(train_path='train.csv', | ||
small_path='small.csv', | ||
eval_path='test.csv') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
x,y | ||
-1.7135959928671598,-1.3236697512018198 | ||
5.775453161144872,-0.6218361590317668 | ||
2.4751942119192307,0.743951019738786 | ||
3.1098593944626245,0.31500240387463146 | ||
4.506122796058088,-0.835822224031054 | ||
-4.759988869075444,1.1278082211457872 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
x,y | ||
-0.317332591271696,-0.27252612060220777 | ||
0.4442656277803749,0.7441417051994887 | ||
-2.3482611754105522,-0.4250988546387261 | ||
0.5711986642890539,0.19700515728840046 | ||
-0.571198664289053,-0.49942444579341955 | ||
4.759988869075444,-1.283338184074442 | ||
-0.82506473730641,-0.08245646574302756 | ||
3.8714576135146945,-0.18415458829873677 | ||
2.9829263579539447,0.4758064248683743 | ||
-0.19039955476301706,-0.2506767039464714 | ||
2.7290602849365886,0.19785264653127796 | ||
1.3327968833411248,0.6735951593012379 | ||
-5.013854942092801,0.3195507168821301 | ||
5.14078797860148,-0.811294748353128 | ||
-3.6175915404973376,0.8902099228214186 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
x,y | ||
-1.7135959928671598,-1.3236697512018198 | ||
5.775453161144872,-0.6218361590317668 | ||
2.4751942119192307,0.743951019738786 | ||
3.1098593944626245,0.31500240387463146 | ||
4.506122796058088,-0.835822224031054 | ||
-4.759988869075444,1.1278082211457872 | ||
5.394654051618836,-0.6168932910620547 | ||
-5.521587088127515,1.3759450983569934 | ||
-6.283185307179586,0.2618867695071692 | ||
-2.7290602849365877,-0.5473273253020433 | ||
6.283185307179586,-0.3832308586050987 | ||
-2.094395102393195,-0.9182733428038911 | ||
-5.648520124636193,0.6210859271978089 | ||
-5.394654051618836,0.7836974901592654 | ||
-3.998390650023373,-0.020685920053224338 | ||
0.06346651825433991,0.2544178427196909 | ||
-3.744524577006016,-0.3268237701924621 | ||
6.02931923416223,-0.833379369636474 | ||
-1.2058638468324459,-0.7104043587071238 | ||
0.3173325912716969,0.6438364256881062 | ||
-3.8714576135146945,0.44964965142488666 | ||
-6.156252270670907,0.21633028930973713 | ||
0.9519977738150889,0.6164849765140968 | ||
-5.267721015110158,1.1009599569712076 | ||
0.6981317007977319,0.5528540681082632 | ||
-3.3637254674799806,0.0866952359542428 | ||
1.9674620658845168,1.1494443580332727 | ||
-1.8405290293758378,-1.2275075156125759 | ||
4.886921905584122,-1.4335039057135976 | ||
-0.6981317007977319,-0.9236960692116173 | ||
4.3791897595494085,-1.1031181380479915 | ||
2.2213281389018746,1.0781326803448221 | ||
0.8250647373064108,0.7003850606887475 | ||
-2.9829263579539447,-0.1114406925823648 | ||
-2.855993321445266,0.27533667382849525 | ||
-2.602127248427909,-0.4306848018452264 | ||
-4.5061227960580865,1.012701360877828 | ||
-5.90238619765355,0.03708104757204683 | ||
-5.140787978601479,1.2845599157401097 | ||
-0.9519977738150889,-0.5720716776034998 | ||
-1.4597299198498028,-0.9518429894325678 | ||
1.0789308103237678,0.711906804552946 | ||
-5.775453161144872,0.5550372497940461 | ||
1.2058638468324459,0.8439781779232759 | ||
4.25225672304073,-0.6016651669541326 | ||
6.156252270670908,-0.1230834265935148 | ||
2.0943951023931966,0.7719608249024198 | ||
3.3637254674799806,-0.11733481320098312 | ||
3.7445245770060165,-0.28068317876140825 | ||
1.5866629563584818,0.7285755609071582 | ||
-3.490658503988659,0.23292673955463286 | ||
5.902386197653552,-0.7345162716500795 | ||
1.4597299198498028,0.9331551217299455 | ||
-1.5866629563584809,-0.541238884746396 | ||
-1.9674620658845168,-1.297516048968786 | ||
5.648520124636194,-0.7018219915149035 | ||
-0.06346651825433902,-0.13334885210532094 | ||
1.7135959928671598,1.1215905266317125 | ||
3.2367924309713025,-0.006711923492500341 | ||
3.6175915404973384,-0.8240343370379357 | ||
-1.332796883341124,-0.9068045806994273 | ||
2.8559933214452666,0.3653127453048675 | ||
-4.25225672304073,0.8490100861943669 | ||
5.267721015110158,-0.6531886198510957 | ||
-4.886921905584122,1.4220244701824878 | ||
-1.078930810323767,-0.8240161588129145 | ||
1.8405290293758387,0.6914572054829922 | ||
-4.6330558325667655,1.2434696618948986 | ||
5.013854942092802,-0.9881685000098093 | ||
-2.4751942119192307,-0.5993405651186523 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
|
||
def add_intercept(x): | ||
"""Add intercept to matrix x. | ||
Args: | ||
x: 2D NumPy array. | ||
Returns: | ||
New matrix same as x with 1's in the 0th column. | ||
""" | ||
new_x = np.zeros((x.shape[0], x.shape[1] + 1), dtype=x.dtype) | ||
new_x[:, 0] = 1 | ||
new_x[:, 1:] = x | ||
|
||
return new_x | ||
|
||
|
||
def load_dataset(csv_path, label_col='y', add_intercept=False): | ||
"""Load dataset from a CSV file. | ||
Args: | ||
csv_path: Path to CSV file containing dataset. | ||
label_col: Name of column to use as labels (should be 'y' or 't'). | ||
add_intercept: Add an intercept entry to x-values. | ||
Returns: | ||
xs: Numpy array of x-values (inputs). | ||
ys: Numpy array of y-values (labels). | ||
""" | ||
|
||
def add_intercept_fn(x): | ||
global add_intercept | ||
return add_intercept(x) | ||
|
||
# Validate label_col argument | ||
allowed_label_cols = ('y', 't') | ||
if label_col not in allowed_label_cols: | ||
raise ValueError('Invalid label_col: {} (expected {})' | ||
.format(label_col, allowed_label_cols)) | ||
|
||
# Load headers | ||
with open(csv_path, 'r') as csv_fh: | ||
headers = csv_fh.readline().strip().split(',') | ||
|
||
# Load features and labels | ||
x_cols = [i for i in range(len(headers)) if headers[i].startswith('x')] | ||
l_cols = [i for i in range(len(headers)) if headers[i] == label_col] | ||
inputs = np.loadtxt(csv_path, delimiter=',', skiprows=1, usecols=x_cols) | ||
labels = np.loadtxt(csv_path, delimiter=',', skiprows=1, usecols=l_cols) | ||
|
||
if inputs.ndim == 1: | ||
inputs = np.expand_dims(inputs, -1) | ||
|
||
if add_intercept: | ||
inputs = add_intercept_fn(inputs) | ||
|
||
return inputs, labels | ||
|
||
|
||
def plot(x, y, theta, save_path, correction=1.0): | ||
"""Plot dataset and fitted logistic regression parameters. | ||
Args: | ||
x: Matrix of training examples, one per row. | ||
y: Vector of labels in {0, 1}. | ||
theta: Vector of parameters for logistic regression model. | ||
save_path: Path to save the plot. | ||
correction: Correction factor to apply, if any. | ||
""" | ||
# Plot dataset | ||
plt.figure() | ||
plt.plot(x[y == 1, -2], x[y == 1, -1], 'bx', linewidth=2) | ||
plt.plot(x[y == 0, -2], x[y == 0, -1], 'go', linewidth=2) | ||
|
||
# Plot decision boundary (found by solving for theta^T x = 0) | ||
x1 = np.arange(min(x[:, -2]), max(x[:, -2]), 0.01) | ||
x2 = -(theta[0] / theta[2] + theta[1] / theta[2] * x1 | ||
+ np.log((2 - correction) / correction) / theta[2]) | ||
plt.plot(x1, x2, c='red', linewidth=2) | ||
plt.xlim(x[:, -2].min()-.1, x[:, -2].max()+.1) | ||
plt.ylim(x[:, -1].min()-.1, x[:, -1].max()+.1) | ||
|
||
# Add labels and save to disk | ||
plt.xlabel('x1') | ||
plt.ylabel('x2') | ||
plt.savefig(save_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
x,y | ||
-0.4442656277803749,-0.5157917925219232 | ||
4.125323686532052,-0.8687817345629987 | ||
-6.029319234162229,-0.006427591745514716 | ||
-3.109859394462623,0.31388354222106174 | ||
0.19039955476301795,-0.2926502709673009 | ||
3.9983906500233743,-1.0862782512895621 | ||
5.521587088127516,-0.6288963129290719 | ||
-3.2367924309713016,0.4202167381212816 | ||
2.6021272484279105,0.8337955956240055 | ||
-4.3791897595494085,1.1331404417490503 | ||
2.3482611754105527,0.9592953129389641 | ||
-4.1253236865320515,0.749100578779578 | ||
-2.2213281389018737,-1.8415080963614736 | ||
3.4906585039886586,-0.951207483221332 | ||
4.633055832566766,-0.6520944814558465 |
Oops, something went wrong.