Skip to content

Commit

Permalink
Modified MATLAB API to use the new C++ interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jdrugo committed Oct 12, 2014
1 parent 227c8df commit 06ccb07
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 181 deletions.
133 changes: 22 additions & 111 deletions matlab/ddm_fpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@

#include "../ddm_fpt_lib/ddm_fpt_lib.h"

#include <cmath>
#include <cstdlib>
#include <string>
#include <cassert>
#include <algorithm>
Expand All @@ -66,9 +64,6 @@
/** the gateway function */
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
int mu_size, bound_size, cur_argin, k_max;
int weighted_ddm = 0, normalise_mass = 0;
double *mu, *bound, delta_t, t_max, k = 0.0;
/* [g1, g2] = ddm_fpt(mu, bound, delta_t, t_max) or
[g1, g2] = ddm_fpt(a, k, bound, delta_t, t_max) */

Expand All @@ -95,12 +90,12 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
if (!MEX_ARGIN_IS_REAL_DOUBLE(3))
mexErrMsgIdAndTxt("ddm_fpt:WrongInput",
"Forth input argument expected to be a double");
mu_size = std::max(mxGetN(prhs[0]), mxGetM(prhs[0]));
mu = mxGetPr(prhs[0]);
bound_size = std::max(mxGetN(prhs[1]), mxGetM(prhs[1]));
bound = mxGetPr(prhs[1]);
delta_t = mxGetScalar(prhs[2]);
t_max = mxGetScalar(prhs[3]);
int mu_size = std::max(mxGetN(prhs[0]), mxGetM(prhs[0]));
int bound_size = std::max(mxGetN(prhs[1]), mxGetM(prhs[1]));
ExtArray mu(ExtArray::shared_noowner(mxGetPr(prhs[0])), mu_size);
ExtArray bound(ExtArray::shared_noowner(mxGetPr(prhs[1])), bound_size);
double delta_t = mxGetScalar(prhs[2]);
double t_max = mxGetScalar(prhs[3]);
if (delta_t <= 0.0)
mexErrMsgIdAndTxt("ddm_fpt:WrongInput",
"delta_t needs to be larger than 0.0");
Expand All @@ -109,17 +104,20 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
"t_max needs to be at least as large as delta_t");

/* Process possible 5th non-string argument */
cur_argin = 4;
bool weighted_ddm = false;
int cur_argin = 4;
double k = 0.0;
if (nrhs > 4 && !mxIsChar(prhs[4])) {
if (!MEX_ARGIN_IS_REAL_DOUBLE(4))
mexErrMsgIdAndTxt("ddm_fpt:WrongInput",
"Fifth input argument expected to be a double");
k = mxGetScalar(prhs[4]);
weighted_ddm = 1;
weighted_ddm = true;
++cur_argin;
}

/* Process string arguments */
bool normalise_mass = false;
if (nrhs > cur_argin) {
char str_arg[6];
/* current only accept 'mnorm' string argument */
Expand All @@ -138,116 +136,29 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
(strcmp(str_arg, "yes") != 0 && strcmp(str_arg, "no") != 0))
mexErrMsgIdAndTxt("ddm_fpt:WrongInput",
"\"yes\" or \"no\" expected after \"mnorm\"");
normalise_mass = strcmp(str_arg, "yes") == 0;
normalise_mass = (strcmp(str_arg, "yes") == 0);

/* no arguments allowed after that */
if (nrhs > cur_argin + 2)
mexErrMsgIdAndTxt("ddm_fpt:WrongInputs",
"Too many input arguments");
}

/* extend mu and bound by replicating last element, if necessary */
k_max = (int) ceil(t_max / delta_t);

/* reserve space for output */
int k_max = (int) ceil(t_max / delta_t);
plhs[0] = mxCreateDoubleMatrix(1, k_max, mxREAL);
plhs[1] = mxCreateDoubleMatrix(1, k_max, mxREAL);

