CANN: Add broadcast for softmax and FA (#15208)

* refactor softmax

* fix fa

* fix mask shape

* format

* add comments

* Remove whitespace
This commit is contained in:
hipudding
2025-08-11 22:50:31 +08:00
committed by GitHub
parent cf9e5648a7
commit be48528b06
2 changed files with 216 additions and 343 deletions

View File

@@ -2391,7 +2391,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// only support F32 and F16.
return false;
}
return true;
return ggml_is_contiguous(op);
} break;
case GGML_OP_CONT: {
// TODO: support GGML_TYPE_BF16
@@ -2456,8 +2456,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// value of paddingW should be at most half of kernelW
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
}
case GGML_OP_SUM:
case GGML_OP_DUP:
return ggml_is_contiguous(op);
case GGML_OP_SUM:
case GGML_OP_IM2COL:
case GGML_OP_CONCAT:
case GGML_OP_REPEAT:
@@ -2503,9 +2504,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
if (op->src[2]) {
return false;
}
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
return true;
case GGML_OP_FLASH_ATTN_EXT:{
// derived from [ggml-cuda.cu]
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
@@ -2532,11 +2531,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// DeepSeek MLA
return false;
}
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
if (op->src[0]->ne[3] != 1) {
return false;
}
float logitSoftcap = 0.0f;
memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
if(logitSoftcap != 0.0f) {