Skip to content

Commit 6decfe1

Browse files
committed
introduce arnet, cleaning
1 parent 57cd8e7 commit 6decfe1

40 files changed

+1965
-205791
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
mlruns/
44
archive/
55
download/
6+
datasets/
67
outputs/
78
outputs_smoke/
9+
lightning_logs*
810
logs/
911
# output_* will mess up the working directory
1012

config.py arnet/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#cfg.LEARNER.CLASS_WEIGHT = 1 # Positive/Negative weight ratio
3636
cfg.LEARNER.LOSS_PN_RATIO = 1
3737
cfg.LEARNER.LEARNING_RATE = 1e-4 # Don't change it here!
38+
cfg.LEARNER.CHECKPOINT = "" # path to checkpoint file to read from
3839
### For visualization
3940
cfg.LEARNER.VIS = CN()
4041
cfg.LEARNER.VIS.GRADCAM_LAYERS = []#['convs.conv3']
@@ -43,7 +44,6 @@
4344

4445
cfg.LEARNER.MODEL = CN()
4546
cfg.LEARNER.MODEL.NAME = 'SimpleLSTM'
46-
cfg.LEARNER.MODEL.WEIGHTS = "" # path to checkpoint file to read from
4747
cfg.LEARNER.MODEL.SETTINGS = 'cnn'
4848

4949
cfg.TRAINER = CN()

constants.py arnet/constants.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66

77

88
FEATURES = ['AREA', 'USFLUX', 'MEANGBZ', 'R_VALUE', 'FLARE_INDEX']
9+
PROCESSED_DATA_DIR = '/home/zeyusun/work/flare-prediction-smarp/datasets/M1.0_24hr_balanced'
910

1011
@lru_cache
1112
def get_constants():
1213
CONSTANTS = {}
1314
for dataset in ['sharp', 'smarp']:
14-
filepath = os.path.join('datasets', dataset, 'train.csv')
15+
filepath = os.path.join(PROCESSED_DATA_DIR, dataset, 'train.csv')
1516
df = pd.read_csv(filepath)
1617

1718
CONSTANTS[dataset.upper() + '_MEAN'] = df[FEATURES].mean().to_dict()

run_net.py arnet/dataset.py

+18-95
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,22 @@
11
import os
22
import functools
33
from datetime import datetime, timedelta
4-
import argparse
5-
import cProfile, pstats
64
import pandas as pd
7-
import mlflow
85
import drms
96
import torch
107
from torch.utils.data import Dataset, DataLoader
11-
from torch.nn import functional as F
128
from torchvision.transforms import Compose
139
import pytorch_lightning as pl
1410

15-
from arnet.data.datamodule import ARVideoDataModule as _ARVideoDataModule
16-
from arnet.run_net import train, test, visualize
17-
from arnet import utils
18-
from arnet import const
19-
from config import cfg
20-
from data import query, query_parameters, read_header
21-
from constants import CONSTANTS
11+
from arnet.utils import query, query_parameters, read_header
12+
from arnet.transforms import get_transform
13+
from arnet.constants import CONSTANTS
2214

2315

