Skip to content

Commit 86d610f

Browse files
committed
initial support for fit size reduction
1 parent 20834ae commit 86d610f

8 files changed

+140
-5
lines changed

R/Lrnr_base.R

+51
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ Lrnr_base <- R6Class(
151151
}
152152
new_object <- self$clone() # copy parameters, and whatever else
153153
new_object$set_train(fit_object, task)
154+
if (getOption("sl3.reduce_fit")) {
155+
new_object$reduce_fit(check_preds = FALSE)
156+
}
154157
return(new_object)
155158
},
156159
set_train = function(fit_object, training_task) {
@@ -335,6 +338,53 @@ Lrnr_base <- R6Class(
335338
} else {
336339
return(task)
337340
}
341+
},
342+
reduce_fit = function(fit_object = NULL, check_preds = TRUE, set_train = TRUE) {
343+
if (is.null(fit_object)) {
344+
fit_object <- self$fit_object
345+
}
346+
if (check_preds) {
347+
preds_full <- self$predict(task)
348+
}
349+
350+
351+
# try reducing the size
352+
size_full <- true_obj_size(fit_object)
353+
354+
# see what's taking up the space
355+
# element_sizes <- sapply(fo, true_obj_size)
356+
# ranked <- sort(element_sizes/size_full, decreasing = TRUE)
357+
358+
# by default, drop out call
359+
# within(fit_object, rm(private$.fit_can_remove))
360+
keep <- setdiff(names(fit_object), private$.fit_can_remove)
361+
362+
# gotta preserve the attributes (not sure why they're getting dropped)
363+
attrs <- attributes(fit_object)
364+
attrs$names <- keep
365+
reduced <- fit_object[keep]
366+
attributes(reduced) <- attrs
367+
fit_object <- reduced
368+
size_reduced <- true_obj_size(fit_object)
369+
reduction_percent <- 1 - size_reduced / size_full
370+
371+
if (getOption("sl3.verbose")) {
372+
message(sprintf("Fit object size reduced %0.0f%%", 100 * reduction_percent))
373+
}
374+
375+
376+
if (set_train) {
377+
self$set_train(fit_object, self$training_task)
378+
}
379+
380+
381+
# verify prediction still works
382+
if (check_preds) {
383+
preds_reduced <- self$predict(task)
384+
assert_that(all.equal(preds_full, preds_reduced))
385+
}
386+
387+
return(fit_object)
338388
}
339389
),
340390
active = list(
@@ -399,6 +449,7 @@ Lrnr_base <- R6Class(
399449
.required_packages = NULL,
400450
.properties = list(),
401451
.custom_chain = NULL,
452+
.fit_can_remove = c("call"),
402453
.train_sublearners = function(task) {
403454
# train sublearners here
404455
return(NULL)

R/Lrnr_glm_fast.R

+1
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ Lrnr_glm_fast <- R6Class(
157157
}
158158
return(predictions)
159159
},
160+
.fit_can_remove = c("XTX"),
160161
.required_packages = c("speedglm")
161162
)
162163
)

R/Lrnr_hal9001.R

+1-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ Lrnr_hal9001 <- R6Class(
9999

100100
return(fit_object)
101101
},
102-
103102
.predict = function(task = NULL) {
104103
predictions <- stats::predict(
105104
self$fit_object,
@@ -111,7 +110,7 @@ Lrnr_hal9001 <- R6Class(
111110
}
112111
return(predictions)
113112
},
114-
113+
.fit_can_remove = c("lasso_fit", "x_basis"),
115114
.required_packages = c("hal9001", "glmnet")
116115
)
117116
)

R/Lrnr_xgboost.R

+1
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ Lrnr_xgboost <- R6Class(
194194

195195
return(predictions)
196196
},
197+
.fit_can_remove = c("raw", "call"),
197198
.required_packages = c("xgboost")
198199
)
199200
)

R/utils.R

+10
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,16 @@ true_obj_size <- function(obj) {
155155
length(serialize(obj, NULL))
156156
}
157157

158+
#' @keywords internal
159+
check_fit_sizes <- function(fit) {
160+
fo <- fit$fit_object
161+
# see what's taking up the space
162+
element_sizes <- sapply(fo, true_obj_size)
163+
ranked <- sort(element_sizes / sum(element_sizes), decreasing = TRUE)
164+
165+
return(ranked)
166+
}
167+
158168
################################################################################
159169

160170
#' Drop components from learner fits

R/zzz.R

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ sl3Options <- function(o, value) {
3838
}
3939
if (is.null(value)) {
4040
res[o] <- list(NULL)
41-
}
42-
else {
41+
} else {
4342
res[[o]] <- value
4443
}
4544
options(res[o])
@@ -62,7 +61,8 @@ sl3Options <- function(o, value) {
6261
"sl3.pcontinuous" = 0.05,
6362
"sl3.max_p_missing" = 0.5,
6463
"sl3.transform.offset" = TRUE,
65-
"sl3.enable.future" = TRUE
64+
"sl3.enable.future" = TRUE,
65+
"sl3.reduce_fit" = FALSE
6666
)
6767
# for (i in setdiff(names(opts),names(options()))) {
6868
# browser()

