Skip to content

Commit fea856a

Browse files
committed
1)attention_alg 2)no_attention 3)scale
1 parent a012976 commit fea856a

File tree

6 files changed

+85
-38
lines changed

6 files changed

+85
-38
lines changed

main_tabular_data.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -285,18 +285,20 @@ def Fold_learning(fold_n,data,config,visual):
285285
parser = argparse.ArgumentParser()
286286
# parser.add_argument('--use-gpu', action='store_true')
287287
parser.add_argument('--data_root',required=True)
288-
parser.add_argument('--dataset', default='CLICK',help="MICROSOFT,YAHOO,YEAR,CLICK,HIGGS")
288+
parser.add_argument('--dataset', default='CLICK',help="MICROSOFT,YAHOO,YEAR,CLICK,HIGGS,EPSILON")
289289
parser.add_argument('--iterations', default=1000, type=int)
290290
parser.add_argument('--model', default="QForest", help='QForest,GBDT,LinearRegressor')
291-
parser.add_argument('--learning_rate', default="0.01", type=float)
291+
parser.add_argument('--learning_rate', default="0.001", type=float)
292292
parser.add_argument('--subsample', default="1", type=float)
293-
parser.add_argument('--QF_fit', default="0", type=int)
293+
parser.add_argument('--QF_fit', default="1", type=int)
294+
parser.add_argument('--attention', default="eca_response", type=str)
295+
parser.add_argument('--scale', default="medium",help='small,medium,large', type=str)
294296
args = parser.parse_args()
295297
dataset = args.dataset
296298
data = quantum_forest.TabularDataset(dataset,data_path=args.data_root, random_state=1337, quantile_transform=True, quantile_noise=1e-3)
297299
#data = quantum_forest.TabularDataset(dataset,data_path=data_root, random_state=1337, quantile_transform=True)
298300

299-
config = quantum_forest.QForest_config(data,0.002,feat_info="attention") #,feat_info="importance","attention"
301+
config = quantum_forest.QForest_config(data,0.002) #,feat_info="importance","attention"
300302
random_state = 42
301303
config.device = quantum_forest.OnInitInstance(random_state)
302304
config.model=args.model #"QForest" "GBDT" "LinearRegressor"
@@ -306,19 +308,26 @@ def Fold_learning(fold_n,data,config,visual):
306308
config.bagging_fraction = args.subsample
307309
config.nMostEpochs = args.iterations
308310
config.QF_fit = args.QF_fit
311+
config.attention_alg = args.attention
312+
if args.scale == "small":
313+
config.depth, config.batch_size, config.nTree = 4, 256, 256
314+
elif args.scale == "medium":
315+
config.depth, config.batch_size, config.nTree = 5, 512, 1024
316+
elif args.scale == "large":
317+
config.depth, config.batch_size, config.nTree = 5, 512, 2048
309318

310319
if dataset=="YAHOO" or dataset=="MICROSOFT" or dataset=="CLICK" or dataset=="HIGGS" or dataset=="EPSILON":
311320
config,visual = quantum_forest.InitExperiment(config, 0)
312321
data.onFold(0,config,pkl_path=f"{args.data_root}{dataset}/FOLD_Quantile_.pickle")
313322
Fold_learning(0,data, config,visual)
314-
else:
323+
else: #"YEAR"
315324
nFold = 5 if dataset != "HIGGS" else 20
316325
folds = KFold(n_splits=nFold, shuffle=True)
317326
index_sets=[]
318327
for fold_n, (train_index, valid_index) in enumerate(folds.split(data.X)):
319328
index_sets.append(valid_index)
320329
for fold_n in range(len(index_sets)):
321-
config, visual = InitExperiment(config, fold_n)
330+
config, visual = quantum_forest.InitExperiment(config, fold_n)
322331
train_list=[]
323332
for i in range(nFold):
324333
if i==fold_n: #test
@@ -330,7 +339,7 @@ def Fold_learning(fold_n,data,config,visual):
330339
train_index=np.concatenate(train_list)
331340
print(f"train={len(train_index)} valid={len(valid_index)} test={len(index_sets[fold_n])}")
332341

