Skip to content

Commit 703fb1d

Browse files
committed
dev
1 parent a136acc commit 703fb1d

File tree

2 files changed

+98
-40
lines changed

2 files changed

+98
-40
lines changed

activestorage/active.py

+59-30
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,20 @@ def __new__(cls, *args, **kwargs):
6464
"""Store reduction methods."""
6565
instance = super().__new__(cls)
6666
instance._methods = {
67-
"min": np.min,
68-
"max": np.max,
69-
"sum": np.sum,
67+
"min": np.ma.min,
68+
"max": np.ma.max,
69+
"sum": np.ma.sum,
7070
# For the unweighted mean we calulate the sum and divide
7171
# by the number of non-missing elements
72-
"mean": np.sum,
72+
"mean": np.ma.sum,
7373
}
7474
return instance
7575

7676
def __init__(
7777
self,
7878
uri,
7979
ncvar,
80+
axis=None,
8081
storage_type=None,
8182
max_threads=100,
8283
storage_options=None,
@@ -115,6 +116,13 @@ def __init__(
115116
raise ValueError("Must set a netCDF variable name to slice")
116117
self.zds = None
117118

119+
if axis is not None:
120+
if isinstance(axis, int):
121+
axis = (axis,)
122+
else:
123+
axis = tuple(axis)
124+
125+
self._axis = axis
118126
self._version = 1
119127
self._components = False
120128
self._method = None
@@ -269,7 +277,7 @@ def _via_kerchunk(self, index):
269277
self.storage_type,
270278
self.storage_options,
271279
)
272-
# The following is a hangove from exploration
280+
# The following is a hangover from exploration
273281
# and is needed if using the original doing it ourselves
274282
# self.zds = make_an_array_instance_active(ds)
275283
self.zds = ds
@@ -284,6 +292,9 @@ def _via_kerchunk(self, index):
284292
# FIXME: We do not get the correct byte order on the Zarr
285293
# Array's dtype when using S3, so capture it here.
286294
self._dtype = np.dtype(zarray['dtype'])
295+
296+
if self._axis is None:
297+
self._axis = tuple(range(len(self.zds.shape)))
287298

288299
return self._get_selection(index)
289300

@@ -333,17 +344,23 @@ def _get_selection(self, *args):
333344
# fsref = self.zds.chunk_store._mutable_mapping.fs.references
334345
fsref = self.zds.chunk_store.fs.references
335346

347+
print ('axis=', self._axis)
336348
return self._from_storage(stripped_indexer, drop_axes, out_shape,
337-
out_dtype, compressor, filters, missing, fsref)
338-
349+
out_dtype, compressor, filters, missing, fsref,
350+
axis=self._axis)
351+
339352
def _from_storage(self, stripped_indexer, drop_axes, out_shape, out_dtype,
340-
compressor, filters, missing, fsref):
353+
compressor, filters, missing, fsref, axis):
341354
method = self.method
342355
if method is not None:
343-
out = []
344-
counts = []
356+
out_shape = list(out_shape)
357+
for i in axis:
358+
out_shape[i] = 1
359+
360+
out = np.ma.empty(out_shape, dtype=out_dtype, order=self.zds._order)
361+
counts = np.ma.empty(out_shape, dtype=out_dtype, order=self.zds._order)
345362
else:
346-
out = np.empty(out_shape, dtype=out_dtype, order=self.zds._order)
363+
out = np.ma.empty(out_shape, dtype=out_dtype, order=self.zds._order)
347364
counts = None # should never get touched with no method!
348365

