Problem
When training a model using data stored in a table in Unity Catalog, the lineage to the upstream dataset(s) is tracked using mlflow.log_input
, which logs the input table with the MLflow run. However, lineage for the output table (containing predictions) is not tracked.
Cause
There is no built-in method in MLflow to log output tables, similar to mlflow.log_input
for input tables.
Solution
Save the output table as a CSV file and log it as an artifact. This way, you can indirectly track the lineage of the output table. You can use the following code.
import mlflow
from sklearn import datasets
from sklearn.ensemble import RandomForestRegressor
import pandas as pd
import tempfile
import os
# Load dataset
dataset = mlflow.data.load_delta(table_name="<your-catalog>.<your-schema>.<your-table-name>", version="0")
pd_df = dataset.df.toPandas()
X = pd_df.drop("species", axis=1)
y = pd_df["species"]
# Train model and log input table
with mlflow.start_run() as run:
clf = RandomForestRegressor(n_estimators=100)
clf.fit(X, y)
mlflow.log_input(dataset, "training")
# Make predictions
predictions = clf.predict(X)
pd_df["predictions"] = predictions
# Save predictions to an output table
<your-output-table-name> = "<your-catalog>.<your-schema>.<your-iris-output>"
output_df = spark.createDataFrame(pd_df)
output_df.write.format("delta").mode("overwrite").saveAsTable(<your-output-table-name>)
# Log the output table
with tempfile.TemporaryDirectory() as tmpdir:
temp_path = os.path.join(tmpdir, "predictions.csv")
pd_df.to_csv(temp_path, index=False)
# Log the temporary file as an artifact
mlflow.log_artifact(temp_path, "output_table")
print(f"Output table {<your-output-table-name>} created successfully.")
To locate the output table, navigate to the Artifacts page of the specific MLflow run.
By following the above steps, we can ensure that the output table can be traced back to the corresponding run.