333-
data.onFold(fold_n,config,train_index=train_index, valid_index=valid_index,test_index=index_sets[fold_n],pkl_path=f"{data_root}{dataset}/FOLD_{fold_n}.pickle")
342+
data.onFold(fold_n,config,train_index=train_index, valid_index=valid_index,test_index=index_sets[fold_n],pkl_path=f"{args.data_root}{dataset}/FOLD_{fold_n}.pickle")
334343
Fold_learning(fold_n,data,config,visual)
335344
break
336345

python-package/quantum_forest/DifferentiableTree.py

+32-27
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def init_attention(self, feat_info=None):
3737
self.feat_map = random.choices(population = list(range(self.in_features)),k = self.nGateFuncs) #weights = weight,
3838
#self.feat_map[0]=1
3939
self.init_attention_func=nn.init.eye
40-
if self.no_attention:
40+
if False: #self.no_attention:
4141
self.feat_weight = nn.Parameter(torch.Tensor(self.nGateFuncs).uniform_(), requires_grad=True)
4242
pass
4343
elif False: #仅用于对比
@@ -63,7 +63,7 @@ def init_attention(self, feat_info=None):
6363
self.feat_attention = nn.Parameter( feat_val, requires_grad=True )
6464
#self.feat_W = nn.Parameter(torch.Tensor(self.in_features).uniform_(), requires_grad=True)
6565

66-
print(f"====== init_attention f={self.init_attention_func.__name__} no_attention={self.no_attention}")
66+
print(f"====== init_attention f={self.init_attention_func.__name__} attention={self.config.attention_alg}")
6767