if (weighted_ddm) {
/* extend mu and bound by replicating last element */
double *mu_ext, *bound_ext, last_mu, last_bound;
int i, err;

mu_ext = static_cast<double*>(malloc(k_max * sizeof(double)));
bound_ext = static_cast<double*>(malloc(k_max * sizeof(double)));
if (mu_ext == NULL || bound_ext == NULL) {
free(mu_ext);
free(bound_ext);
mexErrMsgIdAndTxt("ddm_fpt:OutOfMemory", "Out of memory");
}

memcpy(mu_ext, mu, sizeof(double) * std::min(mu_size, k_max));
last_mu = mu[mu_size - 1];
for (i = mu_size; i < k_max; ++i)
mu_ext[i] = last_mu;

memcpy(bound_ext, bound, sizeof(double) * std::min(bound_size, k_max));
last_bound = bound[bound_size - 1];
for (i = bound_size; i < k_max; ++i)
bound_ext[i] = last_bound;

/* compute the pdf's with weighted evidence */
err = ddm_fpt_w(mu_ext, bound_ext, k, delta_t, k_max,
mxGetPr(plhs[0]), mxGetPr(plhs[1]));

free(mu_ext);
free(bound_ext);

if (err == -1)
mexErrMsgIdAndTxt("ddm_fpt:OutOfMemory", "Out of memory");
ExtArray g1(ExtArray::shared_noowner(mxGetPr(plhs[0])), k_max);
ExtArray g2(ExtArray::shared_noowner(mxGetPr(plhs[1])), k_max);

} else if (mu_size == 1) {
if (bound_size == 1) {
/* constant bound and drift - can use simpler method */
ddm_fpt_const(mu[0], bound[0], delta_t, k_max,
mxGetPr(plhs[0]), mxGetPr(plhs[1]));
} else {
/* extend bound by replicating last element */
double *bound_ext, last_bound;
int i, err;

bound_ext = static_cast<double*>(malloc(k_max * sizeof(double)));
if (bound_ext == NULL)
mexErrMsgIdAndTxt("ddm_fpt:OutOfMemory", "Out of memory");

memcpy(bound_ext, bound, sizeof(double) * std::min(bound_size, k_max));
last_bound = bound[bound_size - 1];
for (i = bound_size; i < k_max; ++i)
bound_ext[i] = last_bound;

/* constant drift - slightly more efficient */
err = ddm_fpt_const_mu(mu[0], bound_ext, delta_t, k_max,
mxGetPr(plhs[0]), mxGetPr(plhs[1]));

free(bound_ext);

if (err == -1)
mexErrMsgIdAndTxt("ddm_fpt:OutOfMemory", "Out of memory");

}
if (weighted_ddm) {
DMBase* dm = DMBase::createw(mu, bound, k, delta_t);
dm->pdfseq(k_max, g1, g2);
delete dm;
} else {
/* extend mu and bound by replicating last element */
double *mu_ext, *bound_ext, last_mu, last_bound;
int i, err;

mu_ext = static_cast<double*>(malloc(k_max * sizeof(double)));
bound_ext = static_cast<double*>(malloc(k_max * sizeof(double)));
if (mu_ext == NULL || bound_ext == NULL) {
free(mu_ext);
free(bound_ext);
mexErrMsgIdAndTxt("ddm_fpt:OutOfMemory", "Out of memory");
}

memcpy(mu_ext, mu, sizeof(double) * std::min(mu_size, k_max));
last_mu = mu[mu_size - 1];
for (i = mu_size; i < k_max; ++i)
mu_ext[i] = last_mu;

memcpy(bound_ext, bound, sizeof(double) * std::min(bound_size, k_max));
last_bound = bound[bound_size - 1];
for (i = bound_size; i < k_max; ++i)
bound_ext[i] = last_bound;

/* compute the pdf's */
err = ddm_fpt(mu_ext, bound_ext, delta_t, k_max,
mxGetPr(plhs[0]), mxGetPr(plhs[1]));

free(mu_ext);
free(bound_ext);

if (err == -1)
mexErrMsgIdAndTxt("ddm_fpt:OutOfMemory", "Out of memory");

DMBase* dm = DMBase::create(mu, bound, delta_t);
dm->pdfseq(k_max, g1, g2);
delete dm;
}

/* normalise mass, if requested */
Expand Down
108 changes: 38 additions & 70 deletions matlab/ddm_fpt_full.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,9 @@

#include "../ddm_fpt_lib/ddm_fpt_lib.h"

#include <cmath>
#include <cstdlib>
#include <string>
#include <cassert>
#include <algorithm>
#include <limits>

#define MEX_ARGIN_IS_REAL_DOUBLE(arg_idx) (mxIsDouble(prhs[arg_idx]) && !mxIsComplex(prhs[arg_idx]) && mxGetN(prhs[arg_idx]) == 1 && mxGetM(prhs[arg_idx]) == 1)
#define MEX_ARGIN_IS_REAL_VECTOR(arg_idx) (mxIsDouble(prhs[arg_idx]) && !mxIsComplex(prhs[arg_idx]) && ((mxGetN(prhs[arg_idx]) == 1 && mxGetM(prhs[arg_idx]) >= 1) || (mxGetN(prhs[arg_idx]) >= 1 && mxGetM(prhs[arg_idx]) == 1)))
Expand All @@ -63,11 +61,6 @@
/** the gateway function */
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
int mu_size, sig2_size, k_max, err, cur_argin, normalise_mass = 0;
int b_lo_size, b_up_size, b_lo_deriv_size, b_up_deriv_size, has_leak = 0;
double *mu, *sig2, *b_lo, *b_up, *b_lo_deriv, *b_up_deriv, *mu_ext;
double *sig2_ext, *b_lo_ext, *b_up_ext, *b_lo_deriv_ext, *b_up_deriv_ext;
double delta_t, t_max, inv_leak;
/* [g1, g2] = ddm_fpt_full(mu, sig2, b_lo, b_up, b_lo_deriv, b_up_deriv,
delta_t, t_max, [leak]) */

Expand Down Expand Up @@ -106,20 +99,20 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
if (!MEX_ARGIN_IS_REAL_DOUBLE(7))
mexErrMsgIdAndTxt("ddm_fpt_full:WrongInput",
"Eight input argument expected to be a double");
mu_size = std::max(mxGetN(prhs[0]), mxGetM(prhs[0]));
sig2_size = std::max(mxGetN(prhs[1]), mxGetM(prhs[1]));
b_lo_size = std::max(mxGetN(prhs[2]), mxGetM(prhs[2]));
b_up_size = std::max(mxGetN(prhs[3]), mxGetM(prhs[3]));
b_lo_deriv_size = std::max(mxGetN(prhs[4]), mxGetM(prhs[4]));
b_up_deriv_size = std::max(mxGetN(prhs[5]), mxGetM(prhs[5]));
mu = mxGetPr(prhs[0]);
sig2 = mxGetPr(prhs[1]);
b_lo = mxGetPr(prhs[2]);
b_up = mxGetPr(prhs[3]);
b_lo_deriv = mxGetPr(prhs[4]);
b_up_deriv = mxGetPr(prhs[5]);
delta_t = mxGetScalar(prhs[6]);
t_max = mxGetScalar(prhs[7]);
int mu_size = std::max(mxGetN(prhs[0]), mxGetM(prhs[0]));
int sig2_size = std::max(mxGetN(prhs[1]), mxGetM(prhs[1]));
int b_lo_size = std::max(mxGetN(prhs[2]), mxGetM(prhs[2]));
int b_up_size = std::max(mxGetN(prhs[3]), mxGetM(prhs[3]));
int b_lo_deriv_size = std::max(mxGetN(prhs[4]), mxGetM(prhs[4]));
int b_up_deriv_size = std::max(mxGetN(prhs[5]), mxGetM(prhs[5]));
ExtArray mu(ExtArray::shared_noowner(mxGetPr(prhs[0])), mu_size);
ExtArray sig2(ExtArray::shared_noowner(mxGetPr(prhs[1])), sig2_size);
ExtArray b_lo(ExtArray::shared_noowner(mxGetPr(prhs[2])), b_lo_size);
ExtArray b_up(ExtArray::shared_noowner(mxGetPr(prhs[3])), b_up_size);
ExtArray b_lo_deriv(ExtArray::shared_noowner(mxGetPr(prhs[4])), 0.0, b_lo_deriv_size);
ExtArray b_up_deriv(ExtArray::shared_noowner(mxGetPr(prhs[5])), 0.0, b_up_deriv_size);
double delta_t = mxGetScalar(prhs[6]);
double t_max = mxGetScalar(prhs[7]);
if (delta_t <= 0.0)
mexErrMsgIdAndTxt("ddm_fpt_full:WrongInput",
"delta_t needs to be larger than 0.0");
Expand All @@ -128,7 +121,9 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
"t_max needs to be at least as large as delta_t");