349366
# Create a shared session object.
@@ -381,23 +398,24 @@ def _from_storage(self, stripped_indexer, drop_axes, out_shape, out_dtype,
381398
# Wait for completion.
382399
for future in concurrent.futures.as_completed(futures):
383400
try:
384-
result = future.result()
401+
result, count, out_selection = future.result()
385402
except Exception as exc:
386403
raise
387-
else:
388-
if method is not None:
389-
result, count = result
390-
out.append(result)
391-
counts.append(count)
392-
else:
393-
# store selected data in output
394-
result, selection = result
395-
out[selection] = result
396404

405+
print (out.shape,result.shape, count.shape, out_selection)
406+
out[out_selection] = result
407+
if method is not None:
408+
out[out_selection] = result
409+
counts[out_selection] = count
410+
else:
411+
# store selected data in output
412+
out[out_selection] = result
413+
print (888, out.shape)
397414
if method is not None:
398415
# Apply the method (again) to aggregate the result
399-
out = method(out)
400-
shape1 = (1,) * len(out_shape)
416+
out = method(out, axis=axis, keepdims=True)
417+
if self._components or self._method == "mean":
418+
n = np.ma.sum(counts, axis=axis, keepdims=True)
401419

402420
if self._components:
403421
# Return a dictionary of components containing the
@@ -412,9 +430,6 @@ def _from_storage(self, stripped_indexer, drop_axes, out_shape, out_dtype,
412430
# reductions require the per-dask-chunk partial
413431
# reductions to retain these dimensions so that
414432
# partial results can be concatenated correctly.)
415-
out = out.reshape(shape1)
416-
417-
n = np.sum(counts).reshape(shape1)
418433
if self._method == "mean":
419434
# For the average, the returned component is
420435
# "sum", not "mean"
@@ -428,7 +443,7 @@ def _from_storage(self, stripped_indexer, drop_axes, out_shape, out_dtype,
428443
# For the average, it is actually the sum that has
429444
# been created, so we need to divide by the sample
430445
# size.
431-
out = out / np.sum(counts).reshape(shape1)
446+
out = out / n # TODO
432447

433448
return out
434449

@@ -462,6 +477,8 @@ def _process_chunk(self, session, fsref, chunk_coords, chunk_selection, counts,
462477
key = f"{self.ncvar}/{coord}"
463478
rfile, offset, size = tuple(fsref[key])
464479

480+
axis = self._axis
481+
465482
# S3: pass in pre-configured storage options (credentials)
466483
if self.storage_type == "s3":
467484
print("S3 rfile is:", rfile)
@@ -516,14 +533,26 @@ def _process_chunk(self, session, fsref, chunk_coords, chunk_selection, counts,
516533
tmp, count = reduce_chunk(rfile, offset, size, compressor, filters,
517534
missing, self.zds._dtype,
518535
self.zds._chunks, self.zds._order,
519-
chunk_selection, method=self.method)
536+
chunk_selection, axis, method=self.method)
520537

521538
if self.method is not None:
522-
return tmp, count
539+
out_selection = list(out_selection)
540+
for i in axis:
541+
out_selection[i] = slice(0,1)
542+
543+
return tmp, count, tuple(out_selection)
523544
else:
524545
if drop_axes:
525546
tmp = np.squeeze(tmp, axis=drop_axes)
526-
return tmp, out_selection
547+
548+
return tmp, None, out_selection
549+
550+
# if self.method is not None:
551+
# return tmp, count
552+
# else:
553+
# if drop_axes:
554+
# tmp = np.squeeze(tmp, axis=drop_axes)
555+
# return tmp, out_selection
527556

528557
def _mask_data(self, data, ds_var):
529558
"""ppp"""

activestorage/storage.py

+39-10
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from numcodecs.compat import ensure_ndarray
55

6-
def reduce_chunk(rfile, offset, size, compression, filters, missing, dtype, shape, order, chunk_selection, method=None):
6+
def reduce_chunk(rfile, offset, size, compression, filters, missing, dtype, shape, order, chunk_selection, axis, method=None):
77
""" We do our own read of chunks and decoding etc
88
99
rfile - the actual file with the data
@@ -39,17 +39,28 @@ def reduce_chunk(rfile, offset, size, compression, filters, missing, dtype, shap
3939
chunk = chunk.reshape(shape, order=order)
4040

4141
tmp = chunk[chunk_selection]
42+
tmp = mask_missing(tmp, missing)
43+
# print ('tmp0', tmp)
4244
if method:
43-
if missing != (None, None, None, None):
44-
tmp = remove_missing(tmp, missing)
45-
# check on size of tmp; method(empty) returns nan
46-
if tmp.any():
47-
return method(tmp), tmp.size
48-
else:
49-
return tmp, None
50-
else:
51-
return tmp, None
45+
N = np.ma.count(tmp, axis=axis, keepdims=True)
46+
tmp = method(tmp, axis=axis, keepdims=True)
47+
# print ('tmp', tmp)
48+
# print (chunk_selection, axis, 'N', N)
49+
return tmp, N
5250

51+
return tmp, None
52+
#
53+
# if method:
54+
# if missing != (None, None, None, None):
55+
# tmp = remove_missing(tmp, missing)
56+
# # check on size of tmp; method(empty) returns nan
57+
# if tmp.any():
58+
# return method(tmp), tmp.size
59+
# else:
60+
# return tmp, None
61+
# else:
62+
# return tmp, None
63+
#
5364

5465
def filter_pipeline(chunk, compression, filters):
5566
"""
@@ -70,6 +81,24 @@ def filter_pipeline(chunk, compression, filters):
7081
return chunk
7182

7283

84+
def mask_missing(data, missing):
85+
"""
86+
As we are using numpy, we can use a masked array, storage implementations
87+
will have to do this by hand
88+
"""
89+
fill_value, missing_value, valid_min, valid_max = missing
90+
91+
if fill_value:
92+
data = np.ma.masked_equal(data, fill_value)
93+
if missing_value:
94+
data = np.ma.masked_equal(data, missing_value)
95+
if valid_max:
96+
data = np.ma.masked_greater(data, valid_max)
97+
if valid_min:
98+
data = np.ma.masked_less(data, valid_min)
99+
100+
return data
101+
73102
def remove_missing(data, missing):
74103
"""
75104
As we are using numpy, we can use a masked array, storage implementations

0 commit comments

Comments
 (0)