Skip to content

Commit 2804490

Browse files
committed
Some refactoring of the MNIST loader, still incomplete
1 parent 7de0c24 commit 2804490

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+473
-74120
lines changed

Learn/Datasets/MNIST.cs

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,48 +43,39 @@ public class Mnist
4343
public byte [] TrainLabels, TestLabels, ValidationLabels;
4444
public byte [,] OneHotTrainLabels, OneHotTestLabels, OneHotValidationLabels;
4545

46-
// Simple batch reader to get pieces of data from the dataset
47-
public BatchReader GetBatchReader (MnistImage [] source)
48-
{
49-
return new BatchReader (source);
50-
}
46+
public BatchReader GetTrainReader () => new BatchReader (TrainImages, TrainLabels, OneHotTrainLabels);
47+
public BatchReader GetTestReader () => new BatchReader (TestImages, TestLabels, OneHotTestLabels);
48+
public BatchReader GetValidationReader () => new BatchReader (ValidationImages, ValidationLabels, OneHotValidationLabels);
5149

5250
public class BatchReader
5351
{
5452
int start = 0;
5553
MnistImage [] source;
54+
byte [] labels;
55+
byte [,] oneHotLabels;
5656

57-
public BatchReader (MnistImage [] source)
57+
internal BatchReader (MnistImage [] source, byte [] labels, byte [,] oneHotLabels)
5858
{
5959
this.source = source;
60+
this.labels = labels;
61+
this.oneHotLabels = oneHotLabels;
6062
}
6163

62-
public MnistImage [] Read (int batchSize)
63-
{
64-
var result = new MnistImage [batchSize];
65-
if (start + batchSize < source.Length) {
66-
Array.Copy (source, start, result, 0, batchSize);
67-
start += batchSize;
68-
} else {
69-
var firstLength = source.Length - start;
70-
Array.Copy (source, start, result, 0, firstLength);
71-
Array.Copy (source, 0, result, firstLength, batchSize-firstLength);
72-
start = firstLength;
73-
}
74-
return result;
75-
}
76-
77-
public TFTensor ReadAsTensor (int batchSize)
64+
public (float[,],float [,]) NextBatch (int batchSize)
7865
{
79-
var result = new float [batchSize, 784];
66+
var imageData = new float [batchSize, 784];
67+
var labelData = new float [batchSize, 10];
8068

81-
var x = Read (batchSize);
8269
int p = 0;
83-
for (int i = 0; i < batchSize; i++) {
84-
Buffer.BlockCopy (x [i].DataFloat, 0, result, p, 784);
70+
for (int item = 0; item < batchSize; item++) {
71+
Buffer.BlockCopy (source [start+item].DataFloat, 0, imageData, p, 784);
8572
p += 784;
73+
for (var j = 0; j < 10; j++)
74+
labelData [item, j] = oneHotLabels [item + start, j];
8675
}
87-
return (TFTensor)result;
76+
77+
start += batchSize;
78+
return (imageData, labelData);
8879
}
8980
}
9081

Learn/Learn.csproj

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
</PropertyGroup>
3232
<ItemGroup>
3333
<Reference Include="System" />
34+
<Reference Include="System.ValueTuple">
35+
<HintPath>..\packages\System.ValueTuple.4.3.1\lib\netstandard1.0\System.ValueTuple.dll</HintPath>
36+
</Reference>
37+
<Reference Include="System.Numerics" />
3438
</ItemGroup>
3539
<ItemGroup>
3640
<Compile Include="Properties\AssemblyInfo.cs" />
@@ -47,5 +51,8 @@
4751
<Name>TensorFlowSharp</Name>
4852
</ProjectReference>
4953
</ItemGroup>
54+
<ItemGroup>
55+
<None Include="packages.config" />
56+
</ItemGroup>
5057
<Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" />
5158
</Project>

Learn/packages.config

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<packages>
3+
<package id="System.ValueTuple" version="4.3.1" targetFramework="net461" />
4+
</packages>

Makefile

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1+
all: doc-update yaml
12

23
rebuild-docs: docs/template
34
mdoc export-html --force-update -o docs --template=docs/template ecmadocs/en/
45

56
# Used to fetch XML doc updates from the C# compiler into the ECMA docs
67
doc-update:
78
mdoc update -i TensorFlowSharp/bin/Debug/TensorFlowSharp.xml -o ecmadocs/en TensorFlowSharp/bin/Debug/TensorFlowSharp.dll
9+
10+
yaml:
11+
-rm ecmadocs/en/ns-.xml
12+
mono /cvs/ECMA2Yaml/ECMA2Yaml/ECMA2Yaml/bin/Debug/ECMA2Yaml.exe --source=`pwd`/ecmadocs/en --output=`pwd`/docfx/api
13+
(cd docfx; mono ~/Downloads/docfx/docfx.exe build)
14+

SampleTest/SampleTest.cs

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@ void BasicMatrix ()
187187
};
188188
}
189189

