Batch Predictions¶
See also
Forecast UI Tutorial: Predict and Evaluate | Concepts: Predictions | API Tutorial: Credit Default — Step 12 | API Tutorial: Store Sales Forecast — Step 7
Prerequisites
This page uses the client and wait_for_task helpers defined in API Overview.
Generate predictions for an entire observation table.
Create a Prediction Table¶
client = fb.Configurations().get_client()
response = client.post(
f"/ml_model/{ml_model_id}/prediction_table",
json={
"request_input": {
"request_type": "observation_table",
"table_id": observation_table_id,
},
"include_input_features": False,
},
)
task_id = response.json()["id"]
task = wait_for_task(client, task_id)
# Get the prediction table ID from the completed task
prediction_table_id = task.get("payload", {}).get("output_document_id")
Parameters:
| Parameter | Type | Required | Description |
|---|---|---|---|
request_input |
object | Yes | Specifies the input data for predictions |
request_input.request_type |
string | Yes | Input type: "observation_table" |
request_input.table_id |
string | Yes | ID of the observation table to predict on |
include_input_features |
boolean | No | Include computed feature values in the output (default: false) |
include_shap_values |
boolean | No | Include SHAP values for each prediction (default: false) |
normalize_shap_values |
boolean | No | Normalize SHAP values (default: false, only applies when include_shap_values is true) |
Check for Existing Predictions¶
Before creating new predictions, check if they already exist:
response = client.get(
"/catalog/prediction_table",
params={"ml_model_id": ml_model_id, "page_size": 100},
)
for pt in response.json().get("data", []):
request_input = pt.get("request_input", {})
if request_input.get("table_id") == observation_table_id:
prediction_table_id = pt["_id"]
break
Download Predictions¶
Download the prediction table as a Parquet file:
import io
import pyarrow.parquet as pq
response = client.get(
f"/prediction_table/parquet/{prediction_table_id}",
stream=True,
)
buffer = io.BytesIO()
for chunk in response.iter_content(chunk_size=8192):
if chunk:
buffer.write(chunk)
buffer.seek(0)
predictions_df = pq.read_table(buffer).to_pandas()
Downloaded DataFrame columns:
| Column | Description |
|---|---|
Entity columns (e.g., item_store_id) |
Entity identifiers from the observation table |
POINT_IN_TIME |
The point in time for the prediction |
FORECAST_POINT |
The forecast point (forecast use cases only) |
prediction |
The model's predicted value |
Target column (e.g., sales) |
Actual target value, if the observation table has a target |