|
4 | 4 | """
|
5 | 5 |
|
6 | 6 | import hashlib
|
7 |
| -from typing import Callable, Sequence |
| 7 | +from typing import Callable, Sequence, Annotated |
| 8 | +from typing_extensions import assert_never |
| 9 | + |
| 10 | +import sacc |
| 11 | +from pydantic import ( |
| 12 | + BaseModel, |
| 13 | + BeforeValidator, |
| 14 | + ConfigDict, |
| 15 | + Field, |
| 16 | + model_validator, |
| 17 | + PrivateAttr, |
| 18 | + field_serializer, |
| 19 | +) |
8 | 20 | import numpy as np
|
9 | 21 | import numpy.typing as npt
|
10 |
| -import sacc |
| 22 | + |
11 | 23 | from firecrown.metadata_types import (
|
12 | 24 | TwoPointHarmonic,
|
13 | 25 | TwoPointReal,
|
| 26 | + Measurement, |
14 | 27 | )
|
15 | 28 | from firecrown.metadata_functions import (
|
16 | 29 | extract_all_tracers_inferred_galaxy_zdists,
|
17 | 30 | extract_window_function,
|
18 | 31 | extract_all_harmonic_metadata_indices,
|
19 | 32 | extract_all_real_metadata_indices,
|
20 | 33 | make_two_point_xy,
|
| 34 | + make_measurement, |
| 35 | + make_measurement_dict, |
21 | 36 | )
|
22 | 37 | from firecrown.data_types import TwoPointMeasurement
|
23 | 38 |
|
@@ -222,3 +237,275 @@ def check_two_point_consistence_real(
|
222 | 237 | ) -> None:
|
223 | 238 | """Check the indices of the real-space two-point functions."""
|
224 | 239 | check_consistence(two_point_reals, lambda m: m.is_real(), "TwoPointReal")
|
| 240 | + |
| 241 | + |
| 242 | +class TwoPointTracerSpec(BaseModel): |
| 243 | + """Class defining a tracer bin specification.""" |
| 244 | + |
| 245 | + model_config = ConfigDict(extra="forbid", frozen=True) |
| 246 | + |
| 247 | + name: Annotated[str, Field(description="The name of the tracer bin.")] |
| 248 | + measurement: Annotated[ |
| 249 | + Measurement, |
| 250 | + Field(description="The measurement of the tracer bin."), |
| 251 | + BeforeValidator(make_measurement), |
| 252 | + ] |
| 253 | + |
| 254 | + @field_serializer("measurement") |
| 255 | + @classmethod |
| 256 | + def serialize_measurement(cls, value: Measurement) -> dict[str, str]: |
| 257 | + """Serialize the Measurement.""" |
| 258 | + return make_measurement_dict(value) |
| 259 | + |
| 260 | + |
| 261 | +def make_interval_from_list( |
| 262 | + values: list[float] | tuple[float, float], |
| 263 | +) -> tuple[float, float]: |
| 264 | + """Create an interval from a list of values.""" |
| 265 | + if isinstance(values, list): |
| 266 | + if len(values) != 2: |
| 267 | + raise ValueError("The list should have two values.") |
| 268 | + if not all(isinstance(v, float) for v in values): |
| 269 | + raise ValueError("The list should have two float values.") |
| 270 | + |
| 271 | + return (values[0], values[1]) |
| 272 | + if isinstance(values, tuple): |
| 273 | + return values |
| 274 | + |
| 275 | + raise ValueError("The values should be a list or a tuple.") |
| 276 | + |
| 277 | + |
| 278 | +class TwoPointBinFilter(BaseModel): |
| 279 | + """Class defining a filter for a bin.""" |
| 280 | + |
| 281 | + model_config = ConfigDict(extra="forbid", frozen=True) |
| 282 | + |
| 283 | + spec: Annotated[ |
| 284 | + list[TwoPointTracerSpec], |
| 285 | + Field( |
| 286 | + description="The two-point bin specification.", |
| 287 | + ), |
| 288 | + ] |
| 289 | + interval: Annotated[ |
| 290 | + tuple[float, float], |
| 291 | + BeforeValidator(make_interval_from_list), |
| 292 | + Field(description="The range of the bin to filter."), |
| 293 | + ] |
| 294 | + |
| 295 | + @model_validator(mode="after") |
| 296 | + def check_bin_filter(self) -> "TwoPointBinFilter": |
| 297 | + """Check the bin filter.""" |
| 298 | + if self.interval[0] >= self.interval[1]: |
| 299 | + raise ValueError("The bin filter should be a valid range.") |
| 300 | + if not 1 <= len(self.spec) <= 2: |
| 301 | + raise ValueError("The bin_spec must contain one or two elements.") |
| 302 | + return self |
| 303 | + |
| 304 | + @field_serializer("interval") |
| 305 | + @classmethod |
| 306 | + def serialize_interval(cls, value: tuple[float, float]) -> list[float]: |
| 307 | + """Serialize the Measurement.""" |
| 308 | + return list(value) |
| 309 | + |
| 310 | + @classmethod |
| 311 | + def from_args( |
| 312 | + cls, |
| 313 | + name1: str, |
| 314 | + measurement1: Measurement, |
| 315 | + name2: str, |
| 316 | + measurement2: Measurement, |
| 317 | + lower: float, |
| 318 | + upper: float, |
| 319 | + ) -> "TwoPointBinFilter": |
| 320 | + """Create a TwoPointBinFilter from the arguments.""" |
| 321 | + return cls( |
| 322 | + spec=[ |
| 323 | + TwoPointTracerSpec(name=name1, measurement=measurement1), |
| 324 | + TwoPointTracerSpec(name=name2, measurement=measurement2), |
| 325 | + ], |
| 326 | + interval=(lower, upper), |
| 327 | + ) |
| 328 | + |
| 329 | + @classmethod |
| 330 | + def from_args_auto( |
| 331 | + cls, name: str, measurement: Measurement, lower: float, upper: float |
| 332 | + ) -> "TwoPointBinFilter": |
| 333 | + """Create a TwoPointBinFilter from the arguments.""" |
| 334 | + return cls( |
| 335 | + spec=[ |
| 336 | + TwoPointTracerSpec(name=name, measurement=measurement), |
| 337 | + ], |
| 338 | + interval=(lower, upper), |
| 339 | + ) |
| 340 | + |
| 341 | + |
| 342 | +BinSpec = frozenset[TwoPointTracerSpec] |
| 343 | + |
| 344 | + |
| 345 | +def bin_spec_from_metadata(metadata: TwoPointReal | TwoPointHarmonic) -> BinSpec: |
| 346 | + """Return the bin spec from the metadata.""" |
| 347 | + return frozenset( |
| 348 | + ( |
| 349 | + TwoPointTracerSpec( |
| 350 | + name=metadata.XY.x.bin_name, |
| 351 | + measurement=metadata.XY.x_measurement, |
| 352 | + ), |
| 353 | + TwoPointTracerSpec( |
| 354 | + name=metadata.XY.y.bin_name, |
| 355 | + measurement=metadata.XY.y_measurement, |
| 356 | + ), |
| 357 | + ) |
| 358 | + ) |
| 359 | + |
| 360 | + |
| 361 | +class TwoPointBinFilterCollection(BaseModel): |
| 362 | + """Class defining a collection of bin filters.""" |
| 363 | + |
| 364 | + model_config = ConfigDict(extra="forbid", frozen=True) |
| 365 | + |
| 366 | + require_filter_for_all: bool = Field( |
| 367 | + default=False, |
| 368 | + description="If True, all bins should match a filter.", |
| 369 | + ) |
| 370 | + allow_empty: bool = Field( |
| 371 | + default=False, |
| 372 | + description=( |
| 373 | + "When true, objects with no elements remaining after applying " |
| 374 | + "the filter will be ignored rather than treated as an error." |
| 375 | + ), |
| 376 | + ) |
| 377 | + filters: list[TwoPointBinFilter] = Field( |
| 378 | + description="The list of bin filters.", |
| 379 | + ) |
| 380 | + |
| 381 | + _bin_filter_dict: dict[BinSpec, tuple[float, float]] = PrivateAttr() |
| 382 | + |
| 383 | + @model_validator(mode="after") |
| 384 | + def check_bin_filters(self) -> "TwoPointBinFilterCollection": |
| 385 | + """Check the bin filters.""" |
| 386 | + bin_specs = set() |
| 387 | + for bin_filter in self.filters: |
| 388 | + bin_spec = frozenset(bin_filter.spec) |
| 389 | + if bin_spec in bin_specs: |
| 390 | + raise ValueError( |
| 391 | + f"The bin name {bin_filter.spec} is repeated " |
| 392 | + f"in the bin filters." |
| 393 | + ) |
| 394 | + bin_specs.add(bin_spec) |
| 395 | + |
| 396 | + self._bin_filter_dict = { |
| 397 | + frozenset(bin_filter.spec): bin_filter.interval |
| 398 | + for bin_filter in self.filters |
| 399 | + } |
| 400 | + return self |
| 401 | + |
| 402 | + @property |
| 403 | + def bin_filter_dict(self) -> dict[BinSpec, tuple[float, float]]: |
| 404 | + """Return the bin filter dictionary.""" |
| 405 | + return self._bin_filter_dict |
| 406 | + |
| 407 | + def filter_match(self, tpm: TwoPointMeasurement) -> bool: |
| 408 | + """Check if the TwoPointMeasurement matches the filter.""" |
| 409 | + bin_spec_key = bin_spec_from_metadata(tpm.metadata) |
| 410 | + return bin_spec_key in self._bin_filter_dict |
| 411 | + |
| 412 | + def run_bin_filter( |
| 413 | + self, |
| 414 | + bin_filter: tuple[float, float], |
| 415 | + vals: npt.NDArray[np.float64] | npt.NDArray[np.int64], |
| 416 | + ) -> npt.NDArray[np.bool_]: |
| 417 | + """Run the filter merge.""" |
| 418 | + return (vals >= bin_filter[0]) & (vals <= bin_filter[1]) |
| 419 | + |
| 420 | + def apply_filter_single( |
| 421 | + self, tpm: TwoPointMeasurement |
| 422 | + ) -> tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_]]: |
| 423 | + """Apply the filter to a single TwoPointMeasurement.""" |
| 424 | + assert self.filter_match(tpm) |
| 425 | + bin_spec_key = bin_spec_from_metadata(tpm.metadata) |
| 426 | + bin_filter = self._bin_filter_dict[bin_spec_key] |
| 427 | + if tpm.is_real(): |
| 428 | + assert isinstance(tpm.metadata, TwoPointReal) |
| 429 | + match_elements = self.run_bin_filter(bin_filter, tpm.metadata.thetas) |
| 430 | + return match_elements, match_elements |
| 431 | + |
| 432 | + assert isinstance(tpm.metadata, TwoPointHarmonic) |
| 433 | + match_elements = self.run_bin_filter(bin_filter, tpm.metadata.ells) |
| 434 | + match_obs = match_elements |
| 435 | + if tpm.metadata.window is not None: |
| 436 | + # The window function is represented by a matrix where each column |
| 437 | + # corresponds to the weights for the ell values of each observation. We |
| 438 | + # need to ensure that the window function is filtered correctly. To do this, |
| 439 | + # we will check each column of the matrix and verify that all non-zero |
| 440 | + # elements are within the filtered set. If any non-zero element falls |
| 441 | + # outside the filtered set, the match_elements will be set to False for that |
| 442 | + # observation. |
| 443 | + non_zero_window = tpm.metadata.window > 0 |
| 444 | + match_obs = ( |
| 445 | + np.all( |
| 446 | + (non_zero_window & match_elements[:, None]) == non_zero_window, |
| 447 | + axis=0, |
| 448 | + ) |
| 449 | + .ravel() |
| 450 | + .astype(np.bool_) |
| 451 | + ) |
| 452 | + |
| 453 | + return match_elements, match_obs |
| 454 | + |
| 455 | + def __call__( |
| 456 | + self, tpms: Sequence[TwoPointMeasurement] |
| 457 | + ) -> list[TwoPointMeasurement]: |
| 458 | + """Filter the two-point measurements.""" |
| 459 | + result = [] |
| 460 | + |
| 461 | + for tpm in tpms: |
| 462 | + if not self.filter_match(tpm): |
| 463 | + if not self.require_filter_for_all: |
| 464 | + result.append(tpm) |
| 465 | + continue |
| 466 | + raise ValueError(f"The bin name {tpm.metadata} does not have a filter.") |
| 467 | + |
| 468 | + match_elements, match_obs = self.apply_filter_single(tpm) |
| 469 | + if not match_obs.any(): |
| 470 | + if not self.allow_empty: |
| 471 | + # If empty results are not allowed, we raise an error |
| 472 | + raise ValueError( |
| 473 | + f"The TwoPointMeasurement {tpm.metadata} does not " |
| 474 | + f"have any elements matching the filter." |
| 475 | + ) |
| 476 | + # If the filter is empty, we skip this measurement |
| 477 | + continue |
| 478 | + |
| 479 | + assert isinstance(tpm.metadata, (TwoPointReal, TwoPointHarmonic)) |
| 480 | + new_metadata: TwoPointReal | TwoPointHarmonic |
| 481 | + match tpm.metadata: |
| 482 | + case TwoPointReal(): |
| 483 | + new_metadata = TwoPointReal( |
| 484 | + XY=tpm.metadata.XY, |
| 485 | + thetas=tpm.metadata.thetas[match_elements], |
| 486 | + ) |
| 487 | + case TwoPointHarmonic(): |
| 488 | + # If the window function is not None, we need to filter it as well |
| 489 | + # and update the metadata accordingly. |
| 490 | + new_metadata = TwoPointHarmonic( |
| 491 | + XY=tpm.metadata.XY, |
| 492 | + window=( |
| 493 | + tpm.metadata.window[:, match_obs][match_elements, :] |
| 494 | + if tpm.metadata.window is not None |
| 495 | + else None |
| 496 | + ), |
| 497 | + ells=tpm.metadata.ells[match_elements], |
| 498 | + ) |
| 499 | + case _ as unreachable: |
| 500 | + assert_never(unreachable) |
| 501 | + |
| 502 | + result.append( |
| 503 | + TwoPointMeasurement( |
| 504 | + data=tpm.data[match_obs], |
| 505 | + indices=tpm.indices[match_obs], |
| 506 | + covariance_name=tpm.covariance_name, |
| 507 | + metadata=new_metadata, |
| 508 | + ) |
| 509 | + ) |
| 510 | + |
| 511 | + return result |
0 commit comments