Skip to content

Commit

Permalink
fix(sourcedata.py::TransientArraySourceData): refactor TransientSourc…
Browse files Browse the repository at this point in the history
…eDataMixin to have a general stress_period_mapping property to handle cases with and without a parent model (previously expected a parent model)
  • Loading branch information
aleaf committed Feb 14, 2024
1 parent 6fadcf0 commit 6f434e9
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions mfsetup/sourcedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def __init__(self, period_stats, dest_model):
self._period_stats = None

# attributes
self.dest_model = dest_model
self.perioddata = dest_model.perioddata.sort_values(by='per').reset_index(drop=True)

@property
Expand All @@ -192,6 +193,21 @@ def period_stats(self):
self._period_stats = self.get_period_stats()
return self._period_stats

@property
def stress_period_mapping(self):
# if there is a parent/source model,
# get the mapping between the parent model and
# inset/destination model stress periods {inset_kper: parent_kper}
if self.dest_model.parent is not None:
# for now, just assume one-to-one correspondance
# between source and dest model stress periods
return self.dest_model.parent_stress_periods
# otherwise, just return a dictionary of the same
# key, value pairs for consistency
# with logic of subclass get_data() methods
else:
return dict(zip(self.perioddata['per'], self.perioddata['per']))

def get_period_stats(self):
"""Populate each stress period with period_stat information
for temporal resampling (tdis.aggregate_dataframe_to_stress_period and
Expand Down Expand Up @@ -576,7 +592,6 @@ def __init__(self, filenames, variable, period_stats=None,

self.variable = variable
self.resample_method = resample_method
self.dest_model = dest_model

def get_data(self):

Expand All @@ -599,11 +614,9 @@ def get_data(self):
# would follow logic of netcdf files, but trickier because steady-state periods need to be handled
#da = transient2d_to_xarray(data, time)

# for now, just assume one-to-one correspondance
# between source and dest model stress periods
results = {}
for inset_kper, parent_kper in self.dest_model.parent_stress_periods.items():
data = source_data[parent_kper].copy()
for dest_kper, source_kper in self.stress_period_mapping.items():
data = source_data[source_kper].copy()
if regrid:
# sample the data onto the model grid
resampled = self.regrid_from_source_model(data, method=self.resample_method)
Expand All @@ -612,7 +625,7 @@ def get_data(self):
# reshape results to model grid
period_mean2d = resampled.reshape(self.dest_model.nrow,
self.dest_model.ncol)
results[inset_kper] = period_mean2d
results[dest_kper] = period_mean2d
self.data = results
return results

Expand Down

0 comments on commit 6f434e9

Please sign in to comment.