Skip to content

Commit

Permalink
Sbachmei/mic 4936 5456/add tests for sequelae risk factors covariates (
Browse files Browse the repository at this point in the history
…#382)

* Add unmocked sequela-like tests; refactor DataType
* Add risk-like tests
* Add covariate-like test
  • Loading branch information
stevebachmeier authored Nov 13, 2024
1 parent d7d4977 commit 7958298
Show file tree
Hide file tree
Showing 7 changed files with 430 additions and 122 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
**5.2.1 - 11/13/24**

- Add more get_measure tests: sequelae (mocked), risk factors, covariates

**5.2.0 - 11/07/24**

- Add framework for getting mean data
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

setup_requires = ["setuptools_scm"]

data_requires = ["vivarium-gbd-access>=4.0.0, <5.0.0", "core-maths"]
data_requires = ["vivarium-gbd-access>=4.1.0, <5.0.0", "core-maths"]

lint_requirements = ["black==22.3.0", "isort"]

Expand Down
11 changes: 6 additions & 5 deletions src/vivarium_inputs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ def get_cause_specific_mortality_rate(
)
data = deaths.join(pop, lsuffix="_deaths", rsuffix="_pop")
data[data_type.value_columns] = data[data_type.value_columns].divide(data.value, axis=0)
return data.drop(["value"], axis="columns")
data = data.drop(columns="value")
return data


def get_excess_mortality_rate(
Expand Down Expand Up @@ -398,7 +399,7 @@ def get_exposure(
)

data = extract.extract_data(entity, "exposure", location_id, years, data_type)
data = data.drop("modelable_entity_id", "columns")
data = data.drop(columns="modelable_entity_id")

value_columns = data_type.value_columns

Expand Down Expand Up @@ -455,7 +456,7 @@ def get_exposure_standard_deviation(
data = extract.extract_data(
entity, "exposure_standard_deviation", location_id, years, data_type
)
data = data.drop("modelable_entity_id", "columns")
data = data.drop(columns="modelable_entity_id")

exposure = extract.extract_data(entity, "exposure", location_id, years, data_type)
valid_age_groups = utilities.get_exposure_and_restriction_ages(exposure, entity)
Expand Down Expand Up @@ -485,7 +486,7 @@ def get_exposure_distribution_weights(
exposure = extract.extract_data(entity, "exposure", location_id, years, data_type)
valid_ages = utilities.get_exposure_and_restriction_ages(exposure, entity)

data.drop("age_group_id", axis=1, inplace=True)
data = data.drop(columns="age_group_id")
df = []
for age_id in valid_ages:
copied = data.copy()
Expand Down Expand Up @@ -690,7 +691,7 @@ def get_structure(
f"Data type(s) {data_type.type} are not supported for this function."
)
data = extract.extract_data(entity, "structure", location_id, years, data_type)
data = data.drop("run_id", axis="columns").rename(columns={"population": "value"})
data = data.drop(columns="run_id").rename(columns={"population": "value"})
data = utilities.normalize(data, data_type.value_columns)
return data

Expand Down
14 changes: 7 additions & 7 deletions src/vivarium_inputs/validation/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from loguru import logger

from vivarium_inputs import utilities, utility_data
from vivarium_inputs import utility_data
from vivarium_inputs.globals import (
DEMOGRAPHIC_COLUMNS,
DRAW_COLUMNS,
Expand Down Expand Up @@ -1268,19 +1268,19 @@ def validate_population_attributable_fraction(

grouped = data.groupby(["cause_id", "measure_id"])

for (c_id, _), g in grouped:
cause = [c for c in causes if c.gbd_id == c_id][0]
for (c_id, _), group in grouped:
cause = [cause for cause in causes if cause.gbd_id == c_id][0]
cause_male_expected = risk_male_expected and not cause.restrictions.female_only
cause_female_expected = risk_female_expected and not cause.restrictions.male_only

check_age_group_ids(g, context, None, None)
check_sex_ids(g, context, cause_male_expected, cause_female_expected)
check_age_group_ids(group, context, None, None)
check_sex_ids(group, context, cause_male_expected, cause_female_expected)
# check only if there is a sex restriction (male only or female only).
if not cause_male_expected or not cause_female_expected:
check_sex_restrictions(
g, context, cause_male_expected, cause_female_expected, DRAW_COLUMNS
group, context, cause_male_expected, cause_female_expected, DRAW_COLUMNS
)
check_paf_rr_exposure_age_groups(g, context, entity)
check_paf_rr_exposure_age_groups(group, context, entity)

protective_causes = (
PROTECTIVE_CAUSE_RISK_PAIRS[entity.name]
Expand Down
2 changes: 1 addition & 1 deletion src/vivarium_inputs/validation/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
from loguru import logger

from vivarium_inputs.globals import DRAW_COLUMNS, VivariumInputsError
from vivarium_inputs.globals import VivariumInputsError

###############################
# Shared validation utilities #
Expand Down
Loading

0 comments on commit 7958298

Please sign in to comment.