Skip to content

Commit

Permalink
full documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ChriZ982 committed Jul 19, 2017
1 parent ef0592e commit 2fababa
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 6 deletions.
87 changes: 82 additions & 5 deletions src/zindach/neuralnettest/main/ButtonPanel.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,23 @@
import zindach.neuralnetlib.options.regularization.L2Regularization;
import zindach.neuralnetlib.trainer.StochasticGradientDescentTrainer;

/**
* Panel holding all the buttons and interacting with neural network.
*
* @author ChriZ98
*/
public class ButtonPanel extends JPanel {

private final JLabel[] labels;
private final Frame frame;
private final JButton predictButton;
private StochasticGradientDescentTrainer sgdt;
private JButton predictButton;

public JButton getPredictButton() {
return predictButton;
}

/**
* Initializes all buttons and labels.
*
* @param frame the parent frame
*/
public ButtonPanel(Frame frame) {
super(new GridLayout(12, 3));
this.frame = frame;
Expand Down Expand Up @@ -83,6 +89,9 @@ public ButtonPanel(Frame frame) {
frame.repaint();
}

/**
* Trains the network using MNIST data.
*/
private void trainButtonActionPerformed() {
if (sgdt == null) {
sgdt = new StochasticGradientDescentTrainer(frame.getNet(), new CrossEntropyCostFunction(), new L2Regularization());
Expand All @@ -94,6 +103,10 @@ private void trainButtonActionPerformed() {
sgdt.train(2, 0.5, 5.0, 10, true);
}

/**
* Testing the network by iterating through some test images and diplaying
* calculated results.
*/
private void testButtonActionPerformed() {
Random rand = new Random();
Vector[] data = MNISTLoader.importData("data/t10k-images-idx3-ubyte.gz");
Expand All @@ -111,14 +124,23 @@ private void testButtonActionPerformed() {
}).start();
}

/**
* Saves the current neural network.
*/
private void saveButtonActionPerformed() {
NetworkIO.saveNetwork("data/network.dat", frame.getNet());
}

/**
* Loads some saved neural network.
*/
private void loadButtonActionPerformed() {
frame.setNet(NetworkIO.loadNetwork("data/network.dat"));
}

/**
* Resets the drawing panel and the labels.
*/
public void resetButtonActionPerformed() {
Graphics2D g = frame.getGraphics2D();
g.setBackground(Color.WHITE);
Expand All @@ -128,6 +150,11 @@ public void resetButtonActionPerformed() {
frame.repaint();
}

/**
* Sets the probabilities as label content.
*
* @param output calculated probabilities
*/
public void setLabels(double[] output) {
double sum = 0;
double max = Double.NEGATIVE_INFINITY;
Expand All @@ -151,6 +178,14 @@ public void setLabels(double[] output) {
}
}

/**
* Generates a double array representing the drawing. White pixels are
* encoded as 0. Black pixels are encoded as 1. Grey pixels are values
* inbetween.
*
* @param image drawing to represent
* @return double array with image data
*/
public double[] calcInput(BufferedImage image) {
double[] input = new double[784];
for (int i = 0; i < 28; i++) {
Expand All @@ -161,6 +196,12 @@ public double[] calcInput(BufferedImage image) {
return input;
}

/**
* Draws image from vector. The vector contains values between 0 and 1. That
* is convertet into different grey colors and then painted out.
*
* @param vector vector to be shown on panel
*/
private void drawImageFromVector(Vector vector) {
resetButtonActionPerformed();
Graphics2D g = frame.getGraphics2D();
Expand All @@ -176,6 +217,15 @@ private void drawImageFromVector(Vector vector) {
setLabels(frame.getNet().feedforward(new Matrix(vector)).getCols()[0].getArray());
}

/**
* Counts rows of white pixels in current drawing.
*
* @param image image to analyze
* @param i1 start index
* @param i2 direchtion. either -1 or 1
* @param horizontal count rows if true. count cols if false.
* @return number of rows with white pixels
*/
private int getWhitespaceInImage(BufferedImage image, int i1, int i2, boolean horizontal) {
int cut = i1;
for (int i = i1; i >= 0 && i < Frame.DRAW_SIZE; i += i2) {
Expand All @@ -198,6 +248,12 @@ private int getWhitespaceInImage(BufferedImage image, int i1, int i2, boolean ho
return cut;
}

/**
* Calculates the center of an image with the highest pixel density.
*
* @param image image to analyze
* @return vector containing y and x value of center
*/
private Vector centerOfMassOfPixels(BufferedImage image) {
int iCount = 0, jCount = 0;
int iVal = 0, jVal = 0;
Expand All @@ -214,6 +270,15 @@ private Vector centerOfMassOfPixels(BufferedImage image) {
return new Vector(iVal / iCount, jVal / jCount);
}

/**
* Translates the image towards some point.
*
* @param image image to translate
* @param iC new center y
* @param jC new center x
* @param iM image center y
* @param jM image center x
*/
private void moveTowardsPoint(BufferedImage image, int iC, int jC, int iM, int jM) {
int iDiff = iM - iC;
int jDiff = jM - jC;
Expand Down Expand Up @@ -262,6 +327,9 @@ private void moveTowardsPoint(BufferedImage image, int iC, int jC, int iM, int j
}
}

/**
* Predicts the digit drawn by using the neural network to calculate.
*/
public void predictButtonActionPerformed() {
if (frame.isPredicted()) {
return;
Expand Down Expand Up @@ -325,4 +393,13 @@ public void predictButtonActionPerformed() {
});
t1.start();
}

/**
* Gets the predict button.
*
* @return predict button
*/
public JButton getPredictButton() {
return predictButton;
}
}
15 changes: 15 additions & 0 deletions src/zindach/neuralnettest/main/DrawPanel.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,20 @@
import java.awt.Graphics;
import javax.swing.JPanel;

/**
* The panel where drawing is performed.
*
* @author ChriZ98
*/
public class DrawPanel extends JPanel {

private final Frame frame;

/**
* Initializes drawing space.
*
* @param frame parent frame
*/
public DrawPanel(Frame frame) {
super();
this.frame = frame;
Expand All @@ -22,6 +32,11 @@ public DrawPanel(Frame frame) {
addMouseMotionListener(frame.getMouse());
}

/**
* Paints the panel on screen.
*
* @param grphcs graphics to paint at
*/
@Override
protected void paintComponent(Graphics grphcs) {
super.paintComponent(grphcs);
Expand Down
63 changes: 62 additions & 1 deletion src/zindach/neuralnettest/main/Frame.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,18 @@
import javax.swing.JFrame;
import zindach.neuralnetlib.net.NeuralNetwork;

/**
* Main and window class of the project. Responsible for initializing the gui
* and performing actions on the neural network.
*
* @author ChriZ98
*/
public class Frame extends JFrame {

/**
* Size of the panel used for drawing.
*/
public static final int DRAW_SIZE = 700;
public static final String FILE_ENDING = ".png";

private BufferedImage image;
private NeuralNetwork net;
Expand All @@ -23,6 +31,9 @@ public class Frame extends JFrame {
private final Mouse mouse;
private boolean predicted = false;

/**
* Initializes all needed components for digit recognition.
*/
public Frame() {
super("Neural Network Test");

Expand All @@ -46,44 +57,94 @@ public Frame() {
setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
}

/**
* Sets image.
*
* @param image image to set
*/
public void setImage(BufferedImage image) {
this.image = image;
repaint();
}

/**
* Gets image.
*
* @return image
*/
public BufferedImage getImage() {
return image;
}

/**
* Gets 2D Graphics to draw on.
*
* @return graphics2d object
*/
public Graphics2D getGraphics2D() {
return (Graphics2D) image.getGraphics();
}

/**
* Gets button panel.
*
* @return button panel
*/
public ButtonPanel getButtonPanel() {
return buttonPanel;
}

/**
* Sets neural net.
*
* @param nn neural net to set
*/
public void setNet(NeuralNetwork nn) {
this.net = nn;
}

/**
* Gets neural net.
*
* @return neural net
*/
public NeuralNetwork getNet() {
return net;
}

/**
* Gets mouse.
*
* @return mouse
*/
public Mouse getMouse() {
return mouse;
}

/**
* Gets prediction status.
*
* @return prediction status
*/
public boolean isPredicted() {
return predicted;
}

/**
* Sets prediction status.
*
* @param predicted prediction status
*/
public void setPredicted(boolean predicted) {
this.predicted = predicted;
buttonPanel.getPredictButton().setEnabled(!predicted);
}

/**
* Main method. Creates new Frame.
*
* @param args arguments passed
*/
public static void main(String[] args) {
new Frame();
}
Expand Down
32 changes: 32 additions & 0 deletions src/zindach/neuralnettest/main/MNISTLoader.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,20 @@
import java.util.zip.GZIPInputStream;
import zindach.mathlib.algebra.Vector;

/**
* Loads MNIST data sets. Loading works according to specified standards at
* http://yann.lecun.com/exdb/mnist/.
*
* @author ChriZ98
*/
public class MNISTLoader {

/**
* Tries to read MNIST data from some file.
*
* @param fileName filename e.g. "train-images-idx3-ubyte.gz"
* @return imported data
*/
public static Vector[] importData(String fileName) {
try {
System.out.println("\n---Importing MNIST data---\nfile: " + fileName);
Expand All @@ -35,10 +47,23 @@ public static Vector[] importData(String fileName) {
return new Vector[0];
}

/**
* Converts 4 bytes to 32 bit integer.
*
* @param bytes byte array
* @return integer value representation
*/
private static int bytesToInt(byte[] bytes) {
return ((bytes[0] & 0xFF) << 24 | (bytes[1] & 0xFF) << 16 | (bytes[2] & 0xFF) << 8 | (bytes[3] & 0xFF));
}

/**
* Imports a label file form MNIST database
*
* @param gzip file stream
* @return imported labels
* @throws IOException error while reading file
*/
private static Vector[] importLabelFile(GZIPInputStream gzip) throws IOException {
byte[] itemCountBytes = new byte[4];
gzip.read(itemCountBytes);
Expand All @@ -54,6 +79,13 @@ private static Vector[] importLabelFile(GZIPInputStream gzip) throws IOException
return data;
}

/**
* Imports an image file form MNIST database
*
* @param gzip file stream
* @return imported images
* @throws IOException error while reading file
*/
private static Vector[] importImageFile(GZIPInputStream gzip) throws IOException {
byte[] infoBytes = new byte[4];
gzip.read(infoBytes);
Expand Down
Loading

0 comments on commit 2fababa

Please sign in to comment.