Skip to content

Commit 07f89c5

Browse files
committed
move only_colors to model and remove from estimate
1 parent aca3baf commit 07f89c5

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

src/rail/estimation/algos/k_nearneigh.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,11 @@ def run(self):
134134
print(f"\n\n\nbest fit values are sigma={sigma} and numneigh={numneigh}\n\n\n")
135135
# remake tree with full dataset!
136136
kdtree = KDTree(colordata, leaf_size=self.config.leaf_size)
137-
self.model = dict(kdtree=kdtree, bestsig=sigma, nneigh=numneigh, truezs=trainszs)
137+
self.model = dict(kdtree=kdtree,
138+
bestsig=sigma,
139+
nneigh=numneigh,
140+
truezs=trainszs,
141+
only_colors=self.config.only_colors)
138142
self.add_data('model', self.model)
139143

140144

@@ -150,8 +154,7 @@ class KNearNeighEstimator(CatEstimator):
150154
ref_band=SHARED_PARAMS,
151155
nondetect_val=SHARED_PARAMS,
152156
mag_limits=SHARED_PARAMS,
153-
redshift_col=SHARED_PARAMS,
154-
only_colors=Param(bool, False, msg="if only_colors True, then do not use ref_band mag, only use colors"))
157+
redshift_col=SHARED_PARAMS)
155158

156159
def __init__(self, args, **kwargs):
157160
""" Constructor:
@@ -161,6 +164,7 @@ def __init__(self, args, **kwargs):
161164
self.model = None
162165
self.trainszs = None
163166
self.zgrid = None
167+
self.only_colors = None
164168
super().__init__(args, **kwargs)
165169
usecols = self.config.bands.copy()
166170
usecols.append(self.config.redshift_col)
@@ -174,6 +178,7 @@ def open_model(self, **kwargs):
174178
self.numneigh = self.model['nneigh']
175179
self.kdtree = self.model['kdtree']
176180
self.trainszs = self.model['truezs']
181+
self.only_colors = self.model['only_colors']
177182

178183
def _process_chunk(self, start, end, data, first):
179184
"""
@@ -191,7 +196,7 @@ def _process_chunk(self, start, end, data, first):
191196
else:
192197
knn_df.loc[np.isclose(knn_df[col], self.config.nondetect_val), col] = np.float32(self.config.mag_limits[col])
193198

194-
testcolordata = _computecolordata(knn_df, self.config.ref_band, self.config.bands, self.config.only_colors)
199+
testcolordata = _computecolordata(knn_df, self.config.ref_band, self.config.bands, self.only_colors)
195200
dists, idxs = self.kdtree.query(testcolordata, k=self.numneigh)
196201
dists += TEENY
197202
test_ens = _makepdf(dists, idxs, self.trainszs, self.sigma)

tests/sklearn/test_algos.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def test_KNearNeigh_justcol():
136136
model="KNearNeighEstimator_justcols.pkl",
137137
only_colors=True,
138138
)
139-
estim_config_dict = dict(hdf5_groupname="photometry", model="KNearNeighEstimator_justcols.pkl", only_colors=True)
139+
estim_config_dict = dict(hdf5_groupname="photometry", model="KNearNeighEstimator_justcols.pkl")
140140

141141
# zb_expected = np.array([0.13, 0.14, 0.13, 0.13, 0.11, 0.15, 0.13, 0.14,
142142
# 0.11, 0.12])

0 commit comments

Comments
 (0)