diff --git a/doc/roadmap.rst b/doc/roadmap.rst index eeaaf10813b..820ff82151c 100644 --- a/doc/roadmap.rst +++ b/doc/roadmap.rst @@ -156,7 +156,7 @@ types would also be highly useful for xarray users. By pursuing these improvements in NumPy we hope to extend the benefits to the full scientific Python community, and avoid tight coupling between xarray and specific third-party libraries (e.g., for -implementing untis). This will allow xarray to maintain its domain +implementing units). This will allow xarray to maintain its domain agnostic strengths. We expect that we may eventually add some minimal interfaces in xarray diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 50eece5f0af..16562ed0988 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -90,9 +90,16 @@ Internal Changes when the data isn't datetime-like. (:issue:`8718`, :pull:`8724`) By `Maximilian Roos `_. -- Move `parallelcompat` and `chunk managers` modules from `xarray/core` to `xarray/namedarray`. (:pull:`8319`) +- Move ``parallelcompat`` and ``chunk managers`` modules from ``xarray/core`` to ``xarray/namedarray``. (:pull:`8319`) By `Tom Nicholas `_ and `Anderson Banihirwe `_. +- Imports ``datatree`` repository and history into internal + location. (:pull:`8688`) By `Matt Savoie `_ + and `Justus Magin `_. + +- Adds :py:func:`open_datatree` into ``xarray/backends`` (:pull:`8697`) By `Matt + Savoie `_. + .. _whats-new.2024.01.1: v2024.01.1 (23 Jan, 2024) diff --git a/pyproject.toml b/pyproject.toml index 62c2eed3295..4b5f6b31a43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,11 @@ warn_redundant_casts = true warn_unused_configs = true warn_unused_ignores = true +# Ignore mypy errors for modules imported from datatree_. +[[tool.mypy.overrides]] +module = "xarray.datatree_.*" +ignore_errors = true + # Much of the numerical computing stack doesn't have type annotations yet. [[tool.mypy.overrides]] ignore_missing_imports = true diff --git a/xarray/backends/api.py b/xarray/backends/api.py index e69faa4b100..d3026a535e2 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -69,6 +69,7 @@ T_NetcdfTypes = Literal[ "NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", "NETCDF3_CLASSIC" ] + from xarray.datatree_.datatree import DataTree DATAARRAY_NAME = "__xarray_dataarray_name__" DATAARRAY_VARIABLE = "__xarray_dataarray_variable__" @@ -788,6 +789,34 @@ def open_dataarray( return data_array +def open_datatree( + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + engine: T_Engine = None, + **kwargs, +) -> DataTree: + """ + Open and decode a DataTree from a file or file-like object, creating one tree node for each group in the file. + + Parameters + ---------- + filename_or_obj : str, Path, file-like, or DataStore + Strings and Path objects are interpreted as a path to a netCDF file or Zarr store. + engine : str, optional + Xarray backend engine to use. Valid options include `{"netcdf4", "h5netcdf", "zarr"}`. + **kwargs : dict + Additional keyword arguments passed to :py:func:`~xarray.open_dataset` for each group. + Returns + ------- + xarray.DataTree + """ + if engine is None: + engine = plugins.guess_engine(filename_or_obj) + + backend = plugins.get_backend(engine) + + return backend.open_datatree(filename_or_obj, **kwargs) + + def open_mfdataset( paths: str | NestedSequence[str | os.PathLike], chunks: T_Chunks | None = None, diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 5b7cdc4cf50..6245b3442a3 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -19,8 +19,12 @@ if TYPE_CHECKING: from io import BufferedIOBase + from h5netcdf.legacyapi import Dataset as ncDatasetLegacyH5 + from netCDF4 import Dataset as ncDataset + from xarray.core.dataset import Dataset from xarray.core.types import NestedSequence + from xarray.datatree_.datatree import DataTree # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -127,6 +131,43 @@ def _decode_variable_name(name): return name +def _open_datatree_netcdf( + ncDataset: ncDataset | ncDatasetLegacyH5, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + **kwargs, +) -> DataTree: + from xarray.backends.api import open_dataset + from xarray.datatree_.datatree import DataTree + from xarray.datatree_.datatree.treenode import NodePath + + ds = open_dataset(filename_or_obj, **kwargs) + tree_root = DataTree.from_dict({"/": ds}) + with ncDataset(filename_or_obj, mode="r") as ncds: + for path in _iter_nc_groups(ncds): + subgroup_ds = open_dataset(filename_or_obj, group=path, **kwargs) + + # TODO refactor to use __setitem__ once creation of new nodes by assigning Dataset works again + node_name = NodePath(path).name + new_node: DataTree = DataTree(name=node_name, data=subgroup_ds) + tree_root._set_item( + path, + new_node, + allow_overwrite=False, + new_nodes_along_path=True, + ) + return tree_root + + +def _iter_nc_groups(root, parent="/"): + from xarray.datatree_.datatree.treenode import NodePath + + parent = NodePath(parent) + for path, group in root.groups.items(): + gpath = parent / path + yield str(gpath) + yield from _iter_nc_groups(group, parent=gpath) + + def find_root_and_group(ds): """Find the root and group name of a netCDF4/h5netcdf dataset.""" hierarchy = () @@ -458,6 +499,11 @@ class BackendEntrypoint: - ``guess_can_open`` method: it shall return ``True`` if the backend is able to open ``filename_or_obj``, ``False`` otherwise. The implementation of this method is not mandatory. + - ``open_datatree`` method: it shall implement reading from file, variables + decoding and it returns an instance of :py:class:`~datatree.DataTree`. + It shall take in input at least ``filename_or_obj`` argument. The + implementation of this method is not mandatory. For more details see + . Attributes ---------- @@ -496,7 +542,7 @@ def open_dataset( Backend open_dataset method used by Xarray in :py:func:`~xarray.open_dataset`. """ - raise NotImplementedError + raise NotImplementedError() def guess_can_open( self, @@ -508,6 +554,17 @@ def guess_can_open( return False + def open_datatree( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + **kwargs: Any, + ) -> DataTree: + """ + Backend open_datatree method used by Xarray in :py:func:`~xarray.open_datatree`. + """ + + raise NotImplementedError() + # mapping of engine name to (module name, BackendEntrypoint Class) BACKEND_ENTRYPOINTS: dict[str, tuple[str | None, type[BackendEntrypoint]]] = {} diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index d9385fc68a9..b7c1b2a5f03 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -11,6 +11,7 @@ BackendEntrypoint, WritableCFDataStore, _normalize_path, + _open_datatree_netcdf, find_root_and_group, ) from xarray.backends.file_manager import CachingFileManager, DummyFileManager @@ -38,6 +39,7 @@ from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset + from xarray.datatree_.datatree import DataTree class H5NetCDFArrayWrapper(BaseNetCDF4Array): @@ -423,5 +425,14 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti ) return ds + def open_datatree( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + **kwargs, + ) -> DataTree: + from h5netcdf.legacyapi import Dataset as ncDataset + + return _open_datatree_netcdf(ncDataset, filename_or_obj, **kwargs) + BACKEND_ENTRYPOINTS["h5netcdf"] = ("h5netcdf", H5netcdfBackendEntrypoint) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index d3845568709..6720a67ae2f 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -16,6 +16,7 @@ BackendEntrypoint, WritableCFDataStore, _normalize_path, + _open_datatree_netcdf, find_root_and_group, robust_getitem, ) @@ -44,6 +45,7 @@ from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset + from xarray.datatree_.datatree import DataTree # This lookup table maps from dtype.byteorder to a readable endian # string used by netCDF4. @@ -667,5 +669,14 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti ) return ds + def open_datatree( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + **kwargs, + ) -> DataTree: + from netCDF4 import Dataset as ncDataset + + return _open_datatree_netcdf(ncDataset, filename_or_obj, **kwargs) + BACKEND_ENTRYPOINTS["netcdf4"] = ("netCDF4", NetCDF4BackendEntrypoint) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 73e2145468b..ac208da097a 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -34,6 +34,7 @@ from xarray.backends.common import AbstractDataStore from xarray.core.dataset import Dataset + from xarray.datatree_.datatree import DataTree # need some special secret attributes to tell us the dimensions @@ -1039,5 +1040,48 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti ) return ds + def open_datatree( + self, + filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore, + **kwargs, + ) -> DataTree: + import zarr + + from xarray.backends.api import open_dataset + from xarray.datatree_.datatree import DataTree + from xarray.datatree_.datatree.treenode import NodePath + + zds = zarr.open_group(filename_or_obj, mode="r") + ds = open_dataset(filename_or_obj, engine="zarr", **kwargs) + tree_root = DataTree.from_dict({"/": ds}) + for path in _iter_zarr_groups(zds): + try: + subgroup_ds = open_dataset( + filename_or_obj, engine="zarr", group=path, **kwargs + ) + except zarr.errors.PathNotFoundError: + subgroup_ds = Dataset() + + # TODO refactor to use __setitem__ once creation of new nodes by assigning Dataset works again + node_name = NodePath(path).name + new_node: DataTree = DataTree(name=node_name, data=subgroup_ds) + tree_root._set_item( + path, + new_node, + allow_overwrite=False, + new_nodes_along_path=True, + ) + return tree_root + + +def _iter_zarr_groups(root, parent="/"): + from xarray.datatree_.datatree.treenode import NodePath + + parent = NodePath(parent) + for path, group in root.groups(): + gpath = parent / path + yield str(gpath) + yield from _iter_zarr_groups(group, parent=gpath) + BACKEND_ENTRYPOINTS["zarr"] = ("zarr", ZarrBackendEntrypoint) diff --git a/xarray/core/options.py b/xarray/core/options.py index 25b56b5ef06..18e3484e9c4 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -20,6 +20,7 @@ "display_expand_coords", "display_expand_data_vars", "display_expand_data", + "display_expand_groups", "display_expand_indexes", "display_default_indexes", "enable_cftimeindex", @@ -44,6 +45,7 @@ class T_Options(TypedDict): display_expand_coords: Literal["default", True, False] display_expand_data_vars: Literal["default", True, False] display_expand_data: Literal["default", True, False] + display_expand_groups: Literal["default", True, False] display_expand_indexes: Literal["default", True, False] display_default_indexes: Literal["default", True, False] enable_cftimeindex: bool @@ -68,6 +70,7 @@ class T_Options(TypedDict): "display_expand_coords": "default", "display_expand_data_vars": "default", "display_expand_data": "default", + "display_expand_groups": "default", "display_expand_indexes": "default", "display_default_indexes": False, "enable_cftimeindex": True, diff --git a/xarray/datatree_/datatree/__init__.py b/xarray/datatree_/datatree/__init__.py index 3b97ea9d4db..f9fd419bddc 100644 --- a/xarray/datatree_/datatree/__init__.py +++ b/xarray/datatree_/datatree/__init__.py @@ -1,25 +1,15 @@ # import public API from .datatree import DataTree from .extensions import register_datatree_accessor -from .io import open_datatree from .mapping import TreeIsomorphismError, map_over_subtree from .treenode import InvalidTreeError, NotFoundInTreeError -try: - # NOTE: the `_version.py` file must not be present in the git repository - # as it is generated by setuptools at install time - from ._version import __version__ -except ImportError: # pragma: no cover - # Local copy or not installed with setuptools - __version__ = "999" __all__ = ( "DataTree", - "open_datatree", "TreeIsomorphismError", "InvalidTreeError", "NotFoundInTreeError", "map_over_subtree", "register_datatree_accessor", - "__version__", ) diff --git a/xarray/datatree_/datatree/datatree.py b/xarray/datatree_/datatree/datatree.py index 0ce382a6460..13cca7de80d 100644 --- a/xarray/datatree_/datatree/datatree.py +++ b/xarray/datatree_/datatree/datatree.py @@ -16,6 +16,7 @@ List, Mapping, MutableMapping, + NoReturn, Optional, Set, Tuple, @@ -160,7 +161,7 @@ def __setitem__(self, key, val) -> None: "use `.copy()` first to get a mutable version of the input dataset." ) - def update(self, other) -> None: + def update(self, other) -> NoReturn: raise AttributeError( "Mutation of the DatasetView is not allowed, please use `.update` on the wrapping DataTree node, " "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_subtree`," diff --git a/xarray/datatree_/datatree/formatting_html.py b/xarray/datatree_/datatree/formatting_html.py index 4531f5aec18..547b567a396 100644 --- a/xarray/datatree_/datatree/formatting_html.py +++ b/xarray/datatree_/datatree/formatting_html.py @@ -10,9 +10,6 @@ datavar_section, dim_section, ) -from xarray.core.options import OPTIONS - -OPTIONS["display_expand_groups"] = "default" def summarize_children(children: Mapping[str, Any]) -> str: diff --git a/xarray/datatree_/datatree/io.py b/xarray/datatree_/datatree/io.py index 8bb7682f085..d3d533ee71e 100644 --- a/xarray/datatree_/datatree/io.py +++ b/xarray/datatree_/datatree/io.py @@ -1,22 +1,4 @@ -from xarray import Dataset, open_dataset - -from .datatree import DataTree, NodePath - - -def _iter_zarr_groups(root, parent="/"): - parent = NodePath(parent) - for path, group in root.groups(): - gpath = parent / path - yield str(gpath) - yield from _iter_zarr_groups(group, parent=gpath) - - -def _iter_nc_groups(root, parent="/"): - parent = NodePath(parent) - for path, group in root.groups.items(): - gpath = parent / path - yield str(gpath) - yield from _iter_nc_groups(group, parent=gpath) +from xarray.datatree_.datatree import DataTree def _get_nc_dataset_class(engine): @@ -34,77 +16,6 @@ def _get_nc_dataset_class(engine): return Dataset -def open_datatree(filename_or_obj, engine=None, **kwargs) -> DataTree: - """ - Open and decode a dataset from a file or file-like object, creating one Tree node for each group in the file. - - Parameters - ---------- - filename_or_obj : str, Path, file-like, or DataStore - Strings and Path objects are interpreted as a path to a netCDF file or Zarr store. - engine : str, optional - Xarray backend engine to us. Valid options include `{"netcdf4", "h5netcdf", "zarr"}`. - kwargs : - Additional keyword arguments passed to ``xarray.open_dataset`` for each group. - - Returns - ------- - DataTree - """ - - if engine == "zarr": - return _open_datatree_zarr(filename_or_obj, **kwargs) - elif engine in [None, "netcdf4", "h5netcdf"]: - return _open_datatree_netcdf(filename_or_obj, engine=engine, **kwargs) - else: - raise ValueError("Unsupported engine") - - -def _open_datatree_netcdf(filename: str, **kwargs) -> DataTree: - ncDataset = _get_nc_dataset_class(kwargs.get("engine", None)) - - ds = open_dataset(filename, **kwargs) - tree_root = DataTree.from_dict({"/": ds}) - with ncDataset(filename, mode="r") as ncds: - for path in _iter_nc_groups(ncds): - subgroup_ds = open_dataset(filename, group=path, **kwargs) - - # TODO refactor to use __setitem__ once creation of new nodes by assigning Dataset works again - node_name = NodePath(path).name - new_node: DataTree = DataTree(name=node_name, data=subgroup_ds) - tree_root._set_item( - path, - new_node, - allow_overwrite=False, - new_nodes_along_path=True, - ) - return tree_root - - -def _open_datatree_zarr(store, **kwargs) -> DataTree: - import zarr # type: ignore - - zds = zarr.open_group(store, mode="r") - ds = open_dataset(store, engine="zarr", **kwargs) - tree_root = DataTree.from_dict({"/": ds}) - for path in _iter_zarr_groups(zds): - try: - subgroup_ds = open_dataset(store, engine="zarr", group=path, **kwargs) - except zarr.errors.PathNotFoundError: - subgroup_ds = Dataset() - - # TODO refactor to use __setitem__ once creation of new nodes by assigning Dataset works again - node_name = NodePath(path).name - new_node: DataTree = DataTree(name=node_name, data=subgroup_ds) - tree_root._set_item( - path, - new_node, - allow_overwrite=False, - new_nodes_along_path=True, - ) - return tree_root - - def _create_empty_netcdf_group(filename, group, mode, engine): ncDataset = _get_nc_dataset_class(engine) diff --git a/xarray/datatree_/datatree/tests/conftest.py b/xarray/datatree_/datatree/tests/conftest.py index 3ed1325ccd5..bd2e7ba3247 100644 --- a/xarray/datatree_/datatree/tests/conftest.py +++ b/xarray/datatree_/datatree/tests/conftest.py @@ -1,7 +1,7 @@ import pytest import xarray as xr -from datatree import DataTree +from xarray.datatree_.datatree import DataTree @pytest.fixture(scope="module") diff --git a/xarray/datatree_/datatree/tests/test_dataset_api.py b/xarray/datatree_/datatree/tests/test_dataset_api.py index 6879b869299..c3eb74451a6 100644 --- a/xarray/datatree_/datatree/tests/test_dataset_api.py +++ b/xarray/datatree_/datatree/tests/test_dataset_api.py @@ -1,8 +1,8 @@ import numpy as np import xarray as xr -from datatree import DataTree -from datatree.testing import assert_equal +from xarray.datatree_.datatree import DataTree +from xarray.datatree_.datatree.testing import assert_equal class TestDSMethodInheritance: diff --git a/xarray/datatree_/datatree/tests/test_datatree.py b/xarray/datatree_/datatree/tests/test_datatree.py index fde83b2e226..cfb57470651 100644 --- a/xarray/datatree_/datatree/tests/test_datatree.py +++ b/xarray/datatree_/datatree/tests/test_datatree.py @@ -6,8 +6,8 @@ import xarray.testing as xrt from xarray.tests import create_test_data, source_ndarray -import datatree.testing as dtt -from datatree import DataTree, NotFoundInTreeError +import xarray.datatree_.datatree.testing as dtt +from xarray.datatree_.datatree import DataTree, NotFoundInTreeError class TestTreeCreation: diff --git a/xarray/datatree_/datatree/tests/test_extensions.py b/xarray/datatree_/datatree/tests/test_extensions.py index b288998e2ce..0241e496abf 100644 --- a/xarray/datatree_/datatree/tests/test_extensions.py +++ b/xarray/datatree_/datatree/tests/test_extensions.py @@ -1,6 +1,6 @@ import pytest -from datatree import DataTree, register_datatree_accessor +from xarray.datatree_.datatree import DataTree, register_datatree_accessor class TestAccessor: diff --git a/xarray/datatree_/datatree/tests/test_formatting.py b/xarray/datatree_/datatree/tests/test_formatting.py index 0f64644c05a..8726c95fe62 100644 --- a/xarray/datatree_/datatree/tests/test_formatting.py +++ b/xarray/datatree_/datatree/tests/test_formatting.py @@ -2,8 +2,8 @@ from xarray import Dataset -from datatree import DataTree -from datatree.formatting import diff_tree_repr +from xarray.datatree_.datatree import DataTree +from xarray.datatree_.datatree.formatting import diff_tree_repr class TestRepr: diff --git a/xarray/datatree_/datatree/tests/test_formatting_html.py b/xarray/datatree_/datatree/tests/test_formatting_html.py index 7c6a47ea86e..943bbab4154 100644 --- a/xarray/datatree_/datatree/tests/test_formatting_html.py +++ b/xarray/datatree_/datatree/tests/test_formatting_html.py @@ -1,7 +1,7 @@ import pytest import xarray as xr -from datatree import DataTree, formatting_html +from xarray.datatree_.datatree import DataTree, formatting_html @pytest.fixture(scope="module", params=["some html", "some other html"]) diff --git a/xarray/datatree_/datatree/tests/test_mapping.py b/xarray/datatree_/datatree/tests/test_mapping.py index 929ce7644dd..53d6e085440 100644 --- a/xarray/datatree_/datatree/tests/test_mapping.py +++ b/xarray/datatree_/datatree/tests/test_mapping.py @@ -2,9 +2,9 @@ import pytest import xarray as xr -from datatree.datatree import DataTree -from datatree.mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree -from datatree.testing import assert_equal +from xarray.datatree_.datatree.datatree import DataTree +from xarray.datatree_.datatree.mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree +from xarray.datatree_.datatree.testing import assert_equal empty = xr.Dataset() diff --git a/xarray/datatree_/datatree/tests/test_treenode.py b/xarray/datatree_/datatree/tests/test_treenode.py index f2d314c50e3..3c75f3ac8a4 100644 --- a/xarray/datatree_/datatree/tests/test_treenode.py +++ b/xarray/datatree_/datatree/tests/test_treenode.py @@ -1,7 +1,7 @@ import pytest -from datatree.iterators import LevelOrderIter, PreOrderIter -from datatree.treenode import InvalidTreeError, NamedNode, NodePath, TreeNode +from xarray.datatree_.datatree.iterators import LevelOrderIter, PreOrderIter +from xarray.datatree_.datatree.treenode import InvalidTreeError, NamedNode, NodePath, TreeNode class TestFamilyTree: diff --git a/xarray/datatree_/datatree/tests/test_version.py b/xarray/datatree_/datatree/tests/test_version.py deleted file mode 100644 index 207d5d86d53..00000000000 --- a/xarray/datatree_/datatree/tests/test_version.py +++ /dev/null @@ -1,5 +0,0 @@ -import datatree - - -def test_version(): - assert datatree.__version__ != "999" diff --git a/xarray/datatree_/pyproject.toml b/xarray/datatree_/pyproject.toml deleted file mode 100644 index 40f7d5a59b3..00000000000 --- a/xarray/datatree_/pyproject.toml +++ /dev/null @@ -1,61 +0,0 @@ -[project] -name = "xarray-datatree" -description = "Hierarchical tree-like data structures for xarray" -readme = "README.md" -authors = [ - {name = "Thomas Nicholas", email = "thomas.nicholas@columbia.edu"} -] -license = {text = "Apache-2"} -classifiers = [ - "Development Status :: 3 - Alpha", - "Intended Audience :: Science/Research", - "Topic :: Scientific/Engineering", - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", - "Programming Language :: Python", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", -] -requires-python = ">=3.9" -dependencies = [ - "xarray >=2023.12.0", - "packaging", -] -dynamic = ["version"] - -[project.urls] -Home = "https://github.com/xarray-contrib/datatree" -Documentation = "https://xarray-datatree.readthedocs.io/en/stable/" - -[build-system] -requires = [ - "setuptools>=61.0.0", - "wheel", - "setuptools_scm[toml]>=7.0", - "check-manifest" -] - -[tool.setuptools_scm] -write_to = "datatree/_version.py" -write_to_template = ''' -# Do not change! Do not track in version control! -__version__ = "{version}" -''' - -[tool.setuptools.packages.find] -exclude = ["docs", "tests", "tests.*", "docs.*"] - -[tool.setuptools.package-data] -datatree = ["py.typed"] - -[tool.isort] -profile = "black" -skip_gitignore = true -float_to_top = true -default_section = "THIRDPARTY" -known_first_party = "datatree" - -[mypy] -files = "datatree/**/*.py" -show_error_codes = true diff --git a/xarray/tests/datatree/conftest.py b/xarray/tests/datatree/conftest.py new file mode 100644 index 00000000000..b341f3007aa --- /dev/null +++ b/xarray/tests/datatree/conftest.py @@ -0,0 +1,65 @@ +import pytest + +import xarray as xr +from xarray.datatree_.datatree import DataTree + + +@pytest.fixture(scope="module") +def create_test_datatree(): + """ + Create a test datatree with this structure: + + + |-- set1 + | |-- + | | Dimensions: () + | | Data variables: + | | a int64 0 + | | b int64 1 + | |-- set1 + | |-- set2 + |-- set2 + | |-- + | | Dimensions: (x: 2) + | | Data variables: + | | a (x) int64 2, 3 + | | b (x) int64 0.1, 0.2 + | |-- set1 + |-- set3 + |-- + | Dimensions: (x: 2, y: 3) + | Data variables: + | a (y) int64 6, 7, 8 + | set0 (x) int64 9, 10 + + The structure has deliberately repeated names of tags, variables, and + dimensions in order to better check for bugs caused by name conflicts. + """ + + def _create_test_datatree(modify=lambda ds: ds): + set1_data = modify(xr.Dataset({"a": 0, "b": 1})) + set2_data = modify(xr.Dataset({"a": ("x", [2, 3]), "b": ("x", [0.1, 0.2])})) + root_data = modify(xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})) + + # Avoid using __init__ so we can independently test it + root: DataTree = DataTree(data=root_data) + set1: DataTree = DataTree(name="set1", parent=root, data=set1_data) + DataTree(name="set1", parent=set1) + DataTree(name="set2", parent=set1) + set2: DataTree = DataTree(name="set2", parent=root, data=set2_data) + DataTree(name="set1", parent=set2) + DataTree(name="set3", parent=root) + + return root + + return _create_test_datatree + + +@pytest.fixture(scope="module") +def simple_datatree(create_test_datatree): + """ + Invoke create_test_datatree fixture (callback). + + Returns a DataTree. + """ + return create_test_datatree() diff --git a/xarray/datatree_/datatree/tests/test_io.py b/xarray/tests/datatree/test_io.py similarity index 92% rename from xarray/datatree_/datatree/tests/test_io.py rename to xarray/tests/datatree/test_io.py index 6fa20479f9a..4f32e19de4a 100644 --- a/xarray/datatree_/datatree/tests/test_io.py +++ b/xarray/tests/datatree/test_io.py @@ -1,9 +1,12 @@ import pytest -import zarr.errors -from datatree.io import open_datatree -from datatree.testing import assert_equal -from datatree.tests import requires_h5netcdf, requires_netCDF4, requires_zarr +from xarray.backends.api import open_datatree +from xarray.datatree_.datatree.testing import assert_equal +from xarray.tests import ( + requires_h5netcdf, + requires_netCDF4, + requires_zarr, +) class TestIO: @@ -35,7 +38,7 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree): assert roundtrip_dt["/set2/a"].encoding["zlib"] == comp["zlib"] assert roundtrip_dt["/set2/a"].encoding["complevel"] == comp["complevel"] - enc["/not/a/group"] = {"foo": "bar"} + enc["/not/a/group"] = {"foo": "bar"} # type: ignore with pytest.raises(ValueError, match="unexpected encoding group.*"): original_dt.to_netcdf(filepath, encoding=enc, engine="netcdf4") @@ -78,7 +81,7 @@ def test_zarr_encoding(self, tmpdir, simple_datatree): print(roundtrip_dt["/set2/a"].encoding) assert roundtrip_dt["/set2/a"].encoding["compressor"] == comp["compressor"] - enc["/not/a/group"] = {"foo": "bar"} + enc["/not/a/group"] = {"foo": "bar"} # type: ignore with pytest.raises(ValueError, match="unexpected encoding group.*"): original_dt.to_zarr(filepath, encoding=enc, engine="zarr") @@ -113,6 +116,8 @@ def test_to_zarr_not_consolidated(self, tmpdir, simple_datatree): @requires_zarr def test_to_zarr_default_write_mode(self, tmpdir, simple_datatree): + import zarr + simple_datatree.to_zarr(tmpdir) # with default settings, to_zarr should not overwrite an existing dir