diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 0792fa76..906b5081 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -105,6 +105,7 @@ jobs: pytest -sv --durations=0 tests/e2e/singlecard/test_xlite.py pytest -sv --durations=0 tests/e2e/singlecard/pooling/ pytest -sv --durations=0 tests/e2e/singlecard/compile/test_norm_quant_fusion.py + pytest -sv --durations=0 tests/e2e/singlecard/test_cross_layer_attn_model.py # ------------------------------------ v1 spec decode test ------------------------------------ # pytest -sv --durations=0 tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py diff --git a/requirements-dev.txt b/requirements-dev.txt index aa4701b7..8d955047 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -22,4 +22,5 @@ msserviceprofiler>=1.2.2 mindstudio-probe>=8.3.0 arctic-inference==0.1.1 xlite -uc-manager \ No newline at end of file +uc-manager +timm diff --git a/tests/e2e/singlecard/test_cross_layer_attn_model.py b/tests/e2e/singlecard/test_cross_layer_attn_model.py new file mode 100644 index 00000000..61d82ae1 --- /dev/null +++ b/tests/e2e/singlecard/test_cross_layer_attn_model.py @@ -0,0 +1,69 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Compare the outputs of cross layer attention model with and without aclgraph. + +Run `pytest tests/e2e/singlecard/test_cross_layer_attn_model.py`. +""" + +import os + +import pytest + +from tests.e2e.conftest import VllmRunner +from tests.e2e.model_utils import check_outputs_equal + +os.environ["VLLM_USE_MODELSCOPE"] = "True" +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +MODELS = [ + "google/gemma-3n-E2B-it", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [32]) +def test_models_with_aclgraph( + model: str, + max_tokens: int, +) -> None: + prompts = [ + "Hello, my name is", "The president of the United States is", + "The capital of France is", "The future of AI is" + ] + + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=False, + cudagraph_capture_sizes=[4], + ) as vllm_model: + vllm_aclgraph_outputs = vllm_model.generate_greedy(prompts, max_tokens) + + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=True, + ) as vllm_model: + vllm_eager_outputs = vllm_model.generate_greedy(prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs, + outputs_1_lst=vllm_aclgraph_outputs, + name_0="vllm_eager_outputs", + name_1="vllm_aclgraph_outputs", + ) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 001d58fb..3032630c 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -307,6 +307,7 @@ class AscendAttentionBackendImpl(AttentionImpl): device="npu") self.alibi_slopes = alibi_slopes self.attn_type = attn_type + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -618,24 +619,26 @@ class AscendAttentionBackendImpl(AttentionImpl): if len(kv_cache) > 1: if self.key_cache is None: self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] - slots = attn_metadata.slot_mapping - if get_ascend_device_type() == AscendDeviceType.A5: - # TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping. - # Should check if the 0 dim of slot_mapping must equal to the 0 dim of key. - # If it's necessary, the slots should be sliced. - torch_npu.npu_scatter_pa_kv_cache( - key=key[:attn_metadata.num_actual_tokens], - value=value[:attn_metadata.num_actual_tokens].contiguous(), - key_cache=self.key_cache, - value_cache=self.value_cache, - slot_mapping=slots) - else: - torch_npu._npu_reshape_and_cache( - key=key[:attn_metadata.num_actual_tokens], - value=value[:attn_metadata.num_actual_tokens], - key_cache=self.key_cache, - value_cache=self.value_cache, - slot_indices=slots[:attn_metadata.num_actual_tokens]) + if self.kv_sharing_target_layer_name is None: + slots = attn_metadata.slot_mapping + if get_ascend_device_type() == AscendDeviceType.A5: + # TODO: Once eagle running to here, it may has error because of the 0 dim of slot_mapping. + # Should check if the 0 dim of slot_mapping must equal to the 0 dim of key. + # If it's necessary, the slots should be sliced. + torch_npu.npu_scatter_pa_kv_cache( + key=key[:attn_metadata.num_actual_tokens], + value=value[:attn_metadata. + num_actual_tokens].contiguous(), + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_mapping=slots) + else: + torch_npu._npu_reshape_and_cache( + key=key[:attn_metadata.num_actual_tokens], + value=value[:attn_metadata.num_actual_tokens], + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_indices=slots[:attn_metadata.num_actual_tokens]) return key, value def forward_impl( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a72cdeae..d8454674 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1195,6 +1195,10 @@ class NPUModelRunner(GPUModelRunner): def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens): + if self.shared_kv_cache_layers is not None: + # sharing kv across layers need to read the kvcache, + # directly return chunked prefill in this scenario + return AscendAttentionState.ChunkedPrefill if np.array_equal(self.seq_lens.np[:num_reqs], num_scheduled_tokens): attn_state = AscendAttentionState.PrefillNoCache # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. @@ -2243,6 +2247,7 @@ class NPUModelRunner(GPUModelRunner): kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config self.may_add_encoder_only_layers_to_kv_cache_config() + self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) # NOTE(cmq): initialize_attn_backend must before using self.attn_groups self.initialize_attn_backend(kv_cache_config) self.use_hybrid_blocks = (len(self.attn_groups) > 1) @@ -2282,6 +2287,13 @@ class NPUModelRunner(GPUModelRunner): kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, kv_cache_raw_tensors) + # Set up cross-layer KV cache sharing + for layer_name, target_layer_name in self.shared_kv_cache_layers.items( + ): + logger.debug("%s reuses KV cache of %s", layer_name, + target_layer_name) + kv_caches[layer_name] = kv_caches[target_layer_name] + from vllm.v1.worker.utils import bind_kv_cache bind_kv_cache(kv_caches, self.compilation_config.static_forward_context,