190-
int ArgMax (byte [,] array, int idx)
190+
int ArgMax (float [,] array, int idx)
191191
{
192-
int max = -1;
192+
float max = -1;
193193
int maxIdx = -1;
194194
var l = array.GetLength (1);
195195
for (int i = 0; i < l; i++)
@@ -198,7 +198,17 @@ int ArgMax (byte [,] array, int idx)
198198
max = array [idx, i];
199199
}
200200
return maxIdx;
201-
}
201+
}
202+
203+
public float [] Extract (float [,] array, int index)
204+
{
205+
var n = array.GetLength (1);
206+
var ret = new float [n];
207+
208+
for (int i = 0; i < n; i++)
209+
ret [i] = array [index,i];
210+
return ret;
211+
}
202212

203213
// This sample has a bug, I suspect the data loaded is incorrect, because the returned
204214
// values in distance is wrong, and so is the prediction computed from it.
@@ -211,25 +221,21 @@ void NearestNeighbor ()
211221
// 5000 for training
212222
const int trainCount = 5000;
213223
const int testCount = 200;
214-
var Xtr = mnist.GetBatchReader (mnist.TrainImages).ReadAsTensor (trainCount);
215-
var Ytr = mnist.OneHotTrainLabels;
216-
var Xte = mnist.GetBatchReader (mnist.TestImages).Read (testCount);
217-
var Yte = mnist.OneHotTestLabels;
218-
219-
224+
(var trainingImages, var trainingLabels) = mnist.GetTrainReader ().NextBatch (trainCount);
225+
(var testImages, var testLabels) = mnist.GetTestReader ().NextBatch (testCount);
220226

221227
Console.WriteLine ("Nearest neighbor on Mnist images");
222228
using (var g = new TFGraph ()) {
223229
var s = new TFSession (g);
224230

225231

226-
TFOutput xtr = g.Placeholder (TFDataType.Float, new TFShape (-1, 784));
232+
TFOutput trainingInput = g.Placeholder (TFDataType.Float, new TFShape (-1, 784));
227233

228234
TFOutput xte = g.Placeholder (TFDataType.Float, new TFShape (784));
229235

230236
// Nearest Neighbor calculation using L1 Distance
231237
// Calculate L1 Distance
232-
TFOutput distance = g.ReduceSum (g.Abs (g.Add (xtr, g.Neg (xte))), axis: g.Const (1));
238+
TFOutput distance = g.ReduceSum (g.Abs (g.Add (trainingInput, g.Neg (xte))), axis: g.Const (1));
233239

234240
// Prediction: Get min distance index (Nearest neighbor)
235241
TFOutput pred = g.ArgMin (distance, g.Const (0));
@@ -241,15 +247,15 @@ void NearestNeighbor ()
241247

242248
// Get nearest neighbor
243249

244-
var result = runner.Fetch (pred).Fetch (distance).AddInput (xtr, Xtr).AddInput (xte, Xte [i].DataFloat).Run ();
250+
var result = runner.Fetch (pred).Fetch (distance).AddInput (trainingInput, trainingImages).AddInput (xte, Extract (testImages, i)).Run ();
245251
var r = result [0].GetValue ();
246252
var tr = result [1].GetValue ();
247253
var nn_index = (int)(long) result [0].GetValue ();
248254

249255
// Get nearest neighbor class label and compare it to its true label
250-
Console.WriteLine ($"Test {i}: Prediction: {ArgMax (Ytr, nn_index)} True class: {ArgMax (Yte, i)} (nn_index={nn_index})");
251-
if (ArgMax (Ytr, nn_index) == ArgMax (Yte, i))
252-
accuracy += 1f/ Xte.Length;
256+
Console.WriteLine ($"Test {i}: Prediction: {ArgMax (trainingLabels, nn_index)} True class: {ArgMax (testLabels, i)} (nn_index={nn_index})");
257+
if (ArgMax (trainingLabels, nn_index) == ArgMax (testLabels, i))
258+
accuracy += 1f/ testImages.Length;
253259
}
254260
Console.WriteLine ("Accuracy: " + accuracy);
255261
}

