Skip to content

Commit

Permalink
Add rgb_range percentile support
Browse files Browse the repository at this point in the history
  • Loading branch information
atanas-balevsky committed Dec 7, 2023
1 parent 1a60e00 commit b345546
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 10 deletions.
24 changes: 21 additions & 3 deletions terracotta/handlers/rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from terracotta.profile import trace

Number = TypeVar("Number", int, float)
ListOfRanges = Sequence[Optional[Tuple[Optional[Number], Optional[Number]]]]
NumberOrString = TypeVar("NumberOrString", int, float, str)
ListOfRanges = Sequence[Optional[Tuple[Optional[NumberOrString], Optional[NumberOrString]]]]


@trace("rgb_handler")
Expand Down Expand Up @@ -90,10 +91,10 @@ def get_band_future(band_key: str) -> Future:
scale_min, scale_max = band_stretch_override

if scale_min is not None:
band_stretch_range[0] = scale_min
band_stretch_range[0] = get_scale(scale_min, metadata)

if scale_max is not None:
band_stretch_range[1] = scale_max
band_stretch_range[1] = get_scale(scale_max, metadata)

if band_stretch_range[1] < band_stretch_range[0]:
raise exceptions.InvalidArgumentsError(
Expand All @@ -105,3 +106,20 @@ def get_band_future(band_key: str) -> Future:

out = np.ma.stack(out_arrays, axis=-1)
return image.array_to_png(out)


def get_scale(scale: NumberOrString, metadata) -> Number:
if isinstance(scale, (int, float)):
return scale
if isinstance(scale, str):
# can be a percentile
if scale.startswith("p"):
# TODO check if percentile is in range
percentile = int(scale[1:]) - 1
return metadata["percentiles"][percentile]

# can be a number
return float(scale)
raise exceptions.InvalidArgumentsError(
"Invalid scale value: %s" % scale
)
21 changes: 15 additions & 6 deletions terracotta/server/rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,34 @@ class Meta:
g = fields.String(required=True, description="Key value for green band")
b = fields.String(required=True, description="Key value for blue band")
r_range = fields.List(
fields.Number(allow_none=True),
fields.String(allow_none=True, validate=validate.Regexp("^p?(\d*\.)?\d+$")),
validate=validate.Length(equal=2),
example="[0,1]",
missing=None,
description="Stretch range [min, max] to use for red band as JSON array",
description=(
"Stretch range [min, max] to use for red band as JSON array, "
"prefix with `p` for percentile"
),
)
g_range = fields.List(
fields.Number(allow_none=True),
fields.String(allow_none=True, validate=validate.Regexp("^p?(\d*\.)?\d+$")),
validate=validate.Length(equal=2),
example="[0,1]",
missing=None,
description="Stretch range [min, max] to use for green band as JSON array",
description=(
"Stretch range [min, max] to use for red band as JSON array, "
"prefix with `p` for percentile"
),
)
b_range = fields.List(
fields.Number(allow_none=True),
fields.String(allow_none=True, validate=validate.Regexp("^p?(\d*\.)?\d+$")),
validate=validate.Length(equal=2),
example="[0,1]",
missing=None,
description="Stretch range [min, max] to use for blue band as JSON array",
description=(
"Stretch range [min, max] to use for red band as JSON array, "
"prefix with `p` for percentile"
),
)
tile_size = fields.List(
fields.Integer(),
Expand Down
50 changes: 49 additions & 1 deletion tests/handlers/test_rgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def test_rgb_lowzoom(use_testdb, raster_file, raster_file_xyz_lowzoom):


@pytest.mark.parametrize(
"stretch_range", [[0, 20000], [10000, 20000], [-50000, 50000], [100, 100]]
"stretch_range", [
[0, 20000], [10000, 20000], [-50000, 50000], [100, 100],
["0", "20000"], ["10000", "20000"], ["-50000", "50000"], ["100", "100"],
]
)
def test_rgb_stretch(stretch_range, use_testdb, testdb, raster_file_xyz):
import terracotta
Expand Down Expand Up @@ -106,6 +109,7 @@ def test_rgb_stretch(stretch_range, use_testdb, testdb, raster_file_xyz):
valid_img = img_data[valid_mask]
valid_data = tile_data.compressed()

stretch_range = [float(stretch_range[0]), float(stretch_range[1])]
assert np.all(valid_img[valid_data < stretch_range[0]] == 1)
stretch_range_mask = (valid_data > stretch_range[0]) & (
valid_data < stretch_range[1]
Expand All @@ -131,6 +135,50 @@ def test_rgb_invalid_stretch(use_testdb, raster_file_xyz):
)


def test_rgb_percentile_stretch(use_testdb, testdb, raster_file_xyz):
import terracotta
from terracotta.xyz import get_tile_data
from terracotta.handlers import rgb

ds_keys = ["val21", "x", "val22"]
bands = ["val22", "val23", "val24"]
pct_stretch_range = ["p2", "p98"]

raw_img = rgb.rgb(
ds_keys[:2],
bands,
raster_file_xyz,
stretch_ranges=[pct_stretch_range] * 3,
)
img_data = np.asarray(Image.open(raw_img))[..., 0]

# get unstretched data to compare to
driver = terracotta.get_driver(testdb)

with driver.connect():
tile_data = get_tile_data(
driver, ds_keys, tile_xyz=raster_file_xyz, tile_size=img_data.shape
)
band_metadata = driver.get_metadata(ds_keys)

stretch_range = [band_metadata["percentiles"][1], band_metadata["percentiles"][97]]

# filter transparent values
valid_mask = ~tile_data.mask
assert np.all(img_data[~valid_mask] == 0)

valid_img = img_data[valid_mask]
valid_data = tile_data.compressed()

assert np.all(valid_img[valid_data < stretch_range[0]] == 1)
stretch_range_mask = (valid_data > stretch_range[0]) & (
valid_data < stretch_range[1]
)
assert np.all(valid_img[stretch_range_mask] >= 1)
assert np.all(valid_img[stretch_range_mask] <= 255)
assert np.all(valid_img[valid_data > stretch_range[1]] == 255)


def test_rgb_preview(use_testdb):
import terracotta
from terracotta.handlers import rgb
Expand Down

0 comments on commit b345546

Please sign in to comment.