Skip to content

Commit

Permalink
Migrate state_demand analysis script to Kaggle example notebook.
Browse files Browse the repository at this point in the history
  • Loading branch information
zaneselvans committed Dec 2, 2023
1 parent 6b2f836 commit 4293db9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 244 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ ferc_to_sqlite = "pudl.ferc_to_sqlite.cli:main"
pudl_datastore = "pudl.workspace.datastore:main"
pudl_etl = "pudl.etl.cli:pudl_etl"
pudl_setup = "pudl.workspace.setup_cli:main"
state_demand = "pudl.analysis.state_demand:main"
pudl_check_fks = "pudl.etl.check_foreign_keys:main"
# pudl_territories currently blows up memory usage to 100+ GB.
# See https://github.com/catalyst-cooperative/pudl/issues/1174
Expand Down
249 changes: 6 additions & 243 deletions src/pudl/analysis/state_demand.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Predict state-level electricity demand.
"""Estimate historical hourly state-level electricity demand.
Using hourly electricity demand reported at the balancing authority and utility level in
the FERC 714, and service territories for utilities and balancing autorities inferred
Expand All @@ -15,22 +15,15 @@
manual and could certainly be improved, but overall the results seem reasonable.
Additional predictive spatial variables will be required to obtain more granular
electricity demand estimates (e.g. at the county level).
Currently the script takes no arguments and simply runs a predefined analysis across all
states and all years for which both EIA 861 and FERC 714 data are available, and outputs
the results as a CSV in PUDL_DIR/local/state-demand/demand.csv
"""
import argparse
import datetime
import sys
from collections.abc import Iterable
from typing import Any

import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from dagster import AssetKey, AssetOut, Field, asset, multi_asset
from dagster import AssetOut, Field, asset, multi_asset

