Skip to content

Commit

Permalink
added method for getting a set of random integers (dmlc#297)
Browse files Browse the repository at this point in the history
  • Loading branch information
asmushetzel authored and piiswrong committed Oct 8, 2017
1 parent 1f6318e commit 2f9e7a6
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion mshadow/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ class Random<cpu, DType> {
return rnd_engine_();
}

/*!
* \brief get a set of random integers
*/
inline void GetRandInt(const Tensor<cpu, 1, unsigned>& dst) {
std::generate_n(dst.dptr_, dst.size(0), rnd_engine_);
}

/*!
* \brief generate data from a distribution
* \param dst destination
Expand Down Expand Up @@ -395,6 +402,13 @@ class Random<gpu, DType> {
status = curandSetPseudoRandomGeneratorSeed(gen_, seed);
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "Set CURAND seed failed.";
}
/*!
* \brief get a set of random integers
*/
inline void GetRandInt(const Tensor<gpu, 1, unsigned>& dst) {
curandStatus_t status = curandGenerate(gen_, dst.dptr_, dst.size(0));
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen rand ints failed.";
}
/*!
* \brief generate data from uniform [a,b)
* \param dst destination
Expand Down Expand Up @@ -476,7 +490,7 @@ class Random<gpu, DType> {
CHECK_EQ(status, CURAND_STATUS_SUCCESS) << "CURAND Gen Uniform double failed."
<< " size = " << size;
}
/*! \brief random numbeer generator */
/*! \brief random number generator */
curandGenerator_t gen_;
/*! \brief templ buffer */
TensorContainer<gpu, 1, DType> buffer_;
Expand Down

0 comments on commit 2f9e7a6

Please sign in to comment.