Skip to content

Commit

Permalink
Merge pull request #58 from alan-turing-institute/gpu-metrics
Browse files Browse the repository at this point in the history
Move calculation of metrics to GPU via`dm_pix`
  • Loading branch information
phinate authored Sep 20, 2024
2 parents 24a0f34 + c32d074 commit 6539034
Show file tree
Hide file tree
Showing 11 changed files with 592 additions and 315 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ repos:
- id: mixed-line-ending
- id: name-tests-test
args: ["--pytest-test-first"]
exclude: ^tests/legacy_metrics.py
- id: requirements-txt-fixer
- id: trailing-whitespace

Expand Down Expand Up @@ -46,3 +47,4 @@ repos:
- torch
- jaxtyping
- types-tqdm
- chex
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,20 @@ Tooling and infrastructure to enable cloud nowcasting.

## Installation

From source (development mode):
### For users:

```zsh
git clone https://github.com/alan-turing-institute/cloudcasting
cd cloudcasting
python -m pip install .
```

To run metrics on GPU:

```zsh
python -m install --upgrade "jax[cuda12]"
```
### For making changes to the library:

On macOS you first need to install `ffmpeg` with the following command. On other platforms this is
not necessary.
Expand All @@ -20,7 +33,7 @@ brew install ffmpeg
Clone and install the repo.

```bash
git clone https://github.com/climetrend/cloudcasting
git clone https://github.com/alan-turing-institute/cloudcasting
cd cloudcasting
python -m pip install ".[dev]"
```
Expand Down
12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ dependencies = [
"typer",
"lightning",
"torch>=2.3.0", # needed for numpy 2.0
"scikit-image",
"jaxtyping",
"wandb",
"tqdm",
"moviepy>=1.0.3",
"imageio>=2.35.1",
"numpy <2.1.0", # https://github.com/wandb/wandb/issues/8166
"chex",
]
[project.optional-dependencies]
dev = [
Expand All @@ -54,6 +54,7 @@ dev = [
"pre-commit",
"scipy",
"pytest-mock",
"scikit-image",
]

[tool.setuptools.package-data]
Expand All @@ -63,10 +64,10 @@ dev = [
cloudcasting = "cloudcasting.cli:app"

[project.urls]
Homepage = "https://github.com/climetrend/cloudcasting"
"Bug Tracker" = "https://github.com/climetrend/cloudcasting/issues"
Discussions = "https://github.com/climetrend/cloudcasting/discussions"
Changelog = "https://github.com/climetrend/cloudcasting/releases"
Homepage = "https://github.com/alan-turing-institute/cloudcasting"
"Bug Tracker" = "https://github.com/alan-turing-institute/cloudcasting/issues"
Discussions = "https://github.com/alan-turing-institute/cloudcasting/discussions"
Changelog = "https://github.com/alan-turing-institute/cloudcasting/releases"

[tool.pytest.ini_options]
minversion = "6.0"
Expand Down Expand Up @@ -161,6 +162,7 @@ select = [
ignore = [
"PLR", # Design related pylint codes
"ISC001", # Conflicts with formatter
"F722" # Marks jaxtyping syntax annotations as incorrect
]
unfixable = [
"F401", # Would remove unused imports
Expand Down
7 changes: 4 additions & 3 deletions src/cloudcasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
# dataclasses, meaning that they will be type-checked
# (and therefore shape-checked via jaxtyping) at runtime.
with install_import_hook("cloudcasting", "typeguard.typechecked"):
from cloudcasting import metrics, models
from cloudcasting import models, validation

from cloudcasting import cli, dataset, download
from cloudcasting import cli, dataset, download, metrics

__all__ = (
"__version__",
"download",
"cli",
"dataset",
"metrics",
"models",
"validation",
"metrics",
)
__version__ = version(__name__)
Loading

0 comments on commit 6539034

Please sign in to comment.