Skip to content

Commit 9613906

Browse files
vitentipaulrogozenskimarcpaterno
authoredFeb 13, 2025··
Adding support for TwoPointMeasurement filters. (#479)
* Adding support for TwoPointMeasurement filters. * Adding tests for filters. * Removed unused import. * Using two-point pair as filter specification. * Adding support for serializing and deserializing filters. * Adding tests for factories. * Added check for never reached branch. * TwoPointFilter documentation first draft * Simplify logic * If _path is set do not search current directory * Release cython version restriction * Do not load our duplicate_code plugin * Update finding of SACC files in some tests * Update version tag * Refactor for improved test coverage and fix missing error case * Complete branch coverage * Improving tutorial. * Correct serialization for TwoPointFactory. * Added test for TwoPointFactory serialization. --------- Co-authored-by: paulrogozenski <[email protected]> Co-authored-by: Marc Paterno <[email protected]>
1 parent da7bb3a commit 9613906

21 files changed

+1578
-115
lines changed
 

‎docs/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
author = "LSST DESC Firecrown Contributors"
2424

2525
# The full version, including alpha/beta/rc tags
26-
release = "1.8.0"
26+
release = "1.9.0a0"
2727

2828

2929
# -- General configuration ---------------------------------------------------

‎environment.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ dependencies:
99
- cosmosis >= 3.0
1010
- cosmosis-build-standard-library
1111
- coverage
12-
- cython < 3.0.0
12+
- cython
1313
- dill
1414
- fitsio
1515
- flake8

‎fctools/tracer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
# some global context to be used in the tracing. We are relying on
2626
# 'trace_call' to act as a closure that captures these names.
27-
tracefile = None # the file used for logging
27+
tracefile: TextIO | None = None # the file used for logging
2828
level = 0 # the call nesting level
2929
entry = 0 # sequential entry number for each record
3030

‎firecrown/data_functions.py

+289-2
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,35 @@
44
"""
55

66
import hashlib
7-
from typing import Callable, Sequence
7+
from typing import Callable, Sequence, Annotated
8+
from typing_extensions import assert_never
9+
10+
import sacc
11+
from pydantic import (
12+
BaseModel,
13+
BeforeValidator,
14+
ConfigDict,
15+
Field,
16+
model_validator,
17+
PrivateAttr,
18+
field_serializer,
19+
)
820
import numpy as np
921
import numpy.typing as npt
10-
import sacc
22+
1123
from firecrown.metadata_types import (
1224
TwoPointHarmonic,
1325
TwoPointReal,
26+
Measurement,
1427
)
1528
from firecrown.metadata_functions import (
1629
extract_all_tracers_inferred_galaxy_zdists,
1730
extract_window_function,
1831
extract_all_harmonic_metadata_indices,
1932
extract_all_real_metadata_indices,
2033
make_two_point_xy,
34+
make_measurement,
35+
make_measurement_dict,
2136
)
2237
from firecrown.data_types import TwoPointMeasurement
2338

@@ -222,3 +237,275 @@ def check_two_point_consistence_real(
222237
) -> None:
223238
"""Check the indices of the real-space two-point functions."""
224239
check_consistence(two_point_reals, lambda m: m.is_real(), "TwoPointReal")
240+
241+
242+
class TwoPointTracerSpec(BaseModel):
243+
"""Class defining a tracer bin specification."""
244+
245+
model_config = ConfigDict(extra="forbid", frozen=True)
246+
247+
name: Annotated[str, Field(description="The name of the tracer bin.")]
248+
measurement: Annotated[
249+
Measurement,
250+
Field(description="The measurement of the tracer bin."),
251+
BeforeValidator(make_measurement),
252+
]
253+
254+
@field_serializer("measurement")
255+
@classmethod
256+
def serialize_measurement(cls, value: Measurement) -> dict[str, str]:
257+
"""Serialize the Measurement."""
258+
return make_measurement_dict(value)
259+
260+
261+
def make_interval_from_list(
262+
values: list[float] | tuple[float, float],
263+
) -> tuple[float, float]:
264+
"""Create an interval from a list of values."""
265+
if isinstance(values, list):
266+
if len(values) != 2:
267+
raise ValueError("The list should have two values.")
268+
if not all(isinstance(v, float) for v in values):
269+
raise ValueError("The list should have two float values.")
270+
271+
return (values[0], values[1])
272+
if isinstance(values, tuple):
273+
return values
274+
275+
raise ValueError("The values should be a list or a tuple.")
276+
277+
278+
class TwoPointBinFilter(BaseModel):
279+
"""Class defining a filter for a bin."""
280+
281+
model_config = ConfigDict(extra="forbid", frozen=True)
282+
283+
spec: Annotated[
284+
list[TwoPointTracerSpec],
285+
Field(
286+
description="The two-point bin specification.",
287+
),
288+
]
289+
interval: Annotated[
290+
tuple[float, float],
291+
BeforeValidator(make_interval_from_list),
292+
Field(description="The range of the bin to filter."),
293+
]
294+
295+
@model_validator(mode="after")
296+
def check_bin_filter(self) -> "TwoPointBinFilter":
297+
"""Check the bin filter."""
298+
if self.interval[0] >= self.interval[1]:
299+
raise ValueError("The bin filter should be a valid range.")
300+
if not 1 <= len(self.spec) <= 2:
301+
raise ValueError("The bin_spec must contain one or two elements.")
302+
return self
303+
304+
@field_serializer("interval")
305+
@classmethod
306+
def serialize_interval(cls, value: tuple[float, float]) -> list[float]:
307+
"""Serialize the Measurement."""
308+
return list(value)
309+
310+
@classmethod
311+
def from_args(
312+
cls,
313+
name1: str,
314+
measurement1: Measurement,
315+
name2: str,
316+
measurement2: Measurement,
317+
lower: float,
318+
upper: float,
319+
) -> "TwoPointBinFilter":
320+
"""Create a TwoPointBinFilter from the arguments."""
321+
return cls(
322+
spec=[
323+
TwoPointTracerSpec(name=name1, measurement=measurement1),
324+
TwoPointTracerSpec(name=name2, measurement=measurement2),
325+
],
326+
interval=(lower, upper),
327+
)
328+
329+
@classmethod
330+
def from_args_auto(
331+
cls, name: str, measurement: Measurement, lower: float, upper: float
332+
) -> "TwoPointBinFilter":
333+
"""Create a TwoPointBinFilter from the arguments."""
334+
return cls(
335+
spec=[
336+
TwoPointTracerSpec(name=name, measurement=measurement),
337+
],
338+
interval=(lower, upper),
339+
)
340+
341+
342+
BinSpec = frozenset[TwoPointTracerSpec]
343+
344+
345+
def bin_spec_from_metadata(metadata: TwoPointReal | TwoPointHarmonic) -> BinSpec:
346+
"""Return the bin spec from the metadata."""
347+
return frozenset(
348+
(
349+
TwoPointTracerSpec(
350+
name=metadata.XY.x.bin_name,
351+
measurement=metadata.XY.x_measurement,
352+
),
353+
TwoPointTracerSpec(
354+
name=metadata.XY.y.bin_name,
355+
measurement=metadata.XY.y_measurement,
356+
),
357+
)
358+
)
359+
360+
361+
class TwoPointBinFilterCollection(BaseModel):
362+
"""Class defining a collection of bin filters."""
363+
364+
model_config = ConfigDict(extra="forbid", frozen=True)
365+
366+
require_filter_for_all: bool = Field(
367+
default=False,
368+
description="If True, all bins should match a filter.",
369+
)
370+
allow_empty: bool = Field(
371+
default=False,
372+
description=(
373+
"When true, objects with no elements remaining after applying "
374+
"the filter will be ignored rather than treated as an error."
375+
),
376+
)
377+
filters: list[TwoPointBinFilter] = Field(
378+
description="The list of bin filters.",
379+
)
380+
381+
_bin_filter_dict: dict[BinSpec, tuple[float, float]] = PrivateAttr()
382+
383+
@model_validator(mode="after")
384+
def check_bin_filters(self) -> "TwoPointBinFilterCollection":
385+
"""Check the bin filters."""
386+
bin_specs = set()
387+
for bin_filter in self.filters:
388+
bin_spec = frozenset(bin_filter.spec)
389+
if bin_spec in bin_specs:
390+
raise ValueError(
391+
f"The bin name {bin_filter.spec} is repeated "
392+
f"in the bin filters."
393+
)
394+
bin_specs.add(bin_spec)
395+
396+
self._bin_filter_dict = {
397+
frozenset(bin_filter.spec): bin_filter.interval
398+
for bin_filter in self.filters
399+
}
400+
return self
401+
402+
@property
403+
def bin_filter_dict(self) -> dict[BinSpec, tuple[float, float]]:
404+
"""Return the bin filter dictionary."""
405+
return self._bin_filter_dict
406+
407+
def filter_match(self, tpm: TwoPointMeasurement) -> bool:
408+
"""Check if the TwoPointMeasurement matches the filter."""
409+
bin_spec_key = bin_spec_from_metadata(tpm.metadata)
410+
return bin_spec_key in self._bin_filter_dict
411+
412+
def run_bin_filter(
413+
self,
414+
bin_filter: tuple[float, float],
415+
vals: npt.NDArray[np.float64] | npt.NDArray[np.int64],
416+
) -> npt.NDArray[np.bool_]:
417+
"""Run the filter merge."""
418+
return (vals >= bin_filter[0]) & (vals <= bin_filter[1])
419+
420+
def apply_filter_single(
421+
self, tpm: TwoPointMeasurement
422+
) -> tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_]]:
423+
"""Apply the filter to a single TwoPointMeasurement."""
424+
assert self.filter_match(tpm)
425+
bin_spec_key = bin_spec_from_metadata(tpm.metadata)
426+
bin_filter = self._bin_filter_dict[bin_spec_key]
427+
if tpm.is_real():
428+
assert isinstance(tpm.metadata, TwoPointReal)
429+
match_elements = self.run_bin_filter(bin_filter, tpm.metadata.thetas)
430+
return match_elements, match_elements
431+
432+
assert isinstance(tpm.metadata, TwoPointHarmonic)
433+
match_elements = self.run_bin_filter(bin_filter, tpm.metadata.ells)
434+
match_obs = match_elements
435+
if tpm.metadata.window is not None:
436+
# The window function is represented by a matrix where each column
437+
# corresponds to the weights for the ell values of each observation. We
438+
# need to ensure that the window function is filtered correctly. To do this,
439+
# we will check each column of the matrix and verify that all non-zero
440+
# elements are within the filtered set. If any non-zero element falls
441+
# outside the filtered set, the match_elements will be set to False for that
442+
# observation.
443+
non_zero_window = tpm.metadata.window > 0
444+
match_obs = (
445+
np.all(
446+
(non_zero_window & match_elements[:, None]) == non_zero_window,
447+
axis=0,
448+
)
449+
.ravel()
450+
.astype(np.bool_)
451+
)
452+
453+
return match_elements, match_obs
454+
455+
def __call__(
456+
self, tpms: Sequence[TwoPointMeasurement]
457+
) -> list[TwoPointMeasurement]:
458+
"""Filter the two-point measurements."""
459+
result = []
460+
461+
for tpm in tpms:
462+
if not self.filter_match(tpm):
463+
if not self.require_filter_for_all:
464+
result.append(tpm)
465+
continue
466+
raise ValueError(f"The bin name {tpm.metadata} does not have a filter.")
467+
468+
match_elements, match_obs = self.apply_filter_single(tpm)
469+
if not match_obs.any():
470+
if not self.allow_empty:
471+
# If empty results are not allowed, we raise an error
472+
raise ValueError(
473+
f"The TwoPointMeasurement {tpm.metadata} does not "
474+
f"have any elements matching the filter."
475+
)
476+
# If the filter is empty, we skip this measurement
477+
continue
478+
479+
assert isinstance(tpm.metadata, (TwoPointReal, TwoPointHarmonic))
480+
new_metadata: TwoPointReal | TwoPointHarmonic
481+
match tpm.metadata:
482+
case TwoPointReal():
483+
new_metadata = TwoPointReal(
484+
XY=tpm.metadata.XY,
485+
thetas=tpm.metadata.thetas[match_elements],
486+
)
487+
case TwoPointHarmonic():
488+
# If the window function is not None, we need to filter it as well
489+
# and update the metadata accordingly.
490+
new_metadata = TwoPointHarmonic(
491+
XY=tpm.metadata.XY,
492+
window=(
493+
tpm.metadata.window[:, match_obs][match_elements, :]
494+
if tpm.metadata.window is not None
495+
else None
496+
),
497+
ells=tpm.metadata.ells[match_elements],
498+
)
499+
case _ as unreachable:
500+
assert_never(unreachable)
501+
502+
result.append(
503+
TwoPointMeasurement(
504+
data=tpm.data[match_obs],
505+
indices=tpm.indices[match_obs],
506+
covariance_name=tpm.covariance_name,
507+
metadata=new_metadata,
508+
)
509+
)
510+
511+
return result

‎firecrown/generators/inferred_galaxy_zdist.py

+4-42
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,12 @@
1313

1414
from numcosmo_py import Ncm
1515

16-
from firecrown.metadata_types import (
17-
InferredGalaxyZDist,
18-
ALL_MEASUREMENT_TYPES,
16+
from firecrown.metadata_types import InferredGalaxyZDist, Galaxies
17+
from firecrown.metadata_functions import (
18+
Measurement,
19+
make_measurements,
1920
make_measurements_dict,
20-
Galaxies,
21-
CMB,
22-
Clusters,
2321
)
24-
from firecrown.metadata_functions import Measurement
2522

2623

2724
BinsType = TypedDict("BinsType", {"edges": npt.NDArray, "sigma_z": float})
@@ -446,41 +443,6 @@ def generate(self) -> npt.NDArray:
446443
Grid1D = LinearGrid1D | RawGrid1D
447444

448445

449-
def make_measurements(
450-
value: set[Measurement] | list[dict[str, Any]],
451-
) -> set[Measurement]:
452-
"""Create a Measurement object from a dictionary."""
453-
if isinstance(value, set) and all(
454-
isinstance(v, ALL_MEASUREMENT_TYPES) for v in value
455-
):
456-
return value
457-
458-
measurements: set[Measurement] = set()
459-
for measurement_dict in value:
460-
if not isinstance(measurement_dict, dict):
461-
raise ValueError(f"Invalid Measurement: {value} is not a dictionary")
462-
463-
if "subject" not in measurement_dict:
464-
raise ValueError(
465-
"Invalid Measurement: dictionary does not contain 'subject'"
466-
)
467-
468-
subject = measurement_dict["subject"]
469-
470-
match subject:
471-
case "Galaxies":
472-
measurements.update({Galaxies[measurement_dict["property"]]})
473-
case "CMB":
474-
measurements.update({CMB[measurement_dict["property"]]})
475-
case "Clusters":
476-
measurements.update({Clusters[measurement_dict["property"]]})
477-
case _:
478-
raise ValueError(
479-
f"Invalid Measurement: subject: '{subject}' is not recognized"
480-
)
481-
return measurements
482-
483-
484446
class ZDistLSSTSRDBin(BaseModel):
485447
"""LSST Inferred galaxy redshift distributions in bins."""
486448

‎firecrown/likelihood/factories.py

+68-31
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from typing_extensions import assert_never
2323
import yaml
24-
from pydantic import BaseModel, ConfigDict, BeforeValidator
24+
from pydantic import BaseModel, ConfigDict, BeforeValidator, Field, field_serializer
2525

2626
import sacc
2727
from firecrown.likelihood.likelihood import Likelihood, NamedParameters
@@ -34,6 +34,7 @@
3434
extract_all_harmonic_data,
3535
check_two_point_consistence_real,
3636
check_two_point_consistence_harmonic,
37+
TwoPointBinFilterCollection,
3738
)
3839
from firecrown.modeling_tools import ModelingTools
3940
from firecrown.ccl_factory import CCLFactory
@@ -56,10 +57,12 @@ def _generate_next_value_(name, _start, _count, _last_values):
5657
HARMONIC = auto()
5758

5859

59-
def _validate_correlation_space(value):
60-
if isinstance(value, str):
60+
def _validate_correlation_space(value: TwoPointCorrelationSpace | str):
61+
if not isinstance(value, TwoPointCorrelationSpace) and isinstance(value, str):
6162
try:
62-
return TwoPointCorrelationSpace(value) # Convert from string to Enum
63+
return TwoPointCorrelationSpace(
64+
value.lower()
65+
) # Convert from string to Enum
6366
except ValueError as exc:
6467
raise ValueError(
6568
f"Invalid value for TwoPointCorrelationSpace: {value}"
@@ -73,42 +76,60 @@ class TwoPointFactory(BaseModel):
7376
model_config = ConfigDict(extra="forbid", frozen=True)
7477

7578
correlation_space: Annotated[
76-
TwoPointCorrelationSpace, BeforeValidator(_validate_correlation_space)
79+
TwoPointCorrelationSpace,
80+
BeforeValidator(_validate_correlation_space),
81+
Field(description="The two-point correlation space."),
7782
]
7883
weak_lensing_factory: WeakLensingFactory
7984
number_counts_factory: NumberCountsFactory
8085

8186
def model_post_init(self, __context) -> None:
8287
"""Initialize the WeakLensingFactory object."""
8388

89+
@field_serializer("correlation_space")
90+
@classmethod
91+
def serialize_correlation_space(cls, value: TwoPointCorrelationSpace) -> str:
92+
"""Serialize the amplitude parameter."""
93+
return value.name
94+
8495

8596
class DataSourceSacc(BaseModel):
8697
"""Model for the data source in a likelihood configuration."""
8798

8899
sacc_data_file: str
100+
filters: TwoPointBinFilterCollection | None = None
89101
_path: Path | None = None
90102

91103
def set_path(self, path: Path) -> None:
92104
"""Set the path for the data source."""
93105
self._path = path
94106

95-
def get_sacc_data(self) -> sacc.Sacc:
96-
"""Load the SACC data file."""
107+
def get_filepath(self) -> Path:
108+
"""Return the filename of the data source.
109+
110+
Raises a FileNotFoundError if the file does not exist.
111+
:return: The filename
112+
"""
97113
sacc_data_path = Path(self.sacc_data_file)
98114
# If sacc_data_file is absolute, use it directly
99-
if sacc_data_path.is_absolute():
100-
return sacc.Sacc.load_fits(self.sacc_data_file)
115+
if sacc_data_path.is_absolute() and sacc_data_path.exists():
116+
return Path(self.sacc_data_file)
101117
# If path is set, use it to find the file
102118
if self._path is not None:
103119
full_sacc_data_path = self._path / sacc_data_path
104120
if full_sacc_data_path.exists():
105-
return sacc.Sacc.load_fits(full_sacc_data_path)
121+
return full_sacc_data_path
106122
# If path is not set, use the current directory
107-
if sacc_data_path.exists():
108-
return sacc.Sacc.load_fits(sacc_data_path)
123+
elif sacc_data_path.exists():
124+
return sacc_data_path
109125
# If the file does not exist, raise an error
110126
raise FileNotFoundError(f"File {sacc_data_path} does not exist")
111127

128+
def get_sacc_data(self) -> sacc.Sacc:
129+
"""Load the SACC data file."""
130+
filename = self.get_filepath()
131+
return sacc.Sacc.load_fits(filename)
132+
112133

113134
def ensure_path(file: str | Path) -> Path:
114135
"""Ensure the file path is a Path object."""
@@ -130,6 +151,8 @@ class TwoPointExperiment(BaseModel):
130151

131152
def model_post_init(self, __context) -> None:
132153
"""Initialize the TwoPointExperiment object."""
154+
if self.ccl_factory is None:
155+
self.ccl_factory = CCLFactory()
133156

134157
@classmethod
135158
def load_from_yaml(cls, file: str | Path) -> "TwoPointExperiment":
@@ -144,6 +167,32 @@ def load_from_yaml(cls, file: str | Path) -> "TwoPointExperiment":
144167
tpe.data_source.set_path(filepath.parent)
145168
return tpe
146169

170+
def make_likelihood(self) -> Likelihood:
171+
"""Create a likelihood object for two-point statistics from a SACC file."""
172+
# Load the SACC file
173+
sacc_data = self.data_source.get_sacc_data()
174+
175+
likelihood: None | Likelihood = None
176+
match self.two_point_factory.correlation_space:
177+
case TwoPointCorrelationSpace.REAL:
178+
likelihood = _build_two_point_likelihood_real(
179+
sacc_data,
180+
self.two_point_factory.weak_lensing_factory,
181+
self.two_point_factory.number_counts_factory,
182+
filters=self.data_source.filters,
183+
)
184+
case TwoPointCorrelationSpace.HARMONIC:
185+
likelihood = _build_two_point_likelihood_harmonic(
186+
sacc_data,
187+
self.two_point_factory.weak_lensing_factory,
188+
self.two_point_factory.number_counts_factory,
189+
filters=self.data_source.filters,
190+
)
191+
case _ as unreachable:
192+
assert_never(unreachable)
193+
assert likelihood is not None
194+
return likelihood
195+
147196

148197
def build_two_point_likelihood(
149198
build_parameters: NamedParameters,
@@ -165,24 +214,7 @@ def build_two_point_likelihood(
165214
exp = TwoPointExperiment.load_from_yaml(likelihood_config_file)
166215
modeling_tools = ModelingTools(ccl_factory=exp.ccl_factory)
167216

168-
# Load the SACC file
169-
sacc_data = exp.data_source.get_sacc_data()
170-
171-
match exp.two_point_factory.correlation_space:
172-
case TwoPointCorrelationSpace.REAL:
173-
likelihood = _build_two_point_likelihood_real(
174-
sacc_data,
175-
exp.two_point_factory.weak_lensing_factory,
176-
exp.two_point_factory.number_counts_factory,
177-
)
178-
case TwoPointCorrelationSpace.HARMONIC:
179-
likelihood = _build_two_point_likelihood_harmonic(
180-
sacc_data,
181-
exp.two_point_factory.weak_lensing_factory,
182-
exp.two_point_factory.number_counts_factory,
183-
)
184-
case _ as unreachable:
185-
assert_never(unreachable)
217+
likelihood = exp.make_likelihood()
186218

187219
return likelihood, modeling_tools
188220

@@ -191,6 +223,7 @@ def _build_two_point_likelihood_harmonic(
191223
sacc_data: sacc.Sacc,
192224
wl_factory: WeakLensingFactory,
193225
nc_factory: NumberCountsFactory,
226+
filters: TwoPointBinFilterCollection | None = None,
194227
):
195228
"""
196229
Build a likelihood object for two-point statistics in harmonic space.
@@ -211,8 +244,9 @@ def _build_two_point_likelihood_harmonic(
211244
raise ValueError(
212245
"No two-point measurements in harmonic space found in the SACC file."
213246
)
214-
215247
check_two_point_consistence_harmonic(tpms)
248+
if filters is not None:
249+
tpms = filters(tpms)
216250

217251
two_points = TwoPoint.from_measurement(
218252
tpms, wl_factory=wl_factory, nc_factory=nc_factory
@@ -227,6 +261,7 @@ def _build_two_point_likelihood_real(
227261
sacc_data: sacc.Sacc,
228262
wl_factory: WeakLensingFactory,
229263
nc_factory: NumberCountsFactory,
264+
filters: TwoPointBinFilterCollection | None = None,
230265
):
231266
"""
232267
Build a likelihood object for two-point statistics in real space.
@@ -248,6 +283,8 @@ def _build_two_point_likelihood_real(
248283
"No two-point measurements in real space found in the SACC file."
249284
)
250285
check_two_point_consistence_real(tpms)
286+
if filters is not None:
287+
tpms = filters(tpms)
251288

252289
two_points = TwoPoint.from_measurement(
253290
tpms, wl_factory=wl_factory, nc_factory=nc_factory

‎firecrown/metadata_functions.py

+62-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
from itertools import combinations_with_replacement, product
8-
from typing import TypedDict
8+
from typing import TypedDict, Any
99

1010
import numpy as np
1111
import numpy.typing as npt
@@ -25,6 +25,10 @@
2525
measurement_is_compatible,
2626
GALAXY_LENS_TYPES,
2727
GALAXY_SOURCE_TYPES,
28+
ALL_MEASUREMENT_TYPES,
29+
Galaxies,
30+
CMB,
31+
Clusters,
2832
)
2933

3034
# TwoPointRealIndex is a type used to create intermediate objects when reading SACC
@@ -48,6 +52,63 @@
4852
)
4953

5054

55+
def make_measurement(value: Measurement | dict[str, Any]) -> Measurement:
56+
"""Create a Measurement object from a dictionary."""
57+
if isinstance(value, ALL_MEASUREMENT_TYPES):
58+
return value
59+
60+
if not isinstance(value, dict):
61+
raise ValueError(f"Invalid Measurement: {value} is not a dictionary")
62+
63+
if "subject" not in value:
64+
raise ValueError("Invalid Measurement: dictionary does not contain 'subject'")
65+
66+
subject = value["subject"]
67+
68+
match subject:
69+
case "Galaxies":
70+
return Galaxies[value["property"]]
71+
case "CMB":
72+
return CMB[value["property"]]
73+
case "Clusters":
74+
return Clusters[value["property"]]
75+
case _:
76+
raise ValueError(
77+
f"Invalid Measurement: subject: '{subject}' is not recognized"
78+
)
79+
80+
81+
def make_measurements(
82+
value: set[Measurement] | list[dict[str, Any]],
83+
) -> set[Measurement]:
84+
"""Create a Measurement object from a dictionary."""
85+
if isinstance(value, set) and all(
86+
isinstance(v, ALL_MEASUREMENT_TYPES) for v in value
87+
):
88+
return value
89+
90+
measurements: set[Measurement] = set()
91+
for measurement_dict in value:
92+
measurements.update([make_measurement(measurement_dict)])
93+
return measurements
94+
95+
96+
def make_measurement_dict(value: Measurement) -> dict[str, str]:
97+
"""Create a dictionary from a Measurement object.
98+
99+
:param value: the measurement to turn into a dictionary
100+
"""
101+
return {"subject": type(value).__name__, "property": value.name}
102+
103+
104+
def make_measurements_dict(value: set[Measurement]) -> list[dict[str, str]]:
105+
"""Create a dictionary from a Measurement object.
106+
107+
:param value: the measurement to turn into a dictionary
108+
"""
109+
return [make_measurement_dict(measurement) for measurement in value]
110+
111+
51112
def _extract_all_candidate_measurement_types(
52113
data_points: list[sacc.DataPoint],
53114
include_maybe_types: bool = False,

‎firecrown/metadata_types.py

-11
Original file line numberDiff line numberDiff line change
@@ -312,17 +312,6 @@ def __eq__(self, other):
312312
)
313313

314314

315-
def make_measurements_dict(value: set[Measurement]) -> list[dict[str, str]]:
316-
"""Create a dictionary from a Measurement object.
317-
318-
:param value: the measurement to turn into a dictionary
319-
"""
320-
return [
321-
{"subject": type(measurement).__name__, "property": measurement.name}
322-
for measurement in value
323-
]
324-
325-
326315
def measurement_is_compatible(a: Measurement, b: Measurement) -> bool:
327316
"""Check if two Measurement are compatible.
328317

‎firecrown/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def base_model_from_yaml(cls: type, yaml_str: str):
4949

5050
def base_model_to_yaml(model: BaseModel) -> str:
5151
"""Convert a base model to a yaml string."""
52-
return yaml.dump(model.model_dump(), default_flow_style=False, sort_keys=False)
52+
return yaml.dump(
53+
model.model_dump(), default_flow_style=None, sort_keys=False, width=80
54+
)
5355

5456

5557
def upper_triangle_indices(n: int) -> Generator[tuple[int, int], None, None]:

‎firecrown/version.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
"""
88

99
FIRECROWN_MAJOR = 1
10-
FIRECROWN_MINOR = 8
11-
FIRECROWN_PATCH = 0
10+
FIRECROWN_MINOR = 9
11+
FIRECROWN_PATCH = "0a0"
1212
__version__ = f"{FIRECROWN_MAJOR}.{FIRECROWN_MINOR}.{FIRECROWN_PATCH}"

‎pylintrc

-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@ py-version=3.10
1212
# Discover python modules and packages in the file system subtree.
1313
recursive=yes
1414

15-
# Add custom pylint plugins
16-
load-plugins=pylint_plugins.duplicate_code
17-
1815
[MESSAGES CONTROL]
1916

2017
# Enable the message, report, category or checker with the given id(s). You can

‎tests/conftest.py

+41
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,25 @@ def make_harmonic_bin_2(request) -> InferredGalaxyZDist:
206206
return x
207207

208208

209+
@pytest.fixture(
210+
name="all_harmonic_bins",
211+
)
212+
def make_all_harmonic_bins() -> list[InferredGalaxyZDist]:
213+
"""Generate a list of InferredGalaxyZDist objects with 5 bins."""
214+
z = np.linspace(0.0, 1.0, 256)
215+
dndzs = [
216+
np.exp(-0.5 * (z - 0.5) ** 2 / 0.05**2) / (np.sqrt(2 * np.pi) * 0.05),
217+
np.exp(-0.5 * (z - 0.6) ** 2 / 0.05**2) / (np.sqrt(2 * np.pi) * 0.05),
218+
]
219+
return [
220+
InferredGalaxyZDist(
221+
bin_name=f"bin_{i + 1}", z=z, dndz=dndzs[i], measurements={m}
222+
)
223+
for i in range(2)
224+
for m in [Galaxies.COUNTS, Galaxies.SHEAR_E]
225+
]
226+
227+
209228
@pytest.fixture(
210229
name="real_bin_1",
211230
params=[
@@ -248,6 +267,28 @@ def make_real_bin_2(request) -> InferredGalaxyZDist:
248267
return x
249268

250269

270+
@pytest.fixture(
271+
name="all_real_bins",
272+
)
273+
def make_all_real_bins() -> list[InferredGalaxyZDist]:
274+
"""Generate a list of InferredGalaxyZDist objects with 5 bins."""
275+
return [
276+
InferredGalaxyZDist(
277+
bin_name=f"bin_{i + 1}",
278+
z=np.linspace(0, 1, 5),
279+
dndz=np.array([0.1, 0.5, 0.2, 0.3, 0.4]),
280+
measurements={m},
281+
)
282+
for i in range(2)
283+
for m in [
284+
Galaxies.COUNTS,
285+
Galaxies.SHEAR_T,
286+
Galaxies.SHEAR_MINUS,
287+
Galaxies.SHEAR_PLUS,
288+
]
289+
]
290+
291+
251292
@pytest.fixture(name="window_1")
252293
def make_window_1() -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float64]]:
253294
"""Generate a Window object with 100 ells."""

‎tests/generators/test_inferred_galaxy_zdist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def test_make_measurement_from_dictionary():
467467
_ = make_measurements([{}])
468468

469469
with pytest.raises(
470-
ValueError, match=re.escape(r"Invalid Measurement: {3} is not a dictionary")
470+
ValueError, match=re.escape(r"Invalid Measurement: 3 is not a dictionary")
471471
):
472472
_ = make_measurements({3}) # type: ignore
473473

‎tests/likelihood/test_factories.py

+301-12
Large diffs are not rendered by default.

‎tests/metadata/test_data_functions.py

+631
Large diffs are not rendered by default.

‎tests/metadata/test_metadata_two_point_measurement.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,16 @@ def test_two_point_cells_with_data(harmonic_two_point_xy: TwoPointXY):
4040

4141
def test_two_point_two_point_cwindow_with_data(harmonic_two_point_xy: TwoPointXY):
4242
ells = np.array(np.linspace(0, 100, 100), dtype=np.int64)
43-
weights = np.ones(400).reshape(-1, 4)
43+
weights = np.zeros((100, 4), dtype=np.float64)
44+
# Create a window with 4 bins, each containing 25 elements with a weight of 1.0.
45+
# The bins are defined as follows:
46+
# - Bin 1: Elements 0 to 24
47+
# - Bin 2: Elements 25 to 49
48+
# - Bin 3: Elements 50 to 74
49+
# - Bin 4: Elements 75 to 99
50+
rows = np.arange(100)
51+
cols = rows // 25
52+
weights[rows, cols] = 1.0
4453

4554
ells = np.array(np.linspace(0, 100, 100), dtype=np.int64)
4655
data = np.array(np.zeros(4) + 1.1, dtype=np.float64)

‎tests/test_version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@
88

99

1010
def test_version():
11-
assert firecrown.__version__ == "1.8.0"
11+
assert firecrown.__version__ == "1.9.0a0"

‎tutorial/_quarto.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ format:
5858

5959
reference-location: margin
6060
citation-location: margin
61-
subtitle: "version 1.8.0"
61+
subtitle: "version 1.9.0a0"
6262
authors:
6363
- Marc Paterno
6464
- Sandro Vitenti

‎tutorial/introduction_to_firecrown.qmd

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
---
22
title: An Introduction to Firecrown
3-
subtitle: "Version 1.8.0"
3+
subtitle: "Version 1.9.0a0"
44
authors:
55
- Marc Paterno
66
- Sandro Vitenti

‎tutorial/two_point_factories.qmd

+159-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ sacc_data = sacc.Sacc.load_fits("../examples/des_y1_3x2pt/sacc_data.fits")
3737
all_meta = extract_all_real_metadata_indices(sacc_data)
3838
```
3939

40-
The metadata can be seem below:
40+
The metadata can be seen below:
4141
```{python}
4242
# | code-fold: true
4343
import yaml
@@ -194,6 +194,10 @@ from firecrown.modeling_tools import ModelingTools
194194
from firecrown.ccl_factory import CCLFactory
195195
from firecrown.updatable import get_default_params_map
196196
from firecrown.parameters import ParamsMap
197+
from firecrown.utils import base_model_to_yaml
198+
from firecrown.data_functions import TwoPointBinFilterCollection, TwoPointBinFilter
199+
from firecrown.metadata_types import Galaxies
200+
197201
198202
tools = ModelingTools(ccl_factory=CCLFactory(require_nonlinear_pk=True))
199203
params = get_default_params_map(tools, likelihood)
@@ -211,3 +215,157 @@ print(f"Loglike from metadata only: {likelihood.compute_loglike(tools)}")
211215
print(f"Loglike from ready state: {likelihood_ready.compute_loglike(tools)}")
212216
```
213217

218+
## Filtering Data: Scale-cuts
219+
220+
Real analyses use only a subset of the measured two-points statistics, where the utilized data is typically limited my the accuracy of the models used to fit the data.
221+
It is then useful to define the physical scales (corresponding to the data) that should be analyzed in a given likelihood evaluation of two-point statistics.
222+
Firecrown can implement this feature though its factories, notably by defining a `TwoPointBinFilterCollection` object.
223+
This object is a collection of `TwoPointBinFilter` objects, which define the valid data analysis range for a given combination of two-point tracers.
224+
For instance, we can define the filtered range of galaxy clustering auto-correlations as follows:
225+
226+
```{python}
227+
tp_collection = TwoPointBinFilterCollection(
228+
filters=[
229+
TwoPointBinFilter.from_args(
230+
name1=f"lens{i}",
231+
measurement1=Galaxies.COUNTS,
232+
name2=f"lens{i}",
233+
measurement2=Galaxies.COUNTS,
234+
lower=2,
235+
upper=300,
236+
)
237+
for i in range(5)
238+
],
239+
require_filter_for_all=True,
240+
allow_empty=True,
241+
)
242+
Markdown(f"```yaml\n{base_model_to_yaml(tp_collection)}\n```")
243+
```
244+
245+
Equivalently, we may reduce the complexity of the code slightly and specify the use of auto-correlations only:
246+
247+
```{python}
248+
tp_collection = TwoPointBinFilterCollection(
249+
filters=[
250+
TwoPointBinFilter.from_args_auto(
251+
name=f"lens{i}",
252+
measurement=Galaxies.COUNTS,
253+
lower=2,
254+
upper=300,
255+
)
256+
for i in range(5)
257+
],
258+
require_filter_for_all=True,
259+
allow_empty=True,
260+
)
261+
Markdown(f"```yaml\n{base_model_to_yaml(tp_collection)}\n```")
262+
```
263+
264+
One may alternatively define the tracers directly (instead of from arguments) as `TwoPointTracerSpec` objects.
265+
266+
A `TwoPointExperiment` object is able to keep track of the relevant `Factory` instances to generate the two-point configurations of the analysis (either in configuration or harmonic space) and the scale-cut/data filtering choices to evaluate a defined likelihood.
267+
The interpretation of the filtered lower and upper limits of the data depend on the definition of the `TwoPointExperiment` factories in either configuration or harmonic space.
268+
269+
With this formalism, we are able to evaluate the likelihood exactly as the previous section by defining filters to be very wide.
270+
Alternatively, by setting a restrictively small filtered range, we can remove data from the analysis and do so in the example below by filtering-out all galaxy clustering data.
271+
272+
```{python}
273+
from firecrown.likelihood.factories import (
274+
DataSourceSacc,
275+
TwoPointCorrelationSpace,
276+
TwoPointExperiment,
277+
TwoPointFactory,
278+
)
279+
280+
tpf = TwoPointFactory(
281+
correlation_space=TwoPointCorrelationSpace.REAL,
282+
weak_lensing_factory=weak_lensing_factory,
283+
number_counts_factory=number_counts_factory,
284+
)
285+
286+
two_point_experiment = TwoPointExperiment(
287+
two_point_factory=tpf,
288+
data_source=DataSourceSacc(
289+
sacc_data_file="../examples/des_y1_3x2pt/sacc_data.fits",
290+
filters=TwoPointBinFilterCollection(
291+
require_filter_for_all=False,
292+
allow_empty=True,
293+
filters=[
294+
TwoPointBinFilter.from_args_auto(
295+
name=f"lens{i}",
296+
measurement=Galaxies.COUNTS,
297+
lower=0.5,
298+
upper=300,
299+
)
300+
for i in range(5)
301+
],
302+
),
303+
),
304+
)
305+
306+
two_point_experiment_filtered = TwoPointExperiment(
307+
two_point_factory=tpf,
308+
data_source=DataSourceSacc(
309+
sacc_data_file="../examples/des_y1_3x2pt/sacc_data.fits",
310+
filters=TwoPointBinFilterCollection(
311+
require_filter_for_all=False,
312+
allow_empty=True,
313+
filters=[
314+
TwoPointBinFilter.from_args_auto(
315+
name=f"lens{i}",
316+
measurement=Galaxies.COUNTS,
317+
lower=2999,
318+
upper=3000,
319+
)
320+
for i in range(5)
321+
],
322+
),
323+
),
324+
)
325+
```
326+
327+
The `TwoPointExperiment` objects can also be used to create likelihoods in the ready state.
328+
Additionally, they can be serialized into a yaml file, making it easier to share specific analysis choices with other users and collaborators.
329+
330+
The `yaml` below shows the first experiment.
331+
```{python}
332+
# | code-fold: true
333+
Markdown(f"```yaml\n{base_model_to_yaml(two_point_experiment)}\n```")
334+
```
335+
336+
The `yaml` below shows the second experiment.
337+
```{python}
338+
# | code-fold: true
339+
Markdown(f"```yaml\n{base_model_to_yaml(two_point_experiment_filtered)}\n```")
340+
```
341+
342+
Next, we can create likelihoods from the `TwoPointExperiment` objects and compare the loglike values.
343+
344+
```{python}
345+
likelihood_tpe = two_point_experiment.make_likelihood()
346+
347+
params = get_default_params_map(tools, likelihood_tpe)
348+
349+
tools = ModelingTools()
350+
tools.update(params)
351+
tools.prepare()
352+
likelihood_tpe.update(params)
353+
354+
likelihood_tpe_filtered = two_point_experiment_filtered.make_likelihood()
355+
356+
params = get_default_params_map(tools, likelihood_tpe_filtered)
357+
358+
tools = ModelingTools()
359+
tools.update(params)
360+
tools.prepare()
361+
likelihood_tpe_filtered.update(params)
362+
363+
```
364+
365+
```{python}
366+
# | code-fold: true
367+
print(f"Loglike from metadata only: {likelihood.compute_loglike(tools)}")
368+
print(f"Loglike from ready state: {likelihood_ready.compute_loglike(tools)}")
369+
print(f"Loglike from TwoPointExperiment: {likelihood_tpe.compute_loglike(tools)}")
370+
print(f"Loglike from filtered TwoPointExperiment: {likelihood_tpe_filtered.compute_loglike(tools)}")
371+
```

‎tutorial/two_point_generators.qmd

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ wl_photoz = wl.PhotoZShiftFactory()
9393
wl_mult_bias = wl.MultiplicativeShearBiasFactory()
9494
9595
# NumberCounts systematics -- global
96-
# As for Firecrown 1.8.0, we do not have any global systematics for number counts
96+
# As for Firecrown 1.9.0a0, we do not have any global systematics for number counts
9797
# NumberCounts systematics -- per-bin
9898
nc_photoz = nc.PhotoZShiftFactory()
9999

0 commit comments

Comments
 (0)
Please sign in to comment.