-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain_explainer_cxr_domain_transfer.py
83 lines (76 loc) · 5.15 KB
/
train_explainer_cxr_domain_transfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import os
import sys
from Explainer.experiments_explainer_cxr_domain_transfer import train_glt
sys.path.append(os.path.abspath("/ocean/projects/asc170022p/shg121/PhD/ICLR-2022"))
def config():
parser = argparse.ArgumentParser()
parser.add_argument('--base_path', metavar='DIR',
default='/ocean/projects/asc170022p/shg121/PhD/ICLR-2022',
help='path to checkpoints')
parser.add_argument('--image_dir_path', metavar='DIR',
default='/ocean/projects/asc170022p/shared/Data/chestXRayDatasets/StanfordCheXpert',
help='image path in ocean')
parser.add_argument('--image_source_dir', metavar='DIR', type=str, default='CheXpert-v1.0-small',
help='dataset directory')
parser.add_argument('--image_col_header', metavar='DIR', type=str, default='Path',
help='dataset directory')
parser.add_argument('--checkpoints', metavar='DIR',
default='/ocean/projects/asc170022p/shg121/PhD/ICLR-2022/checkpoints',
help='path to checkpoints')
parser.add_argument('--output', metavar='DIR',
default='/ocean/projects/asc170022p/shg121/PhD/ICLR-2022/out',
help='path to output logs')
parser.add_argument('--logs', metavar='DIR',
default='/ocean/projects/asc170022p/shg121/PhD/ICLR-2022/log',
help='path to tensorboard logs')
parser.add_argument('--chexpert_names', nargs='+',
default=["No_Finding", "Enlarged_Cardiomediastinum", "Cardiomegaly", "Lung_Opacity",
"Lung_Lesion",
"Edema", "Consolidation", "Pneumonia", "Atelectasis", "Pneumothorax", "Effusion",
"Pleural_Other", "Fracture", "Support Devices"]
)
parser.add_argument('--batch-size', default=16, type=int, metavar='N', help='batch size BB')
parser.add_argument('--flattening-type', type=str, default="flatten", help='flatten or adaptive or maxpool')
parser.add_argument('--uncertain', default=1, type=int, help='number of epochs warm up.')
parser.add_argument('--seed', type=int, default=0, help='seed')
parser.add_argument('--epochs', type=int, default=500, help='epochs')
parser.add_argument('--disease', type=str, default="effusion", help='disease name')
parser.add_argument('--arch', type=str, default="densenet121", help='arch')
parser.add_argument('--iter', type=int, default="1", help='iteration')
parser.add_argument('--cov', type=float, default=0, help='coverage')
parser.add_argument('--tot_samples', type=int, default="1000", help='tot_samples')
parser.add_argument('--image_size', default=512, type=int, help='image_size.')
parser.add_argument('--crop_size', default=512, type=int, help='image_size.')
parser.add_argument('--model', default='MoIE', type=str, help='MoIE')
parser.add_argument('--target_dataset', default='stanford_cxr', type=str, help='dataset')
parser.add_argument('--source_dataset', type=str, default="mimic_cxr", help='source dataset name')
parser.add_argument('--source-checkpoint-t-path', type=str,
default="lr_0.01_epochs_60_loss_BCE_W_flattening_type_flatten_layer_features_denseblock4",
help='dataset folder of concepts')
parser.add_argument('--target-checkpoint-t-path', type=str,
default="lr_0.1_epochs_90_loss_BCE_W_flattening_type_flatten_layer_features_denseblock4",
help='dataset folder of concepts')
parser.add_argument('--prev_chk_pt_explainer_folder', nargs='+', type=str,
default="xxxx",
help='chkpt explainer')
parser.add_argument('--checkpoint-model', metavar='file', nargs="+",
default=['xxxx'],
help='checkpoint file of the model GatedLogicNet')
parser.add_argument('--channels', default=3, type=int, help='channels ')
parser.add_argument('--optim', type=str, default="SGD", help='optimizer of GLT')
parser.add_argument('--gpu', default=0, type=int, help='GPU id to use.')
parser.add_argument('--train_phi', default="n", type=str, metavar='TYPE')
parser.add_argument('--profile', default="n", type=str, metavar='TYPE')
parser.add_argument('--initialize_w_mimic', default="n", type=str, metavar='TYPE')
parser.add_argument('--selection_threshold', default=0.5, type=float,
help='selection threshold of the selector for the test/val set')
parser.add_argument('--expert-to-train', default="explainer", type=str, metavar='N',
help='which expert to train? explainer or residual')
parser.add_argument('--source-checkpoint-bb', metavar='file', help='checkpoint file of BB')
parser.add_argument('--source-checkpoint-bb_path', metavar='file', help='checkpoint file of BB')
parser.add_argument('--prev_covs', nargs='+', default=[0.4, 0.3, 0.3])
return parser.parse_args()
if __name__ == "__main__":
args = config()
train_glt(args)