diff --git a/include/prnn/detail/rnn/recurrent_ops_config.h b/include/prnn/detail/rnn/recurrent_ops_config.h index 0efb737..3ab8099 100644 --- a/include/prnn/detail/rnn/recurrent_ops_config.h +++ b/include/prnn/detail/rnn/recurrent_ops_config.h @@ -410,7 +410,7 @@ class RecurrentConfig { }; enum { - BARRIER_WAIT_COUNT = 3//333 // about 10us + BARRIER_WAIT_COUNT = 3333 // about 10us }; public: diff --git a/include/prnn/detail/rnn/recurrent_ops_kernels.h b/include/prnn/detail/rnn/recurrent_ops_kernels.h index 92ffa99..3f0635c 100644 --- a/include/prnn/detail/rnn/recurrent_ops_kernels.h +++ b/include/prnn/detail/rnn/recurrent_ops_kernels.h @@ -6,7 +6,7 @@ #include -#define DEBUG_RECURRENT_OPS 1 +#define DEBUG_RECURRENT_OPS 0 #define ATOMIC_INCREMENT 1 #define USE_BARRIER 1 #define SHOULD_SPIN 1 @@ -18,11 +18,11 @@ #if DEBUG_RECURRENT_OPS -#define dprintf(...) do { if( true ) \ +#define dprintf(...) do { if( blockIdx.x == 0 && blockIdx.y == 0 ) \ { std::printf(__VA_ARGS__); } } while(0) #define t0printf(...) do { if(threadIdx.x == 0 && (threadIdx.y == 0) && \ - true ) { std::printf(__VA_ARGS__); } } while(0) + blockIdx.x == 0 && blockIdx.y == 0 ) { std::printf(__VA_ARGS__); } } while(0) #define UNROLL @@ -1828,7 +1828,7 @@ class PersistentEngine index_t offsetInLayer = threadOffset + row * 2 * Config::THREADS_PER_BLOCK; index_t offset = offsetInLayer + blockOffset + bufferOffset; - bool condition = offset < register_state.expanded_layer_size && + bool condition = (offsetInLayer + blockOffset) < register_state.expanded_layer_size && offsetInLayer < Config::EXPANDED_BLOCK_TILE_ROWS; #if DEBUG_RECURRENT_OPS