-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy path06_staging_batch_inference.py
57 lines (34 loc) · 1.17 KB
/
06_staging_batch_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Databricks notebook source
# MAGIC %md
# MAGIC ## Churn Prediction Batch Inference
# MAGIC
# MAGIC <img src="https://github.com/RafiKurlansik/laughing-garbanzo/blob/main/step6.png?raw=true">
# COMMAND ----------
# MAGIC %run ./00_includes
# COMMAND ----------
# MAGIC %md
# MAGIC #### Load Model
# MAGIC
# MAGIC Loading as a Spark UDF to set us up for future scale.
# COMMAND ----------
import mlflow
model = mlflow.pyfunc.spark_udf(spark, model_uri=f"models:/{database_name}_churn/staging") # may need to replace with your own model name
# COMMAND ----------
# MAGIC %md
# MAGIC #### Load Features
# COMMAND ----------
from databricks.feature_store import FeatureStoreClient
fs = FeatureStoreClient()
features = fs.read_table(f'{database_name}.churn_features')
# COMMAND ----------
# MAGIC %md
# MAGIC #### Inference
# COMMAND ----------
predictions = features.withColumn('predictions', model(*features.columns))
display(predictions.select("customerId", "predictions"))
# COMMAND ----------
# MAGIC %md
# MAGIC #### Write to Delta Lake
# COMMAND ----------
predictions.write.format("delta").mode("append").saveAsTable(f"{database_name}.churn_preds")
# COMMAND ----------