Skip to content

Commit

Permalink
added FP16/FP32 readout in camera samples
Browse files Browse the repository at this point in the history
  • Loading branch information
dusty-nv committed Sep 26, 2016
1 parent 62ec978 commit 90bb465
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 37 deletions.
2 changes: 1 addition & 1 deletion detectNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ detectNet* detectNet::Create( const char* prototxt, const char* model, const cha

for( uint32_t n=0; n < numClasses; n++ )
{
if( n != 0 )
if( n != 1 )
{
net->mClassColors[0][n*4+0] = 0.0f; // r
net->mClassColors[0][n*4+1] = 200.0f; // g
Expand Down
2 changes: 1 addition & 1 deletion detectNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class detectNet : public tensorNet
/**
* Set the visualization color of a particular class of object.
*/
void SetClassColor( uint32_t classIndex, float r, float g, float b, float a=1.0f );
void SetClassColor( uint32_t classIndex, float r, float g, float b, float a=255.0f );


protected:
Expand Down
2 changes: 1 addition & 1 deletion detectnet-camera/detectnet-camera.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ int main( int argc, char** argv )
if( display != NULL )
{
char str[256];
sprintf(str, "GIE build %x | %04.1f FPS", NV_GIE_VERSION, display->GetFPS());
sprintf(str, "GIE build %x | %s | %04.1f FPS", NV_GIE_VERSION, net->HasFP16() ? "FP16" : "FP32", display->GetFPS());
//sprintf(str, "GIE build %x | %s | %04.1f FPS | %05.2f%% %s", NV_GIE_VERSION, net->GetNetworkName(), display->GetFPS(), confidence * 100.0f, net->GetClassDesc(img_class));
display->SetTitle(str);
}
Expand Down
4 changes: 2 additions & 2 deletions imagenet-camera/imagenet-camera.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,15 @@ int main( int argc, char** argv )
{
char str[256];
sprintf(str, "%05.2f%% %s", confidence * 100.0f, net->GetClassDesc(img_class));

font->RenderOverlay((float4*)imgRGBA, (float4*)imgRGBA, camera->GetWidth(), camera->GetHeight(),
str, 0, 0, make_float4(255.0f, 255.0f, 255.0f, 255.0f));
}

if( display != NULL )
{
char str[256];
sprintf(str, "GIE build %x | %s | %04.1f FPS", NV_GIE_VERSION, net->GetNetworkName(), display->GetFPS());
sprintf(str, "GIE build %x | %s | %s | %04.1f FPS", NV_GIE_VERSION, net->GetNetworkName(), net->HasFP16() ? "FP16" : "FP32", display->GetFPS());
//sprintf(str, "GIE build %x | %s | %04.1f FPS | %05.2f%% %s", NV_GIE_VERSION, net->GetNetworkName(), display->GetFPS(), confidence * 100.0f, net->GetClassDesc(img_class));
display->SetTitle(str);
}
Expand Down
49 changes: 30 additions & 19 deletions tensorNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ tensorNet::tensorNet()
mInfer = NULL;
mContext = NULL;

mWidth = 0;
mHeight = 0;
mInputSize = 0;
mInputCPU = NULL;
mInputCUDA = NULL;

mWidth = 0;
mHeight = 0;
mInputSize = 0;
mInputCPU = NULL;
mInputCUDA = NULL;
mEnableFP16 = false;

memset(&mInputDims, 0, sizeof(nvinfer1::Dims3));
}

Expand Down Expand Up @@ -59,16 +60,16 @@ bool tensorNet::ProfileModel(const std::string& deployFile, // name for caf
nvinfer1::INetworkDefinition* network = builder->createNetwork();

builder->setMinFindIterations(3); // allow time for TX1 GPU to spin up
builder->setAverageFindIterations(2);
builder->setAverageFindIterations(2);

// parse the caffe model to populate the network, then set the outputs
nvcaffeparser1::ICaffeParser* parser = nvcaffeparser1::createCaffeParser();

const bool useFp16 = builder->platformHasFastFp16(); // getHalf2Mode();
printf(LOG_GIE "platform %s FP16 support.\n", useFp16 ? "has" : "does not have");
mEnableFP16 = builder->platformHasFastFp16();
printf(LOG_GIE "platform %s FP16 support.\n", mEnableFP16 ? "has" : "does not have");
printf(LOG_GIE "loading %s %s\n", deployFile.c_str(), modelFile.c_str());

nvinfer1::DataType modelDataType = useFp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT; // create a 16-bit model if it's natively supported
nvinfer1::DataType modelDataType = mEnableFP16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT; // create a 16-bit model if it's natively supported
const nvcaffeparser1::IBlobNameToTensor *blobNameToTensor =
parser->parse(deployFile.c_str(), // caffe deploy file
modelFile.c_str(), // caffe model file
Expand All @@ -95,7 +96,7 @@ bool tensorNet::ProfileModel(const std::string& deployFile, // name for caf
builder->setMaxWorkspaceSize(16 << 20);

// set up the network for paired-fp16 format
if(useFp16)
if(mEnableFP16)
builder->setHalf2Mode(true);

printf(LOG_GIE "building CUDA engine\n");
Expand Down Expand Up @@ -137,7 +138,7 @@ bool tensorNet::LoadNetwork( const char* prototxt_path, const char* model_path,
return false;

/*
* load and parse network definition and model file
* attempt to load network from cache before profiling with tensorRT
*/
std::stringstream gieModelStream;
gieModelStream.seekg(0, gieModelStream.beg);
Expand All @@ -148,16 +149,16 @@ bool tensorNet::LoadNetwork( const char* prototxt_path, const char* model_path,

std::ifstream cache( cache_path );

if( !cache )
{
if( !cache )
{
printf(LOG_GIE "cache file not found, profiling network model\n");

if( !ProfileModel(prototxt_path, model_path, output_blobs, MAX_BATCH_SIZE, gieModelStream) )
{
printf("failed to load %s\n", model_path);
return 0;
}

printf(LOG_GIE "network profiling complete, writing cache to %s\n", cache_path);
std::ofstream outFile;
outFile.open(cache_path);
Expand All @@ -169,9 +170,19 @@ bool tensorNet::LoadNetwork( const char* prototxt_path, const char* model_path,
else
{
printf(LOG_GIE "loading network profile from cache... %s\n", cache_path);
gieModelStream << cache.rdbuf();
cache.close();
}
gieModelStream << cache.rdbuf();
cache.close();

// test for half FP16 support
nvinfer1::IBuilder* builder = createInferBuilder(gLogger);

if( builder != NULL )
{
mEnableFP16 = builder->platformHasFastFp16();
printf(LOG_GIE "platform %s FP16 support.\n", mEnableFP16 ? "has" : "does not have");
builder->destroy();
}
}

printf(LOG_GIE "%s loaded\n", model_path);

Expand Down
32 changes: 19 additions & 13 deletions tensorNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,14 @@ class tensorNet
*/
virtual ~tensorNet();

protected:

/**
* Constructor.
*/
tensorNet();

/**
* Load a new network instance
* @param prototxt File path to the deployable network prototxt
* @param model File path to the caffemodel
* @param mean File path to the mean value binary proto (NULL if none)
*/
bool LoadNetwork( const char* prototxt, const char* model, const char* mean=NULL,
const char* input_blob="data", const char* output_blob="prob");
const char* input_blob="data", const char* output_blob="prob");

/**
* Load a new network instance with multiple output layers
Expand All @@ -47,8 +40,20 @@ class tensorNet
* @param mean File path to the mean value binary proto (NULL if none)
*/
bool LoadNetwork( const char* prototxt, const char* model, const char* mean,
const char* input_blob, const std::vector<std::string>& output_blobs);

const char* input_blob, const std::vector<std::string>& output_blobs);

/**
* Query for half-precision FP16 support.
*/
inline bool HasFP16() const { return mEnableFP16; }

protected:

/**
* Constructor.
*/
tensorNet();

/**
* Create and output an optimized network model
* @note this function is automatically used by LoadNetwork, but also can
Expand All @@ -60,9 +65,9 @@ class tensorNet
* @param modelStream output model stream
*/
bool ProfileModel( const std::string& deployFile, const std::string& modelFile,
const std::vector<std::string>& outputs,
uint32_t maxBatchSize, std::ostream& modelStream);
const std::vector<std::string>& outputs,
uint32_t maxBatchSize, std::ostream& modelStream);

/**
* Prefix used for tagging printed log output
*/
Expand Down Expand Up @@ -96,6 +101,7 @@ class tensorNet
uint32_t mInputSize;
float* mInputCPU;
float* mInputCUDA;
bool mEnableFP16;

nvinfer1::Dims3 mInputDims;

Expand Down

0 comments on commit 90bb465

Please sign in to comment.