@@ -37,7 +37,7 @@ def init_attention(self, feat_info=None):
37
37
self .feat_map = random .choices (population = list (range (self .in_features )),k = self .nGateFuncs ) #weights = weight,
38
38
#self.feat_map[0]=1
39
39
self .init_attention_func = nn .init .eye
40
- if self .no_attention :
40
+ if False : # self.no_attention:
41
41
self .feat_weight = nn .Parameter (torch .Tensor (self .nGateFuncs ).uniform_ (), requires_grad = True )
42
42
pass
43
43
elif False : #仅用于对比
@@ -63,7 +63,7 @@ def init_attention(self, feat_info=None):
63
63
self .feat_attention = nn .Parameter ( feat_val , requires_grad = True )
64
64
#self.feat_W = nn.Parameter(torch.Tensor(self.in_features).uniform_(), requires_grad=True)
65
65
66
- print (f"====== init_attention f={ self .init_attention_func .__name__ } no_attention ={ self .no_attention } " )
66
+ print (f"====== init_attention f={ self .init_attention_func .__name__ } attention ={ self .config . attention_alg } " )
67
67
68
68
#weights computed as entmax over the learnable feature selection matrix F ∈ R d×n
69
69
def get_attention_value (self ,input ):
@@ -226,20 +226,25 @@ def __init__(self, in_features, num_trees,config, flatten_output=True,feat_info=
226
226
if self .config .path_way == "TREE_map" :
227
227
self .nFeature = 2 ** depth - 1
228
228
self .nGateFuncs = self .num_trees * self .nFeature
229
- if False :
229
+ if self . config . attention_alg == "eca_response" :
230
230
self .att_reponse = eca_reponse (self .num_trees )
231
- #self.att_input = eca_input(self.in_features)
231
+ elif self .config .attention_alg == "eca_input" :
232
+ self .att_input = eca_input (self .in_features )
233
+ elif self .config .attention_alg == "" :
234
+ print ("!!! Empty attention_alg.Please try \" --attention=eca_response\" !!!\n " )
235
+ else :
236
+ raise ValueError ( f'INVALID self.config.attention_alg = { self .config .attention_alg } ' )
232
237
233
- if self .config .attention_alg == "weight" :
234
- if False and isAdptiveAlpha : #可以试试,就是时间太长
235
- self .entmax_alpha = nn .Parameter (torch .tensor (1.5 , requires_grad = True ))
236
- self .attention_func = entmax .entmax_bisect #sparsemax, entmax15, entmax_bisect
237
- else :
238
- self .attention_func = entmax15
238
+
239
+ if False and isAdptiveAlpha : #可以试试,就是时间太长
240
+ self .entmax_alpha = nn .Parameter (torch .tensor (1.5 , requires_grad = True ))
241
+ self .attention_func = entmax .entmax_bisect #sparsemax, entmax15, entmax_bisect
239
242
else :
243
+ self .attention_func = entmax15
244
+ if self .config .attention_alg == "excitation_max" : #失败的尝试
240
245
self .attention_func = excitation_max (in_features ,self .nGateFuncs ,self .num_trees )
241
246
self .bin_func = "05_01" #"05_01" "entmoid15" "softmax"
242
- self .no_attention = config .no_attention
247
+ # self.no_attention = config.no_attention
243
248
self .threshold_init_beta , self .threshold_init_cutoff = threshold_init_beta , threshold_init_cutoff
244
249
self .init_responce_func = initialize_response_
245
250
self .response_mean ,self .response_std = 0 ,0
@@ -309,13 +314,13 @@ def forward(self, input):
309
314
if hasattr (self ,"att_input" ):
310
315
input = self .att_input (input )
311
316
# new input shape: [batch_size, in_features]
312
- if self .no_attention :
313
- feature_values = input [:, self .feat_map ] # torch.index_select(input.flatten(), 0, self.feat_select)
314
- #feature_values = torch.einsum('bf,f->bf',feature_values,self.feat_weight)
315
- feature_values = feature_values .reshape (- 1 , self .num_trees , self .depth )
316
- assert feature_values .shape [0 ] == input .shape [0 ]
317
- else :
318
- feature_values = self .get_attention_value (input )
317
+ # if self.no_attention:
318
+ # feature_values = input[:, self.feat_map] # torch.index_select(input.flatten(), 0, self.feat_select)
319
+ # #feature_values = torch.einsum('bf,f->bf',feature_values,self.feat_weight)
320
+ # feature_values = feature_values.reshape(-1, self.num_trees, self.depth)
321
+ # assert feature_values.shape[0] == input.shape[0]
322
+ # else:
323
+ feature_values = self .get_attention_value (input )
319
324
# ^--[batch_size, num_trees, depth]
320
325
321
326
threshold_logits = (feature_values - self .feature_thresholds ) * torch .exp (- self .log_temperatures )
@@ -421,13 +426,13 @@ def initialize(self, input, eps=1e-6):
421
426
print (f"!!!!!! DeTree::initialize@{ self .__repr__ ()} has only { nSamp } sampls. This may cause instability.\n " )
422
427
423
428
with torch .no_grad ():
424
- if self .no_attention :
425
- feature_values = input [:, self .feat_map ] #torch.index_select(input.flatten(), 0, self.feat_select)
426
- feature_values = torch .einsum ('bf,f->bf' , feature_values , self .feat_weight )
427
- feature_values = feature_values .reshape (- 1 ,self .num_trees ,self .depth )
428
- assert feature_values .shape [0 ]== input .shape [0 ]
429
- else :
430
- feature_values = self .get_attention_value (input )
429
+ # if self.no_attention:
430
+ # feature_values = input[:, self.feat_map] #torch.index_select(input.flatten(), 0, self.feat_select)
431
+ # feature_values = torch.einsum('bf,f->bf', feature_values, self.feat_weight)
432
+ # feature_values=feature_values.reshape(-1,self.num_trees,self.depth)
433
+ # assert feature_values.shape[0]==input.shape[0]
434
+ # else:
435
+ feature_values = self .get_attention_value (input )
431
436
432
437
# initialize thresholds: sample random percentiles of data
433
438
percentiles_q = 100 * np .random .beta (self .threshold_init_beta , self .threshold_init_beta ,
@@ -450,8 +455,8 @@ def __repr__(self):
450
455
f_info = self .attention_func .__repr__ ()
451
456
f_name = "excitation_max"
452
457
f_init = ""
453
- else :
454
- f_info = 0 if self . no_attention else self .feat_attention .shape [0 ]
458
+ else :
459
+ f_info = self .feat_attention .shape [0 ]
455
460
f_name = self .attention_func .__name__
456
461
f_init = self .init_attention_func .__name__
457
462
main_str = "{}(F={},f={},B={}, T={},D={}, response_dim={}, " \
0 commit comments