Skip to content

Commit 7a62aab

Browse files
committed
Pick colab:e42a4f~..20716d: backwd comptb py3.7.12
1 parent 67b2ec6 commit 7a62aab

File tree

5 files changed

+28
-15
lines changed

5 files changed

+28
-15
lines changed

arnet/constants.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
FEATURES = ['AREA', 'USFLUXL', 'MEANGBL', 'R_VALUE', 'FLARE_INDEX']
99
#TODO: constants calculated on training set
10-
PROCESSED_DATA_DIR = '/home/zeyusun/work/flare-prediction-smarp/datasets/M_Q_24hr'
10+
PROCESSED_DATA_DIR = 'datasets/M_Q_24hr'
1111

12-
@lru_cache
12+
@lru_cache(8)
1313
def get_constants():
1414
CONSTANTS = {}
1515
for dataset in ['sharp', 'smarp']:

arnet/dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __getitem__(self, idx):
9696
t_end = datetime.strptime(s['t_end'], '%Y-%m-%d %H:%M:%S').strftime('%Y%m%d%H%M%S')
9797
largest_flare = max(s['flares'].split('|')) #WARNING: X10+
9898
meta = f'{idx}_{s["prefix"]}{s["arpnum"]:06d}_{t_end}_H0_W0_{largest_flare}.npy'
99-
return *data_list, label, meta
99+
return (*data_list, label, meta)
100100

101101
def load_video(self, prefix, arpnum, t_now, bad_img_idx):
102102
t_now = drms.to_datetime(t_now)

download.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from tqdm import tqdm, trange
99
import drms
1010
from sunpy.time import TimeRange
11-
from sunpy.instr.goes import get_goes_event_list
12-
11+
from sunpy.net import Fido
12+
from sunpy.net import attrs as a
1313

1414
######### Change these ##########
1515
@@ -139,14 +139,27 @@ def download_smarp_images(tarpnum):
139139

140140

141141
def download_goes_per_year(year):
142+
print(year)
142143
t_start = datetime(year=year, month=1, day=1)
143144
t_end = datetime(year=year+1, month=1, day=1)
144-
timerange = TimeRange(t_start, t_end)
145-
event_list = get_goes_event_list(timerange)
146-
if len(event_list) == 0:
145+
results = Fido.search(
146+
a.Time(t_start, t_end),
147+
a.hek.EventType("FL"),
148+
# a.hek.FL.GOESCls > "M1.0",
149+
a.hek.OBS.Observatory == "GOES"
150+
)
151+
if not results.all_colnames: # no columns / no results
147152
return None
148153

149-
event_df = pd.DataFrame(event_list)
154+
event_table = results['hek']["event_starttime", "event_peaktime", "event_endtime", "fl_goescls", "ar_noaanum"]
155+
event_df = event_table.to_pandas().rename(columns={
156+
'event_starttime': 'start_time',
157+
'event_peaktime': 'peak_time',
158+
'event_endtime': 'end_time',
159+
'fl_goescls': 'goes_class',
160+
#'hgc_coord': 'goes_location',
161+
'ar_noaanum': 'noaa_active_region',
162+
})
150163
event_df = event_df[event_df['noaa_active_region'] != 0]
151164
if len(event_df) == 0:
152165
return None

requirements.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ tomli==1.2.2
108108
torch==1.10.0
109109
torchinfo==1.5.3
110110
torchmetrics==0.6.0
111-
torchvision==0.11.1
111+
torchvision==0.11.2
112112
tqdm==4.62.2
113113
traitlets==5.0.5
114114
typing-extensions==3.10.0.1
@@ -121,3 +121,5 @@ yacs==0.1.8
121121
yarl==1.6.3
122122
zipp==3.6.0
123123
scikit-image==0.19.1
124+
sunpy[net]==3.1.2 # stable as of Dec 18, 2021
125+
torchinfo==1.5.3

run_arnet.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,9 @@ def main():
126126
'TRAINER.default_root_dir', 'lightning_logs_dev'
127127
])
128128

129-
with cProfile.Profile() as p:
130-
mlflow.set_experiment(experiment_name=args.experiment_name)
131-
with mlflow.start_run(run_name=args.run_name) as run:
132-
launch(args.config, args.modes, args.resume, args.opts)
133-
pstats.Stats(p).sort_stats('cumtime').print_stats(50)
129+
mlflow.set_experiment(experiment_name=args.experiment_name)
130+
with mlflow.start_run(run_name=args.run_name) as run:
131+
launch(args.config, args.modes, args.resume, args.opts)
134132

135133

136134
def sweep():

0 commit comments

Comments
 (0)