Skip to content

Commit

Permalink
Some refactoring of the MNIST loader, still incomplete
Browse files Browse the repository at this point in the history
  • Loading branch information
migueldeicaza committed Jun 4, 2017
1 parent 7de0c24 commit 2804490
Show file tree
Hide file tree
Showing 52 changed files with 473 additions and 74,120 deletions.
45 changes: 18 additions & 27 deletions Learn/Datasets/MNIST.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,48 +43,39 @@ public class Mnist
public byte [] TrainLabels, TestLabels, ValidationLabels;
public byte [,] OneHotTrainLabels, OneHotTestLabels, OneHotValidationLabels;

// Simple batch reader to get pieces of data from the dataset
public BatchReader GetBatchReader (MnistImage [] source)
{
return new BatchReader (source);
}
public BatchReader GetTrainReader () => new BatchReader (TrainImages, TrainLabels, OneHotTrainLabels);
public BatchReader GetTestReader () => new BatchReader (TestImages, TestLabels, OneHotTestLabels);
public BatchReader GetValidationReader () => new BatchReader (ValidationImages, ValidationLabels, OneHotValidationLabels);

public class BatchReader
{
int start = 0;
MnistImage [] source;
byte [] labels;
byte [,] oneHotLabels;

public BatchReader (MnistImage [] source)
internal BatchReader (MnistImage [] source, byte [] labels, byte [,] oneHotLabels)
{
this.source = source;
this.labels = labels;
this.oneHotLabels = oneHotLabels;
}

public MnistImage [] Read (int batchSize)
{
var result = new MnistImage [batchSize];
if (start + batchSize < source.Length) {
Array.Copy (source, start, result, 0, batchSize);
start += batchSize;
} else {
var firstLength = source.Length - start;
Array.Copy (source, start, result, 0, firstLength);
Array.Copy (source, 0, result, firstLength, batchSize-firstLength);
start = firstLength;
}
return result;
}

public TFTensor ReadAsTensor (int batchSize)
public (float[,],float [,]) NextBatch (int batchSize)
{
var result = new float [batchSize, 784];
var imageData = new float [batchSize, 784];
var labelData = new float [batchSize, 10];

var x = Read (batchSize);
int p = 0;
for (int i = 0; i < batchSize; i++) {
Buffer.BlockCopy (x [i].DataFloat, 0, result, p, 784);
for (int item = 0; item < batchSize; item++) {
Buffer.BlockCopy (source [start+item].DataFloat, 0, imageData, p, 784);
p += 784;
for (var j = 0; j < 10; j++)
labelData [item, j] = oneHotLabels [item + start, j];
}
return (TFTensor)result;

start += batchSize;
return (imageData, labelData);
}
}

Expand Down
7 changes: 7 additions & 0 deletions Learn/Learn.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
</PropertyGroup>
<ItemGroup>
<Reference Include="System" />
<Reference Include="System.ValueTuple">
<HintPath>..\packages\System.ValueTuple.4.3.1\lib\netstandard1.0\System.ValueTuple.dll</HintPath>
</Reference>
<Reference Include="System.Numerics" />
</ItemGroup>
<ItemGroup>
<Compile Include="Properties\AssemblyInfo.cs" />
Expand All @@ -47,5 +51,8 @@
<Name>TensorFlowSharp</Name>
</ProjectReference>
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />
</ItemGroup>
<Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" />
</Project>
4 changes: 4 additions & 0 deletions Learn/packages.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="System.ValueTuple" version="4.3.1" targetFramework="net461" />
</packages>
7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
all: doc-update yaml

rebuild-docs: docs/template
mdoc export-html --force-update -o docs --template=docs/template ecmadocs/en/

# Used to fetch XML doc updates from the C# compiler into the ECMA docs
doc-update:
mdoc update -i TensorFlowSharp/bin/Debug/TensorFlowSharp.xml -o ecmadocs/en TensorFlowSharp/bin/Debug/TensorFlowSharp.dll

yaml:
-rm ecmadocs/en/ns-.xml
mono /cvs/ECMA2Yaml/ECMA2Yaml/ECMA2Yaml/bin/Debug/ECMA2Yaml.exe --source=`pwd`/ecmadocs/en --output=`pwd`/docfx/api
(cd docfx; mono ~/Downloads/docfx/docfx.exe build)

36 changes: 21 additions & 15 deletions SampleTest/SampleTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ void BasicMatrix ()
};
}

int ArgMax (byte [,] array, int idx)
int ArgMax (float [,] array, int idx)
{
int max = -1;
float max = -1;
int maxIdx = -1;
var l = array.GetLength (1);
for (int i = 0; i < l; i++)
Expand All @@ -198,7 +198,17 @@ int ArgMax (byte [,] array, int idx)
max = array [idx, i];
}
return maxIdx;
}
}

public float [] Extract (float [,] array, int index)
{
var n = array.GetLength (1);
var ret = new float [n];

for (int i = 0; i < n; i++)
ret [i] = array [index,i];
return ret;
}

// This sample has a bug, I suspect the data loaded is incorrect, because the returned
// values in distance is wrong, and so is the prediction computed from it.
Expand All @@ -211,25 +221,21 @@ void NearestNeighbor ()
// 5000 for training
const int trainCount = 5000;
const int testCount = 200;
var Xtr = mnist.GetBatchReader (mnist.TrainImages).ReadAsTensor (trainCount);
var Ytr = mnist.OneHotTrainLabels;
var Xte = mnist.GetBatchReader (mnist.TestImages).Read (testCount);
var Yte = mnist.OneHotTestLabels;


