Skip to content

Commit 1157929

Browse files
committed
validator for the fit CI estimator
1 parent 19a4564 commit 1157929

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

probscale/tests/test_validate.py

+16
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import pytest
44

55
from probscale import validate
6+
from probscale import algo
7+
68

79

810
def test_axes_object_invalid():
@@ -86,3 +88,17 @@ def test_other_options(value, expected):
8688
def test_axis_label(value, expected):
8789
result = validate.axis_label(value)
8890
assert result == expected
91+
92+
93+
@pytest.mark.parametrize(('value', 'expected'), [
94+
('fit', algo._bs_fit),
95+
('resids', algo._bs_resid),
96+
('junk', None)
97+
])
98+
def test_estimator(value, expected):
99+
if expected is None:
100+
with pytest.raises(ValueError):
101+
validate.estimator(value)
102+
else:
103+
est = validate.estimator(value)
104+
assert est is expected

probscale/validate.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,15 @@ def other_options(options):
8282
Replaces None with an empty dict for plotting options.
8383
"""
8484

85-
return dict() if options is None else options.copy()
85+
return dict() if options is None else options.copy()
86+
87+
def estimator(value):
88+
from .algo import _bs_fit, _bs_resid
89+
if value.lower() in ['res', 'resid', 'resids', 'residual', 'residuals']:
90+
est = _bs_resid
91+
elif value.lower() in ['fit', 'values']:
92+
est = _bs_fit
93+
else:
94+
raise ValueError('estimator must be either "resid" or "fit".')
95+
96+
return est

0 commit comments

Comments
 (0)