1
1
# ' Learns the Optimal Rule given a tmle_task and likelihood, using the Revere framework.
2
2
# ' Complements 'tmle3_Spec_mopttx_blip_revere'.
3
- # '
3
+ # '
4
4
# '
5
5
# ' @importFrom R6 R6Class
6
6
# ' @importFrom data.table data.table
@@ -14,16 +14,16 @@ Optimal_Rule_Revere <- R6Class(
14
14
inherit = tmle3_Spec ,
15
15
lock_objects = FALSE ,
16
16
public = list (
17
- initialize = function (tmle_task , likelihood , fold_number = " split-specific" , V = NULL ,
18
- blip_type = " blip2" , learners , maximize = TRUE , realistic = FALSE ) {
17
+ initialize = function (tmle_task , likelihood , fold_number = " split-specific" , V = NULL ,
18
+ blip_type = " blip2" , learners , maximize = TRUE , realistic = FALSE ) {
19
19
private $ .tmle_task <- tmle_task
20
20
private $ .likelihood <- likelihood
21
21
private $ .fold_number <- fold_number
22
22
private $ .blip_type <- blip_type
23
23
private $ .learners <- learners
24
24
private $ .maximize <- maximize
25
25
private $ .realistic <- realistic
26
-
26
+
27
27
if (missing(V )) {
28
28
V <- tmle_task $ npsem $ W $ variables
29
29
}
@@ -47,64 +47,63 @@ Optimal_Rule_Revere <- R6Class(
47
47
DR <- data.frame (private $ .DR_full [[v ]])
48
48
return (data.frame (DR [indx , ]))
49
49
},
50
-
51
- blip_revere_function = function (tmle_task , fold_number ){
52
-
50
+
51
+ blip_revere_function = function (tmle_task , fold_number ) {
53
52
likelihood <- self $ likelihood
54
53
A_vals <- tmle_task $ npsem $ A $ variable_type $ levels
55
54
V <- self $ V
56
-
55
+
57
56
# Generate counterfactual tasks for each value of A:
58
57
cf_tasks <- lapply(A_vals , function (A_val ) {
59
- if (is.character(A_val )){
60
- A_val <- as.numeric(A_val )
61
- # A_val<-as.factor(A_val)
58
+ if (is.character(A_val )) {
59
+ A_val <- as.numeric(A_val )
60
+ # A_val<-as.factor(A_val)
62
61
}
63
62
newdata <- data.table(A = A_val )
64
63
cf_task <- tmle_task $ generate_counterfactual_task(UUIDgenerate(), new_data = newdata )
65
64
return (cf_task )
66
65
})
67
-
66
+
68
67
# DR A-IPW mapping of blip
69
68
A <- tmle_task $ get_tmle_node(" A" )
70
69
Y <- tmle_task $ get_tmle_node(" Y" )
71
70
A_vals <- tmle_task $ npsem $ A $ variable_type $ levels
72
- A_ind <- self $ factor_to_indicators(A ,A_vals )
71
+ A_ind <- self $ factor_to_indicators(A , A_vals )
73
72
Y_mat <- replicate(length(A_vals ), Y )
74
-
75
- # Use fold_number fits for Q and g:
73
+
74
+ # Use fold_number fits for Q and g:
76
75
Q_vals <- sapply(cf_tasks , likelihood $ get_likelihood , " Y" , fold_number )
77
76
g_vals <- sapply(cf_tasks , likelihood $ get_likelihood , " A" , fold_number )
78
77
DR <- (A_ind / g_vals ) * (Y_mat - Q_vals ) + Q_vals
79
78
80
79
# Type of pseudo-blip:
81
80
blip_type <- self $ blip_type
82
-
83
- if (blip_type == " blip1" ){
84
- blip <- DR [,2 ] - DR [,1 ]
85
- }else if (blip_type == " blip2" ){
81
+
82
+ if (blip_type == " blip1" ) {
83
+ blip <- DR [, 2 ] - DR [, 1 ]
84
+ } else if (blip_type == " blip2" ) {
86
85
blip <- DR - rowMeans(DR )
87
- }else if (blip_type == " blip3" ){
86
+ } else if (blip_type == " blip3" ) {
88
87
blip <- DR - (rowMeans(DR ) * g_vals )
89
88
}
90
-
91
- # TO DO: Nicer solutions. Do it one by one, for now
92
- if (is.null(V )){
93
- data <- data.table(V = blip ,blip = blip )
89
+
90
+ # TO DO: Nicer solutions. Do it one by one, for now
91
+ if (is.null(V )) {
92
+ data <- data.table(V = blip , blip = blip )
94
93
outcomes <- grep(" blip" , names(data ), value = TRUE )
95
94
V <- grep(" V" , names(data ), value = TRUE )
96
- revere_task <- make_sl3_Task(data , outcome = outcomes , covariates = V , folds = tmle_task $ folds )
97
- }else {
98
- V <- tmle_task $ data [,self $ V ,with = FALSE ]
99
- data <- data.table(V ,blip = blip )
95
+ revere_task <- make_sl3_Task(data , outcome = outcomes , covariates = V , folds = tmle_task $ folds )
96
+ } else {
97
+ V <- tmle_task $ data [, self $ V , with = FALSE ]
98
+ data <- data.table(V , blip = blip )
100
99
outcomes <- grep(" blip" , names(data ), value = TRUE )
101
- revere_task <- make_sl3_Task(data , outcome = outcomes , covariates = self $ V , folds = tmle_task $ folds )
100
+ revere_task <- make_sl3_Task(data , outcome = outcomes , covariates = self $ V , folds = tmle_task $ folds )
102
101
}
103
-
102
+
104
103
105
104
return (revere_task )
106
105
},
107
-
106
+
108
107
bound = function (cv_g ) {
109
108
cv_g [cv_g < 0.01 ] <- 0.01
110
109
cv_g [cv_g > 0.99 ] <- 0.99
@@ -123,81 +122,79 @@ Optimal_Rule_Revere <- R6Class(
123
122
private $ .blip_fit <- blip_fit
124
123
},
125
124
126
- rule = function (tmle_task , fold_number = " full" ) {
127
-
125
+ rule = function (tmle_task , fold_number = " full" ) {
128
126
realistic <- private $ .realistic
129
127
likelihood <- self $ likelihood
130
-
128
+
131
129
# TODO: when applying the rule, we actually only need the covariates
132
130
blip_task <- self $ blip_revere_function(tmle_task , fold_number )
133
131
blip_preds <- self $ blip_fit $ predict_fold(blip_task , fold_number )
134
-
132
+
135
133
# Type of pseudo-blip:
136
134
blip_type <- self $ blip_type
137
-
138
- if (is.list(blip_preds )){
135
+
136
+ if (is.list(blip_preds )) {
139
137
blip_preds <- unpack_predictions(blip_preds )
140
138
}
141
-
139
+
142
140
rule_preds <- NULL
143
-
144
- if (realistic ){
145
-
146
- # Need to grab the propensity score:
141
+
142
+ if (realistic ) {
143
+
144
+ # Need to grab the propensity score:
147
145
g_learner <- likelihood $ factor_list [[" A" ]]$ learner
148
146
g_task <- tmle_task $ get_regression_task(" A" )
149
147
g_fits <- unpack_predictions(g_learner $ predict(g_task ))
150
-
148
+
151
149
if (! private $ .maximize ) {
152
150
blip_preds <- blip_preds * - 1
153
151
}
154
-
155
- if (blip_type == " blip1" ){
152
+
153
+ if (blip_type == " blip1" ) {
156
154
rule_preds <- as.numeric(blip_preds > 0 )
157
-
158
- for (i in 1 : length(rule_preds )){
159
- rule_preds_prob <- g_fits [i ,]
160
- # TO DO: What is a realistic cutoff here?
161
- if (rule_preds_prob < 0.05 ){
162
- # Switch- assumes options are 0 and 1.
155
+
156
+ for (i in 1 : length(rule_preds )) {
157
+ rule_preds_prob <- g_fits [i , ]
158
+ # TO DO: What is a realistic cutoff here?
159
+ if (rule_preds_prob < 0.05 ) {
160
+ # Switch- assumes options are 0 and 1.
163
161
rule_preds [i ] <- abs(rule_preds [i ] - 1 )
164
162
}
165
163
}
166
-
167
- }else {
168
- if (dim(blip_preds )[2 ]< 3 ){
164
+ } else {
165
+ if (dim(blip_preds )[2 ] < 3 ) {
169
166
rule_preds <- max.col(blip_preds ) - 1
170
- for (i in 1 : length(rule_preds )){
171
- rule_preds_prob <- g_fits [i ,rule_preds [i ]]
172
- # TO DO: What is a realistic cutoff here?
173
- if (rule_preds_prob < 0.05 ){
174
- # Pick the next largest blip
175
- rule_preds [i ] <- max.col(blip_preds [i ,order(blip_preds [i ,], decreasing = TRUE )[2 ]])
167
+ for (i in 1 : length(rule_preds )) {
168
+ rule_preds_prob <- g_fits [i , rule_preds [i ]]
169
+ # TO DO: What is a realistic cutoff here?
170
+ if (rule_preds_prob < 0.05 ) {
171
+ # Pick the next largest blip
172
+ rule_preds [i ] <- max.col(blip_preds [i , order(blip_preds [i , ], decreasing = TRUE )[2 ]])
176
173
}
177
174
}
178
- }else {
175
+ } else {
179
176
rule_preds <- max.col(blip_preds )
180
- for (i in 1 : length(rule_preds )){
181
- rule_preds_prob <- g_fits [i ,rule_preds [i ]]
182
- # TO DO: What is a realistic cutoff here?
183
- if (rule_preds_prob < 0.07 ){
184
- # Pick the next largest blip
185
- rule_preds [i ] <- order(blip_preds [i ,], decreasing = TRUE )[2 ]
177
+ for (i in 1 : length(rule_preds )) {
178
+ rule_preds_prob <- g_fits [i , rule_preds [i ]]
179
+ # TO DO: What is a realistic cutoff here?
180
+ if (rule_preds_prob < 0.07 ) {
181
+ # Pick the next largest blip
182
+ rule_preds [i ] <- order(blip_preds [i , ], decreasing = TRUE )[2 ]
186
183
}
187
184
}
188
185
}
189
186
}
190
- }else {
187
+ } else {
191
188
if (! private $ .maximize ) {
192
189
blip_preds <- blip_preds * - 1
193
190
}
194
-
195
- if (blip_type == " blip1" ){
191
+
192
+ if (blip_type == " blip1" ) {
196
193
rule_preds <- as.numeric(blip_preds > 0 )
197
- }else {
198
- if (dim(blip_preds )[2 ]< 3 ) {
194
+ } else {
195
+ if (dim(blip_preds )[2 ] < 3 ) {
199
196
rule_preds <- max.col(blip_preds ) - 1
200
- }else {
197
+ } else {
201
198
rule_preds <- max.col(blip_preds )
202
199
}
203
200
}
0 commit comments