Skip to content

Commit be37523

Browse files
committedOct 1, 2021
initial support for fit size reduction
1 parent 20834ae commit be37523

File tree

7 files changed

+111
-5
lines changed

7 files changed

+111
-5
lines changed
 

‎R/Lrnr_base.R

+51
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ Lrnr_base <- R6Class(
151151
}
152152
new_object <- self$clone() # copy parameters, and whatever else
153153
new_object$set_train(fit_object, task)
154+
if (getOption("sl3.reduce_fit")) {
155+
new_object$reduce_fit(check_preds = FALSE)
156+
}
154157
return(new_object)
155158
},
156159
set_train = function(fit_object, training_task) {
@@ -335,6 +338,53 @@ Lrnr_base <- R6Class(
335338
} else {
336339
return(task)
337340
}
341+
},
342+
reduce_fit = function(fit_object = NULL, check_preds = TRUE, set_train = TRUE) {
343+
if (is.null(fit_object)) {
344+
fit_object <- self$fit_object
345+
}
346+
if (check_preds) {
347+
preds_full <- self$predict(task)
348+
}
349+
350+
351+
# try reducing the size
352+
size_full <- true_obj_size(fit_object)
353+
354+
# see what's taking up the space
355+
# element_sizes <- sapply(fo, true_obj_size)
356+
# ranked <- sort(element_sizes/size_full, decreasing = TRUE)
357+
358+
# by default, drop out call
359+
# within(fit_object, rm(private$.fit_can_remove))
360+
keep <- setdiff(names(fit_object), private$.fit_can_remove)
361+
362+
# gotta preserve the attributes (not sure why they're getting dropped)
363+
attrs <- attributes(fit_object)
364+
attrs$names <- keep
365+
reduced <- fit_object[keep]
366+
attributes(reduced) <- attrs
367+
fit_object <- reduced
368+
size_reduced <- true_obj_size(fit_object)
369+
reduction_percent <- 1 - size_reduced / size_full
370+
371+
if (getOption("sl3.verbose")) {
372+
message(sprintf("Fit object size reduced %0.0f%%", 100 * reduction_percent))
373+
}
374+
375+
376+
if (set_train) {
377+
self$set_train(fit_object, self$training_task)
378+
}
379+
380+
381+
# verify prediction still works
382+
if (check_preds) {
383+
preds_reduced <- self$predict(task)
384+
assert_that(all.equal(preds_full, preds_reduced))
385+
}
386+
387+
return(fit_object)
338388
}
339389
),
340390
active = list(
@@ -399,6 +449,7 @@ Lrnr_base <- R6Class(
399449
.required_packages = NULL,
400450
.properties = list(),
401451
.custom_chain = NULL,
452+
.fit_can_remove = c("call"),
402453
.train_sublearners = function(task) {
403454
# train sublearners here
404455
return(NULL)

‎R/Lrnr_glm_fast.R

+1
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ Lrnr_glm_fast <- R6Class(
157157
}
158158
return(predictions)
159159
},
160+
.fit_can_remove = c("XTX"),
160161
.required_packages = c("speedglm")
161162
)
162163
)

‎R/Lrnr_hal9001.R

+1-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ Lrnr_hal9001 <- R6Class(
9999

100100
return(fit_object)
101101
},
102-
103102
.predict = function(task = NULL) {
104103
predictions <- stats::predict(
105104
self$fit_object,
@@ -111,7 +110,7 @@ Lrnr_hal9001 <- R6Class(
111110
}
112111
return(predictions)
113112
},
114-
113+
.fit_can_remove = c("lasso_fit", "x_basis"),
115114
.required_packages = c("hal9001", "glmnet")
116115
)
117116
)

‎R/Lrnr_xgboost.R

+1
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ Lrnr_xgboost <- R6Class(
194194

195195
return(predictions)
196196
},
197+
.fit_can_remove = c("raw", "call"),
197198
.required_packages = c("xgboost")
198199
)
199200
)

‎R/utils.R

+10
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,16 @@ true_obj_size <- function(obj) {
155155
length(serialize(obj, NULL))
156156
}
157157

158+
#' @keywords internal
159+
check_fit_sizes <- function(fit) {
160+
fo <- fit$fit_object
161+
# see what's taking up the space
162+
element_sizes <- sapply(fo, true_obj_size)
163+
ranked <- sort(element_sizes / sum(element_sizes), decreasing = TRUE)
164+
165+
return(ranked)
166+
}
167+
158168
################################################################################
159169

160170
#' Drop components from learner fits

‎R/zzz.R

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ sl3Options <- function(o, value) {
3838
}
3939
if (is.null(value)) {
4040
res[o] <- list(NULL)
41-
}
42-
else {
41+
} else {
4342
res[[o]] <- value
4443
}
4544
options(res[o])
@@ -62,7 +61,8 @@ sl3Options <- function(o, value) {
6261
"sl3.pcontinuous" = 0.05,
6362
"sl3.max_p_missing" = 0.5,
6463
"sl3.transform.offset" = TRUE,
65-
"sl3.enable.future" = TRUE
64+
"sl3.enable.future" = TRUE,
65+
"sl3.reduce_fit" = FALSE
6666
)
6767
# for (i in setdiff(names(opts),names(options()))) {
6868
# browser()

‎tests/testthat/test-reduce_fit.R

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
2+
set.seed(1234)
3+
4+
# TODO: maybe check storage at different n to get rate
5+
n <- 1e3
6+
p <- 100
7+
# these two define the DGP (randomly)
8+
p_X <- runif(p, 0.2, 0.8)
9+
beta <- rnorm(p)
10+
11+
# simulate from the DGP
12+
X <- sapply(p_X, function(p_Xi) rbinom(n, 1, p_Xi))
13+
p_Yx <- plogis(X %*% beta)
14+
Y <- rbinom(n, 1, p_Yx)
15+
data <- data.table(X, Y)
16+
17+
# generate the sl3 task and learner
18+
outcome <- "Y"
19+
covariates <- setdiff(names(data), outcome)
20+
task <- make_sl3_Task(data, covariates, outcome)
21+
22+
options(sl3.verbose = TRUE)
23+
options(sl3.reduce_fit = TRUE)
24+
test_reduce_fit <- function(learner) {
25+
fit <- learner$train(task)
26+
print(sl3:::check_fit_sizes(fit))
27+
if (!getOption("sl3.reduce_fit")) {
28+
# if we aren't automatically reducing, do it manually
29+
fit_object <- fit$reduce_fit()
30+
}
31+
32+
still_present <- intersect(
33+
names(fit$fit_object),
34+
fit$.__enclos_env__$private$.fit_can_remove
35+
)
36+
37+
expect_equal(length(still_present), 0)
38+
}
39+
40+
test_reduce_fit(make_learner(Lrnr_glmnet))
41+
test_reduce_fit(make_learner(Lrnr_ranger))
42+
test_reduce_fit(make_learner(Lrnr_glm_fast))
43+
test_reduce_fit(make_learner(Lrnr_xgboost))
44+
test_reduce_fit(make_learner(Lrnr_hal9001))

0 commit comments

Comments
 (0)