Skip to content

Commit 903506c

Browse files
authored
Add keepdims keyword argument (#701)
* add function to deal with keepdims=True * preliminary keepdims fix * fux keepdims code * remove out-commented code
1 parent 73ed8cc commit 903506c

File tree

7 files changed

+190
-38
lines changed

7 files changed

+190
-38
lines changed

code/numpy/numerical.c

+20-13
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ static mp_obj_t numerical_sum_mean_std_iterable(mp_obj_t oin, uint8_t optype, si
274274
}
275275
}
276276

277-
static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t axis, uint8_t optype, size_t ddof) {
277+
static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t axis, mp_obj_t keepdims, uint8_t optype, size_t ddof) {
278278
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
279279
uint8_t *array = (uint8_t *)ndarray->array;
280280
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
@@ -372,15 +372,15 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
372372
mp_float_t norm = (mp_float_t)_shape_strides.shape[0];
373373
// re-wind the array here
374374
farray = (mp_float_t *)results->array;
375-
for(size_t i=0; i < results->len; i++) {
375+
for(size_t i = 0; i < results->len; i++) {
376376
*farray++ *= norm;
377377
}
378378
}
379379
} else {
380380
bool isStd = optype == NUMERICAL_STD ? 1 : 0;
381381
results = ndarray_new_dense_ndarray(_shape_strides.ndim, _shape_strides.shape, NDARRAY_FLOAT);
382382
farray = (mp_float_t *)results->array;
383-
// we can return the 0 array here, if the degrees of freedom is larger than the length of the axis
383+
// we can return the 0 array here, if the degrees of freedom are larger than the length of the axis
384384
if((optype == NUMERICAL_STD) && (_shape_strides.shape[0] <= ddof)) {
385385
return MP_OBJ_FROM_PTR(results);
386386
}
@@ -397,11 +397,9 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
397397
RUN_MEAN_STD(mp_float_t, array, farray, _shape_strides, div, isStd);
398398
}
399399
}
400-
if(results->ndim == 0) { // return a scalar here
401-
return mp_binary_get_val_array(results->dtype, results->array, 0);
402-
}
403-
return MP_OBJ_FROM_PTR(results);
400+
return ulab_tools_restore_dims(ndarray, results, keepdims, _shape_strides);
404401
}
402+
// we should never get to this point
405403
return mp_const_none;
406404
}
407405
#endif
@@ -441,7 +439,7 @@ static mp_obj_t numerical_argmin_argmax_iterable(mp_obj_t oin, uint8_t optype) {
441439
}
442440
}
443441

444-
static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t axis, uint8_t optype) {
442+
static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t keepdims, mp_obj_t axis, uint8_t optype) {
445443
// TODO: treat the flattened array
446444
if(ndarray->len == 0) {
447445
mp_raise_ValueError(MP_ERROR_TEXT("attempt to get (arg)min/(arg)max of empty sequence"));
@@ -521,7 +519,9 @@ static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t
521519
int32_t *strides = m_new0(int32_t, ULAB_MAX_DIMS);
522520

523521
numerical_reduce_axes(ndarray, ax, shape, strides);
524-
uint8_t index = ULAB_MAX_DIMS - ndarray->ndim + ax;
522+
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
523+
524+
uint8_t index = _shape_strides.axis;
525525

526526
ndarray_obj_t *results = NULL;
527527

@@ -550,8 +550,9 @@ static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t
550550
if(results->len == 1) {
551551
return mp_binary_get_val_array(results->dtype, results->array, 0);
552552
}
553-
return MP_OBJ_FROM_PTR(results);
553+
return ulab_tools_restore_dims(ndarray, results, keepdims, _shape_strides);
554554
}
555+
// we should never get to this point
555556
return mp_const_none;
556557
}
557558
#endif
@@ -560,13 +561,16 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
560561
static const mp_arg_t allowed_args[] = {
561562
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE} } ,
562563
{ MP_QSTR_axis, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
564+
{ MP_QSTR_keepdims, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_FALSE } },
563565
};
564566

565567
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
566568
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
567569