SampleTest/SampleTest.csproj

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
<Reference Include="mscorlib" />
4040
<Reference Include="System.Core" />
4141
<Reference Include="System.Numerics" />
42+
<Reference Include="System.ValueTuple">
43+
<HintPath>..\packages\System.ValueTuple.4.3.1\lib\netstandard1.0\System.ValueTuple.dll</HintPath>
44+
</Reference>
4245
</ItemGroup>
4346
<ItemGroup>
4447
<Compile Include="SampleTest.cs" />

SampleTest/packages.config

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
<?xml version="1.0" encoding="utf-8"?>
22
<packages>
33
<package id="CsvHelper" version="2.16.3.0" targetFramework="net45" />
4+
<package id="System.ValueTuple" version="4.3.1" targetFramework="net461" />
45
</packages>

TensorFlowSharp/Buffer.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,22 @@ internal struct LLBuffer
2323
/// Holds a block of data, suitable to pass, or retrieve from TensorFlow.
2424
/// </summary>
2525
/// <remarks>
26+
/// <para>
2627
/// Use the TFBuffer to blobs of data into TensorFlow, or to retrieve blocks
2728
/// of data out of TensorFlow.
28-
///
29+
/// </para>
30+
/// <para>
2931
/// There are two constructors to wrap existing data, one to wrap blocks that are
3032
/// pointed to by an IntPtr and one that takes a byte array that we want to wrap.
31-
///
33+
/// </para>
34+
/// <para>
3235
/// The empty constructor can be used to create a new TFBuffer that can be populated
3336
/// by the TensorFlow library and returned to user code.
34-
///
37+
/// </para>
38+
/// <para>
3539
/// Typically, the data consists of a serialized protocol buffer, but other data
3640
/// may also be held in a buffer.
41+
/// </para>
3742
/// </remarks>
3843
// TODO: the string ctor
3944
// TODO: perhaps we should have an implicit byte [] conversion that just calls ToArray?

TensorFlowSharp/Tensor.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,20 @@ namespace TensorFlow
1818
/// TFTensor holds a multi-dimensional array of elements of a single data type.
1919
/// </summary>
2020
/// <remarks>
21+
/// <para>
2122
/// You can create tensors with the various constructors in this class, or using
2223
/// the implicit conversions from various data types into a TFTensor.
23-
///
24+
///</para>
25+
/// <para>
2426
/// The implicit conversions for basic types produce tensors of one dimesion with
2527
/// a single element, while the implicit conversion from an array, expects a multi-dimensional
2628
/// array that is converted into a tensor of the right dimensions.
27-
///
29+
/// </para>
30+
/// <para>
2831
/// The special "String" tensor data type that you will find in TensorFlow documentation
2932
/// really represents a byte array. You can create string tensors by using the <see cref="M:TensorFlow.TFTensor.CreateString"/>
3033
/// method that takes a byte array buffer as input.
34+
/// </para>
3135
/// </remarks>
3236
public class TFTensor : TFDisposable
3337
{

0 commit comments

Comments
 (0)