1
1
import os
2
2
import functools
3
3
from datetime import datetime , timedelta
4
- import argparse
5
- import cProfile , pstats
6
4
import pandas as pd
7
- import mlflow
8
5
import drms
9
6
import torch
10
7
from torch .utils .data import Dataset , DataLoader
11
- from torch .nn import functional as F
12
8
from torchvision .transforms import Compose
13
9
import pytorch_lightning as pl
14
10
15
- from arnet .data .datamodule import ARVideoDataModule as _ARVideoDataModule
16
- from arnet .run_net import train , test , visualize
17
- from arnet import utils
18
- from arnet import const
19
- from config import cfg
20
- from data import query , query_parameters , read_header
21
- from constants import CONSTANTS
11
+ from arnet .utils import query , query_parameters , read_header
12
+ from arnet .transforms import get_transform
13
+ from arnet .constants import CONSTANTS
22
14
23
15
16
+ PROCESSED_DATA_DIR = '/home/zeyusun/work/flare-prediction-smarp/datasets/M1.0_24hr_balanced'
24
17
SPLIT_DIRS = {
25
- 'HARP' : '/home/zeyusun/work/flare-prediction-smarp/datasets/ sharp/' ,
26
- 'TARP' : '/home/zeyusun/work/flare-prediction- smarp/datasets/smarp/' ,
18
+ 'HARP' : os . path . join ( PROCESSED_DATA_DIR , ' sharp' ) ,
19
+ 'TARP' : os . path . join ( PROCESSED_DATA_DIR , ' smarp' ) ,
27
20
}
28
21
DATA_DIRS = {
29
22
'HARP' : '/data2/SHARP/image/' ,
@@ -157,7 +150,8 @@ def __init__(self, cfg):
157
150
self .testmode = 'test'
158
151
159
152
def construct_transforms (self ):
160
- transforms = [utils .get_transform (name , cfg ) for name in cfg .DATA .TRANSFORMS ]
153
+ transforms = [get_transform (name , self .cfg )
154
+ for name in self .cfg .DATA .TRANSFORMS ]
161
155
self .transform = Compose (transforms )
162
156
163
157
def construct_datasets (self ):
@@ -185,8 +179,8 @@ def construct_datasets(self):
185
179
186
180
def train_dataloader (self ):
187
181
dataset = ActiveRegionDataset (self .df_train ,
188
- features = cfg .DATA .FEATURES ,
189
- num_frames = cfg .DATA .NUM_FRAMES ,
182
+ features = self . cfg .DATA .FEATURES ,
183
+ num_frames = self . cfg .DATA .NUM_FRAMES ,
190
184
transform = self .transform )
191
185
#sampler = RandomSampler(dataset, len(dataset) // 2)
192
186
loader = DataLoader (dataset ,
@@ -199,8 +193,8 @@ def train_dataloader(self):
199
193
200
194
def val_dataloader (self ):
201
195
dataset = ActiveRegionDataset (self .df_val ,
202
- features = cfg .DATA .FEATURES ,
203
- num_frames = cfg .DATA .NUM_FRAMES ,
196
+ features = self . cfg .DATA .FEATURES ,
197
+ num_frames = self . cfg .DATA .NUM_FRAMES ,
204
198
transform = self .transform )
205
199
loader = DataLoader (dataset ,
206
200
batch_size = self .cfg .DATA .BATCH_SIZE ,
@@ -211,26 +205,26 @@ def val_dataloader(self):
211
205
def test_dataloader (self ):
212
206
if self .testmode == 'test' :
213
207
dataset = ActiveRegionDataset (self .df_test ,
214
- features = cfg .DATA .FEATURES ,
215
- num_frames = cfg .DATA .NUM_FRAMES ,
208
+ features = self . cfg .DATA .FEATURES ,
209
+ num_frames = self . cfg .DATA .NUM_FRAMES ,
216
210
transform = self .transform )
217
211
loader = DataLoader (dataset ,
218
212
batch_size = self .cfg .DATA .BATCH_SIZE ,
219
213
num_workers = self .cfg .DATA .NUM_WORKERS ,
220
214
pin_memory = True )
221
215
elif self .testmode == 'visualize_predictions' :
222
216
dataset = ActiveRegionDataset (self .df_vis ,
223
- features = cfg .DATA .FEATURES ,
224
- num_frames = cfg .DATA .NUM_FRAMES ,
217
+ features = self . cfg .DATA .FEATURES ,
218
+ num_frames = self . cfg .DATA .NUM_FRAMES ,
225
219
transform = self .transform )
226
220
loader = DataLoader (dataset ,
227
221
batch_size = self .cfg .DATA .BATCH_SIZE ,
228
222
num_workers = 0 ,
229
223
pin_memory = False )
230
224
elif self .testmode == 'visualize_features' :
231
225
dataset = ActiveRegionDataset (self .df_vis ,
232
- features = cfg .DATA .FEATURES ,
233
- num_frames = cfg .DATA .NUM_FRAMES ,
226
+ features = self . cfg .DATA .FEATURES ,
227
+ num_frames = self . cfg .DATA .NUM_FRAMES ,
234
228
transform = self .transform )
235
229
loader = DataLoader (dataset ,
236
230
batch_size = 1 ,
@@ -241,74 +235,3 @@ def test_dataloader(self):
241
235
return loader
242
236
243
237
244
- def main (args ):
245
- """Perform training, testing, and/or visualization"""
246
- if args .smoke :
247
- args .opts .extend ([
248
- 'TRAINER.limit_train_batches' , '10' ,
249
- 'TRAINER.limit_val_batches' , '2' ,
250
- 'TRAINER.limit_test_batches' , '2' ,
251
- 'TRAINER.max_epochs' , '1' ,
252
- 'TRAINER.default_root_dir' , 'lightning_logs_c3d_dev'
253
- ])
254
- experiment_name = 'c3d_smoke'
255
- else :
256
- experiment_name = 'c3d'
257
- mlflow .set_experiment (experiment_name )
258
-
259
- with mlflow .start_run (run_name = args .run_name ) as run :
260
- if args .config is not None :
261
- cfg .merge_from_file (args .config )
262
- cfg .merge_from_list (args .opts )
263
- # cfg.freeze()
264
- mlflow .log_params (cfg .flatten ())
265
-
266
- logger = utils .setup_logger (cfg .MISC .OUTPUT_DIR )
267
- logger .info (cfg )
268
-
269
- dm = ActiveRegionDataModule (cfg )
270
-
271
- if 'train' in args .modes :
272
- logger .info ("======== TRAIN ========" )
273
- if args .resume :
274
- cfg .LEARNER .MODEL .WEIGHTS
275
- best_model_path = train (cfg , dm )
276
- cfg .LEARNER .MODEL .WEIGHTS = best_model_path
277
-
278
- if 'test' in args .modes :
279
- logger .info ("======== TEST ========" )
280
- test (cfg , dm )
281
-
282
- #if 'visualize' in args.modes:
283
- # logger.info("======== VISUALIZE ========")
284
- # visualize(cfg, dm)
285
-
286
-
287
- if __name__ == '__main__' :
288
- parser = argparse .ArgumentParser ()
289
- parser .add_argument ('-s' , '--smoke' , action = 'store_true' ,
290
- help = 'Smoke test' )
291
- parser .add_argument ('-e' , '--experiment_name' , default = 'experiment' ,
292
- help = 'MLflow experiment name' )
293
- parser .add_argument ('-r' , '--run_name' , default = 'CNN' ,
294
- help = 'MLflow run name' )
295
- parser .add_argument ('--config' , metavar = 'FILE' ,
296
- help = "Path to a yaml formatted config file" )
297
- parser .add_argument ('--modes' , default = 'train|test|visualize' ,
298
- help = "Perform training, testing, and/or visualization" )
299
- parser .add_argument ('--resume' , default = False ,
300
- help = "Resume training. Valid only in training mode." ) #TODO
301
- parser .add_argument ('opts' , default = None , nargs = argparse .REMAINDER ,
302
- help = "Modify config options. Use dot(.) to indicate hierarchy." )
303
- args = parser .parse_args ()
304
- args .modes = args .modes .split ('|' )
305
- accepted_modes = ['train' , 'test' , 'visualize' ]
306
- if any ([m not in accepted_modes for m in args .modes ]):
307
- raise AssertionError ('Mode {} is not accepted' .format (args .modes ))
308
- if 'train' not in args .modes and 'LEARNER.MODEL.WEIGHTS' not in args .opts :
309
- raise ValueError ('LEARNER.MODEL.WEIGHTS must be specified in the absence of training mode.' )
310
-
311
- with cProfile .Profile () as p :
312
- main (args )
313
-
314
- pstats .Stats (p ).sort_stats ('cumtime' ).print_stats (50 )
0 commit comments