@@ -64,19 +64,20 @@ def __new__(cls, *args, **kwargs):
64
64
"""Store reduction methods."""
65
65
instance = super ().__new__ (cls )
66
66
instance ._methods = {
67
- "min" : np .min ,
68
- "max" : np .max ,
69
- "sum" : np .sum ,
67
+ "min" : np .ma . min ,
68
+ "max" : np .ma . max ,
69
+ "sum" : np .ma . sum ,
70
70
# For the unweighted mean we calulate the sum and divide
71
71
# by the number of non-missing elements
72
- "mean" : np .sum ,
72
+ "mean" : np .ma . sum ,
73
73
}
74
74
return instance
75
75
76
76
def __init__ (
77
77
self ,
78
78
uri ,
79
79
ncvar ,
80
+ axis = None ,
80
81
storage_type = None ,
81
82
max_threads = 100 ,
82
83
storage_options = None ,
@@ -115,6 +116,13 @@ def __init__(
115
116
raise ValueError ("Must set a netCDF variable name to slice" )
116
117
self .zds = None
117
118
119
+ if axis is not None :
120
+ if isinstance (axis , int ):
121
+ axis = (axis ,)
122
+ else :
123
+ axis = tuple (axis )
124
+
125
+ self ._axis = axis
118
126
self ._version = 1
119
127
self ._components = False
120
128
self ._method = None
@@ -269,7 +277,7 @@ def _via_kerchunk(self, index):
269
277
self .storage_type ,
270
278
self .storage_options ,
271
279
)
272
- # The following is a hangove from exploration
280
+ # The following is a hangover from exploration
273
281
# and is needed if using the original doing it ourselves
274
282
# self.zds = make_an_array_instance_active(ds)
275
283
self .zds = ds
@@ -284,6 +292,9 @@ def _via_kerchunk(self, index):
284
292
# FIXME: We do not get the correct byte order on the Zarr
285
293
# Array's dtype when using S3, so capture it here.
286
294
self ._dtype = np .dtype (zarray ['dtype' ])
295
+
296
+ if self ._axis is None :
297
+ self ._axis = tuple (range (len (self .zds .shape )))
287
298
288
299
return self ._get_selection (index )
289
300
@@ -333,17 +344,23 @@ def _get_selection(self, *args):
333
344
# fsref = self.zds.chunk_store._mutable_mapping.fs.references
334
345
fsref = self .zds .chunk_store .fs .references
335
346
347
+ print ('axis=' , self ._axis )
336
348
return self ._from_storage (stripped_indexer , drop_axes , out_shape ,
337
- out_dtype , compressor , filters , missing , fsref )
338
-
349
+ out_dtype , compressor , filters , missing , fsref ,
350
+ axis = self ._axis )
351
+
339
352
def _from_storage (self , stripped_indexer , drop_axes , out_shape , out_dtype ,
340
- compressor , filters , missing , fsref ):
353
+ compressor , filters , missing , fsref , axis ):
341
354
method = self .method
342
355
if method is not None :
343
- out = []
344
- counts = []
356
+ out_shape = list (out_shape )
357
+ for i in axis :
358
+ out_shape [i ] = 1
359
+
360
+ out = np .ma .empty (out_shape , dtype = out_dtype , order = self .zds ._order )
361
+ counts = np .ma .empty (out_shape , dtype = out_dtype , order = self .zds ._order )
345
362
else :
346
- out = np .empty (out_shape , dtype = out_dtype , order = self .zds ._order )
363
+ out = np .ma . empty (out_shape , dtype = out_dtype , order = self .zds ._order )
347
364
counts = None # should never get touched with no method!
348
365
349
366
# Create a shared session object.
@@ -381,23 +398,24 @@ def _from_storage(self, stripped_indexer, drop_axes, out_shape, out_dtype,
381
398
# Wait for completion.
382
399
for future in concurrent .futures .as_completed (futures ):
383
400
try :
384
- result = future .result ()
401
+ result , count , out_selection = future .result ()
385
402
except Exception as exc :
386
403
raise
387
- else :
388
- if method is not None :
389
- result , count = result
390
- out .append (result )
391
- counts .append (count )
392
- else :
393
- # store selected data in output
394
- result , selection = result
395
- out [selection ] = result
396
404
405
+ print (out .shape ,result .shape , count .shape , out_selection )
406
+ out [out_selection ] = result
407
+ if method is not None :
408
+ out [out_selection ] = result
409
+ counts [out_selection ] = count
410
+ else :
411
+ # store selected data in output
412
+ out [out_selection ] = result
413
+ print (888 , out .shape )
397
414
if method is not None :
398
415
# Apply the method (again) to aggregate the result
399
- out = method (out )
400
- shape1 = (1 ,) * len (out_shape )
416
+ out = method (out , axis = axis , keepdims = True )
417
+ if self ._components or self ._method == "mean" :
418
+ n = np .ma .sum (counts , axis = axis , keepdims = True )
401
419
402
420
if self ._components :
403
421
# Return a dictionary of components containing the
@@ -412,9 +430,6 @@ def _from_storage(self, stripped_indexer, drop_axes, out_shape, out_dtype,
412
430
# reductions require the per-dask-chunk partial
413
431
# reductions to retain these dimensions so that
414
432
# partial results can be concatenated correctly.)
415
- out = out .reshape (shape1 )
416
-
417
- n = np .sum (counts ).reshape (shape1 )
418
433
if self ._method == "mean" :
419
434
# For the average, the returned component is
420
435
# "sum", not "mean"
@@ -428,7 +443,7 @@ def _from_storage(self, stripped_indexer, drop_axes, out_shape, out_dtype,
428
443
# For the average, it is actually the sum that has
429
444
# been created, so we need to divide by the sample
430
445
# size.
431
- out = out / np . sum ( counts ). reshape ( shape1 )
446
+ out = out / n # TODO
432
447
433
448
return out
434
449
@@ -462,6 +477,8 @@ def _process_chunk(self, session, fsref, chunk_coords, chunk_selection, counts,
462
477
key = f"{ self .ncvar } /{ coord } "
463
478
rfile , offset , size = tuple (fsref [key ])
464
479
480
+ axis = self ._axis
481
+
465
482
# S3: pass in pre-configured storage options (credentials)
466
483
if self .storage_type == "s3" :
467
484
print ("S3 rfile is:" , rfile )
@@ -516,14 +533,26 @@ def _process_chunk(self, session, fsref, chunk_coords, chunk_selection, counts,
516
533
tmp , count = reduce_chunk (rfile , offset , size , compressor , filters ,
517
534
missing , self .zds ._dtype ,
518
535
self .zds ._chunks , self .zds ._order ,
519
- chunk_selection , method = self .method )
536
+ chunk_selection , axis , method = self .method )
520
537
521
538
if self .method is not None :
522
- return tmp , count
539
+ out_selection = list (out_selection )
540
+ for i in axis :
541
+ out_selection [i ] = slice (0 ,1 )
542
+
543
+ return tmp , count , tuple (out_selection )
523
544
else :
524
545
if drop_axes :
525
546
tmp = np .squeeze (tmp , axis = drop_axes )
526
- return tmp , out_selection
547
+
548
+ return tmp , None , out_selection
549
+
550
+ # if self.method is not None:
551
+ # return tmp, count
552
+ # else:
553
+ # if drop_axes:
554
+ # tmp = np.squeeze(tmp, axis=drop_axes)
555
+ # return tmp, out_selection
527
556
528
557
def _mask_data (self , data , ds_var ):
529
558
"""ppp"""
0 commit comments