Skip to content

Commit 969b5cf

Browse files
author
Luke Sewell
committed
Fix more generically, using itemsize rather than checking strides
1 parent 9f026eb commit 969b5cf

File tree

2 files changed

+46
-79
lines changed

2 files changed

+46
-79
lines changed

bottleneck/include/iterators.h

+35-76
Original file line numberDiff line numberDiff line change
@@ -98,86 +98,45 @@ static inline void init_iter_all(iter *it, PyArrayObject *a, int ravel, int anyo
9898
it->ndim_m2 = -1;
9999
it->length = 1;
100100
it->astride = 0;
101+
} else if (F_CONTIGUOUS(a) && !C_CONTIGUOUS(a) && ravel && !anyorder) {
102+
it->ndim_m2 = -1;
103+
a = (PyArrayObject *)PyArray_Ravel(a, NPY_CORDER);
104+
it->a_ravel = a;
105+
it->length = PyArray_DIM(a, 0);
106+
it->astride = PyArray_STRIDE(a, 0);
107+
} else if (C_CONTIGUOUS(a) || F_CONTIGUOUS(a)) {
108+
/* If continguous then we just need the itemsize */
109+
it->ndim_m2 = -1;
110+
// it->axis does not matter
111+
it->length = PyArray_SIZE(a);
112+
it->astride = item_size;
113+
} else if (ravel) {
114+
it->ndim_m2 = -1;
115+
if (anyorder) {
116+
a = (PyArrayObject *)PyArray_Ravel(a, NPY_ANYORDER);
117+
} else {
118+
a = (PyArrayObject *)PyArray_Ravel(a, NPY_CORDER);
119+
}
120+
it->a_ravel = a;
121+
it->length = PyArray_DIM(a, 0);
122+
it->astride = PyArray_STRIDE(a, 0);
101123
} else {
102-
/* If strides aren't in descending order, some of the assumptions for C_CONTIGUOUS don't hold */
103-
int strides_descending = 1;
124+
it->ndim_m2 = ndim - 2;
125+
it->astride = strides[0];
104126
for (i = 1; i < ndim; i++) {
105-
if (strides[i] > strides[i-1]) {
106-
strides_descending = 0;
107-
break;
108-
}
109-
}
110-
111-
if (strides_descending && C_CONTIGUOUS(a) && !F_CONTIGUOUS(a)) {
112-
113-
/* The &&! in the next two else ifs is to deal with relaxed
114-
* stride checking introduced in numpy 1.12.0; see gh #161 */
115-
it->ndim_m2 = -1;
116-
it->axis = ndim - 1;
117-
it->length = PyArray_SIZE(a);
118-
it->astride = 0;
119-
for (i = ndim - 1; i > -1; i--) {
120-
/* protect against length zero strides such as in
121-
* np.ones((2, 2))[..., np.newaxis] */
122-
if (strides[i] == 0) {
123-
continue;
124-
}
127+
if (strides[i] < it->astride) {
125128
it->astride = strides[i];
126-
break;
127-
}
128-
} else if (F_CONTIGUOUS(a) && !C_CONTIGUOUS(a)) {
129-
if (anyorder || !ravel) {
130-
it->ndim_m2 = -1;
131-
it->length = PyArray_SIZE(a);
132-
it->astride = 0;
133-
for (i = 0; i < ndim; i++) {
134-
/* protect against length zero strides such as in
135-
* np.ones((2, 2), order='F')[np.newaxis, ...] */
136-
if (strides[i] == 0) {
137-
continue;
138-
}
139-
it->astride = strides[i];
140-
break;
141-
}
142-
} else {
143-
it->ndim_m2 = -1;
144-
if (anyorder) {
145-
a = (PyArrayObject *)PyArray_Ravel(a, NPY_ANYORDER);
146-
} else {
147-
a = (PyArrayObject *)PyArray_Ravel(a, NPY_CORDER);
148-
}
149-
it->a_ravel = a;
150-
it->length = PyArray_DIM(a, 0);
151-
it->astride = PyArray_STRIDE(a, 0);
152-
}
153-
} else if (ravel) {
154-
it->ndim_m2 = -1;
155-
if (anyorder) {
156-
a = (PyArrayObject *)PyArray_Ravel(a, NPY_ANYORDER);
157-
} else {
158-
a = (PyArrayObject *)PyArray_Ravel(a, NPY_CORDER);
129+
it->axis = i;
159130
}
160-
it->a_ravel = a;
161-
it->length = PyArray_DIM(a, 0);
162-
it->astride = PyArray_STRIDE(a, 0);
163-
} else {
164-
it->ndim_m2 = ndim - 2;
165-
it->astride = strides[0];
166-
for (i = 1; i < ndim; i++) {
167-
if (strides[i] < it->astride) {
168-
it->astride = strides[i];
169-
it->axis = i;
170-
}
171-
}
172-
it->length = shape[it->axis];
173-
for (i = 0; i < ndim; i++) {
174-
if (i != it->axis) {
175-
it->indices[j] = 0;
176-
it->astrides[j] = strides[i];
177-
it->shape[j] = shape[i];
178-
it->nits *= shape[i];
179-
j++;
180-
}
131+
}
132+
it->length = shape[it->axis];
133+
for (i = 0; i < ndim; i++) {
134+
if (i != it->axis) {
135+
it->indices[j] = 0;
136+
it->astrides[j] = strides[i];
137+
it->shape[j] = shape[i];
138+
it->nits *= shape[i];
139+
j++;
181140
}
182141
}
183142
}

bottleneck/tests/reduce_test.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,15 @@ def test_ddof_nans(func, dtype) -> None:
306306

307307
@pytest.mark.parametrize("dtype", DTYPES)
308308
@pytest.mark.parametrize("func", (bn.nanmean, bn.nanmax), ids=lambda x: x.__name__)
309-
def test_reduce_with_unordered_strides(func, dtype) -> None:
310-
array = np.zeros((1, 500, 2), dtype=dtype).transpose((1,2,0))
309+
def test_reduce_with_unordered_strides_ccontig(func, dtype) -> None:
310+
array = np.ones((1, 500, 2), dtype=dtype).transpose((1,2,0))
311311
result = func(array)
312-
assert result == 0
312+
assert result == 1000
313+
314+
@pytest.mark.parametrize("dtype", DTYPES)
315+
@pytest.mark.parametrize("func", (bn.nanmean, bn.nanmax), ids=lambda x: x.__name__)
316+
def test_reduce_with_unordered_strides_fcontig(func, dtype) -> None:
317+
array = np.ones((1, 500, 2), dtype=dtype).transpose((0,2,1))
318+
result = func(array)
319+
assert result == 1000
320+

0 commit comments

Comments
 (0)