@@ -372,7 +372,7 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
372
372
mp_float_t norm = (mp_float_t )_shape_strides .shape [0 ];
373
373
// re-wind the array here
374
374
farray = (mp_float_t * )results -> array ;
375
- for (size_t i = 0 ; i < results -> len ; i ++ ) {
375
+ for (size_t i = 0 ; i < results -> len ; i ++ ) {
376
376
* farray ++ *= norm ;
377
377
}
378
378
}
@@ -397,9 +397,9 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
397
397
RUN_MEAN_STD (mp_float_t , array , farray , _shape_strides , div , isStd );
398
398
}
399
399
}
400
- // return(ulab_tools_restore_dims(results, keepdims, axis));
401
- return MP_OBJ_FROM_PTR (results );
400
+ return ulab_tools_restore_dims (ndarray , results , keepdims , _shape_strides );
402
401
}
402
+ // we should never get to this point
403
403
return mp_const_none ;
404
404
}
405
405
#endif
@@ -439,7 +439,7 @@ static mp_obj_t numerical_argmin_argmax_iterable(mp_obj_t oin, uint8_t optype) {
439
439
}
440
440
}
441
441
442
- static mp_obj_t numerical_argmin_argmax_ndarray (ndarray_obj_t * ndarray , mp_obj_t axis , uint8_t optype ) {
442
+ static mp_obj_t numerical_argmin_argmax_ndarray (ndarray_obj_t * ndarray , mp_obj_t keepdims , mp_obj_t axis , uint8_t optype ) {
443
443
// TODO: treat the flattened array
444
444
if (ndarray -> len == 0 ) {
445
445
mp_raise_ValueError (MP_ERROR_TEXT ("attempt to get (arg)min/(arg)max of empty sequence" ));
@@ -519,7 +519,9 @@ static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t
519
519
int32_t * strides = m_new0 (int32_t , ULAB_MAX_DIMS );
520
520
521
521
numerical_reduce_axes (ndarray , ax , shape , strides );
522
- uint8_t index = ULAB_MAX_DIMS - ndarray -> ndim + ax ;
522
+ shape_strides _shape_strides = tools_reduce_axes (ndarray , axis );
523
+
524
+ uint8_t index = _shape_strides .axis ;
523
525
524
526
ndarray_obj_t * results = NULL ;
525
527
@@ -548,8 +550,9 @@ static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t
548
550
if (results -> len == 1 ) {
549
551
return mp_binary_get_val_array (results -> dtype , results -> array , 0 );
550
552
}
551
- return MP_OBJ_FROM_PTR ( results );
553
+ return ulab_tools_restore_dims ( ndarray , results , keepdims , _shape_strides );
552
554
}
555
+ // we should never get to this point
553
556
return mp_const_none ;
554
557
}
555
558
#endif
@@ -599,7 +602,7 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
599
602
case NUMERICAL_ARGMIN :
600
603
case NUMERICAL_ARGMAX :
601
604
COMPLEX_DTYPE_NOT_IMPLEMENTED (ndarray -> dtype )
602
- return numerical_argmin_argmax_ndarray (ndarray , axis , optype );
605
+ return numerical_argmin_argmax_ndarray (ndarray , keepdims , axis , optype );
603
606
case NUMERICAL_SUM :
604
607
case NUMERICAL_MEAN :
605
608
COMPLEX_DTYPE_NOT_IMPLEMENTED (ndarray -> dtype )
@@ -1423,6 +1426,66 @@ MP_DEFINE_CONST_FUN_OBJ_KW(numerical_std_obj, 1, numerical_std);
1423
1426
1424
1427
mp_obj_t numerical_sum (size_t n_args , const mp_obj_t * pos_args , mp_map_t * kw_args ) {
1425
1428
return numerical_function (n_args , pos_args , kw_args , NUMERICAL_SUM );
1429
+ // static const mp_arg_t allowed_args[] = {
1430
+ // { MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE} } ,
1431
+ // { MP_QSTR_axis, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
1432
+ // { MP_QSTR_keepdims, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_FALSE } },
1433
+ // };
1434
+
1435
+ // mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
1436
+ // mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);
1437
+
1438
+ // mp_obj_t oin = args[0].u_obj;
1439
+ // mp_obj_t axis = args[1].u_obj;
1440
+ // mp_obj_t keepdims = args[2].u_obj;
1441
+
1442
+ // if((axis != mp_const_none) && (!mp_obj_is_int(axis))) {
1443
+ // mp_raise_TypeError(MP_ERROR_TEXT("axis must be None, or an integer"));
1444
+ // }
1445
+
1446
+ // ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(oin);
1447
+ // if(!mp_obj_is_int(axis) & (axis != mp_const_none)) {
1448
+ // mp_raise_TypeError(MP_ERROR_TEXT("axis must be None, or an integer"));
1449
+ // }
1450
+
1451
+ // shape_strides _shape_strides;
1452
+
1453
+ // _shape_strides.increment = 0;
1454
+ // // this is the contracted dimension (won't be overwritten for axis == None)
1455
+ // _shape_strides.ndim = 0;
1456
+
1457
+ // size_t *shape = m_new(size_t, ULAB_MAX_DIMS);
1458
+ // _shape_strides.shape = shape;
1459
+ // int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS);
1460
+ // _shape_strides.strides = strides;
1461
+
1462
+ // memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS);
1463
+ // memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS);
1464
+
1465
+ // uint8_t index = ULAB_MAX_DIMS - 1; // value of index for axis == mp_const_none (won't be overwritten)
1466
+
1467
+ // if(axis != mp_const_none) { // i.e., axis is an integer
1468
+ // int8_t ax = tools_get_axis(axis, ndarray->ndim);
1469
+ // index = ULAB_MAX_DIMS - ndarray->ndim + ax;
1470
+ // _shape_strides.ndim = ndarray->ndim - 1;
1471
+ // }
1472
+
1473
+ // // move the value stored at index to the leftmost position, and align everything else to the right
1474
+ // _shape_strides.shape[0] = ndarray->shape[index];
1475
+ // _shape_strides.strides[0] = ndarray->strides[index];
1476
+ // for(uint8_t i = 0; i < index; i++) {
1477
+ // // entries to the right of index must be shifted by one position to the left
1478
+ // _shape_strides.shape[i + 1] = ndarray->shape[i];
1479
+ // _shape_strides.strides[i + 1] = ndarray->strides[i];
1480
+ // }
1481
+
1482
+ // if(_shape_strides.ndim != 0) {
1483
+ // _shape_strides.increment = 1;
1484
+ // }
1485
+
1486
+
1487
+ // return mp_const_none;
1488
+
1426
1489
}
1427
1490
1428
1491
MP_DEFINE_CONST_FUN_OBJ_KW (numerical_sum_obj , 1 , numerical_sum );
0 commit comments