16+
PROCESSED_DATA_DIR = '/home/zeyusun/work/flare-prediction-smarp/datasets/M1.0_24hr_balanced'
2417
SPLIT_DIRS = {
25-
'HARP': '/home/zeyusun/work/flare-prediction-smarp/datasets/sharp/',
26-
'TARP': '/home/zeyusun/work/flare-prediction-smarp/datasets/smarp/',
18+
'HARP': os.path.join(PROCESSED_DATA_DIR, 'sharp'),
19+
'TARP': os.path.join(PROCESSED_DATA_DIR, 'smarp'),
2720
}
2821
DATA_DIRS = {
2922
'HARP': '/data2/SHARP/image/',
@@ -157,7 +150,8 @@ def __init__(self, cfg):
157150
self.testmode = 'test'
158151

159152
def construct_transforms(self):
160-
transforms = [utils.get_transform(name, cfg) for name in cfg.DATA.TRANSFORMS]
153+
transforms = [get_transform(name, self.cfg)
154+
for name in self.cfg.DATA.TRANSFORMS]
161155
self.transform = Compose(transforms)
162156

163157
def construct_datasets(self):
@@ -185,8 +179,8 @@ def construct_datasets(self):
185179

186180
def train_dataloader(self):
187181
dataset = ActiveRegionDataset(self.df_train,
188-
features=cfg.DATA.FEATURES,
189-
num_frames=cfg.DATA.NUM_FRAMES,
182+
features=self.cfg.DATA.FEATURES,
183+
num_frames=self.cfg.DATA.NUM_FRAMES,
190184
transform=self.transform)
191185
#sampler = RandomSampler(dataset, len(dataset) // 2)
192186
loader = DataLoader(dataset,
@@ -199,8 +193,8 @@ def train_dataloader(self):
199193

200194
def val_dataloader(self):
201195
dataset = ActiveRegionDataset(self.df_val,
202-
features=cfg.DATA.FEATURES,
203-
num_frames=cfg.DATA.NUM_FRAMES,
196+
features=self.cfg.DATA.FEATURES,
197+
num_frames=self.cfg.DATA.NUM_FRAMES,
204198
transform=self.transform)
205199
loader = DataLoader(dataset,
206200
batch_size=self.cfg.DATA.BATCH_SIZE,
@@ -211,26 +205,26 @@ def val_dataloader(self):
211205
def test_dataloader(self):
212206
if self.testmode == 'test':
213207
dataset = ActiveRegionDataset(self.df_test,
214-
features=cfg.DATA.FEATURES,
215-
num_frames=cfg.DATA.NUM_FRAMES,
208+
features=self.cfg.DATA.FEATURES,
209+
num_frames=self.cfg.DATA.NUM_FRAMES,
216210
transform=self.transform)
217211
loader = DataLoader(dataset,
218212
batch_size=self.cfg.DATA.BATCH_SIZE,
219213
num_workers=self.cfg.DATA.NUM_WORKERS,
220214
pin_memory=True)
221215
elif self.testmode == 'visualize_predictions':
222216
dataset = ActiveRegionDataset(self.df_vis,
223-
features=cfg.DATA.FEATURES,
224-
num_frames=cfg.DATA.NUM_FRAMES,
217+
features=self.cfg.DATA.FEATURES,
218+
num_frames=self.cfg.DATA.NUM_FRAMES,
225219
transform=self.transform)
226220
loader = DataLoader(dataset,
227221
batch_size=self.cfg.DATA.BATCH_SIZE,
228222
num_workers=0,
229223
pin_memory=False)
230224
elif self.testmode == 'visualize_features':
231225
dataset = ActiveRegionDataset(self.df_vis,
232-
features=cfg.DATA.FEATURES,
233-
num_frames=cfg.DATA.NUM_FRAMES,
226+
features=self.cfg.DATA.FEATURES,
227+
num_frames=self.cfg.DATA.NUM_FRAMES,
234228
transform=self.transform)
235229
loader = DataLoader(dataset,
236230
batch_size=1,
@@ -241,74 +235,3 @@ def test_dataloader(self):
241235
return loader
242236

243237

244-
def main(args):
245-
"""Perform training, testing, and/or visualization"""
246-
if args.smoke:
247-
args.opts.extend([
248-
'TRAINER.limit_train_batches', '10',
249-
'TRAINER.limit_val_batches', '2',
250-
'TRAINER.limit_test_batches', '2',
251-
'TRAINER.max_epochs', '1',
252-
'TRAINER.default_root_dir', 'lightning_logs_c3d_dev'
253-
])
254-
experiment_name = 'c3d_smoke'
255-
else:
256-
experiment_name = 'c3d'
257-
mlflow.set_experiment(experiment_name)
258-
259-
with mlflow.start_run(run_name=args.run_name) as run:
260-
if args.config is not None:
261-
cfg.merge_from_file(args.config)
262-
cfg.merge_from_list(args.opts)
263-
# cfg.freeze()
264-
mlflow.log_params(cfg.flatten())
265-
266-
logger = utils.setup_logger(cfg.MISC.OUTPUT_DIR)
267-
logger.info(cfg)
268-
269-
dm = ActiveRegionDataModule(cfg)
270-
271-
if 'train' in args.modes:
272-
logger.info("======== TRAIN ========")
273-
if args.resume:
274-
cfg.LEARNER.MODEL.WEIGHTS
275-
best_model_path = train(cfg, dm)
276-
cfg.LEARNER.MODEL.WEIGHTS = best_model_path
277-
278-
if 'test' in args.modes:
279-
logger.info("======== TEST ========")
280-
test(cfg, dm)
281-
282-
#if 'visualize' in args.modes:
283-
# logger.info("======== VISUALIZE ========")
284-
# visualize(cfg, dm)
285-
286-
287-
if __name__ == '__main__':
288-
parser = argparse.ArgumentParser()
289-
parser.add_argument('-s', '--smoke', action='store_true',
290-
help='Smoke test')
291-
parser.add_argument('-e', '--experiment_name', default='experiment',
292-
help='MLflow experiment name')
293-
parser.add_argument('-r', '--run_name', default='CNN',
294-
help='MLflow run name')
295-
parser.add_argument('--config', metavar='FILE',
296-
help="Path to a yaml formatted config file")
297-
parser.add_argument('--modes', default='train|test|visualize',
298-
help="Perform training, testing, and/or visualization")
299-
parser.add_argument('--resume', default=False,
300-
help="Resume training. Valid only in training mode.") #TODO
301-
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER,
302-
help="Modify config options. Use dot(.) to indicate hierarchy.")
303-
args = parser.parse_args()
304-
args.modes = args.modes.split('|')
305-
accepted_modes = ['train', 'test', 'visualize']
306-
if any([m not in accepted_modes for m in args.modes]):
307-
raise AssertionError('Mode {} is not accepted'.format(args.modes))
308-
if 'train' not in args.modes and 'LEARNER.MODEL.WEIGHTS' not in args.opts:
309-
raise ValueError('LEARNER.MODEL.WEIGHTS must be specified in the absence of training mode.')
310-
311-
with cProfile.Profile() as p:
312-
main(args)
313-
314-
pstats.Stats(p).sort_stats('cumtime').print_stats(50)

arnet/modeling/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)