Skip to content

Commit

Permalink
added 3-channels support to cv::setIdentity
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Mar 18, 2014
1 parent 8d97d0d commit 04884eb
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
10 changes: 5 additions & 5 deletions modules/core/src/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2679,17 +2679,17 @@ namespace cv {

static bool ocl_setIdentity( InputOutputArray _m, const Scalar& s )
{
int type = _m.type(), cn = CV_MAT_CN(type);
if (cn == 3)
return false;
int type = _m.type(), depth = CV_MAT_DEPTH(type), cn = CV_MAT_CN(type),
sctype = CV_MAKE_TYPE(depth, cn == 3 ? 4 : cn);

ocl::Kernel k("setIdentity", ocl::core::set_identity_oclsrc,
format("-D T=%s", ocl::memopTypeToStr(type)));
format("-D T=%s -D T1=%s -D cn=%d -D ST=%s", ocl::memopTypeToStr(type),
ocl::memopTypeToStr(depth), cn, ocl::memopTypeToStr(sctype)));
if (k.empty())
return false;

UMat m = _m.getUMat();
k.args(ocl::KernelArg::WriteOnly(m), ocl::KernelArg::Constant(Mat(1, 1, type, s)));
k.args(ocl::KernelArg::WriteOnly(m), ocl::KernelArg::Constant(Mat(1, 1, sctype, s)));

size_t globalsize[2] = { m.cols, m.rows };
return k.run(2, globalsize, NULL, false);
Expand Down
19 changes: 15 additions & 4 deletions modules/core/src/opencl/set_identity.cl
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,28 @@
//
//M*/

#if cn != 3
#define loadpix(addr) *(__global const T *)(addr)
#define storepix(val, addr) *(__global T *)(addr) = val
#define TSIZE (int)sizeof(T)
#define scalar scalar_
#else
#define loadpix(addr) vload3(0, (__global const T1 *)(addr))
#define storepix(val, addr) vstore3(val, 0, (__global T1 *)(addr))
#define TSIZE ((int)sizeof(T1)*3)
#define scalar (T)(scalar_.x, scalar_.y, scalar_.z)
#endif

__kernel void setIdentity(__global uchar * srcptr, int src_step, int src_offset, int rows, int cols,
T scalar)
ST scalar_)
{
int x = get_global_id(0);
int y = get_global_id(1);

if (x < cols && y < rows)
{
int src_index = mad24(y, src_step, mad24(x, (int)sizeof(T), src_offset));
__global T * src = (__global T *)(srcptr + src_index);
int src_index = mad24(y, src_step, mad24(x, TSIZE, src_offset));

src[0] = x == y ? scalar : (T)(0);
storepix(x == y ? scalar : (T)(0), srcptr + src_index);
}
}

0 comments on commit 04884eb

Please sign in to comment.