Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix determination of which variables are coordinates #224

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
13 changes: 13 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ def netcdf4_files(tmpdir):
return filepath1, filepath2


@pytest.fixture
def netcdf4_file_with_2d_coords(tmpdir):
# Set up example xarray dataset
ds = xr.tutorial.open_dataset("ROMS_example.nc")

# Save it to disk as netCDF (in temporary directory)
filepath = f"{tmpdir}/ROMS_example.nc"
ds.to_netcdf(filepath, format="NETCDF4")
ds.close()

return filepath


@pytest.fixture
def hdf5_empty(tmpdir):
filepath = f"{tmpdir}/empty.nc"
Expand Down
82 changes: 61 additions & 21 deletions virtualizarr/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,27 @@
Any,
Hashable,
Optional,
TypeAlias,
cast,
)

import xarray as xr
from xarray.backends import AbstractDataStore, BackendArray
from xarray.coding.times import CFDatetimeCoder
from xarray.conventions import decode_cf_variables
from xarray.core.indexes import Index, PandasIndex
from xarray.core.variable import IndexVariable
from xarray.core.variable import IndexVariable, Variable

from virtualizarr.manifests import ManifestArray
from virtualizarr.utils import _fsspec_openfile_from_filepath

XArrayOpenT = str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore

T_Attrs = MutableMapping[Any, Any]
T_Variables = Mapping[Any, Variable]
# alias for (dims, data, attrs, encoding)
T_VariableExpanded: TypeAlias = tuple[Hashable, Any, dict[Any, Any], dict[Any, Any]]


