Ascend attention backend(PA&MLA) (#7722)
Co-authored-by: Maksim <makcum888e@mail.ru> Co-authored-by: VDV1985 <vladdv85@mail.ru>
This commit is contained in:
@@ -1673,6 +1673,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
or global_server_args_dict["attention_backend"] == "flashmla"
|
||||
or global_server_args_dict["attention_backend"] == "cutlass_mla"
|
||||
or global_server_args_dict["attention_backend"] == "ascend"
|
||||
or global_server_args_dict["enable_two_batch_overlap"]
|
||||
):
|
||||
seq_lens_cpu = (
|
||||
@@ -1875,7 +1876,10 @@ def get_last_loc(
|
||||
req_pool_indices_tensor: torch.Tensor,
|
||||
prefix_lens_tensor: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if global_server_args_dict["attention_backend"] != "torch_native":
|
||||
if (
|
||||
global_server_args_dict["attention_backend"] != "ascend"
|
||||
and global_server_args_dict["attention_backend"] != "torch_native"
|
||||
):
|
||||
impl = get_last_loc_triton
|
||||
else:
|
||||
impl = get_last_loc_torch
|
||||
|
||||
Reference in New Issue
Block a user