568570
mp_obj_t oin = args[0].u_obj;
569571
mp_obj_t axis = args[1].u_obj;
572+
mp_obj_t keepdims = args[2].u_obj;
573+
570574
if((axis != mp_const_none) && (!mp_obj_is_int(axis))) {
571575
mp_raise_TypeError(MP_ERROR_TEXT("axis must be None, or an integer"));
572576
}
@@ -598,11 +602,11 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
598602
case NUMERICAL_ARGMIN:
599603
case NUMERICAL_ARGMAX:
600604
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
601-
return numerical_argmin_argmax_ndarray(ndarray, axis, optype);
605+
return numerical_argmin_argmax_ndarray(ndarray, keepdims, axis, optype);
602606
case NUMERICAL_SUM:
603607
case NUMERICAL_MEAN:
604608
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
605-
return numerical_sum_mean_std_ndarray(ndarray, axis, optype, 0);
609+
return numerical_sum_mean_std_ndarray(ndarray, axis, keepdims, optype, 0);
606610
default:
607611
mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is not implemented on ndarrays"));
608612
}
@@ -1385,6 +1389,7 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
13851389
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE } } ,
13861390
{ MP_QSTR_axis, MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE } },
13871391
{ MP_QSTR_ddof, MP_ARG_KW_ONLY | MP_ARG_INT, {.u_int = 0} },
1392+
{ MP_QSTR_keepdims, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_FALSE } },
13881393
};
13891394

13901395
mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
@@ -1393,6 +1398,8 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
13931398
mp_obj_t oin = args[0].u_obj;
13941399
mp_obj_t axis = args[1].u_obj;
13951400
size_t ddof = args[2].u_int;
1401+
mp_obj_t keepdims = args[2].u_obj;
1402+
13961403
if((axis != mp_const_none) && (mp_obj_get_int(axis) != 0) && (mp_obj_get_int(axis) != 1)) {
13971404
// this seems to pass with False, and True...
13981405
mp_raise_ValueError(MP_ERROR_TEXT("axis must be None, or an integer"));
@@ -1401,7 +1408,7 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
14011408
return numerical_sum_mean_std_iterable(oin, NUMERICAL_STD, ddof);
14021409
} else if(mp_obj_is_type(oin, &ulab_ndarray_type)) {
14031410
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(oin);
1404-
return numerical_sum_mean_std_ndarray(ndarray, axis, NUMERICAL_STD, ddof);
1411+
return numerical_sum_mean_std_ndarray(ndarray, axis, keepdims, NUMERICAL_STD, ddof);
14051412
} else {
14061413
mp_raise_TypeError(MP_ERROR_TEXT("input must be tuple, list, range, or ndarray"));
14071414
}

code/ulab.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "user/user.h"
3434
#include "utils/utils.h"
3535

36-
#define ULAB_VERSION 6.7.0
36+
#define ULAB_VERSION 6.7.1
3737
#define xstr(s) str(s)
3838
#define str(s) #s
3939

code/ulab_tools.c

+52-24
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,15 @@ void *ndarray_set_float_function(uint8_t dtype) {
162162
}
163163
#endif /* NDARRAY_BINARY_USES_FUN_POINTER */
164164

165+
int8_t tools_get_axis(mp_obj_t axis, uint8_t ndim) {
166+
int8_t ax = mp_obj_get_int(axis);
167+
if(ax < 0) ax += ndim;
168+
if((ax < 0) || (ax > ndim - 1)) {
169+
mp_raise_ValueError(MP_ERROR_TEXT("axis is out of bounds"));
170+
}
171+
return ax;
172+
}
173+
165174
shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
166175
// TODO: replace numerical_reduce_axes with this function, wherever applicable
167176
// This function should be used, whenever a tensor is contracted;
@@ -172,38 +181,36 @@ shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
172181
}
173182
shape_strides _shape_strides;
174183

175-
size_t *shape = m_new(size_t, ULAB_MAX_DIMS + 1);
176-
_shape_strides.shape = shape;
177-
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS + 1);
178-
_shape_strides.strides = strides;
179-
180184
_shape_strides.increment = 0;
181185
// this is the contracted dimension (won't be overwritten for axis == None)
182186
_shape_strides.ndim = 0;
183187

184-
memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS);
185-
memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS);
186-
187188
if(axis == mp_const_none) {
189+
_shape_strides.shape = ndarray->shape;
190+
_shape_strides.strides = ndarray->strides;
188191
return _shape_strides;
189192
}
190193

191-
uint8_t index = ULAB_MAX_DIMS - 1; // value of index for axis == mp_const_none (won't be overwritten)
194+
size_t *shape = m_new(size_t, ULAB_MAX_DIMS + 1);
195+
_shape_strides.shape = shape;
196+
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS + 1);
197+
_shape_strides.strides = strides;
198+
199+
memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS);
200+
memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS);
201+
202+
_shape_strides.axis = ULAB_MAX_DIMS - 1; // value of index for axis == mp_const_none (won't be overwritten)
192203

