@@ -134,7 +134,11 @@ def run(self):
134
134
print (f"\n \n \n best fit values are sigma={ sigma } and numneigh={ numneigh } \n \n \n " )
135
135
# remake tree with full dataset!
136
136
kdtree = KDTree (colordata , leaf_size = self .config .leaf_size )
137
- self .model = dict (kdtree = kdtree , bestsig = sigma , nneigh = numneigh , truezs = trainszs )
137
+ self .model = dict (kdtree = kdtree ,
138
+ bestsig = sigma ,
139
+ nneigh = numneigh ,
140
+ truezs = trainszs ,
141
+ only_colors = self .config .only_colors )
138
142
self .add_data ('model' , self .model )
139
143
140
144
@@ -150,8 +154,7 @@ class KNearNeighEstimator(CatEstimator):
150
154
ref_band = SHARED_PARAMS ,
151
155
nondetect_val = SHARED_PARAMS ,
152
156
mag_limits = SHARED_PARAMS ,
153
- redshift_col = SHARED_PARAMS ,
154
- only_colors = Param (bool , False , msg = "if only_colors True, then do not use ref_band mag, only use colors" ))
157
+ redshift_col = SHARED_PARAMS )
155
158
156
159
def __init__ (self , args , ** kwargs ):
157
160
""" Constructor:
@@ -161,6 +164,7 @@ def __init__(self, args, **kwargs):
161
164
self .model = None
162
165
self .trainszs = None
163
166
self .zgrid = None
167
+ self .only_colors = None
164
168
super ().__init__ (args , ** kwargs )
165
169
usecols = self .config .bands .copy ()
166
170
usecols .append (self .config .redshift_col )
@@ -174,6 +178,7 @@ def open_model(self, **kwargs):
174
178
self .numneigh = self .model ['nneigh' ]
175
179
self .kdtree = self .model ['kdtree' ]
176
180
self .trainszs = self .model ['truezs' ]
181
+ self .only_colors = self .model ['only_colors' ]
177
182
178
183
def _process_chunk (self , start , end , data , first ):
179
184
"""
@@ -191,7 +196,7 @@ def _process_chunk(self, start, end, data, first):
191
196
else :
192
197
knn_df .loc [np .isclose (knn_df [col ], self .config .nondetect_val ), col ] = np .float32 (self .config .mag_limits [col ])
193
198
194
- testcolordata = _computecolordata (knn_df , self .config .ref_band , self .config .bands , self .config . only_colors )
199
+ testcolordata = _computecolordata (knn_df , self .config .ref_band , self .config .bands , self .only_colors )
195
200
dists , idxs = self .kdtree .query (testcolordata , k = self .numneigh )
196
201
dists += TEENY
197
202
test_ens = _makepdf (dists , idxs , self .trainszs , self .sigma )
0 commit comments