/* Process possible 9th non-string argument */
cur_argin = 8;
int cur_argin = 8;
bool has_leak = false;
double inv_leak = std::numeric_limits<double>::infinity();
if (nrhs > cur_argin && !mxIsChar(prhs[cur_argin])) {
if (!MEX_ARGIN_IS_REAL_DOUBLE(cur_argin))
mexErrMsgIdAndTxt("ddm_fpt_full:WrongInput",
Expand All @@ -137,11 +132,12 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
if (inv_leak < 0.0)
mexErrMsgIdAndTxt("ddm_fpt_full:WrongInput",
"inv_leak needs to be non-negative");
has_leak = 1;
has_leak = true;
++cur_argin;
}

/* Process string arguments */
bool normalise_mass = false;
if (nrhs > cur_argin) {
char str_arg[6];
/* current only accept 'mnorm' string argument */
Expand All @@ -160,63 +156,35 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
(strcmp(str_arg, "yes") != 0 && strcmp(str_arg, "no") != 0))
mexErrMsgIdAndTxt("ddm_fpt_full:WrongInput",
"\"yes\" or \"no\" expected after \"mnorm\"");
normalise_mass = strcmp(str_arg, "yes") == 0;
normalise_mass = (strcmp(str_arg, "yes") == 0);

/* no arguments allowed after that */
if (nrhs > cur_argin + 2)
mexErrMsgIdAndTxt("ddm_fpt_full:WrongInputs",
"Too many input arguments");
}

