Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed bug in HistoricalSignal #232

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 43 additions & 11 deletions tests/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,24 +185,27 @@ def test_forecast(self, hist_signal_forecast, start, end, column, expected):
),
(
"2023-01-01T01:00:00",
"2023-01-01T03:00:00",
"2023-01-01T04:00:00",
"a",
timedelta(minutes=45),
"ffill",
{
np.datetime64("2023-01-01T01:45:00.000000000"): 2.0, # type: ignore
np.datetime64("2023-01-01T02:30:00.000000000"): 3.0, # type: ignore
np.datetime64("2023-01-01T03:15:00.000000000"): 1.5, # type: ignore
np.datetime64("2023-01-01T04:00:00.000000000"): 1.5, # type: ignore
},
),
(
"2023-01-01T00:00:00",
"2023-01-01T03:00:00",
"2023-01-01T05:00:00",
"a",
timedelta(hours=1, minutes=30),
"linear",
{
np.datetime64("2023-01-01T01:30:00.000000000"): 2.0, # type: ignore
np.datetime64("2023-01-01T03:00:00.000000000"): 2.0, # type: ignore
np.datetime64("2023-01-01T04:30:00.000000000"): 2.0, # type: ignore
},
),
(
Expand Down Expand Up @@ -252,21 +255,50 @@ def test_forecast(self, hist_signal_forecast, start, end, column, expected):
np.datetime64("2023-01-01T00:55:00.000000000"): 2.5, # type: ignore
},
),
(
"2023-01-01T01:00:00",
"2023-01-01T04:00:00",
"b",
"1H",
"bfill",
{
np.datetime64("2023-01-01T02:00:00.000000000"): 2.5, # type: ignore
np.datetime64("2023-01-01T03:00:00.000000000"): 1.5, # type: ignore
np.datetime64("2023-01-01T04:00:00.000000000"): np.nan, # type: ignore
},
),
(
"2023-01-01T01:00:00",
"2023-01-01T04:00:00",
"b",
"1H",
"nearest",
{
np.datetime64("2023-01-01T02:00:00.000000000"): 2.5, # type: ignore
np.datetime64("2023-01-01T03:00:00.000000000"): 1.5, # type: ignore
np.datetime64("2023-01-01T04:00:00.000000000"): 1.5, # type: ignore
},
),
],
)
def test_forecast_with_frequency(
self, hist_signal_forecast, start, end, column, frequency, method, expected
):
assert (
hist_signal_forecast.forecast(
start,
end,
column=column,
frequency=frequency,
resample_method=method,
)
== expected
forecast = hist_signal_forecast.forecast(
start,
end,
column=column,
frequency=frequency,
resample_method=method,
)
# Complicated because np.nan == np.nan is False
assert forecast.keys() == expected.keys()
assert all(
np.isnan(expected[k]) if np.isnan(forecast[k])
else forecast[k] == expected[k]
for k in forecast.keys()
)


def test_forecast_fails_if_column_not_specified(self, hist_signal):
with pytest.raises(ValueError):
Expand Down
14 changes: 11 additions & 3 deletions vessim/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,17 @@ def _resample_to_frequency(
)

new_times_indices = np.searchsorted(times, new_times, side="left")
if not np.array_equal(new_times, times[new_times_indices]) and resample_method != "bfill":
if np.all(new_times_indices < times.size) and np.array_equal(
new_times, times[new_times_indices]
):
# No resampling necessary
new_data = data[new_times_indices]
elif resample_method == "bfill":
# Perform backward-fill whereas values outside range are filled with NaN
new_data = np.full(new_times_indices.shape, np.nan)
valid_mask = new_times_indices < len(data)
new_data[valid_mask] = data[new_times_indices[valid_mask]]
else:
# Actual value is used for interpolation
times = np.insert(times, 0, start_time)
data = np.insert(data, 0, self.now(start_time, column))
Expand All @@ -361,8 +371,6 @@ def _resample_to_frequency(
raise ValueError(f"Unknown resample_method '{resample_method}'.")
else:
raise ValueError(f"Not enough data at frequency '{freq}' without resampling.")
else:
new_data = data[new_times_indices]

return dict(zip(new_times, new_data))

Expand Down
Loading