(var trainingImages, var trainingLabels) = mnist.GetTrainReader ().NextBatch (trainCount);
(var testImages, var testLabels) = mnist.GetTestReader ().NextBatch (testCount);

Console.WriteLine ("Nearest neighbor on Mnist images");
using (var g = new TFGraph ()) {
var s = new TFSession (g);


TFOutput xtr = g.Placeholder (TFDataType.Float, new TFShape (-1, 784));
TFOutput trainingInput = g.Placeholder (TFDataType.Float, new TFShape (-1, 784));

TFOutput xte = g.Placeholder (TFDataType.Float, new TFShape (784));

// Nearest Neighbor calculation using L1 Distance
// Calculate L1 Distance
TFOutput distance = g.ReduceSum (g.Abs (g.Add (xtr, g.Neg (xte))), axis: g.Const (1));
TFOutput distance = g.ReduceSum (g.Abs (g.Add (trainingInput, g.Neg (xte))), axis: g.Const (1));

// Prediction: Get min distance index (Nearest neighbor)
TFOutput pred = g.ArgMin (distance, g.Const (0));
Expand All @@ -241,15 +247,15 @@ void NearestNeighbor ()

// Get nearest neighbor

var result = runner.Fetch (pred).Fetch (distance).AddInput (xtr, Xtr).AddInput (xte, Xte [i].DataFloat).Run ();
var result = runner.Fetch (pred).Fetch (distance).AddInput (trainingInput, trainingImages).AddInput (xte, Extract (testImages, i)).Run ();
var r = result [0].GetValue ();
var tr = result [1].GetValue ();
var nn_index = (int)(long) result [0].GetValue ();

// Get nearest neighbor class label and compare it to its true label
Console.WriteLine ($"Test {i}: Prediction: {ArgMax (Ytr, nn_index)} True class: {ArgMax (Yte, i)} (nn_index={nn_index})");
if (ArgMax (Ytr, nn_index) == ArgMax (Yte, i))
accuracy += 1f/ Xte.Length;
Console.WriteLine ($"Test {i}: Prediction: {ArgMax (trainingLabels, nn_index)} True class: {ArgMax (testLabels, i)} (nn_index={nn_index})");
if (ArgMax (trainingLabels, nn_index) == ArgMax (testLabels, i))
accuracy += 1f/ testImages.Length;
}
Console.WriteLine ("Accuracy: " + accuracy);
}
Expand Down
3 changes: 3 additions & 0 deletions SampleTest/SampleTest.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
<Reference Include="mscorlib" />
<Reference Include="System.Core" />
<Reference Include="System.Numerics" />
<Reference Include="System.ValueTuple">
<HintPath>..\packages\System.ValueTuple.4.3.1\lib\netstandard1.0\System.ValueTuple.dll</HintPath>
</Reference>
</ItemGroup>
<ItemGroup>
<Compile Include="SampleTest.cs" />
Expand Down
1 change: 1 addition & 0 deletions SampleTest/packages.config
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="CsvHelper" version="2.16.3.0" targetFramework="net45" />
<package id="System.ValueTuple" version="4.3.1" targetFramework="net461" />
</packages>
11 changes: 8 additions & 3 deletions TensorFlowSharp/Buffer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,22 @@ internal struct LLBuffer
/// Holds a block of data, suitable to pass, or retrieve from TensorFlow.
/// </summary>
/// <remarks>
/// <para>
/// Use the TFBuffer to blobs of data into TensorFlow, or to retrieve blocks
/// of data out of TensorFlow.
///
/// </para>
/// <para>
/// There are two constructors to wrap existing data, one to wrap blocks that are
/// pointed to by an IntPtr and one that takes a byte array that we want to wrap.
///
/// </para>
/// <para>
/// The empty constructor can be used to create a new TFBuffer that can be populated
/// by the TensorFlow library and returned to user code.
///
/// </para>
/// <para>
/// Typically, the data consists of a serialized protocol buffer, but other data
/// may also be held in a buffer.
/// </para>
/// </remarks>
// TODO: the string ctor
// TODO: perhaps we should have an implicit byte [] conversion that just calls ToArray?
Expand Down
8 changes: 6 additions & 2 deletions TensorFlowSharp/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@ namespace TensorFlow
/// TFTensor holds a multi-dimensional array of elements of a single data type.
/// </summary>
/// <remarks>
/// <para>
/// You can create tensors with the various constructors in this class, or using
/// the implicit conversions from various data types into a TFTensor.
///
///</para>
/// <para>
/// The implicit conversions for basic types produce tensors of one dimesion with
/// a single element, while the implicit conversion from an array, expects a multi-dimensional
/// array that is converted into a tensor of the right dimensions.
///
/// </para>
/// <para>
/// The special "String" tensor data type that you will find in TensorFlow documentation
/// really represents a byte array. You can create string tensors by using the <see cref="M:TensorFlow.TFTensor.CreateString"/>
/// method that takes a byte array buffer as input.
/// </para>
/// </remarks>
public class TFTensor : TFDisposable
{
Expand Down
Loading

0 comments on commit 2804490

Please sign in to comment.