Skip to content

Commit

Permalink
MAINT: simplify cast setup in flatiter internals
Browse files Browse the repository at this point in the history
  • Loading branch information
ngoldbaum committed Jul 17, 2023
1 parent 9c3144d commit e2259b2
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 110 deletions.
154 changes: 58 additions & 96 deletions numpy/core/src/multiarray/iterators.c
Original file line number Diff line number Diff line change
Expand Up @@ -418,15 +418,14 @@ iter_length(PyArrayIterObject *self)


static PyArrayObject *
iter_subscript_Bool(PyArrayIterObject *self, PyArrayObject *ind)
iter_subscript_Bool(PyArrayIterObject *self, PyArrayObject *ind,
NPY_cast_info *cast_info)
{
npy_intp counter, strides;
int itemsize;
npy_intp count = 0;
char *dptr, *optr;
PyArrayObject *ret;
NPY_cast_info cast_info = {.func = NULL};


if (PyArray_NDIM(ind) != 1) {
PyErr_SetString(PyExc_ValueError,
Expand Down Expand Up @@ -459,22 +458,13 @@ iter_subscript_Bool(PyArrayIterObject *self, PyArrayObject *ind)
optr = PyArray_DATA(ret);
counter = PyArray_DIMS(ind)[0];
dptr = PyArray_DATA(ind);
NPY_ARRAYMETHOD_FLAGS transfer_flags = 0;
npy_intp one = 1;
npy_intp transfer_strides[2] = {itemsize, itemsize};
/* We can assume the newly allocated output array is aligned */
int is_aligned = IsUintAligned(self->ao);
if (PyArray_GetDTypeTransferFunction(
is_aligned, itemsize, itemsize, dtype, dtype, 0,
&cast_info, &transfer_flags) < 0) {
return NULL;
}
while (counter--) {
if (*((npy_bool *)dptr) != 0) {
char *args[2] = {self->dataptr, optr};
if (cast_info.func(&cast_info.context, args, &one,
transfer_strides, cast_info.auxdata) < 0) {
NPY_cast_info_xfree(&cast_info);
npy_intp transfer_strides[2] = {itemsize, itemsize};
if (cast_info->func(&cast_info->context, args, &one,
transfer_strides, cast_info->auxdata) < 0) {
return NULL;
}
optr += itemsize;
Expand All @@ -484,22 +474,20 @@ iter_subscript_Bool(PyArrayIterObject *self, PyArrayObject *ind)
}
PyArray_ITER_RESET(self);
}
NPY_cast_info_xfree(&cast_info);
return ret;
}

static PyObject *
iter_subscript_int(PyArrayIterObject *self, PyArrayObject *ind)
iter_subscript_int(PyArrayIterObject *self, PyArrayObject *ind,
NPY_cast_info *cast_info)
{
npy_intp num;
PyArrayObject *ret;
PyArrayIterObject *ind_it;
int itemsize;
char *optr;
npy_intp counter;
NPY_cast_info cast_info = {.func = NULL};

itemsize = PyArray_DESCR(self->ao)->elsize;
if (PyArray_NDIM(ind) == 0) {
num = *((npy_intp *)PyArray_DATA(ind));
if (check_and_adjust_index(&num, self->size, -1, NULL) < 0) {
Expand Down Expand Up @@ -533,42 +521,31 @@ iter_subscript_int(PyArrayIterObject *self, PyArrayObject *ind)
return NULL;
}

NPY_ARRAYMETHOD_FLAGS transfer_flags = 0;
npy_intp one = 1;
npy_intp transfer_strides[2] = {itemsize, itemsize};
/* We can assume the newly allocated output array is aligned */
int is_aligned = IsUintAligned(self->ao);
if (PyArray_GetDTypeTransferFunction(
is_aligned, itemsize, itemsize, dtype, dtype, 0, &cast_info,
&transfer_flags) < 0) {
return NULL;
}

itemsize = dtype->elsize;
counter = ind_it->size;
while (counter--) {
num = *((npy_intp *)(ind_it->dataptr));
if (check_and_adjust_index(&num, self->size, -1, NULL) < 0) {
Py_DECREF(ind_it);
Py_DECREF(ret);
PyArray_ITER_RESET(self);
NPY_cast_info_xfree(&cast_info);
return NULL;
}
PyArray_ITER_GOTO1D(self, num);
char *args[2] = {self->dataptr, optr};
if (cast_info.func(&cast_info.context, args, &one,
transfer_strides, cast_info.auxdata) < 0) {
npy_intp transfer_strides[2] = {itemsize, itemsize};
if (cast_info->func(&cast_info->context, args, &one,
transfer_strides, cast_info->auxdata) < 0) {
Py_DECREF(ind_it);
Py_DECREF(ret);
PyArray_ITER_RESET(self);
NPY_cast_info_xfree(&cast_info);
return NULL;
}
optr += itemsize;
PyArray_ITER_NEXT(ind_it);
}
Py_DECREF(ind_it);
NPY_cast_info_xfree(&cast_info);
PyArray_ITER_RESET(self);
return (PyObject *)ret;
}
Expand Down Expand Up @@ -631,6 +608,21 @@ iter_subscript(PyArrayIterObject *self, PyObject *ind)
}
}

dtype = PyArray_DESCR(self->ao);
size = dtype->elsize;

/* set up a cast to handle item copying */

NPY_ARRAYMETHOD_FLAGS transfer_flags = 0;
npy_intp one = 1;
/* We can assume the newly allocated output array is aligned */
int is_aligned = IsUintAligned(self->ao);
if (PyArray_GetDTypeTransferFunction(
is_aligned, size, size, dtype, dtype, 0, &cast_info,
&transfer_flags) < 0) {
goto fail;
}

/* Check for Integer or Slice */
if (PyLong_Check(ind) || PySlice_Check(ind)) {
start = parse_index_entry(ind, &step_size, &n_steps,
Expand All @@ -648,10 +640,9 @@ iter_subscript(PyArrayIterObject *self, PyObject *ind)
PyObject *tmp;
tmp = PyArray_ToScalar(self->dataptr, self->ao);
PyArray_ITER_RESET(self);
NPY_cast_info_xfree(&cast_info);
return tmp;
}
size = PyArray_DESCR(self->ao)->elsize;
dtype = PyArray_DESCR(self->ao);
Py_INCREF(dtype);
ret = (PyArrayObject *)PyArray_NewFromDescr(Py_TYPE(self->ao),
dtype,
Expand All @@ -662,18 +653,9 @@ iter_subscript(PyArrayIterObject *self, PyObject *ind)
goto fail;
}
dptr = PyArray_DATA(ret);
NPY_ARRAYMETHOD_FLAGS transfer_flags = 0;
npy_intp one = 1;
npy_intp transfer_strides[2] = {size, size};
/* We can assume the newly allocated output array is aligned */
int is_aligned = IsUintAligned(self->ao);
if (PyArray_GetDTypeTransferFunction(
is_aligned, size, size, dtype, dtype, 0, &cast_info,
&transfer_flags) < 0) {
goto fail;
}
while (n_steps--) {
char *args[2] = {self->dataptr, dptr};
npy_intp transfer_strides[2] = {size, size};
if (cast_info.func(&cast_info.context, args, &one,
transfer_strides, cast_info.auxdata) < 0) {
goto fail;
Expand All @@ -683,6 +665,7 @@ iter_subscript(PyArrayIterObject *self, PyObject *ind)
dptr += size;
}
PyArray_ITER_RESET(self);
NPY_cast_info_xfree(&cast_info);
return (PyObject *)ret;
}

Expand All @@ -707,10 +690,8 @@ iter_subscript(PyArrayIterObject *self, PyObject *ind)

/* Check for Boolean array */
if (PyArray_TYPE((PyArrayObject *)obj) == NPY_BOOL) {
ret = iter_subscript_Bool(self, (PyArrayObject *)obj);
Py_DECREF(indtype);
Py_DECREF(obj);
return (PyObject *)ret;
ret = iter_subscript_Bool(self, (PyArrayObject *)obj, &cast_info);
goto finish;
}

/* Only integer arrays left */
Expand All @@ -724,14 +705,16 @@ iter_subscript(PyArrayIterObject *self, PyObject *ind)
if (new == NULL) {
goto fail;
}
ret = (PyArrayObject *)iter_subscript_int(self, (PyArrayObject *)new,
&cast_info);
Py_DECREF(new);

finish:
Py_DECREF(indtype);
Py_DECREF(obj);
NPY_cast_info_xfree(&cast_info);
ret = (PyArrayObject *)iter_subscript_int(self, (PyArrayObject *)new);
Py_DECREF(new);
return (PyObject *)ret;


fail:
if (!PyErr_Occurred()) {
PyErr_SetString(PyExc_IndexError, "unsupported iterator index");
Expand All @@ -747,11 +730,10 @@ iter_subscript(PyArrayIterObject *self, PyObject *ind)

static int
iter_ass_sub_Bool(PyArrayIterObject *self, PyArrayObject *ind,
PyArrayIterObject *val)
PyArrayIterObject *val, NPY_cast_info *cast_info)
{
npy_intp counter, strides;
char *dptr;
NPY_cast_info cast_info = {.func = NULL};

if (PyArray_NDIM(ind) != 1) {
PyErr_SetString(PyExc_ValueError,
Expand All @@ -770,23 +752,15 @@ iter_ass_sub_Bool(PyArrayIterObject *self, PyArrayObject *ind,
dptr = PyArray_DATA(ind);
PyArray_ITER_RESET(self);
/* Loop over Boolean array */
NPY_ARRAYMETHOD_FLAGS transfer_flags = 0;
npy_intp one = 1;
PyArray_Descr *dtype = PyArray_DESCR(self->ao);
int itemsize = dtype->elsize;
npy_intp transfer_strides[2] = {itemsize, itemsize};
int is_aligned = IsUintAligned(self->ao) && IsUintAligned(val->ao);
if (PyArray_GetDTypeTransferFunction(
is_aligned, itemsize, itemsize, dtype, dtype, 0,
&cast_info, &transfer_flags) < 0) {
return -1;
}
while (counter--) {
if (*((npy_bool *)dptr) != 0) {
char *args[2] = {val->dataptr, self->dataptr};
if (cast_info.func(&cast_info.context, args, &one,
transfer_strides, cast_info.auxdata) < 0) {
NPY_cast_info_xfree(&cast_info);
if (cast_info->func(&cast_info->context, args, &one,
transfer_strides, cast_info->auxdata) < 0) {
return -1;
}
PyArray_ITER_NEXT(val);
Expand All @@ -798,30 +772,21 @@ iter_ass_sub_Bool(PyArrayIterObject *self, PyArrayObject *ind,
PyArray_ITER_NEXT(self);
}
PyArray_ITER_RESET(self);
NPY_cast_info_xfree(&cast_info);
return 0;
}

static int
iter_ass_sub_int(PyArrayIterObject *self, PyArrayObject *ind,
PyArrayIterObject *val)
PyArrayIterObject *val, NPY_cast_info *cast_info)
{
npy_intp num;
PyArrayIterObject *ind_it;
npy_intp counter;
NPY_cast_info cast_info = {.func = NULL};

NPY_ARRAYMETHOD_FLAGS transfer_flags = 0;
npy_intp one = 1;
PyArray_Descr *dtype = PyArray_DESCR(self->ao);
int itemsize = dtype->elsize;
npy_intp transfer_strides[2] = {itemsize, itemsize};
int is_aligned = IsUintAligned(self->ao) && IsUintAligned(val->ao);
if (PyArray_GetDTypeTransferFunction(
is_aligned, itemsize, itemsize, dtype, dtype, 0,
&cast_info, &transfer_flags) < 0) {
return -1;
}

if (PyArray_NDIM(ind) == 0) {
num = *((npy_intp *)PyArray_DATA(ind));
Expand All @@ -830,17 +795,14 @@ iter_ass_sub_int(PyArrayIterObject *self, PyArrayObject *ind,
}
PyArray_ITER_GOTO1D(self, num);
char *args[2] = {val->dataptr, self->dataptr};
if (cast_info.func(&cast_info.context, args, &one,
transfer_strides, cast_info.auxdata) < 0) {
NPY_cast_info_xfree(&cast_info);
if (cast_info->func(&cast_info->context, args, &one,
transfer_strides, cast_info->auxdata) < 0) {
return -1;
}
NPY_cast_info_xfree(&cast_info);
return 0;
}
ind_it = (PyArrayIterObject *)PyArray_IterNew((PyObject *)ind);
if (ind_it == NULL) {
NPY_cast_info_xfree(&cast_info);
return -1;
}
counter = ind_it->size;
Expand All @@ -852,9 +814,8 @@ iter_ass_sub_int(PyArrayIterObject *self, PyArrayObject *ind,
}
PyArray_ITER_GOTO1D(self, num);
char *args[2] = {val->dataptr, self->dataptr};
if (cast_info.func(&cast_info.context, args, &one,
transfer_strides, cast_info.auxdata) < 0) {
NPY_cast_info_xfree(&cast_info);
if (cast_info->func(&cast_info->context, args, &one,
transfer_strides, cast_info->auxdata) < 0) {
Py_DECREF(ind_it);
return -1;
}
Expand All @@ -865,7 +826,6 @@ iter_ass_sub_int(PyArrayIterObject *self, PyArrayObject *ind,
}
}
Py_DECREF(ind_it);
NPY_cast_info_xfree(&cast_info);
return 0;
}

Expand Down Expand Up @@ -959,6 +919,18 @@ iter_ass_subscript(PyArrayIterObject *self, PyObject *ind, PyObject *val)
goto finish;
}

/* set up cast to handle single-element copies into arrval */
NPY_ARRAYMETHOD_FLAGS transfer_flags = 0;
npy_intp one = 1;
int itemsize = type->elsize;
/* We can assume the newly allocated array is aligned */
int is_aligned = IsUintAligned(self->ao);
if (PyArray_GetDTypeTransferFunction(
is_aligned, itemsize, itemsize, type, type, 0,
&cast_info, &transfer_flags) < 0) {
goto finish;
}

/* Check Slice */
if (PySlice_Check(ind)) {
start = parse_index_entry(ind, &step_size, &n_steps, self->size, 0, 0);
Expand All @@ -971,17 +943,7 @@ iter_ass_subscript(PyArrayIterObject *self, PyObject *ind, PyObject *val)
goto finish;
}
PyArray_ITER_GOTO1D(self, start);
NPY_ARRAYMETHOD_FLAGS transfer_flags = 0;
npy_intp one = 1;
int itemsize = type->elsize;
npy_intp transfer_strides[2] = {itemsize, itemsize};
/* We can assume the newly allocated arrval is aligned */
int is_aligned = IsUintAligned(self->ao);
if (PyArray_GetDTypeTransferFunction(
is_aligned, itemsize, itemsize, type, type, 0,
&cast_info, &transfer_flags) < 0) {
goto finish;
}
if (n_steps == SINGLE_INDEX) {
char *args[2] = {PyArray_DATA(arrval), self->dataptr};
if (cast_info.func(&cast_info.context, args, &one,
Expand Down Expand Up @@ -1025,7 +987,7 @@ iter_ass_subscript(PyArrayIterObject *self, PyObject *ind, PyObject *val)
/* Check for Boolean object */
if (PyArray_TYPE((PyArrayObject *)obj)==NPY_BOOL) {
if (iter_ass_sub_Bool(self, (PyArrayObject *)obj,
val_it) < 0) {
val_it, &cast_info) < 0) {
goto finish;
}
retval=0;
Expand All @@ -1042,7 +1004,7 @@ iter_ass_subscript(PyArrayIterObject *self, PyObject *ind, PyObject *val)
goto finish;
}
if (iter_ass_sub_int(self, (PyArrayObject *)obj,
val_it) < 0) {
val_it, &cast_info) < 0) {
goto finish;
}
retval = 0;
Expand Down
Loading

0 comments on commit e2259b2

Please sign in to comment.