6868
#weights computed as entmax over the learnable feature selection matrix F ∈ R d×n
6969
def get_attention_value(self,input):
@@ -226,20 +226,25 @@ def __init__(self, in_features, num_trees,config, flatten_output=True,feat_info=
226226
if self.config.path_way=="TREE_map":
227227
self.nFeature = 2**depth-1
228228
self.nGateFuncs = self.num_trees*self.nFeature
229-
if False:
229+
if self.config.attention_alg == "eca_response":
230230
self.att_reponse = eca_reponse(self.num_trees)
231-
#self.att_input = eca_input(self.in_features)
231+
elif self.config.attention_alg == "eca_input":
232+
self.att_input = eca_input(self.in_features)
233+
elif self.config.attention_alg == "":
234+
print("!!! Empty attention_alg.Please try \"--attention=eca_response\" !!!\n")
235+
else:
236+
raise ValueError( f'INVALID self.config.attention_alg = {self.config.attention_alg}' )
232237

233-
if self.config.attention_alg == "weight":
234-
if False and isAdptiveAlpha: #可以试试,就是时间太长
235-
self.entmax_alpha = nn.Parameter(torch.tensor(1.5, requires_grad=True))
236-
self.attention_func = entmax.entmax_bisect #sparsemax, entmax15, entmax_bisect
237-
else:
238-
self.attention_func = entmax15
238+
239+
if False and isAdptiveAlpha: #可以试试,就是时间太长
240+
self.entmax_alpha = nn.Parameter(torch.tensor(1.5, requires_grad=True))
241+
self.attention_func = entmax.entmax_bisect #sparsemax, entmax15, entmax_bisect
239242
else:
243+
self.attention_func = entmax15
244+
if self.config.attention_alg == "excitation_max": #失败的尝试
240245
self.attention_func = excitation_max(in_features,self.nGateFuncs,self.num_trees)
241246
self.bin_func = "05_01" #"05_01" "entmoid15" "softmax"
242-
self.no_attention = config.no_attention
247+
# self.no_attention = config.no_attention
243248
self.threshold_init_beta, self.threshold_init_cutoff = threshold_init_beta, threshold_init_cutoff
244249
self.init_responce_func = initialize_response_
245250
self.response_mean,self.response_std = 0,0
@@ -309,13 +314,13 @@ def forward(self, input):
309314
if hasattr(self,"att_input"):
310315
input = self.att_input(input)
311316
# new input shape: [batch_size, in_features]
312-
if self.no_attention:
313-
feature_values = input[:, self.feat_map] # torch.index_select(input.flatten(), 0, self.feat_select)
314-
#feature_values = torch.einsum('bf,f->bf',feature_values,self.feat_weight)
315-
feature_values = feature_values.reshape(-1, self.num_trees, self.depth)
316-
assert feature_values.shape[0] == input.shape[0]
317-
else:
318-
feature_values = self.get_attention_value(input)
317+
# if self.no_attention:
318+
# feature_values = input[:, self.feat_map] # torch.index_select(input.flatten(), 0, self.feat_select)
319+
# #feature_values = torch.einsum('bf,f->bf',feature_values,self.feat_weight)
320+
# feature_values = feature_values.reshape(-1, self.num_trees, self.depth)
321+
# assert feature_values.shape[0] == input.shape[0]
322+
# else:
323+
feature_values = self.get_attention_value(input)
319324
# ^--[batch_size, num_trees, depth]
320325

321326
threshold_logits = (feature_values - self.feature_thresholds) * torch.exp(-self.log_temperatures)
@@ -421,13 +426,13 @@ def initialize(self, input, eps=1e-6):
421426
print(f"!!!!!! DeTree::initialize@{self.__repr__()} has only {nSamp} sampls. This may cause instability.\n")
422427

423428
with torch.no_grad():
424-
if self.no_attention:
425-
feature_values = input[:, self.feat_map] #torch.index_select(input.flatten(), 0, self.feat_select)
426-
feature_values = torch.einsum('bf,f->bf', feature_values, self.feat_weight)
427-
feature_values=feature_values.reshape(-1,self.num_trees,self.depth)
428-
assert feature_values.shape[0]==input.shape[0]
429-
else:
430-
feature_values = self.get_attention_value(input)
429+
# if self.no_attention:
430+
# feature_values = input[:, self.feat_map] #torch.index_select(input.flatten(), 0, self.feat_select)
431+
# feature_values = torch.einsum('bf,f->bf', feature_values, self.feat_weight)
432+
# feature_values=feature_values.reshape(-1,self.num_trees,self.depth)
433+
# assert feature_values.shape[0]==input.shape[0]
434+
# else:
435+
feature_values = self.get_attention_value(input)
431436

432437
# initialize thresholds: sample random percentiles of data
433438
percentiles_q = 100 * np.random.beta(self.threshold_init_beta, self.threshold_init_beta,
@@ -450,8 +455,8 @@ def __repr__(self):
450455
f_info = self.attention_func.__repr__()
451456
f_name = "excitation_max"
452457
f_init = ""
453-
else:
454-
f_info = 0 if self.no_attention else self.feat_attention.shape[0]
458+
else:
459+
f_info = self.feat_attention.shape[0]
455460
f_name = self.attention_func.__name__
456461
f_init = self.init_attention_func.__name__
457462
main_str = "{}(F={},f={},B={}, T={},D={}, response_dim={}, " \

python-package/quantum_forest/QForest.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self,data, lr_base, nLayer=1,choice_func="r_0.5",feat_info = None,r
2121
self.rDrop = 0
2222
self.custom_legend = None
2323
self.feat_info = feat_info
24-
self.no_attention = False
24+
# self.no_attention = False #always False since 5/28/2020
2525
self.max_features = None
2626
self.input_dropout = 0 #YAHOO 0.59253-0.59136 没啥用
2727
self.num_layers = 1
@@ -46,7 +46,7 @@ def __init__(self,data, lr_base, nLayer=1,choice_func="r_0.5",feat_info = None,r
4646

4747
self.err_relative = False
4848
self.task = "train"
49-
self.attention_alg = "weight" #
49+
self.attention_alg = "eca_reponse" # 'eca_input'
5050

5151
self.nMostEpochs = 1000
5252
self.depth, self.batch_size, self.nTree, self.response_dim, self.nLayers = 5, 256, 256, 8, 1
@@ -121,7 +121,7 @@ def env_title(self):
121121
def main_str(self):
122122
main_str = f"{self.data_set}_ layers={self.nLayer} depth={self.depth} batch={self.batch_size} nTree={self.nTree} response_dim={self.response_dim} " \
123123
f"\nmax_out={self.max_out} choice=[{self.choice_func}] feat_info={self.feat_info}" \
124-
f"\nNO_ATTENTION={self.no_attention} reg_L1={self.reg_L1} path_way={self.path_way}"
124+
f"\nATTENTION={self.config.attention_alg} reg_L1={self.reg_L1} path_way={self.path_way}"
125125
#if self.isFC: main_str+=" [FC]"
126126
if self.custom_legend is not None:
127127
main_str = main_str + f"_{self.custom_legend}"

python-package/quantum_forest/tabular_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def problem(self):
156156
def onFold(self,fold,config,pkl_path=None, train_index=None, valid_index=None, test_index=None):
157157
if pkl_path is not None:
158158
print("====== onFold pkl_path={} ......".format(pkl_path))
159-
if False and pkl_path is not None and os.path.isfile(pkl_path):
159+
if pkl_path is not None and os.path.isfile(pkl_path):
160160
with open(pkl_path, "rb") as fp:
161161
[self.X_train,self.y_train,self.X_valid, self.y_valid,self.X_test,self.y_test,\
162162
self.quantile_noise,self.Y_trans_method,self.accu_scale,self.Y_mu_0, self.Y_std_0,self.zero_feats] = pickle.load(fp)

test_LGB.bat

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
E:
2+
cd E:\fengnaixing\cys\QuantumForest
3+
python main_tabular_data.py --data_root="E:/fengnaixing/cys/Datasets/" --dataset=YAHOO --model=GBDT
4+
python main_tabular_data.py --data_root="E:/fengnaixing/cys/Datasets/" --dataset=CLICK --model=GBDT
5+
python main_tabular_data.py --data_root="E:/fengnaixing/cys/Datasets/" --dataset=MICROSOFT --model=GBDT
6+
python main_tabular_data.py --data_root="E:/fengnaixing/cys/Datasets/" --dataset=YEAR --model=GBDT
7+
python main_tabular_data.py --data_root="E:/fengnaixing/cys/Datasets/" --dataset=HIGGS --model=GBDT
8+
python main_tabular_data.py --data_root="E:/fengnaixing/cys/Datasets/" --dataset=EPSILON --model=GBDT
9+
:: python main_tabular_data.py --data_root=“E:/fengnaixing/cys/Datasets/” --model=GBDT --dataset=MICROSOFTtes

test_QF.bat

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
::@echo off
2+
E:
3+
cd E:\fengnaixing\cys\QuantumForest
4+
:: https://stackoverflow.com/questions/18462169/how-to-loop-through-array-in-batch
5+
::set DATA[0]=YAHOO
6+
::set DATA[1]=CLICK
7+
::set DATA[2]=MICROSOFT
8+
::set DATA[3]=YEAR
9+
::set DATA[4]=HIGGS
10+
::set DATA[5]=EPSILON
11+
set "cmd=python main_tabular_data.py --data_root=../Datasets/"
12+
::set "param=--model=GBDT"
13+
14+
%cmd% --dataset=YAHOO %param%
15+
%cmd% --dataset=CLICK %param%
16+
%cmd% --dataset=MICROSOFT %param%
17+
%cmd% --dataset=YEAR %param%
18+
%cmd% --dataset=HIGGS %param%
19+
%cmd% --dataset=EPSILON %param%
20+
:: --model=GBDT --dataset=MICROSOFT --learning_rate=0.001
21+
:: --attention=""
22+
:: --scale="large"
23+
:: C:/Users/fengnaixing/test_QF.bat
24+
:: python main_tabular_data.py --data_root=../Datasets/ --dataset=YAHOO --model=GBDT

0 commit comments

Comments
 (0)