193204
if(axis != mp_const_none) { // i.e., axis is an integer
194-
int8_t ax = mp_obj_get_int(axis);
195-
if(ax < 0) ax += ndarray->ndim;
196-
if((ax < 0) || (ax > ndarray->ndim - 1)) {
197-
mp_raise_ValueError(MP_ERROR_TEXT("index out of range"));
198-
}
199-
index = ULAB_MAX_DIMS - ndarray->ndim + ax;
205+
int8_t ax = tools_get_axis(axis, ndarray->ndim);
206+
_shape_strides.axis = ULAB_MAX_DIMS - ndarray->ndim + ax;
200207
_shape_strides.ndim = ndarray->ndim - 1;
201208
}
202209

203210
// move the value stored at index to the leftmost position, and align everything else to the right
204-
_shape_strides.shape[0] = ndarray->shape[index];
205-
_shape_strides.strides[0] = ndarray->strides[index];
206-
for(uint8_t i = 0; i < index; i++) {
211+
_shape_strides.shape[0] = ndarray->shape[_shape_strides.axis];
212+
_shape_strides.strides[0] = ndarray->strides[_shape_strides.axis];
213+
for(uint8_t i = 0; i < _shape_strides.axis; i++) {
207214
// entries to the right of index must be shifted by one position to the left
208215
_shape_strides.shape[i + 1] = ndarray->shape[i];
209216
_shape_strides.strides[i + 1] = ndarray->strides[i];
@@ -213,16 +220,37 @@ shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
213220
_shape_strides.increment = 1;
214221
}
215222

223+
if(_shape_strides.ndim == 0) {
224+
_shape_strides.ndim = 1;
225+
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
226+
_shape_strides.strides[ULAB_MAX_DIMS - 1] = ndarray->itemsize;
227+
}
228+
216229
return _shape_strides;
217230
}
218231

