Skip to content

Batch Predictions

Prerequisites

This page uses the client and wait_for_task helpers defined in API Overview.

Generate predictions for an entire observation table.

Create a Prediction Table

client = fb.Configurations().get_client()

response = client.post(
    f"/ml_model/{ml_model_id}/prediction_table",
    json={
        "request_input": {
            "request_type": "observation_table",
            "table_id": observation_table_id,
        },
        "include_input_features": False,
    },
)

task_id = response.json()["id"]
task = wait_for_task(client, task_id)

# Get the prediction table ID from the completed task
prediction_table_id = task.get("payload", {}).get("output_document_id")

Parameters:

Parameter Type Required Description
request_input object Yes Specifies the input data for predictions
request_input.request_type string Yes Input type: "observation_table"
request_input.table_id string Yes ID of the observation table to predict on
include_input_features boolean No Include computed feature values in the output (default: false)
include_shap_values boolean No Include SHAP values for each prediction (default: false)
normalize_shap_values boolean No Normalize SHAP values (default: false, only applies when include_shap_values is true)

Check for Existing Predictions

Before creating new predictions, check if they already exist:

response = client.get(
    "/catalog/prediction_table",
    params={"ml_model_id": ml_model_id, "page_size": 100},
)

for pt in response.json().get("data", []):
    request_input = pt.get("request_input", {})
    if request_input.get("table_id") == observation_table_id:
        prediction_table_id = pt["_id"]
        break

Download Predictions

Download the prediction table as a Parquet file:

import io
import pyarrow.parquet as pq

response = client.get(
    f"/prediction_table/parquet/{prediction_table_id}",
    stream=True,
)

buffer = io.BytesIO()
for chunk in response.iter_content(chunk_size=8192):
    if chunk:
        buffer.write(chunk)
buffer.seek(0)

predictions_df = pq.read_table(buffer).to_pandas()

Downloaded DataFrame columns:

Column Description
Entity columns (e.g., item_store_id) Entity identifiers from the observation table
POINT_IN_TIME The point in time for the prediction
FORECAST_POINT The forecast point (forecast use cases only)
prediction The model's predicted value
Target column (e.g., sales) Actual target value, if the observation table has a target