diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index c2fd9dd5..32a66c0e 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -5,7 +5,7 @@ import torch import torch_npu from torch import nn from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config import CUDAGraphMode, VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.forward_context import get_forward_context from vllm.logger import logger @@ -148,6 +148,12 @@ class AscendSFAMetadataBuilder: self.enable_sfa_cp = enable_sp() and \ hasattr(self.model_config.hf_config, "index_topk") + assert not ( + self.enable_sfa_cp + and self.vllm_config.compilation_config.cudagraph_mode + == CUDAGraphMode.FULL_DECODE_ONLY + ), "FlashComm1 is not compatible with FULL_DECODE_ONLY. Please set graph_mode to 'piecewise' or disable FlashComm1." + def reorder_batch(self, input_batch: "NPUInputBatch", scheduler_output: "SchedulerOutput") -> bool: # No need to reorder for Ascend SFA