import pudl.analysis.timeseries_cleaning
import pudl.logging_helpers
Expand Down Expand Up @@ -592,17 +585,17 @@ def predicted_state_hourly_demand(
Args:
imputed_hourly_demand_ferc714: Hourly demand timeseries, with columns
`respondent_id_ferc714`, report `year`, `utc_datetime`, and `demand_mwh`.
``respondent_id_ferc714``, report ``year``, ``utc_datetime``, and
``demand_mwh``.
county_censusdp1: The county layer of the Census DP1 shapefile.
fipsified_respondents_ferc714: Annual respondents with the county FIPS IDs
for their service territories.
sales_eia861: EIA 861 sales data. If provided, the predicted hourly demand is
scaled to match these totals.
Returns:
Dataframe with columns
`state_id_fips`, `utc_datetime`, `demand_mwh`, and
(if `state_totals` was provided) `scaled_demand_mwh`.
Dataframe with columns ``state_id_fips``, ``utc_datetime``, ``demand_mwh``, and
(if ``state_totals`` was provided) ``scaled_demand_mwh``.
"""
# Get config
mean_overlaps = context.op_config["mean_overlaps"]
Expand Down Expand Up @@ -665,233 +658,3 @@ def predicted_state_hourly_demand(
# Sum demand by state by matching UTC time
fields = [x for x in ["demand_mwh", "scaled_demand_mwh"] if x in df]
return df.groupby(["state_id_fips", "utc_datetime"], as_index=False)[fields].sum()


def plot_demand_timeseries(
a: pd.DataFrame,
b: pd.DataFrame = None,
window: int = 168,
title: str = None,
path: str = None,
) -> None:
"""Make a timeseries plot of predicted and reference demand.
Args:
a: Predicted demand with columns `utc_datetime` and any of
`demand_mwh` (in grey) and `scaled_demand_mwh` (in orange).
b: Reference demand with columns `utc_datetime` and `demand_mwh` (in red).
window: Width of window (in rows) to use to compute rolling means,
or `None` to plot raw values.
title: Plot title.
path: Plot path. If provided, the figure is saved to file and closed.
"""
plt.figure(figsize=(16, 8))
# Plot predicted
for field, color in [("demand_mwh", "grey"), ("scaled_demand_mwh", "orange")]:
if field not in a:
continue
y = a[field]
if window:
y = y.rolling(window).mean()
plt.plot(
a["utc_datetime"], y, color=color, alpha=0.5, label=f"Predicted ({field})"
)
# Plot expected
if b is not None:
y = b["demand_mwh"]
if window:
y = y.rolling(window).mean()
plt.plot(
b["utc_datetime"], y, color="red", alpha=0.5, label="Reference (demand_mwh)"
)
if title:
plt.title(title)
plt.ylabel("Demand (MWh)")
plt.legend()
if path:
plt.savefig(path, bbox_inches="tight")
plt.close()


def plot_demand_scatter(
a: pd.DataFrame,
b: pd.DataFrame,
title: str = None,
path: str = None,
) -> None:
"""Make a scatter plot comparing predicted and reference demand.
Args:
a: Predicted demand with columns `utc_datetime` and any of
`demand_mwh` (in grey) and `scaled_demand_mwh` (in orange).
b: Reference demand with columns `utc_datetime` and `demand_mwh`.
Every element in `utc_datetime` must match the one in `a`.
title: Plot title.
path: Plot path. If provided, the figure is saved to file and closed.
Raises:
ValueError: Datetime columns do not match.
"""
if not a["utc_datetime"].equals(b["utc_datetime"]):
raise ValueError("Datetime columns do not match")
plt.figure(figsize=(8, 8))
plt.gca().set_aspect("equal")
plt.axline((0, 0), (1, 1), linestyle=":", color="grey")
for field, color in [("demand_mwh", "grey"), ("scaled_demand_mwh", "orange")]:
if field not in a:
continue
plt.scatter(
b["demand_mwh"],
a[field],
c=color,
s=0.1,
alpha=0.5,
label=f"Prediction ({field})",
)
if title:
plt.title(title)
plt.xlabel("Reference (MWh)")
plt.ylabel("Predicted (MWh)")
plt.legend()
if path:
plt.savefig(path, bbox_inches="tight")
plt.close()


def compare_state_demand(
a: pd.DataFrame, b: pd.DataFrame, scaled: bool = True
) -> pd.DataFrame:
"""Compute statistics comparing predicted and reference demand.
Statistics are computed for each year.
Args:
a: Predicted demand with columns `utc_datetime` and either
`demand_mwh` (if `scaled=False) or `scaled_demand_mwh` (if `scaled=True`).
b: Reference demand with columns `utc_datetime` and `demand_mwh`.
Every element in `utc_datetime` must match the one in `a`.
Returns:
Dataframe with columns `year`,
`rmse` (root mean square error), and `mae` (mean absolute error).
Raises:
ValueError: Datetime columns do not match.
"""
if not a["utc_datetime"].equals(b["utc_datetime"]):
raise ValueError("Datetime columns do not match")
field = "scaled_demand_mwh" if scaled else "demand_mwh"
df = pd.DataFrame(
{
"year": a["utc_datetime"].dt.year,
"diff": a[field] - b["demand_mwh"],
}
)
return df.groupby(["year"], as_index=False)["diff"].agg(
{
"rmse": lambda x: np.sqrt(np.sum(x**2) / x.size),
"mae": lambda x: np.sum(np.abs(x)) / x.size,
}
)


# --- Parse Command Line Args --- #
def parse_command_line(argv):
"""Skeletal command line argument parser to provide a help message."""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--logfile",
default=None,
type=str,
help="If specified, write logs to this file.",
)
parser.add_argument(
"--loglevel",
help="Set logging level (DEBUG, INFO, WARNING, ERROR, or CRITICAL).",
default="INFO",
)
return parser.parse_args(argv[1:])


# --- Example usage --- #


def main():
"""Predict state demand."""
# --- Parse command line args --- #
args = parse_command_line(sys.argv)

# --- Connect to PUDL logger --- #
pudl.logging_helpers.configure_root_logger(
logfile=args.logfile, loglevel=args.loglevel
)

# --- Connect to PUDL database --- #

# --- Read in inputs from PUDL + dagster cache --- #
prediction = pudl.etl.defs.load_asset_value(
AssetKey("predicted_state_hourly_demand")
)

# --- Export results --- #

local_dir = pudl.workspace.setup.PudlPaths().data_dir / "local"
ventyx_path = local_dir / "ventyx/state_level_load_2007_2018.csv"
base_dir = local_dir / "state-demand"
base_dir.mkdir(parents=True, exist_ok=True)
demand_path = base_dir / "demand.csv"
stats_path = base_dir / "demand-stats.csv"
timeseries_dir = base_dir / "timeseries"
timeseries_dir.mkdir(parents=True, exist_ok=True)
scatter_dir = base_dir / "scatter"
scatter_dir.mkdir(parents=True, exist_ok=True)

# Write predicted hourly state demand
prediction.to_csv(
demand_path, index=False, date_format="%Y%m%dT%H", float_format="%.1f"
)

# Load Ventyx as reference if available
reference = None
if ventyx_path.exists():
reference = load_ventyx_hourly_state_demand(ventyx_path)

# Plots and statistics
stats = []
for fips in prediction["state_id_fips"].unique():
state = lookup_state(fips)
# Filter demand by state
a = prediction.query(f"state_id_fips == '{fips}'")
b = None
title = f'{state["fips"]}: {state["name"]} ({state["code"]})'
plot_name = f'{state["fips"]}-{state["name"]}.png'
if reference is not None:
b = reference.query(f"state_id_fips == '{fips}'")
# Save timeseries plot
plot_demand_timeseries(
a, b=b, window=168, title=title, path=timeseries_dir / plot_name
)
if b is None or b.empty:
continue
# Align predicted and reference demand
a = a.set_index("utc_datetime")
b = b.set_index("utc_datetime")
index = a.index.intersection(b.index)
a = a.loc[index].reset_index()
b = b.loc[index].reset_index()
# Compute statistics
stat = compare_state_demand(a, b, scaled=True)
stat["state_id_fips"] = fips
stats.append(stat)
# Save scatter plot
plot_demand_scatter(a, b=b, title=title, path=scatter_dir / plot_name)

# Write statistics
if reference is not None:
pd.concat(stats, ignore_index=True).to_csv(
stats_path, index=False, float_format="%.1f"
)


if __name__ == "__main__":
sys.exit(main())

0 comments on commit 4293db9

Please sign in to comment.