CANN: refactor mask handling and improve performance in FA (#15561)
* CANN(flash-attn): refactor mask handling and improve performance 1. Refactored the mask computation in Flash Attention, unified the logic without separating prefill and decode. 2. Optimized performance in non-alibi scenarios by reducing one repeat operation. 3. Updated operator management to explicitly mark unsupported cases on 310P devices and when dim is not divisible by 16. Signed-off-by: noemotiovon <757486878@qq.com> * [CANN]: fix review Signed-off-by: noemotiovon <757486878@qq.com> * [CANN]: Optimization FA BNSD to BSND Signed-off-by: noemotiovon <757486878@qq.com> --------- Signed-off-by: noemotiovon <757486878@qq.com>
This commit is contained in:
@@ -2336,7 +2336,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
#ifdef ASCEND_310P
|
||||
// Q4 && Q8 per group is not suppor on 310p device
|
||||
// Q4 && Q8 per group is not support on 310p device
|
||||
return false;
|
||||
#endif
|
||||
// only support contiguous for quantized types.
|
||||
@@ -2354,7 +2354,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
#ifdef ASCEND_310P
|
||||
// Q4 && Q8 per group is not suppor on 310p device
|
||||
// Q4 && Q8 per group is not support on 310p device
|
||||
return false;
|
||||
#endif
|
||||
// only support contiguous for quantized types.
|
||||
@@ -2505,6 +2505,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
}
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:{
|
||||
#ifdef ASCEND_310P
|
||||
// FA not support on 310p device
|
||||
return false;
|
||||
#endif
|
||||
// derived from [ggml-cuda.cu]
|
||||
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
|
||||
return false;
|
||||
@@ -2530,6 +2534,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
||||
// DeepSeek MLA
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] % 16 != 0) {
|
||||
// TODO: padding to support
|
||||
return false;
|
||||
}
|
||||
float logitSoftcap = 0.0f;
|
||||
memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
|
||||
if(logitSoftcap != 0.0f) {
|
||||
|
||||
Reference in New Issue
Block a user