Skip to content

Commit

Permalink
added imageNet 2012 pretrained model, followed Matt's parameters.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Apr 25, 2014
1 parent e098d56 commit 803dc42
Show file tree
Hide file tree
Showing 14 changed files with 1,173 additions and 112 deletions.
12 changes: 6 additions & 6 deletions bin/cnnclassify.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ int main(int argc, char** argv)
for (i = 0; i < rank->rnum - 1; i++)
{
ccv_classification_t* classification = (ccv_classification_t*)ccv_array_get(rank, i);
printf("%d %f ", classification->id, classification->confidence);
printf("%d %f ", classification->id + 1, classification->confidence);
}
ccv_classification_t* classification = (ccv_classification_t*)ccv_array_get(rank, rank->rnum - 1);
printf("%d %f\n", classification->id, classification->confidence);
printf("%d %f\n", classification->id + 1, classification->confidence);
printf("elapsed time %dms\n", elapsed_time);
ccv_array_free(rank);
ccv_matrix_free(input);
Expand Down Expand Up @@ -76,10 +76,10 @@ int main(int argc, char** argv)
for (j = 0; j < ranks[i]->rnum - 1; j++)
{
ccv_classification_t* classification = (ccv_classification_t*)ccv_array_get(ranks[i], j);
printf("%d %f ", classification->id, classification->confidence);
printf("%d %f ", classification->id + 1, classification->confidence);
}
ccv_classification_t* classification = (ccv_classification_t*)ccv_array_get(ranks[i], ranks[i]->rnum - 1);
printf("%d %f\n", classification->id, classification->confidence);
printf("%d %f\n", classification->id + 1, classification->confidence);
ccv_array_free(ranks[i]);
}
}
Expand All @@ -98,10 +98,10 @@ int main(int argc, char** argv)
for (j = 0; j < ranks[i]->rnum - 1; j++)
{
ccv_classification_t* classification = (ccv_classification_t*)ccv_array_get(ranks[i], j);
printf("%d %f ", classification->id, classification->confidence);
printf("%d %f ", classification->id + 1, classification->confidence);
}
ccv_classification_t* classification = (ccv_classification_t*)ccv_array_get(ranks[i], ranks[i]->rnum - 1);
printf("%d %f\n", classification->id, classification->confidence);
printf("%d %f\n", classification->id + 1, classification->confidence);
ccv_array_free(ranks[i]);
}
for (i = (k % 32); i < 32; i++)
Expand Down
2 changes: 1 addition & 1 deletion bin/cuda/cwc-bench-runtime.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ extern "C" void cwc_bench_runtime(ccv_convnet_t* convnet, ccv_array_t* categoriz
cwc_convnet_context_t* context = GPU(convnet)->contexts;
for (i = 0; i < convnet->rows * convnet->cols * convnet->channels; i++)
convnet->mean_activity->data.f32[i] = 128;
_cwc_convnet_batch_formation(0, categorizeds, convnet->mean_activity, 0, 0, 0, 0, ccv_size(225, 225), convnet->rows, convnet->cols, convnet->channels, 0, batch, 0, batch, context->host.input, context->host.c);
_cwc_convnet_batch_formation(0, categorizeds, convnet->mean_activity, 0, 0, 0, 0, ccv_size(225, 225), convnet->rows, convnet->cols, convnet->channels, 1000, 0, batch, 0, batch, context->host.input, context->host.c);
cudaMemcpy(context->device.input, context->host.input, sizeof(float) * convnet->rows * convnet->cols * convnet->channels * batch, cudaMemcpyHostToDevice);

ccv_convnet_t* update_params = _ccv_convnet_update_new(convnet);
Expand Down
33 changes: 18 additions & 15 deletions bin/image-net.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "ccv.h"
#include "ccv_internal.h"
#include <ctype.h>
#include <getopt.h>

Expand Down Expand Up @@ -41,8 +42,8 @@ int main(int argc, char** argv)
char* base_dir = 0;
ccv_convnet_train_param_t train_params = {
.max_epoch = 100,
.mini_batch = 256,
.iterations = 5000,
.mini_batch = 128,
.iterations = 20000,
.symmetric = 1,
.color_gain = 0.001,
};
Expand Down Expand Up @@ -95,7 +96,8 @@ int main(int argc, char** argv)
ccv_file_info_t file_info = {
.filename = filename,
};
ccv_categorized_t categorized = ccv_categorized(c, 0, &file_info);
// imageNet's category class starts from 1, thus, minus 1 to get 0-index
ccv_categorized_t categorized = ccv_categorized(c - 1, 0, &file_info);
ccv_array_push(categorizeds, &categorized);
}
fclose(r0);
Expand All @@ -112,7 +114,8 @@ int main(int argc, char** argv)
ccv_file_info_t file_info = {
.filename = filename,
};
ccv_categorized_t categorized = ccv_categorized(c, 0, &file_info);
// imageNet's category class starts from 1, thus, minus 1 to get 0-index
ccv_categorized_t categorized = ccv_categorized(c - 1, 0, &file_info);
ccv_array_push(tests, &categorized);
}
fclose(r1);
Expand All @@ -134,10 +137,10 @@ int main(int argc, char** argv)
.output = {
.convolutional = {
.count = 96,
.strides = 4,
.strides = 2,
.border = 1,
.rows = 11,
.cols = 11,
.rows = 7,
.cols = 7,
.channels = 3,
.partition = 2,
},
Expand All @@ -147,8 +150,8 @@ int main(int argc, char** argv)
.type = CCV_CONVNET_LOCAL_RESPONSE_NORM,
.input = {
.matrix = {
.rows = 55,
.cols = 55,
.rows = 111,
.cols = 111,
.channels = 96,
.partition = 2,
},
Expand All @@ -166,8 +169,8 @@ int main(int argc, char** argv)
.type = CCV_CONVNET_MAX_POOL,
.input = {
.matrix = {
.rows = 55,
.cols = 55,
.rows = 111,
.cols = 111,
.channels = 96,
.partition = 2,
},
Expand All @@ -187,17 +190,17 @@ int main(int argc, char** argv)
.sigma = 0.01,
.input = {
.matrix = {
.rows = 27,
.cols = 27,
.rows = 55,
.cols = 55,
.channels = 96,
.partition = 2,
},
},
.output = {
.convolutional = {
.count = 256,
.strides = 1,
.border = 2,
.strides = 2,
.border = 1,
.rows = 5,
.cols = 5,
.channels = 96,
Expand Down
96 changes: 60 additions & 36 deletions doc/convnet.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ ground-breaking work presented in:

ImageNet Classification with Deep Convolutional Neural Networks, Alex Krizhevsky, Ilya Sutskever, and Geoffrey E. Hinton, NIPS 2012

The parameters are modified based on Matthew D. Zeiler's work presented in:

Visualizing and Understanding Convolutional Networkers, Matthew D. Zeiler, and Rob Fergus, Arxiv 1311.2901 (Nov 2013)

How it works?
-------------

Expand All @@ -22,7 +26,7 @@ Long story short, with advances in GPGPU programming, we can have very large neu
have both and a bag of tricks (dropout, pooling etc.), the resulted neural networks can achieve
good image classification results.

./cnnclassify ../samples/dex.png ../samples/image-net.sqlite3 | ./cnndraw.rb ../samples/image-net.words ../samples/dex.png output.png
./cnnclassify ../samples/dex.png ../samples/image-net-2010.sqlite3 | ./cnndraw.rb ../samples/image-net-2010.words ../samples/dex.png output.png

Check output.png, the neural networks suggest a few possible relevant classes in the top
left chart.
Expand All @@ -41,36 +45,49 @@ topology followed the exact specification detailed in the paper.

Accuracy-wise:

The test is performed on ILSVRC 2010 test dataset, as of time being, I cannot obtain the validation
dataset for ILSVRC 2012.
The test is performed on ILSVRC 2010 test dataset and ILSVRC 2012 validation dataset.

For ILSVRC2010 dataset, The training stopped to improve at around 60 epochs, at that time, the central
patch from test set obtained 36.56% of top-1 missing rate (lower is better) and the training set
obtained 32.2% of top-1 missing rate. In Alex's paper, they reported 37.5% top-1 missing rate when
averaging 10 patches, and 39% top-1 missing rate when using the central patch in test set.

The training stopped to improve at around 60 epochs, at that time, the central patch from test set
obtained 39.71% of top-1 missing rate (lower is better) and the training set obtained 37.80% of
top-1 missing rate. In Alex's paper, they reported 37.5% top-1 missing rate when averaging 10 patches,
and 39% top-1 missing rate when using the central patch in test set.
For ILSVRC2012 dataset, the training stopped to improve at around 70 epochs, at that time, the central
patch from test set obtained 41.4% of top-1 missing rate (lower is better) and the training set
obtained 37.8% of top-1 missing rate. In Alex's paper, they reported 40.5% top-1 missing rate when
averaging 10 patches. In Matt's paper, they reported 38.4% top-1 missing rate when using 1 convnet as
configured in Fig.3 and averaging 10 patches.

Assuming you have ILSVRC 2010 test set files ordered in image-net-test.txt, run
Assuming you have ILSVRC 2012 validation set files ordered in image-net-2012-val.txt, run

./cnnclassify image-net-test.txt ../samples/image-net.sqlite3 > image-net-classify.txt
./cnnclassify image-net-2012-val.txt ../samples/image-net-2012.sqlite3 > image-net-2012-classify.txt

For complete test set to finish, this command takes an hour on GPU, and if you don't have GPU
For complete validation set to finish, this command takes half an hour on GPU, and if you don't have GPU
enabled, it will take about a day to run on CPU.

Assuming you have the ILSVRC 2010 ground truth data in LSVRC2010_test_ground_truth.txt
Assuming you have the ILSVRC 2012 validation ground truth data in LSVRC2012_val_ground_truth.txt

./cnnvldtr.rb LSVRC2010_test_ground_truth.txt image-net-classify.txt
./cnnvldtr.rb LSVRC2012_test_ground_truth.txt image-net-2012-classify.txt

will reports the top-1 missing rate as well as top-5 missing rate.

For 32-bit float point image-net.sqlite3 on GPU, the top-1 missing rate is 36.82%, 0.68% better
than Alex's result, the top-5 missing rate is 16.26%, 0.74% better than Alex's. For half precision
image-net.sqlite3 (the one included in ./samples/), the top-1 missing rate is 36.83% and the top-5
missing rate is 16.25%.
For 32-bit float point image-net-2012.sqlite3 on GPU, the top-1 missing rate is 38.17%, 2.33% better
than Alex's result with 1 convnet, and 0.23% better than Matt's result with 1 convnet and configured
with Fig.3. The top-5 missing rate is 16.22%, 1.98% better than Alex's and 0.28% better than Matt's.
For half precision image-net-2012.sqlite3 (the one included in ./samples/), the top-1 missing rate is
38.18% and the top-5 missing rate is 16.17%.

For 32-bit float point image-net.sqlite3 on CPU, the top-1 missing rate is 37.32%, and the top-5
missing rate is 16.48%.
See http://www.image-net.org/challenges/LSVRC/2013/results.php#cls for the current state-of-the-art,
ccv's implementation is still about 5% behind Clarifai (Matt's commercial implementation, later claimed
to be 10.7%: http://www.clarifai.com/) and 2% behind OverFeat on top-5 missing rate.

You can download the 32-bit float point one with ./samples/download-image-net.sh
For 32-bit float point image-net-2012.sqlite3 on CPU, the top-1 missing rate is XX.XX%, and the top-5
missing rate is XX.XX%.

For 32-bit float point image-net-2010.sqlite3 on GPU, the top-1 missing rate is 33.91%, and the top-5
missing rate is 14.08%.

You can download the 32-bit float point versions with ./samples/download-image-net.sh

Speed-wise:

Expand All @@ -87,11 +104,17 @@ makes little sense). Their reported number are 1s per image on unspecified confi
unspecified hardware (I suspect that their unspecified configuration does much more than the
averaging 10 patches ccv or Decaf does).

The GPU version does forward pass + backward error propagate for batch size of 256 in about 1.6s.
Thus, training ImageNet convolutional network takes about 9 days with 100 epochs. Caffe reported
their forward pass + backward error propagate for batch size of 256 in about 1.8s on Tesla K20 (
known to be about 30% slower cross the board than TITAN). In the paper, Alex reported 90 epochs
within 6 days on two GeForce 580, which suggests my time is within line of these implementations.
For AlexNet, the GPU version does forward pass + backward error propagate for batch size of 256
in about 1.6s. Thus, training ImageNet convolutional network takes about 9 days with 100 epochs.
Caffe reported their forward pass + backward error propagate for batch size of 256 in about 1.8s
on Tesla K20 (known to be about 30% slower cross the board than TITAN). In the paper, Alex
reported 90 epochs within 6 days on two GeForce 580. In "Multi-GPU Training of ConvNets" (Omry Yadan,
Keith Adams, Yaniv Taigman, and Marc'Aurelio Ranzato, arXiv:1312.5853), Omry mentioned that they did
100 epochs of AlexNet in 10.5 days on 1 GPU), which suggests my time is within line of these
implementations.

For MattNet, the GPU version does forward pass + backward error propagate for batch size of 128
in about 1.0s.

As a preliminary implementation, I didn't spend enough time to optimize these operations in ccv if
any at all. For example, [cuda-convnet](http://code.google.com/p/cuda-convnet/) implements its
Expand Down Expand Up @@ -182,18 +205,18 @@ is exactly what I included in ./samples.
Can I use the ImageNet pre-trained data model?
----------------------------------------------

ccv is released under FreeBSD 3-clause license, and the pre-trained data model ./samples/image-net.sqlite3
is released under Creative Commons Attribution 4.0 International License. You can use it, modify it
practically anywhere and anyhow with proper attribution. As far as I can tell, this is the first pre-trained
data model released under commercial-friendly license (Caffe itself is released under FreeBSD license but
its pre-trained data model is "research only" and OverFeat is released under custom research only license).
ccv is released under FreeBSD 3-clause license, and the pre-trained data models ./samples/image-net-2010.sqlite3
and ./samples/image-net-2012.sqlite3 are released under Creative Commons Attribution 4.0 International License.
You can use it, modify it practically anywhere and anyhow with proper attribution. As far as I can tell, this is
the first pre-trained data model released under commercial-friendly license (Caffe itself is released under
FreeBSD license but its pre-trained data model is "research only" and OverFeat is released under custom research
only license).

Differences between ccv's implementation, Caffe's and Alex's
------------------------------------------------------------
Differences between ccv's implementation, Caffe's, Alex's and Matt's
--------------------------------------------------------------------

Although the network topology of ccv's implementation followed closely to Alex's (as well as Caffe's),
the reported results diverged significantly enough for me to document the differences in implementation
details.
Although the network topology of ccv's implementation followed closely to Matt's, the reported results
diverged significantly enough for me to document the differences in implementation details.

Network Topology:

Expand All @@ -202,8 +225,9 @@ the local response normalization. This is briefly mentioned in Alex's paper, but
response normalization layer followed the pooling layer.

The input dimension to ccv's implemented network is 225x225, and in Caffe, it is 227x227. Alex's paper
mentioned their input size is 224x224. For 225x225, it implies a 1 pixel padding around the input image
such that with 11x11 filter and 4 stride size, a 55x55 output will be generated.
as well as Matt's mentioned their input size is 224x224. For 225x225, it implies a 1 pixel padding around
the input image such that with 11x11 filter and 2 stride size, a 111x111 output will be generated. However,
the output of the first convolutional layer in Matt's paper is 110x110.

Data Preparation:

Expand Down
8 changes: 5 additions & 3 deletions lib/cuda/cwc_convnet.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1983,7 +1983,7 @@ static void _cwc_convnet_net_sgd(ccv_convnet_t* convnet, int momentum_read, int
}
}

static void _cwc_convnet_batch_formation(gsl_rng* rng, ccv_array_t* categorizeds, ccv_dense_matrix_t* mean_activity, ccv_dense_matrix_t* eigenvectors, ccv_dense_matrix_t* eigenvalues, float color_gain, int* idx, ccv_size_t dim, int rows, int cols, int channels, int symmetric, int batch, int offset, int size, float* b, int* c)
static void _cwc_convnet_batch_formation(gsl_rng* rng, ccv_array_t* categorizeds, ccv_dense_matrix_t* mean_activity, ccv_dense_matrix_t* eigenvectors, ccv_dense_matrix_t* eigenvalues, float color_gain, int* idx, ccv_size_t dim, int rows, int cols, int channels, int category_count, int symmetric, int batch, int offset, int size, float* b, int* c)
{
int i, k, x;
assert(size <= batch);
Expand All @@ -1992,6 +1992,7 @@ static void _cwc_convnet_batch_formation(gsl_rng* rng, ccv_array_t* categorizeds
for (i = 0; i < size; i++)
{
ccv_categorized_t* categorized = (ccv_categorized_t*)ccv_array_get(categorizeds, idx ? idx[offset + i] : offset + i);
assert(categorized->c < category_count && categorized->c >= 0); // now only accept classes listed
if (c)
c[i] = categorized->c;
ccv_dense_matrix_t* image;
Expand Down Expand Up @@ -2184,6 +2185,7 @@ static void _cwc_convnet_channel_eigen(ccv_array_t* categorizeds, ccv_dense_matr
covariance[i * channels + j] *= p; // scale down
ccv_dense_matrix_t covm = ccv_dense_matrix(3, 3, CCV_64F | CCV_C1, covariance, 0);
ccv_eigen(&covm, eigenvectors, eigenvalues, CCV_64F, 1e-8);
printf("\n");
}

static void _cwc_convnet_dor_mean_net(ccv_convnet_t* convnet, ccv_convnet_layer_train_param_t* layer_params, const cublasHandle_t& handle)
Expand Down Expand Up @@ -2791,7 +2793,7 @@ void cwc_convnet_supervised_train(ccv_convnet_t* convnet, ccv_array_t* categoriz
for (i = z.i; i < ccv_min(z.i + params.iterations, aligned_batches); i++)
{
cwc_convnet_context_t* context = GPU(z.convnet)->contexts + (i % 2);
_cwc_convnet_batch_formation(rng, categorizeds, z.convnet->mean_activity, z.eigenvectors, z.eigenvalues, params.color_gain, z.idx, z.convnet->input, z.convnet->rows, z.convnet->cols, z.convnet->channels, params.symmetric, params.mini_batch, i * params.mini_batch, params.mini_batch, context->host.input, context->host.c);
_cwc_convnet_batch_formation(rng, categorizeds, z.convnet->mean_activity, z.eigenvectors, z.eigenvalues, params.color_gain, z.idx, z.convnet->input, z.convnet->rows, z.convnet->cols, z.convnet->channels, category_count, params.symmetric, params.mini_batch, i * params.mini_batch, params.mini_batch, context->host.input, context->host.c);
cudaMemcpyAsync(context->device.input, context->host.input, sizeof(float) * z.convnet->rows * z.convnet->cols * z.convnet->channels * params.mini_batch, cudaMemcpyHostToDevice, context->device.stream);
assert(cudaGetLastError() == cudaSuccess);
cudaMemcpyAsync(context->device.c, context->host.c, sizeof(int) * params.mini_batch, cudaMemcpyHostToDevice, context->device.stream);
Expand Down Expand Up @@ -2836,7 +2838,7 @@ void cwc_convnet_supervised_train(ccv_convnet_t* convnet, ccv_array_t* categoriz
for (i = j = 0; i < tests->rnum; i += params.mini_batch, j++)
{
cwc_convnet_context_t* context = GPU(z.convnet)->contexts + (j % 2);
_cwc_convnet_batch_formation(0, tests, z.convnet->mean_activity, 0, 0, 0, 0, z.convnet->input, z.convnet->rows, z.convnet->cols, z.convnet->channels, params.symmetric, params.mini_batch, i, ccv_min(params.mini_batch, tests->rnum - i), context->host.input, 0);
_cwc_convnet_batch_formation(0, tests, z.convnet->mean_activity, 0, 0, 0, 0, z.convnet->input, z.convnet->rows, z.convnet->cols, z.convnet->channels, category_count, params.symmetric, params.mini_batch, i, ccv_min(params.mini_batch, tests->rnum - i), context->host.input, 0);
cudaMemcpyAsync(context->device.input, context->host.input, sizeof(float) * z.convnet->rows * z.convnet->cols * z.convnet->channels * params.mini_batch, cudaMemcpyHostToDevice, context->device.stream);
assert(cudaGetLastError() == cudaSuccess);
if (j > 0)
Expand Down
Loading

0 comments on commit 803dc42

Please sign in to comment.