Skip to content

Commit

Permalink
Fix: handling for context vars. add support for disabling prophet opt…
Browse files Browse the repository at this point in the history
…ion. (#1209)

* added proper handling of context variables

* added unit test on prepare_data

* update fix for prophet enabling

* updated timestamp formats

* update redundant dtype conversion

* update notebook reference visuals

* update notebook reference visuals

* update notebook references

* add back unit test file

* clear nb outputs

* remove nb previews

* update pyproject toml file

* add nb outputs

* add nb outputs
  • Loading branch information
alxlyj authored Feb 5, 2025
1 parent 98a2261 commit 025eded
Show file tree
Hide file tree
Showing 7 changed files with 1,106 additions and 812 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
.DS_Store
.Rproj.user
.Rhistory
.pyvenv/
.venv/
robynpy.egg-info/
node_modules/
Expand Down
181 changes: 0 additions & 181 deletions python/docs/robyn/modeling/feature_engineering.md

This file was deleted.

3 changes: 2 additions & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "robynpy"
version = "0.1.1"
version = "0.1.2"
authors = [
{ name="Gufeng Zhou", email="[email protected]" },
{ name="Igor Skokan", email="[email protected]" },
Expand Down Expand Up @@ -36,6 +36,7 @@ dependencies = [
"plotly",
"nlopt",
"cmake",
"ipykernel",
]
[tool.setuptools.packages.find]
where = ["src"]
Expand Down
71 changes: 60 additions & 11 deletions python/src/robyn/modeling/feature_engineering.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,34 @@ def perform_feature_engineering(self, quiet: bool = False) -> FeaturizedMMMData:
dt_transform = self._prepare_data()
self.logger.debug(f"Prepared data shape: {dt_transform.shape}")

if any(
var in self.holidays_data.prophet_vars
for var in ["trend", "season", "holiday", "monthly", "weekday"]
):
# Check if Prophet decomposition should be performed
prophet_enabled = (
self.holidays_data is not None
and self.holidays_data.prophet_vars is not None
and len(self.holidays_data.prophet_vars) > 0
)

if prophet_enabled:
self.logger.info("Starting Prophet decomposition")
dt_transform = self._prophet_decomposition(dt_transform)
if not quiet:
self.logger.info("Prophet decomposition complete")
else:
self.logger.info(
"Prophet decomposition disabled - no prophet variables specified"
)

# Include all independent variables
all_ind_vars = (
self.holidays_data.prophet_vars
+ self.mmm_data.mmmdata_spec.context_vars
all_ind_vars = []
if prophet_enabled:
all_ind_vars.extend(self.holidays_data.prophet_vars)

all_ind_vars.extend(
self.mmm_data.mmmdata_spec.context_vars
+ self.mmm_data.mmmdata_spec.paid_media_spends
+ self.mmm_data.mmmdata_spec.organic_vars
)

self.logger.debug(f"Processing {len(all_ind_vars)} independent variables")

dt_mod = dt_transform
Expand Down Expand Up @@ -107,11 +119,48 @@ def _prepare_data(self) -> pd.DataFrame:
dt_transform = self.mmm_data.data.copy()
dt_transform["ds"] = pd.to_datetime(
dt_transform[self.mmm_data.mmmdata_spec.date_var]
).dt.strftime("%Y-%m-%d")
dt_transform["dep_var"] = dt_transform[self.mmm_data.mmmdata_spec.dep_var]
dt_transform["competitor_sales_B"] = dt_transform["competitor_sales_B"].astype(
"int64"
)
dt_transform["dep_var"] = dt_transform[self.mmm_data.mmmdata_spec.dep_var]

# Set default factor_vars if None
if self.mmm_data.mmmdata_spec.factor_vars is None:
self.mmm_data.set_default_factor_vars()
factor_vars = self.mmm_data.mmmdata_spec.factor_vars or []

# Handle factor variables conversion first
for factor_var in factor_vars:
try:
dt_transform[factor_var] = dt_transform[factor_var].astype("category")
self.logger.debug(f"Converted {factor_var} to categorical")
except Exception as e:
self.logger.warning(
f"Could not convert {factor_var} to categorical: {str(e)}"
)

# Only convert context variables that are used in numerical calculations
# i.e., those that aren't categorical/factor variables
numeric_context_vars = [
var
for var in self.mmm_data.mmmdata_spec.context_vars
if var not in factor_vars
and pd.api.types.is_object_dtype(dt_transform[var])
]

if numeric_context_vars:
self.logger.debug(
f"Converting numeric context variables: {numeric_context_vars}"
)
for var in numeric_context_vars:
try:
dt_transform[var] = pd.to_numeric(
dt_transform[var], errors="coerce"
)
self.logger.debug(
f"Converted {var} to numeric: {dt_transform[var].dtype}"
)
except Exception as e:
self.logger.warning(f"Could not convert {var} to numeric: {str(e)}")

self.logger.debug("Data preparation complete")
return dt_transform

Expand Down
Loading

0 comments on commit 025eded

Please sign in to comment.