class AutoName(Enum):
# Recommended by official Python docs for auto naming:
Expand Down Expand Up @@ -238,43 +245,56 @@ def open_virtual_dataset(

vars = {**virtual_vars, **loadable_vars}

data_vars, coords = separate_coords(vars, indexes, coord_names)
decoded_vars, decoded_attrs, coord_names = determine_cf_coords(vars, ds_attrs)

vds = xr.Dataset(
data_vars,
coords=coords,
# indexes={}, # TODO should be added in a later version of xarray
attrs=ds_attrs,
vds = construct_virtual_dataset(
decoded_vars, indexes, decoded_attrs, coord_names
)

# TODO we should probably also use vds.set_close() to tell xarray how to close the file we opened

return vds


def separate_coords(
def determine_cf_coords(
variables: T_Variables,
attributes: T_Attrs,
) -> tuple[T_Variables, T_Attrs, set[Hashable]]:
"""
Determines which variables are coordinate variables according to CF conventions.

Should not actually do any decoding of values in the variables, only inspect and possibly alter their metadata.
"""
new_vars, attrs, coord_names = decode_cf_variables(
variables=variables,
attributes=attributes,
concat_characters=False,
mask_and_scale=False,
decode_times=False,
decode_coords="all",
drop_variables=None, # should have already been dropped
use_cftime=False, # done separately, to only the loadable_vars
decode_timedelta=False, # done separately, to only the loadable_vars
)
return new_vars, attrs, coord_names


def construct_virtual_dataset(
vars: Mapping[str, xr.Variable],
indexes: MutableMapping[str, Index],
attrs: T_Attrs,
coord_names: Iterable[str] | None = None,
) -> tuple[dict[str, xr.Variable], xr.Coordinates]:
) -> xr.Dataset:
"""
Try to generate a set of coordinates that won't cause xarray to automatically build a pandas.Index for the 1D coordinates.
Constructs the virtual dataset but without automatically building a pandas.Index for 1D coordinates.

Currently requires this function as a workaround unless xarray PR #8124 is merged.

Will also preserve any loaded variables and indexes it is passed.
"""

if coord_names is None:
coord_names = []

# split data and coordinate variables (promote dimension coordinates)
coord_vars: dict[str, T_VariableExpanded | xr.Variable] = {}
data_vars = {}
coord_vars: dict[
str, tuple[Hashable, Any, dict[Any, Any], dict[Any, Any]] | xr.Variable
] = {}
for name, var in vars.items():
if name in coord_names or var.dims == (name,):
if name in coord_names:
# use workaround to avoid creating IndexVariables described here https://github.com/pydata/xarray/pull/8107#discussion_r1311214263
if len(var.dims) == 1:
dim1d, *_ = var.dims
Expand All @@ -293,4 +313,24 @@ def separate_coords(

coords = xr.Coordinates(coord_vars, indexes=indexes)

return data_vars, coords
print(indexes)

print(coords)
print(type(coords))

print(data_vars)

print(list(type(var._data) for var in data_vars.values()))
print(list(type(var.data) for var in data_vars.values()))

vds = xr.Dataset(
data_vars,
coords=coords,
# indexes={}, # TODO should be added in a later version of xarray
attrs=attrs,
)

# TODO we should probably also use vds.set_close() to tell xarray how to close the file we opened
# TODO see how it's done inside `xr.decode_cf`

return vds
16 changes: 8 additions & 8 deletions virtualizarr/readers/kerchunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from xarray.core.indexes import Index
from xarray.core.variable import Variable

from virtualizarr.backend import FileType, separate_coords
from virtualizarr.backend import (
FileType,
construct_virtual_dataset,
determine_cf_coords,
)
from virtualizarr.manifests import ChunkManifest, ManifestArray
from virtualizarr.types.kerchunk import (
KerchunkArrRefs,
Expand Down Expand Up @@ -176,14 +180,10 @@ def dataset_from_kerchunk_refs(

if indexes is None:
indexes = {}
data_vars, coords = separate_coords(vars, indexes, coord_names)

vds = Dataset(
data_vars,
coords=coords,
# indexes={}, # TODO should be added in a later version of xarray
attrs=ds_attrs,
)
decoded_vars, decoded_attrs, coord_names = determine_cf_coords(vars, ds_attrs)

vds = construct_virtual_dataset(decoded_vars, indexes, decoded_attrs, coord_names)

return vds

Expand Down
13 changes: 4 additions & 9 deletions virtualizarr/readers/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from xarray.core.indexes import Index
from xarray.core.variable import Variable

from virtualizarr.backend import separate_coords
from virtualizarr.backend import construct_virtual_dataset, determine_cf_coords
from virtualizarr.manifests import ChunkManifest, ManifestArray
from virtualizarr.zarr import ZArray

Expand Down Expand Up @@ -53,16 +53,11 @@ def open_virtual_dataset_from_v3_store(
else:
indexes = dict(**indexes) # for type hinting: to allow mutation

data_vars, coords = separate_coords(vars, indexes, coord_names)
decoded_vars, decoded_attrs, coord_names = determine_cf_coords(vars, attrs)

ds = Dataset(
data_vars,
coords=coords,
# indexes={}, # TODO should be added in a later version of xarray
attrs=ds_attrs,
)
vds = construct_virtual_dataset(decoded_vars, indexes, decoded_attrs, coord_names)

return ds
return vds


def attrs_from_zarr_group_json(filepath: Path) -> dict:
Expand Down
27 changes: 27 additions & 0 deletions virtualizarr/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,33 @@ def test_coordinate_variable_attrs_preserved(self, netcdf4_file):
}


class TestDetermineCoords:
def test_determine_all_coords(self, netcdf4_file_with_2d_coords):
vds = open_virtual_dataset(netcdf4_file_with_2d_coords, indexes={})

expected_dimension_coords = ["ocean_time", "s_rho"]
expected_2d_coords = ["lon_rho", "lat_rho", "h"]
expected_1d_non_dimension_coords = ["Cs_r"]
expected_scalar_coords = ["hc", "Vtransform"]
expected_coords = (
expected_dimension_coords
+ expected_2d_coords
+ expected_1d_non_dimension_coords
+ expected_scalar_coords
)
assert set(vds.coords) == set(expected_coords)

# print(vds.attrs)
# assert False

# TODO assert coord attributes have been altered
for coord_name in expected_coords:
print(vds[coord_name].attrs)
# assert vds[coord_name].attrs['']

# assert False


@network
@requires_s3fs
class TestReadFromS3:
Expand Down
87 changes: 60 additions & 27 deletions virtualizarr/tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,66 @@
from virtualizarr.zarr import ZArray


def test_wrapping():
chunks = (5, 10)
shape = (5, 20)
dtype = np.dtype("int32")
zarray = ZArray(
chunks=chunks,
compressor={"id": "zlib", "level": 1},
dtype=dtype,
fill_value=0.0,
filters=None,
order="C",
shape=shape,
zarr_format=2,
)

chunks_dict = {
"0.0": {"path": "foo.nc", "offset": 100, "length": 100},
"0.1": {"path": "foo.nc", "offset": 200, "length": 100},
}
manifest = ChunkManifest(entries=chunks_dict)
marr = ManifestArray(zarray=zarray, chunkmanifest=manifest)
ds = xr.Dataset({"a": (["x", "y"], marr)})

assert isinstance(ds["a"].data, ManifestArray)
assert ds["a"].shape == shape
assert ds["a"].dtype == dtype
assert ds["a"].chunks == chunks
class TestWrapping:
def test_wrapping(self):
chunks = (5, 10)
shape = (5, 20)
dtype = np.dtype("int32")
zarray = ZArray(
chunks=chunks,
compressor={"id": "zlib", "level": 1},
dtype=dtype,
fill_value=0.0,
filters=None,
order="C",
shape=shape,
zarr_format=2,
)

chunks_dict = {
"0.0": {"path": "foo.nc", "offset": 100, "length": 100},
"0.1": {"path": "foo.nc", "offset": 200, "length": 100},
}
manifest = ChunkManifest(entries=chunks_dict)
marr = ManifestArray(zarray=zarray, chunkmanifest=manifest)
ds = xr.Dataset({"a": (["x", "y"], marr)})

assert isinstance(ds["a"].data, ManifestArray)
assert ds["a"].shape == shape
assert ds["a"].dtype == dtype
assert ds["a"].chunks == chunks

def test_wrap_no_indexes(self):
chunks = (10,)
shape = (20,)
dtype = np.dtype("int32")
zarray = ZArray(
chunks=chunks,
compressor={"id": "zlib", "level": 1},
dtype=dtype,
fill_value=0.0,
filters=None,
order="C",
shape=shape,
zarr_format=2,
)

chunks_dict = {
"0.0": {"path": "foo.nc", "offset": 100, "length": 100},
"0.1": {"path": "foo.nc", "offset": 200, "length": 100},
}
manifest = ChunkManifest(entries=chunks_dict)
marr = ManifestArray(zarray=zarray, chunkmanifest=manifest)

coords = xr.Coordinates({"x": (["x"], marr)}, indexes={})
ds = xr.Dataset(coords=coords)

assert isinstance(ds["x"].data, ManifestArray)
assert ds["x"].shape == shape
assert ds["x"].dtype == dtype
assert ds["x"].chunks == chunks
assert "x" in ds.coords
assert ds.xindexes == {}


class TestEquals:
Expand Down
Loading