prof_dt.csv

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
learner,train (S).elapsed,predict (S).elapsed,size (MB),MSE,n,p,p_encoded
2+
lasso,0.211999999999989,0.00900000000001455,0.0478677749633789,0.161776523228791,100,100,100
3+
lasso_fast,0.0660000000002583,0.0100000000002183,0.0478677749633789,0.169428931381144,100,100,100
4+
mean,0.0109999999999673,0.00399999999990541,0.000121116638183594,0.17908,100,100,100
5+
ranger,0.0549999999998363,0.0160000000000764,0.436310768127441,0.165616521662889,100,100,100
6+
ranger_small,0.0359999999996035,0.0100000000002183,0.0905466079711914,0.168278786405556,100,100,100
7+
xgb,0.0370000000002619,0.00700000000006185,0.0262327194213867,0.174146900124846,100,100,100
8+
glm,0.0419999999999163,0.00700000000006185,0.0339899063110352,0.477337871242782,100,100,100
9+
hal_ls2,0.813000000000102,0.0279999999997926,0.760213851928711,0.179214823680041,100,100,100
10+
lasso,2.33800000000019,0.01299999999992,0.0838642120361328,0.0823997719258026,1000,100,100
11+
lasso_fast,1.41199999999981,0.0140000000001237,0.0838642120361328,0.0830177435791345,1000,100,100
12+
mean,0.0209999999997308,0.00600000000031287,0.000121116638183594,0.1780702,1000,100,100
13+
ranger,0.357000000000426,0.0879999999997381,4.20776462554932,0.14595598192,1000,100,100
14+
ranger_small,0.103000000000065,0.0329999999999018,0.850159645080566,0.146991513661111,1000,100,100
15+
xgb,0.141000000000076,0.01299999999992,0.0262327194213867,0.132485156061917,1000,100,100
16+
glm,0.0949999999997999,0.00999999999976353,0.0408601760864258,0.0903512975129307,1000,100,100
17+
hal_ls2,8.8130000000001,0.152000000000044,0.760213851928711,0.0983646681386519,1000,100,100
18+
lasso,15.2350000000001,0.0520000000001346,0.0769224166870117,0.0659705855526775,10000,100,100
19+
lasso_fast,4.76299999999992,0.0549999999998363,0.0769224166870117,0.0659733840218839,10000,100,100
20+
mean,0.0109999999999673,0.0320000000001528,0.000121116638183594,0.17774657,10000,100,100
21+
ranger,6.94100000000026,1.28200000000015,40.0050668716431,0.133503125161,10000,100,100
22+
ranger_small,1.43199999999979,0.297999999999774,8.07391452789307,0.134473546638889,10000,100,100
23+
xgb,0.909999999999854,0.0720000000001164,0.0262327194213867,0.111062539475018,10000,100,100
24+
glm,0.493999999999687,0.0500000000001819,0.109524726867676,0.0660307579952773,10000,100,100
25+
lasso_fast,58.0370000000003,0.372999999999593,0.0737161636352539,0.0654017031817335,100000,100,100
26+
mean,0.0140000000001237,0.231999999999971,0.000121116638183594,0.1777693129,100000,100,100
27+
ranger_small,51.5020000000004,4.70499999999993,77.7932748794556,0.124000491958333,100000,100,100
28+
xgb,8.3149999999996,0.532999999999447,0.0262327194213867,0.108275682560827,100000,100,100
29+
glm,4.62099999999919,0.353000000000066,0.796170234680176,0.0654282103103834,100000,100,100

tests/testthat/test-reduce_fit.R

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
2+
set.seed(1234)
3+
4+
# TODO: maybe check storage at different n to get rate
5+
n <- 1e3
6+
p <- 100
7+
# these two define the DGP (randomly)
8+
p_X <- runif(p, 0.2, 0.8)
9+
beta <- rnorm(p)
10+
11+
# simulate from the DGP
12+
X <- sapply(p_X, function(p_Xi) rbinom(n, 1, p_Xi))
13+
p_Yx <- plogis(X %*% beta)
14+
Y <- rbinom(n, 1, p_Yx)
15+
data <- data.table(X, Y)
16+
17+
# generate the sl3 task and learner
18+
outcome <- "Y"
19+
covariates <- setdiff(names(data), outcome)
20+
task <- make_sl3_Task(data, covariates, outcome)
21+
22+
options(sl3.verbose = TRUE)
23+
options(sl3.reduce_fit = TRUE)
24+
test_reduce_fit <- function(learner) {
25+
fit <- learner$train(task)
26+
print(sl3:::check_fit_sizes(fit))
27+
if (!getOption("sl3.reduce_fit")) {
28+
# if we aren't automatically reducing, do it manually
29+
fit_object <- fit$reduce_fit()
30+
}
31+
32+
still_present <- intersect(
33+
names(fit$fit_object),
34+
fit$.__enclos_env__$private$.fit_can_remove
35+
)
36+
37+
expect_equal(length(still_present), 0)
38+
}
39+
40+
test_reduce_fit(make_learner(Lrnr_glmnet))
41+
test_reduce_fit(make_learner(Lrnr_ranger))
42+
test_reduce_fit(make_learner(Lrnr_glm_fast))
43+
test_reduce_fit(make_learner(Lrnr_xgboost))
44+
test_reduce_fit(make_learner(Lrnr_hal9001))

0 commit comments

Comments
 (0)