@@ -274,7 +274,7 @@ static mp_obj_t numerical_sum_mean_std_iterable(mp_obj_t oin, uint8_t optype, si
274
274
}
275
275
}
276
276
277
- static mp_obj_t numerical_sum_mean_std_ndarray (ndarray_obj_t * ndarray , mp_obj_t axis , uint8_t optype , size_t ddof ) {
277
+ static mp_obj_t numerical_sum_mean_std_ndarray (ndarray_obj_t * ndarray , mp_obj_t axis , mp_obj_t keepdims , uint8_t optype , size_t ddof ) {
278
278
COMPLEX_DTYPE_NOT_IMPLEMENTED (ndarray -> dtype )
279
279
uint8_t * array = (uint8_t * )ndarray -> array ;
280
280
shape_strides _shape_strides = tools_reduce_axes (ndarray , axis );
@@ -372,15 +372,15 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
372
372
mp_float_t norm = (mp_float_t )_shape_strides .shape [0 ];
373
373
// re-wind the array here
374
374
farray = (mp_float_t * )results -> array ;
375
- for (size_t i = 0 ; i < results -> len ; i ++ ) {
375
+ for (size_t i = 0 ; i < results -> len ; i ++ ) {
376
376
* farray ++ *= norm ;
377
377
}
378
378
}
379
379
} else {
380
380
bool isStd = optype == NUMERICAL_STD ? 1 : 0 ;
381
381
results = ndarray_new_dense_ndarray (_shape_strides .ndim , _shape_strides .shape , NDARRAY_FLOAT );
382
382
farray = (mp_float_t * )results -> array ;
383
- // we can return the 0 array here, if the degrees of freedom is larger than the length of the axis
383
+ // we can return the 0 array here, if the degrees of freedom are larger than the length of the axis
384
384
if ((optype == NUMERICAL_STD ) && (_shape_strides .shape [0 ] <= ddof )) {
385
385
return MP_OBJ_FROM_PTR (results );
386
386
}
@@ -397,11 +397,9 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
397
397
RUN_MEAN_STD (mp_float_t , array , farray , _shape_strides , div , isStd );
398
398
}
399
399
}
400
- if (results -> ndim == 0 ) { // return a scalar here
401
- return mp_binary_get_val_array (results -> dtype , results -> array , 0 );
402
- }
403
- return MP_OBJ_FROM_PTR (results );
400
+ return ulab_tools_restore_dims (ndarray , results , keepdims , _shape_strides );
404
401
}
402
+ // we should never get to this point
405
403
return mp_const_none ;
406
404
}
407
405
#endif
@@ -441,7 +439,7 @@ static mp_obj_t numerical_argmin_argmax_iterable(mp_obj_t oin, uint8_t optype) {
441
439
}
442
440
}
443
441
444
- static mp_obj_t numerical_argmin_argmax_ndarray (ndarray_obj_t * ndarray , mp_obj_t axis , uint8_t optype ) {
442
+ static mp_obj_t numerical_argmin_argmax_ndarray (ndarray_obj_t * ndarray , mp_obj_t keepdims , mp_obj_t axis , uint8_t optype ) {
445
443
// TODO: treat the flattened array
446
444
if (ndarray -> len == 0 ) {
447
445
mp_raise_ValueError (MP_ERROR_TEXT ("attempt to get (arg)min/(arg)max of empty sequence" ));
@@ -521,7 +519,9 @@ static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t
521
519
int32_t * strides = m_new0 (int32_t , ULAB_MAX_DIMS );
522
520
523
521
numerical_reduce_axes (ndarray , ax , shape , strides );
524
- uint8_t index = ULAB_MAX_DIMS - ndarray -> ndim + ax ;
522
+ shape_strides _shape_strides = tools_reduce_axes (ndarray , axis );
523
+
524
+ uint8_t index = _shape_strides .axis ;
525
525
526
526
ndarray_obj_t * results = NULL ;
527
527
@@ -550,8 +550,9 @@ static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t
550
550
if (results -> len == 1 ) {
551
551
return mp_binary_get_val_array (results -> dtype , results -> array , 0 );
552
552
}
553
- return MP_OBJ_FROM_PTR ( results );
553
+ return ulab_tools_restore_dims ( ndarray , results , keepdims , _shape_strides );
554
554
}
555
+ // we should never get to this point
555
556
return mp_const_none ;
556
557
}
557
558
#endif
@@ -560,13 +561,16 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
560
561
static const mp_arg_t allowed_args [] = {
561
562
{ MP_QSTR_ , MP_ARG_REQUIRED | MP_ARG_OBJ , { .u_rom_obj = MP_ROM_NONE } } ,
562
563
{ MP_QSTR_axis , MP_ARG_OBJ , { .u_rom_obj = MP_ROM_NONE } },
564
+ { MP_QSTR_keepdims , MP_ARG_OBJ , { .u_rom_obj = MP_ROM_FALSE } },
563
565
};
564
566
565
567
mp_arg_val_t args [MP_ARRAY_SIZE (allowed_args )];
566
568
mp_arg_parse_all (n_args , pos_args , kw_args , MP_ARRAY_SIZE (allowed_args ), allowed_args , args );
567
569
568
570
mp_obj_t oin = args [0 ].u_obj ;
569
571
mp_obj_t axis = args [1 ].u_obj ;
572
+ mp_obj_t keepdims = args [2 ].u_obj ;
573
+
570
574
if ((axis != mp_const_none ) && (!mp_obj_is_int (axis ))) {
571
575
mp_raise_TypeError (MP_ERROR_TEXT ("axis must be None, or an integer" ));
572
576
}
@@ -598,11 +602,11 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
598
602
case NUMERICAL_ARGMIN :
599
603
case NUMERICAL_ARGMAX :
600
604
COMPLEX_DTYPE_NOT_IMPLEMENTED (ndarray -> dtype )
601
- return numerical_argmin_argmax_ndarray (ndarray , axis , optype );
605
+ return numerical_argmin_argmax_ndarray (ndarray , keepdims , axis , optype );
602
606
case NUMERICAL_SUM :
603
607
case NUMERICAL_MEAN :
604
608
COMPLEX_DTYPE_NOT_IMPLEMENTED (ndarray -> dtype )
605
- return numerical_sum_mean_std_ndarray (ndarray , axis , optype , 0 );
609
+ return numerical_sum_mean_std_ndarray (ndarray , axis , keepdims , optype , 0 );
606
610
default :
607
611
mp_raise_NotImplementedError (MP_ERROR_TEXT ("operation is not implemented on ndarrays" ));
608
612
}
@@ -1385,6 +1389,7 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
1385
1389
{ MP_QSTR_ , MP_ARG_REQUIRED | MP_ARG_OBJ , {.u_rom_obj = MP_ROM_NONE } } ,
1386
1390
{ MP_QSTR_axis , MP_ARG_OBJ , {.u_rom_obj = MP_ROM_NONE } },
1387
1391
{ MP_QSTR_ddof , MP_ARG_KW_ONLY | MP_ARG_INT , {.u_int = 0 } },
1392
+ { MP_QSTR_keepdims , MP_ARG_OBJ , { .u_rom_obj = MP_ROM_FALSE } },
1388
1393
};
1389
1394
1390
1395
mp_arg_val_t args [MP_ARRAY_SIZE (allowed_args )];
@@ -1393,6 +1398,8 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
1393
1398
mp_obj_t oin = args [0 ].u_obj ;
1394
1399
mp_obj_t axis = args [1 ].u_obj ;
1395
1400
size_t ddof = args [2 ].u_int ;
1401
+ mp_obj_t keepdims = args [2 ].u_obj ;
1402
+
1396
1403
if ((axis != mp_const_none ) && (mp_obj_get_int (axis ) != 0 ) && (mp_obj_get_int (axis ) != 1 )) {
1397
1404
// this seems to pass with False, and True...
1398
1405
mp_raise_ValueError (MP_ERROR_TEXT ("axis must be None, or an integer" ));
@@ -1401,7 +1408,7 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
1401
1408
return numerical_sum_mean_std_iterable (oin , NUMERICAL_STD , ddof );
1402
1409
} else if (mp_obj_is_type (oin , & ulab_ndarray_type )) {
1403
1410
ndarray_obj_t * ndarray = MP_OBJ_TO_PTR (oin );
1404
- return numerical_sum_mean_std_ndarray (ndarray , axis , NUMERICAL_STD , ddof );
1411
+ return numerical_sum_mean_std_ndarray (ndarray , axis , keepdims , NUMERICAL_STD , ddof );
1405
1412
} else {
1406
1413
mp_raise_TypeError (MP_ERROR_TEXT ("input must be tuple, list, range, or ndarray" ));
1407
1414
}
0 commit comments