219-
int8_t tools_get_axis(mp_obj_t axis, uint8_t ndim) {
220-
int8_t ax = mp_obj_get_int(axis);
221-
if(ax < 0) ax += ndim;
222-
if((ax < 0) || (ax > ndim - 1)) {
223-
mp_raise_ValueError(MP_ERROR_TEXT("axis is out of bounds"));
232+
mp_obj_t ulab_tools_restore_dims(ndarray_obj_t *ndarray, ndarray_obj_t *results, mp_obj_t keepdims, shape_strides _shape_strides) {
233+
// restores the contracted dimension, if keepdims is True
234+
if((ndarray->ndim == 1) && (keepdims != mp_const_true)) {
235+
// since the original array has already been contracted and
236+
// we don't want to keep the dimensions here, we have to return a scalar
237+
return mp_binary_get_val_array(results->dtype, results->array, 0);
224238
}
225-
return ax;
239+
240+
if(keepdims == mp_const_true) {
241+
results->ndim += 1;
242+
for(int8_t i = 0; i < ULAB_MAX_DIMS; i++) {
243+
results->shape[i] = ndarray->shape[i];
244+
}
245+
results->shape[_shape_strides.axis] = 1;
246+
247+
results->strides[ULAB_MAX_DIMS - 1] = ndarray->itemsize;
248+
for(uint8_t i = ULAB_MAX_DIMS; i > 1; i--) {
249+
results->strides[i - 2] = results->strides[i - 1] * results->shape[i - 1];
250+
}
251+
}
252+
253+
return MP_OBJ_FROM_PTR(results);
226254
}
227255

228256
#if ULAB_MAX_DIMS > 1

code/ulab_tools.h

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
typedef struct _shape_strides_t {
1919
uint8_t increment;
20+
uint8_t axis;
2021
uint8_t ndim;
2122
size_t *shape;
2223
int32_t *strides;
@@ -34,6 +35,7 @@ void *ndarray_set_float_function(uint8_t );
3435

3536
shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t );
3637
int8_t tools_get_axis(mp_obj_t , uint8_t );
38+
mp_obj_t ulab_tools_restore_dims(ndarray_obj_t * , ndarray_obj_t * , mp_obj_t , shape_strides );
3739
ndarray_obj_t *tools_object_is_square(mp_obj_t );
3840

3941
uint8_t ulab_binary_get_size(uint8_t );

docs/ulab-change-log.md

+12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
Mon, 30 Dec 2024
2+
3+
version 6.7.1
4+
5+
add keepdims keyword argument to numerical functions
6+
7+
Sun, 15 Dec 2024
8+
9+
version 6.7.0
10+
11+
add scipy.integrate module
12+
113
Sun, 24 Nov 2024
214

315
version 6.6.1

tests/2d/numpy/sum.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
try:
2+
from ulab import numpy as np
3+
except ImportError:
4+
import numpy as np
5+
6+
for dtype in (np.uint8, np.int8, np.uint16, np.int8, np.float):
7+
a = np.array(range(12), dtype=dtype)
8+
b = a.reshape((3, 4))
9+
10+
print(a)
11+
print(b)
12+
print()
13+
14+
print(np.sum(a))
15+
print(np.sum(a, axis=0))
16+
print(np.sum(a, axis=0, keepdims=True))
17+
18+
print()
19+
print(np.sum(b))
20+
print(np.sum(b, axis=0))
21+
print(np.sum(b, axis=1))
22+
print(np.sum(b, axis=0, keepdims=True))
23+
print(np.sum(b, axis=1, keepdims=True))

tests/2d/numpy/sum.py.exp

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
array([0, 1, 2, ..., 9, 10, 11], dtype=uint8)
2+
array([[0, 1, 2, 3],
3+
[4, 5, 6, 7],
4+
[8, 9, 10, 11]], dtype=uint8)
5+
6+
66
7+
66
8+
array([66], dtype=uint8)
9+
10+
66
11+
array([12, 15, 18, 21], dtype=uint8)
12+
array([6, 22, 38], dtype=uint8)
13+
array([[12, 15, 18, 21]], dtype=uint8)
14+
array([[6],
15+
[22],
16+
[38]], dtype=uint8)
17+
array([0, 1, 2, ..., 9, 10, 11], dtype=int8)
18+
array([[0, 1, 2, 3],
19+
[4, 5, 6, 7],
20+
[8, 9, 10, 11]], dtype=int8)
21+
22+
66
23+
66
24+
array([66], dtype=int8)
25+
26+
66
27+
array([12, 15, 18, 21], dtype=int8)
28+
array([6, 22, 38], dtype=int8)
29+
array([[12, 15, 18, 21]], dtype=int8)
30+
array([[6],
31+
[22],
32+
[38]], dtype=int8)
33+
array([0, 1, 2, ..., 9, 10, 11], dtype=uint16)
34+
array([[0, 1, 2, 3],
35+
[4, 5, 6, 7],
36+
[8, 9, 10, 11]], dtype=uint16)
37+
38+
66
39+
66
40+
array([66], dtype=uint16)
41+
42+
66
43+
array([12, 15, 18, 21], dtype=uint16)
44+
array([6, 22, 38], dtype=uint16)
45+
array([[12, 15, 18, 21]], dtype=uint16)
46+
array([[6],
47+
[22],
48+
[38]], dtype=uint16)
49+
array([0, 1, 2, ..., 9, 10, 11], dtype=int8)
50+
array([[0, 1, 2, 3],
51+
[4, 5, 6, 7],
52+
[8, 9, 10, 11]], dtype=int8)
53+
54+
66
55+
66
56+
array([66], dtype=int8)
57+
58+
66
59+
array([12, 15, 18, 21], dtype=int8)
60+
array([6, 22, 38], dtype=int8)
61+
array([[12, 15, 18, 21]], dtype=int8)
62+
array([[6],
63+
[22],
64+
[38]], dtype=int8)
65+
array([0.0, 1.0, 2.0, ..., 9.0, 10.0, 11.0], dtype=float64)
66+
array([[0.0, 1.0, 2.0, 3.0],
67+
[4.0, 5.0, 6.0, 7.0],
68+
[8.0, 9.0, 10.0, 11.0]], dtype=float64)
69+
70+
66.0
71+
66.0
72+
array([66.0], dtype=float64)
73+
74+
66.0
75+
array([12.0, 15.0, 18.0, 21.0], dtype=float64)
76+
array([6.0, 22.0, 38.0], dtype=float64)
77+
array([[12.0, 15.0, 18.0, 21.0]], dtype=float64)
78+
array([[6.0],
79+
[22.0],
80+
[38.0]], dtype=float64)

0 commit comments

Comments
 (0)