Skip to content

Commit

Permalink
Added tests on passing marginals to prepare (#670)
Browse files Browse the repository at this point in the history
* added some tests and refactored a line

* explicitly specify marginals if they are unbalanced or not

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add tests and specify tolerances

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
selmanozleyen and pre-commit-ci[bot] authored Mar 20, 2024
1 parent 9667dc7 commit a5187c0
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 28 deletions.
14 changes: 14 additions & 0 deletions tests/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ def _make_grid(grid_size: int) -> ArrayLike:
return np.vstack([X1.ravel(), X2.ravel()]).T


def _assert_marginals_set(adata_time, problem, key, marginal_keys):
"""Helper function to check if marginals are set correctly"""
adata_time0 = adata_time[key[0] == adata_time.obs["time"]]
adata_time1 = adata_time[key[1] == adata_time.obs["time"]]
if marginal_keys[0] is not None: # check if marginal keys are set
a = adata_time0.obs[marginal_keys[0]].values
b = adata_time1.obs[marginal_keys[1]].values
assert np.allclose(problem[key].a, a)
assert np.allclose(problem[key].b, b)
else: # otherwise check if marginals are uniform
assert np.allclose(problem[key].a, 1.0 / adata_time0.shape[0])
assert np.allclose(problem[key].b, 1.0 / adata_time1.shape[0])


class Problem(CompoundProblem[Any, OTProblem]):
@property
def _base_problem_type(self) -> Type[B]:
Expand Down
36 changes: 20 additions & 16 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,32 @@ def adata_y(y: Geom_t) -> AnnData:
return AnnData(X=np.asarray(y, dtype=float), obsm={"X_pca": pc})


def creat_prob(n: int, *, uniform: bool = False, seed: Optional[int] = None) -> Geom_t:
rng = np.random.RandomState(seed)
a = np.ones((n,)) if uniform else np.abs(rng.normal(size=(n,)))
a /= np.sum(a)
return jnp.asarray(a)


@pytest.fixture()
def adata_time() -> AnnData:
rng = np.random.RandomState(42)
adatas = [AnnData(X=csr_matrix(rng.normal(size=(96, 60)))) for _ in range(3)]

adatas = [
AnnData(
X=csr_matrix(rng.normal(size=(96, 60))),
obs={
"left_marginals_balanced": creat_prob(96, seed=42),
"right_marginals_balanced": creat_prob(96, seed=42),
},
)
for _ in range(3)
]
adata = ad.concat(adatas, label="time", index_unique="-")
adata.obs["time"] = pd.to_numeric(adata.obs["time"]).astype("category")
adata.obs["batch"] = rng.choice((0, 1, 2), len(adata))
adata.obs["left_marginals"] = np.ones(len(adata))
adata.obs["right_marginals"] = np.ones(len(adata))
adata.obs["left_marginals_unbalanced"] = np.ones(len(adata))
adata.obs["right_marginals_unbalanced"] = np.ones(len(adata))
adata.obs["celltype"] = rng.choice(["A", "B", "C"], size=len(adata))
# genes from mouse/human proliferation/apoptosis
genes = ["ANLN", "ANP32E", "ATAD2", "Mcm4", "Smc4", "Gtse1", "ADD1", "AIFM3", "ANKH", "Ercc5", "Serpinb5", "Inhbb"]
Expand All @@ -139,19 +156,6 @@ def adata_time() -> AnnData:
return adata


def create_marginals(n: int, m: int, *, uniform: bool = False, seed: Optional[int] = None) -> Geom_t:
rng = np.random.RandomState(seed)
if uniform:
a, b = np.ones((n,)), np.ones((m,))
else:
a = np.abs(rng.normal(size=(n,)))
b = np.abs(rng.normal(size=(m,)))
a /= np.sum(a)
b /= np.sum(b)

return jnp.asarray(a), jnp.asarray(b)


@pytest.fixture()
def gt_temporal_adata() -> AnnData:
adata = _gt_temporal_adata.copy()
Expand Down
12 changes: 12 additions & 0 deletions tests/problems/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ def adata_time_with_tmap(adata_time: AnnData) -> AnnData:
return adata


# keys for marginals
@pytest.fixture(
params=[
(None, None),
("left_marginals_balanced", "right_marginals_balanced"),
],
ids=["default", "balanced"],
)
def marginal_keys(request):
return request.param


sinkhorn_args_1 = {
"epsilon": 0.7,
"tau_a": 1.0,
Expand Down
19 changes: 19 additions & 0 deletions tests/problems/generic/test_fgw_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from moscot.base.output import BaseSolverOutput
from moscot.base.problems import OTProblem
from moscot.problems.generic import FGWProblem
from tests._utils import _assert_marginals_set
from tests.problems.conftest import (
fgw_args_1,
fgw_args_2,
Expand Down Expand Up @@ -67,6 +68,21 @@ def test_prepare(self, adata_space_rotate: AnnData, policy):
assert key in expected_keys[policy]
assert isinstance(problem[key], OTProblem)

@pytest.mark.fast()
def test_prepare_marginals(self, adata_time: AnnData, marginal_keys):
problem = FGWProblem(adata=adata_time)
problem = problem.prepare(
key="time",
policy="sequential",
joint_attr="X_pca",
x_attr="X_pca",
y_attr="X_pca",
a=marginal_keys[0],
b=marginal_keys[1],
)
for key in problem:
_assert_marginals_set(adata_time, problem, key, marginal_keys)

def test_solve_balanced(self, adata_space_rotate: AnnData):
eps = 0.5
adata_space_rotate = adata_space_rotate[adata_space_rotate.obs["batch"].isin(("0", "1"))].copy()
Expand All @@ -84,6 +100,9 @@ def test_solve_balanced(self, adata_space_rotate: AnnData):
for key, subsol in problem.solutions.items():
assert isinstance(subsol, BaseSolverOutput)
assert key in expected_keys
# assert that prior and posterior marginals same
assert np.allclose(subsol.a, problem[key].a, atol=1e-5)
assert np.allclose(subsol.b, problem[key].b, atol=1e-5)

@pytest.mark.parametrize("args_to_check", [fgw_args_1, fgw_args_2])
def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mapping[str, Any]):
Expand Down
18 changes: 17 additions & 1 deletion tests/problems/generic/test_gw_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from moscot.base.output import BaseSolverOutput
from moscot.base.problems import OTProblem
from moscot.problems.generic import GWProblem
from tests._utils import _assert_marginals_set
from tests.problems.conftest import (
geometry_args,
gw_args_1,
Expand All @@ -37,7 +38,10 @@

class TestGWProblem:
@pytest.mark.fast()
@pytest.mark.parametrize("policy", ["sequential", "star"])
@pytest.mark.parametrize(
"policy",
["sequential", "star"],
)
def test_prepare(self, adata_space_rotate: AnnData, policy):
expected_keys = {
"sequential": [("0", "1"), ("1", "2")],
Expand Down Expand Up @@ -80,6 +84,9 @@ def test_solve_balanced(self, adata_space_rotate: AnnData): # type: ignore[no-u
assert isinstance(subsol, BaseSolverOutput)
assert key in expected_keys
assert problem[key].solver._problem.geom_xy is None
# assert prior and posterior marginals are the same
assert np.allclose(subsol.a, problem[key].solver._problem.a, atol=1e-5)
assert np.allclose(subsol.b, problem[key].solver._problem.b, atol=1e-5)

@pytest.mark.parametrize("method", ["fisher", "perm_test"])
def test_compute_feature_correlation(self, adata_space_rotate: AnnData, method: str):
Expand Down Expand Up @@ -181,6 +188,15 @@ def test_prepare_costs(self, adata_time: AnnData, cost_str: str, cost_inst: Any,

problem = problem.solve(max_iterations=2)

@pytest.mark.fast()
def test_prepare_marginals(self, adata_time: AnnData, marginal_keys):
problem = GWProblem(adata=adata_time)
problem = problem.prepare(
a=marginal_keys[0], b=marginal_keys[1], key="time", policy="sequential", x_attr="X_pca", y_attr="X_pca"
)
for key in problem:
_assert_marginals_set(adata_time, problem, key, marginal_keys)

@pytest.mark.fast()
@pytest.mark.parametrize(
("cost_str", "cost_inst", "cost_kwargs"),
Expand Down
15 changes: 10 additions & 5 deletions tests/problems/generic/test_sinkhorn_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from moscot.base.output import BaseSolverOutput
from moscot.base.problems import OTProblem
from moscot.problems.generic import SinkhornProblem
from tests._utils import _assert_marginals_set
from tests.problems.conftest import (
geometry_args,
lin_prob_args,
Expand All @@ -37,15 +38,14 @@
class TestSinkhornProblem:
@pytest.mark.fast()
@pytest.mark.parametrize("policy", ["sequential", "star"])
def test_prepare(self, adata_time: AnnData, policy):
def test_prepare(self, adata_time: AnnData, policy, marginal_keys):
expected_keys = {"sequential": [(0, 1), (1, 2)], "star": [(1, 0), (2, 0)]}
problem = SinkhornProblem(adata=adata_time)

assert len(problem) == 0
assert problem.problems == {}
assert problem.solutions == {}

problem = problem.prepare(key="time", policy=policy, reference=0)
problem = problem.prepare(key="time", policy=policy, reference=0, a=marginal_keys[0], b=marginal_keys[1])

assert isinstance(problem.problems, dict)
assert len(problem.problems) == len(expected_keys[policy])
Expand All @@ -54,16 +54,21 @@ def test_prepare(self, adata_time: AnnData, policy):
assert key in expected_keys[policy]
assert isinstance(problem[key], OTProblem)

def test_solve_balanced(self, adata_time: AnnData):
_assert_marginals_set(adata_time, problem, key, marginal_keys)

def test_solve_balanced(self, adata_time: AnnData, marginal_keys):
eps = 0.5
expected_keys = [(0, 1), (1, 2)]
problem = SinkhornProblem(adata=adata_time)
problem = problem.prepare(key="time")
problem = problem.prepare(key="time", a=marginal_keys[0], b=marginal_keys[1])
problem = problem.solve(epsilon=eps)

for key, subsol in problem.solutions.items():
assert isinstance(subsol, BaseSolverOutput)
assert key in expected_keys
assert subsol.converged
assert np.allclose(subsol.a, problem[key].a, atol=1e-5)
assert np.allclose(subsol.b, problem[key].b, atol=1e-5)

@pytest.mark.fast()
@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,12 @@ def test_solve_unbalanced(self, adata_spatio_temporal: AnnData):
taus = [9e-1, 1e-2]
problem1 = SpatioTemporalProblem(adata=adata_spatio_temporal)
problem2 = SpatioTemporalProblem(adata=adata_spatio_temporal)
problem1 = problem1.prepare("time", spatial_key="spatial", a="left_marginals", b="right_marginals")
problem2 = problem2.prepare("time", spatial_key="spatial", a="left_marginals", b="right_marginals")
problem1 = problem1.prepare(
"time", spatial_key="spatial", a="left_marginals_unbalanced", b="right_marginals_unbalanced"
)
problem2 = problem2.prepare(
"time", spatial_key="spatial", a="left_marginals_unbalanced", b="right_marginals_unbalanced"
)
assert problem1[0, 1].a is not None
assert problem1[0, 1].b is not None
assert problem2[0, 1].a is not None
Expand Down
3 changes: 1 addition & 2 deletions tests/problems/time/test_lineage_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ def test_solve_balanced(self, adata_time_barcodes: AnnData):
)
problem = problem.solve(epsilon=eps)

for key, subsol in problem.solutions.items():
for _, subsol in problem.solutions.items():
assert isinstance(subsol, BaseSolverOutput)
assert key == key

def test_solve_unbalanced(self, adata_time_barcodes: AnnData):
taus = [9e-1, 1e-2]
Expand Down
4 changes: 2 additions & 2 deletions tests/problems/time/test_temporal_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def test_solve_unbalanced(self, adata_time: AnnData):
taus = [9e-1, 1e-2]
problem1 = TemporalProblem(adata=adata_time)
problem2 = TemporalProblem(adata=adata_time)
problem1 = problem1.prepare("time", a="left_marginals", b="right_marginals")
problem2 = problem2.prepare("time", a="left_marginals", b="right_marginals")
problem1 = problem1.prepare("time", a="left_marginals_unbalanced", b="right_marginals_unbalanced")
problem2 = problem2.prepare("time", a="left_marginals_unbalanced", b="right_marginals_unbalanced")

assert problem1[0, 1].a is not None
assert problem1[0, 1].b is not None
Expand Down

0 comments on commit a5187c0

Please sign in to comment.