[Hybrid KV] Follow up UniformTypeKVCacheSpecs (#3070)

### What this PR does / why we need it?
Follow up `UniformTypeKVCacheSpecs` changes introduced by
https://github.com/vllm-project/vllm/pull/25101, which support different
hidden size in uniform type kvcache specs

This also fix the CI issue about `TypeError: AttentionGroup.__init__()
missing 1 required positional argument: 'kv_cache_spec'`

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
Tests passed with exsiting e2e tests.

- vLLM version: v0.10.2
- vLLM main:
c60e6137f0

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-09-22 15:02:41 +08:00
committed by GitHub
parent f1f2c8f5e5
commit f39bd309b6
4 changed files with 101 additions and 35 deletions

View File

@@ -36,7 +36,7 @@ jobs:
- name: Get vLLM version - name: Get vLLM version
run: | run: |
VLLM_COMMIT=c60e6137f0bf2034853919b3a9d705d7e06b93cf VLLM_COMMIT=9607d5eb449711b349d4c2bee0a9c94afcc7ed14
echo "VLLM_COMMIT=https://github.com/vllm-project/vllm/commit/$VLLM_COMMIT" >> $GITHUB_ENV echo "VLLM_COMMIT=https://github.com/vllm-project/vllm/commit/$VLLM_COMMIT" >> $GITHUB_ENV
- name: Checkout repository - name: Checkout repository

View File

@@ -42,7 +42,7 @@ jobs:
lint: lint:
uses: ./.github/workflows/pre-commit.yml uses: ./.github/workflows/pre-commit.yml
with: with:
vllm: c60e6137f0bf2034853919b3a9d705d7e06b93cf vllm: 9607d5eb449711b349d4c2bee0a9c94afcc7ed14
changes: changes:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -83,7 +83,7 @@ jobs:
VLLM_USE_MODELSCOPE: True VLLM_USE_MODELSCOPE: True
strategy: strategy:
matrix: matrix:
vllm_version: [c60e6137f0bf2034853919b3a9d705d7e06b93cf, v0.10.2] vllm_version: [9607d5eb449711b349d4c2bee0a9c94afcc7ed14, v0.10.2]
steps: steps:
- name: Install packages - name: Install packages
run: | run: |
@@ -138,7 +138,7 @@ jobs:
name: e2e-light name: e2e-light
strategy: strategy:
matrix: matrix:
vllm_version: [c60e6137f0bf2034853919b3a9d705d7e06b93cf, v0.10.2] vllm_version: [9607d5eb449711b349d4c2bee0a9c94afcc7ed14, v0.10.2]
# Note (yikun): If CI resource are limited we can split job into two chain jobs # Note (yikun): If CI resource are limited we can split job into two chain jobs
needs: [lint, changes] needs: [lint, changes]
# only trigger e2e test after lint passed and the change is e2e related with pull request. # only trigger e2e test after lint passed and the change is e2e related with pull request.

View File

@@ -68,7 +68,7 @@ jobs:
name: e2e-full name: e2e-full
strategy: strategy:
matrix: matrix:
vllm_version: [c60e6137f0bf2034853919b3a9d705d7e06b93cf, v0.10.2] vllm_version: [9607d5eb449711b349d4c2bee0a9c94afcc7ed14, v0.10.2]
needs: [changes] needs: [changes]
if: ${{ needs.changes.outputs.e2e_tracker == 'true' }} if: ${{ needs.changes.outputs.e2e_tracker == 'true' }}
uses: ./.github/workflows/_e2e_test.yaml uses: ./.github/workflows/_e2e_test.yaml

View File

@@ -27,7 +27,8 @@ from contextlib import contextmanager, nullcontext
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import Manager from multiprocessing import Manager
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional,
Union, cast)
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@@ -72,8 +73,12 @@ from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import \ from vllm.v1.attention.backends.utils import \
reorder_batch_to_split_decodes_and_prefills reorder_batch_to_split_decodes_and_prefills
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# yapf conflicts with isort for this block
# yapf: disable
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec, MambaSpec) KVCacheConfig, KVCacheGroupSpec,
KVCacheSpec, MambaSpec)
# yapf: enable
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsTensors, ModelRunnerOutput) DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
@@ -134,6 +139,11 @@ if is_310p():
else: else:
ACL_FORMAT = ACL_FORMAT_FRACTAL_ND ACL_FORMAT = ACL_FORMAT_FRACTAL_ND
if not vllm_version_is("0.10.2"):
from vllm.v1.kv_cache_interface import UniformTypeKVCacheSpecs
else:
UniformTypeKVCacheSpecs = None
@dataclass @dataclass
class GraphCaptureContext: class GraphCaptureContext:
@@ -2584,10 +2594,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
kv_caches: Dict[str, torch.Tensor] = {} kv_caches: Dict[str, torch.Tensor] = {}
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator( for group in self._kv_cache_spec_attn_group_iterator_dispatcher():
): if vllm_version_is("0.10.2"):
attn_backend = kv_cache_group.backend kv_cache_spec, group = group
for layer_name in kv_cache_group.layer_names: else:
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers: if layer_name in self.runner_only_attn_layers:
continue continue
tensor_size = kv_cache_sizes[layer_name] tensor_size = kv_cache_sizes[layer_name]
@@ -2729,10 +2742,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)), "Some layers are not correctly initialized" )), "Some layers are not correctly initialized"
kv_caches: Dict[str, torch.Tensor] = {} kv_caches: Dict[str, torch.Tensor] = {}
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator( for group in self._kv_cache_spec_attn_group_iterator_dispatcher():
): if vllm_version_is("0.10.2"):
attn_backend = kv_cache_group.backend kv_cache_spec, group = group
for layer_name in kv_cache_group.layer_names: else:
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers: if layer_name in self.runner_only_attn_layers:
continue continue
@@ -2829,15 +2845,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return kv_caches return kv_caches
def _kv_cache_spec_attn_group_iterator(
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
if not self.kv_cache_config.kv_cache_groups:
return
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
for attn_group in attn_groups:
yield self.kv_cache_config.kv_cache_groups[
kv_cache_spec_id].kv_cache_spec, attn_group
def may_reinitialize_input_batch(self, def may_reinitialize_input_batch(self,
kv_cache_config: KVCacheConfig) -> None: kv_cache_config: KVCacheConfig) -> None:
""" """
@@ -2917,9 +2924,45 @@ class NPUModelRunner(LoRAModelRunnerMixin):
assert len(self.attn_groups) == 0, \ assert len(self.attn_groups) == 0, \
"Attention backends are already initialized" "Attention backends are already initialized"
class AttentionGroupKey(NamedTuple):
attn_backend: type[AttentionBackend]
kv_cache_spec: KVCacheSpec
def get_attn_backends_for_group(
kv_cache_group_spec: KVCacheGroupSpec,
) -> dict[AttentionGroupKey, list[str]]:
layers = get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase,
kv_cache_group_spec.layer_names)
attn_backends = {}
attn_backend_layers = defaultdict(list)
# Dedupe based on full class name; this is a bit safer than
# using the class itself as the key because when we create dynamic
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
# they are cached correctly, there will be different objects per
# layer.
for layer_name in kv_cache_group_spec.layer_names:
attn_backend = layers[layer_name].get_attn_backend()
full_cls_name = attn_backend.full_cls_name()
layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
layer_name]
key = (full_cls_name, layer_kv_cache_spec)
attn_backends[key] = AttentionGroupKey(attn_backend,
layer_kv_cache_spec)
attn_backend_layers[key].append(layer_name)
return {
attn_backends[k]: v
for k, v in attn_backend_layers.items()
}
def get_attn_backends_for_layers( def get_attn_backends_for_layers(
layer_names: list[str] layer_names: list[str]
) -> dict[type[AttentionBackend], list[str]]: ) -> dict[type[AttentionBackend], list[str]]:
"""Get attention_backend for all attention layers
TODO: Only used in v0.10.2, drop me when 0.10.2 is dropped
"""
layers = get_layers_from_vllm_config(self.vllm_config, layers = get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase, AttentionLayerBase,
layer_names) layer_names)
@@ -2960,10 +3003,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def create_attn_groups( def create_attn_groups(
attn_backends_map: dict[AttentionBackend, list[str]], attn_backends_map: dict[AttentionBackend, list[str]],
kv_cache_spec: KVCacheSpec,
) -> list[AttentionGroup]: ) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = [] attn_groups: list[AttentionGroup] = []
for attn_backend, layer_names in attn_backends_map.items(): for (attn_backend,
kv_cache_spec), layer_names in attn_backends_map.items():
attn_metadata_builders = [] attn_metadata_builders = []
attn_metadata_builders.append(attn_backend.get_builder_cls()( attn_metadata_builders.append(attn_backend.get_builder_cls()(
kv_cache_spec, kv_cache_spec,
@@ -2973,20 +3016,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)) ))
attn_group = AttentionGroup(attn_backend, attn_group = AttentionGroup(attn_backend,
attn_metadata_builders, attn_metadata_builders,
layer_names) layer_names, kv_cache_spec)
attn_groups.append(attn_group) attn_groups.append(attn_group)
return attn_groups return attn_groups
for kv_cache_group_spec in kv_cache_config.kv_cache_groups: if vllm_version_is("0.10.2"):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
attn_backends = get_attn_backends_for_layers( kv_cache_spec = kv_cache_group_spec.kv_cache_spec
kv_cache_group_spec.layer_names) attn_backends = get_attn_backends_for_layers(
if vllm_version_is("0.10.2"): kv_cache_group_spec.layer_names)
self.attn_groups.append( self.attn_groups.append(
create_attn_groups_v0102(attn_backends, kv_cache_spec)) create_attn_groups_v0102(attn_backends, kv_cache_spec))
else: else:
self.attn_groups.append( for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
create_attn_groups(attn_backends, kv_cache_spec)) attn_backends = get_attn_backends_for_group( # type: ignore
kv_cache_group_spec)
self.attn_groups.append(create_attn_groups(attn_backends))
# Calculate reorder batch threshold (if needed) # Calculate reorder batch threshold (if needed)
self.calculate_reorder_batch_threshold() self.calculate_reorder_batch_threshold()
@@ -2994,6 +3039,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _attn_group_iterator(self) -> Iterator[AttentionGroup]: def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
return itertools.chain.from_iterable(self.attn_groups) return itertools.chain.from_iterable(self.attn_groups)
def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]:
if not self.kv_cache_config.kv_cache_groups:
return
for attn_groups in self.attn_groups:
yield from attn_groups
def _kv_cache_spec_attn_group_iterator_v0102(
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
if not self.kv_cache_config.kv_cache_groups:
return
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
for attn_group in attn_groups:
yield self.kv_cache_config.kv_cache_groups[
kv_cache_spec_id].kv_cache_spec, attn_group
def _kv_cache_spec_attn_group_iterator_dispatcher(self):
if vllm_version_is("0.10.2"):
return self._kv_cache_spec_attn_group_iterator_v0102()
else:
return self._kv_cache_spec_attn_group_iterator()
def calculate_reorder_batch_threshold(self) -> None: def calculate_reorder_batch_threshold(self) -> None:
""" """
Check that if any backends reorder batches; that the reordering Check that if any backends reorder batches; that the reordering