/* extend mu and bound by replicating last element, if necessary */
k_max = (int) ceil(t_max / delta_t);

/* reserve space for output */
plhs[0] = mxCreateDoubleMatrix(1, k_max, mxREAL);
plhs[1] = mxCreateDoubleMatrix(1, k_max, mxREAL);

/* extend vectors by replicating last element */
mu_ext = extend_vector(mu, mu_size, k_max, mu[mu_size - 1]);
sig2_ext = extend_vector(sig2, sig2_size, k_max, sig2[sig2_size - 1]);
b_lo_ext = extend_vector(b_lo, b_lo_size, k_max, b_lo[b_lo_size - 1]);
b_up_ext = extend_vector(b_up, b_up_size, k_max, b_up[b_up_size - 1]);
b_lo_deriv_ext = extend_vector(b_lo_deriv, b_lo_deriv_size, k_max, 0.0);
b_up_deriv_ext = extend_vector(b_up_deriv, b_up_deriv_size, k_max, 0.0);
if (mu_ext == NULL || sig2_ext == NULL ||
b_lo_ext == NULL || b_up_ext == NULL ||
b_lo_deriv_ext == NULL || b_up_deriv_ext == NULL) {
free(mu_ext);
free(sig2_ext);
free(b_lo_ext);
free(b_up_ext);
free(b_lo_deriv_ext);
free(b_up_deriv_ext);
mexErrMsgIdAndTxt("ddm_fpt_full:OutOfMemory", "Out of memory");
}
int n = (int) ceil(t_max / delta_t);
plhs[0] = mxCreateDoubleMatrix(1, n, mxREAL);
plhs[1] = mxCreateDoubleMatrix(1, n, mxREAL);
ExtArray g1(ExtArray::shared_noowner(mxGetPr(plhs[0])), n);
ExtArray g2(ExtArray::shared_noowner(mxGetPr(plhs[1])), n);

/* compute the pdf's */
if (has_leak)
err = ddm_fpt_full_leak(mu_ext, sig2_ext, b_lo_ext, b_up_ext,
b_lo_deriv_ext, b_up_deriv_ext,
inv_leak, delta_t, k_max,
mxGetPr(plhs[0]), mxGetPr(plhs[1]));
else
err = ddm_fpt_full(mu_ext, sig2_ext, b_lo_ext, b_up_ext,
b_lo_deriv_ext, b_up_deriv_ext, delta_t, k_max,
mxGetPr(plhs[0]), mxGetPr(plhs[1]));

free(mu_ext);
free(sig2_ext);
free(b_lo_ext);
free(b_up_ext);
free(b_lo_deriv_ext);
free(b_up_deriv_ext);

if (err == -1)
mexErrMsgIdAndTxt("ddm_fpt_full:OutOfMemory", "Out of memory");
if (has_leak) {
DMBase* dm = DMBase::create(mu, sig2, b_lo, b_up, b_lo_deriv, b_up_deriv,
delta_t, inv_leak);
dm->pdfseq(n, g1, g2);
delete dm;
} else {
DMBase* dm = DMBase::create(mu, sig2, b_lo, b_up, b_lo_deriv, b_up_deriv,
delta_t);
dm->pdfseq(n, g1, g2);
delete dm;
}


/* normalise mass, if requested */
if (normalise_mass)
mnorm(mxGetPr(plhs[0]), mxGetPr(plhs[1]), k_max, delta_t);
mnorm(mxGetPr(plhs[0]), mxGetPr(plhs[1]), n, delta_t);
}

0 comments on commit 06ccb07

Please sign in to comment.