CANN: refactor mask handling and improve performance in FA (#15561)

* CANN(flash-attn): refactor mask handling and improve performance

1. Refactored the mask computation in Flash Attention, unified the logic without separating prefill and decode.
2. Optimized performance in non-alibi scenarios by reducing one repeat operation.
3. Updated operator management to explicitly mark unsupported cases on 310P devices and when dim is not divisible by 16.

Signed-off-by: noemotiovon <757486878@qq.com>

* [CANN]: fix review

Signed-off-by: noemotiovon <757486878@qq.com>

* [CANN]: Optimization FA BNSD to BSND

Signed-off-by: noemotiovon <757486878@qq.com>

---------

Signed-off-by: noemotiovon <757486878@qq.com>
This commit is contained in:
Chenguang Li
2025-08-27 17:21:41 +08:00
committed by GitHub
parent 1cf123a343
commit 1e7489745a
2 changed files with 97 additions and 88 deletions

View File

@@ -2336,7 +2336,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
#ifdef ASCEND_310P
// Q4 && Q8 per group is not suppor on 310p device
// Q4 && Q8 per group is not support on 310p device
return false;
#endif
// only support contiguous for quantized types.
@@ -2354,7 +2354,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
#ifdef ASCEND_310P
// Q4 && Q8 per group is not suppor on 310p device
// Q4 && Q8 per group is not support on 310p device
return false;
#endif
// only support contiguous for quantized types.
@@ -2505,6 +2505,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
}
return true;
case GGML_OP_FLASH_ATTN_EXT:{
#ifdef ASCEND_310P
// FA not support on 310p device
return false;
#endif
// derived from [ggml-cuda.cu]
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
return false;
@@ -2530,6 +2534,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// DeepSeek MLA
return false;
}
if (op->src[0]->ne[0] % 16 != 0) {
// TODO: padding to support
return false;
}
float logitSoftcap = 0.0f;
memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
if(logitSoftcap != 0.0f) {