From ddee5fcca623e787363c6eed2ba1d830a72221e9 Mon Sep 17 00:00:00 2001 From: kilianp14 Date: Tue, 30 Jul 2024 16:38:25 +0200 Subject: [PATCH] Fixed bug in HistoricalSignal --- tests/test_signal.py | 54 +++++++++++++++++++++++++++++++++++--------- vessim/signal.py | 14 +++++++++--- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/tests/test_signal.py b/tests/test_signal.py index 6f5022d..590d76b 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -185,24 +185,27 @@ def test_forecast(self, hist_signal_forecast, start, end, column, expected): ), ( "2023-01-01T01:00:00", - "2023-01-01T03:00:00", + "2023-01-01T04:00:00", "a", timedelta(minutes=45), "ffill", { np.datetime64("2023-01-01T01:45:00.000000000"): 2.0, # type: ignore np.datetime64("2023-01-01T02:30:00.000000000"): 3.0, # type: ignore + np.datetime64("2023-01-01T03:15:00.000000000"): 1.5, # type: ignore + np.datetime64("2023-01-01T04:00:00.000000000"): 1.5, # type: ignore }, ), ( "2023-01-01T00:00:00", - "2023-01-01T03:00:00", + "2023-01-01T05:00:00", "a", timedelta(hours=1, minutes=30), "linear", { np.datetime64("2023-01-01T01:30:00.000000000"): 2.0, # type: ignore np.datetime64("2023-01-01T03:00:00.000000000"): 2.0, # type: ignore + np.datetime64("2023-01-01T04:30:00.000000000"): 2.0, # type: ignore }, ), ( @@ -252,21 +255,50 @@ def test_forecast(self, hist_signal_forecast, start, end, column, expected): np.datetime64("2023-01-01T00:55:00.000000000"): 2.5, # type: ignore }, ), + ( + "2023-01-01T01:00:00", + "2023-01-01T04:00:00", + "b", + "1H", + "bfill", + { + np.datetime64("2023-01-01T02:00:00.000000000"): 2.5, # type: ignore + np.datetime64("2023-01-01T03:00:00.000000000"): 1.5, # type: ignore + np.datetime64("2023-01-01T04:00:00.000000000"): np.nan, # type: ignore + }, + ), + ( + "2023-01-01T01:00:00", + "2023-01-01T04:00:00", + "b", + "1H", + "nearest", + { + np.datetime64("2023-01-01T02:00:00.000000000"): 2.5, # type: ignore + np.datetime64("2023-01-01T03:00:00.000000000"): 1.5, # type: ignore + np.datetime64("2023-01-01T04:00:00.000000000"): 1.5, # type: ignore + }, + ), ], ) def test_forecast_with_frequency( self, hist_signal_forecast, start, end, column, frequency, method, expected ): - assert ( - hist_signal_forecast.forecast( - start, - end, - column=column, - frequency=frequency, - resample_method=method, - ) - == expected + forecast = hist_signal_forecast.forecast( + start, + end, + column=column, + frequency=frequency, + resample_method=method, ) + # Complicated because np.nan == np.nan is False + assert forecast.keys() == expected.keys() + assert all( + np.isnan(expected[k]) if np.isnan(forecast[k]) + else forecast[k] == expected[k] + for k in forecast.keys() + ) + def test_forecast_fails_if_column_not_specified(self, hist_signal): with pytest.raises(ValueError): diff --git a/vessim/signal.py b/vessim/signal.py index e021cfb..68e932c 100644 --- a/vessim/signal.py +++ b/vessim/signal.py @@ -343,7 +343,17 @@ def _resample_to_frequency( ) new_times_indices = np.searchsorted(times, new_times, side="left") - if not np.array_equal(new_times, times[new_times_indices]) and resample_method != "bfill": + if np.all(new_times_indices < times.size) and np.array_equal( + new_times, times[new_times_indices] + ): + # No resampling necessary + new_data = data[new_times_indices] + elif resample_method == "bfill": + # Perform backward-fill whereas values outside range are filled with NaN + new_data = np.full(new_times_indices.shape, np.nan) + valid_mask = new_times_indices < len(data) + new_data[valid_mask] = data[new_times_indices[valid_mask]] + else: # Actual value is used for interpolation times = np.insert(times, 0, start_time) data = np.insert(data, 0, self.now(start_time, column)) @@ -361,8 +371,6 @@ def _resample_to_frequency( raise ValueError(f"Unknown resample_method '{resample_method}'.") else: raise ValueError(f"Not enough data at frequency '{freq}' without resampling.") - else: - new_data = data[new_times_indices] return dict(zip(new_times, new_data))