@@ -151,6 +151,9 @@ Lrnr_base <- R6Class(
151
151
}
152
152
new_object <- self $ clone() # copy parameters, and whatever else
153
153
new_object $ set_train(fit_object , task )
154
+ if (getOption(" sl3.reduce_fit" )) {
155
+ new_object $ reduce_fit(check_preds = FALSE )
156
+ }
154
157
return (new_object )
155
158
},
156
159
set_train = function (fit_object , training_task ) {
@@ -335,6 +338,53 @@ Lrnr_base <- R6Class(
335
338
} else {
336
339
return (task )
337
340
}
341
+ },
342
+ reduce_fit = function (fit_object = NULL , check_preds = TRUE , set_train = TRUE ) {
343
+ if (is.null(fit_object )) {
344
+ fit_object <- self $ fit_object
345
+ }
346
+ if (check_preds ) {
347
+ preds_full <- self $ predict(task )
348
+ }
349
+
350
+
351
+ # try reducing the size
352
+ size_full <- true_obj_size(fit_object )
353
+
354
+ # see what's taking up the space
355
+ # element_sizes <- sapply(fo, true_obj_size)
356
+ # ranked <- sort(element_sizes/size_full, decreasing = TRUE)
357
+
358
+ # by default, drop out call
359
+ # within(fit_object, rm(private$.fit_can_remove))
360
+ keep <- setdiff(names(fit_object ), private $ .fit_can_remove )
361
+
362
+ # gotta preserve the attributes (not sure why they're getting dropped)
363
+ attrs <- attributes(fit_object )
364
+ attrs $ names <- keep
365
+ reduced <- fit_object [keep ]
366
+ attributes(reduced ) <- attrs
367
+ fit_object <- reduced
368
+ size_reduced <- true_obj_size(fit_object )
369
+ reduction_percent <- 1 - size_reduced / size_full
370
+
371
+ if (getOption(" sl3.verbose" )) {
372
+ message(sprintf(" Fit object size reduced %0.0f%%" , 100 * reduction_percent ))
373
+ }
374
+
375
+
376
+ if (set_train ) {
377
+ self $ set_train(fit_object , self $ training_task )
378
+ }
379
+
380
+
381
+ # verify prediction still works
382
+ if (check_preds ) {
383
+ preds_reduced <- self $ predict(task )
384
+ assert_that(all.equal(preds_full , preds_reduced ))
385
+ }
386
+
387
+ return (fit_object )
338
388
}
339
389
),
340
390
active = list (
@@ -399,6 +449,7 @@ Lrnr_base <- R6Class(
399
449
.required_packages = NULL ,
400
450
.properties = list (),
401
451
.custom_chain = NULL ,
452
+ .fit_can_remove = c(" call" ),
402
453
.train_sublearners = function (task ) {
403
454
# train sublearners here
404
455
return (NULL )
0 commit comments