Problem
When using the Flash Attention model from the Hugging Face library, you receive attribute errors or NoModuleFound
errors such as the following.
“AttributeError: 'CLIPTextTransformer' object has no attribute '_use_flash_attention_2'.”
Cause
Flash-attention
models are not supported with the pyfunc
flavor in MLflow because they are incompatible with PyTorch and CUDA versions that require a custom version.
MLflow’s Model Serving infrastructure also does not accommodate using flash-attention
in pyfunc
models.
Solution
When logging a model that requires flash-attention
use mlflow.transformers.log_model
with a custom wheel version of flash-attn
.
Specify all pip requirements as a list and pass the list as a parameter into the mlflow.transformers.log_model
function call.
Make sure to indicate the versions of pytorch
, torch
, and torchvision
which are compatible with the CUDA version you specify in your flash-attention
wheel.
Databricks recommends using the following versions and wheels.
- Pytorch (index page with file download links)
- Torch 2.0.1+cu118
- Torchvision 0.15.2+cu118
- Flash-attention (Github
.whl
download)
Example
mlflow.transformers.log_model(
f"{model_name}",
registered_model_name=f"{model_name}",
extra_pip_requirements=[
"git+https://github.com/huggingface/diffusers@v0.22.1",
"peft==0.12.0",
"compel==2.0.3",
"boto3==1.34.39",
"transformers==4.39.2",
"--extra-index-url https://download.pytorch.org/whl/cu118",
"torch==2.0.1+cu118",
"torchvision==0.15.2+cu118", "https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.8/flash_attn-2.5.8+cu118torch2.0cxx11abiFALSE-cp311-cp311-linux_x86_64.whl"
]
)
Additional context
When using custom GPU serving, Databricks Model Serving will first resolve the version of PyTorch or Tensorflow that your model uses, then install a compatible CUDA version for the PyTorch or Tensorflow version it detects.
Note
For PyTorch, Databricks relies on the torch pip
package to determine the compatible CUDA version, and for Tensorflow, Databricks determines a compatible version based on the GPU section of Tensorflow’s Build from source documentation.