[Hybrid KV] Follow up UniformTypeKVCacheSpecs (#3070)

### What this PR does / why we need it?
Follow up `UniformTypeKVCacheSpecs` changes introduced by
https://github.com/vllm-project/vllm/pull/25101, which support different
hidden size in uniform type kvcache specs

This also fix the CI issue about `TypeError: AttentionGroup.__init__()
missing 1 required positional argument: 'kv_cache_spec'`

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
Tests passed with exsiting e2e tests.

- vLLM version: v0.10.2
- vLLM main:
c60e6137f0

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-09-22 15:02:41 +08:00
committed by GitHub
parent f1f2c8f5e5
commit f39bd309b6
4 changed files with 101 additions and 35 deletions

View File

@@ -36,7 +36,7 @@ jobs:
- name: Get vLLM version
run: |
VLLM_COMMIT=c60e6137f0bf2034853919b3a9d705d7e06b93cf
VLLM_COMMIT=9607d5eb449711b349d4c2bee0a9c94afcc7ed14
echo "VLLM_COMMIT=https://github.com/vllm-project/vllm/commit/$VLLM_COMMIT" >> $GITHUB_ENV
- name: Checkout repository

View File

@@ -42,7 +42,7 @@ jobs:
lint:
uses: ./.github/workflows/pre-commit.yml
with:
vllm: c60e6137f0bf2034853919b3a9d705d7e06b93cf
vllm: 9607d5eb449711b349d4c2bee0a9c94afcc7ed14
changes:
runs-on: ubuntu-latest
@@ -83,7 +83,7 @@ jobs:
VLLM_USE_MODELSCOPE: True
strategy:
matrix:
vllm_version: [c60e6137f0bf2034853919b3a9d705d7e06b93cf, v0.10.2]
vllm_version: [9607d5eb449711b349d4c2bee0a9c94afcc7ed14, v0.10.2]
steps:
- name: Install packages
run: |
@@ -138,7 +138,7 @@ jobs:
name: e2e-light
strategy:
matrix:
vllm_version: [c60e6137f0bf2034853919b3a9d705d7e06b93cf, v0.10.2]
vllm_version: [9607d5eb449711b349d4c2bee0a9c94afcc7ed14, v0.10.2]
# Note (yikun): If CI resource are limited we can split job into two chain jobs
needs: [lint, changes]
# only trigger e2e test after lint passed and the change is e2e related with pull request.

View File

@@ -68,7 +68,7 @@ jobs:
name: e2e-full
strategy:
matrix:
vllm_version: [c60e6137f0bf2034853919b3a9d705d7e06b93cf, v0.10.2]
vllm_version: [9607d5eb449711b349d4c2bee0a9c94afcc7ed14, v0.10.2]
needs: [changes]
if: ${{ needs.changes.outputs.e2e_tracker == 'true' }}
uses: ./.github/workflows/_e2e_test.yaml

View File

@@ -27,7 +27,8 @@ from contextlib import contextmanager, nullcontext
from copy import deepcopy
from dataclasses import dataclass
from multiprocessing import Manager
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
from typing import (TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional,
Union, cast)
import numpy as np
import numpy.typing as npt
@@ -72,8 +73,12 @@ from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import \
reorder_batch_to_split_decodes_and_prefills
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
# yapf conflicts with isort for this block
# yapf: disable
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec, MambaSpec)
KVCacheConfig, KVCacheGroupSpec,
KVCacheSpec, MambaSpec)
# yapf: enable
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsTensors, ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
@@ -134,6 +139,11 @@ if is_310p():
else:
ACL_FORMAT = ACL_FORMAT_FRACTAL_ND
if not vllm_version_is("0.10.2"):
from vllm.v1.kv_cache_interface import UniformTypeKVCacheSpecs
else:
UniformTypeKVCacheSpecs = None
@dataclass
class GraphCaptureContext:
@@ -2584,10 +2594,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
kv_caches: Dict[str, torch.Tensor] = {}
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator(
):
attn_backend = kv_cache_group.backend
for layer_name in kv_cache_group.layer_names:
for group in self._kv_cache_spec_attn_group_iterator_dispatcher():
if vllm_version_is("0.10.2"):
kv_cache_spec, group = group
else:
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
tensor_size = kv_cache_sizes[layer_name]
@@ -2729,10 +2742,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)), "Some layers are not correctly initialized"
kv_caches: Dict[str, torch.Tensor] = {}
for kv_cache_spec, kv_cache_group in self._kv_cache_spec_attn_group_iterator(
):
attn_backend = kv_cache_group.backend
for layer_name in kv_cache_group.layer_names:
for group in self._kv_cache_spec_attn_group_iterator_dispatcher():
if vllm_version_is("0.10.2"):
kv_cache_spec, group = group
else:
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
@@ -2829,15 +2845,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return kv_caches
def _kv_cache_spec_attn_group_iterator(
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
if not self.kv_cache_config.kv_cache_groups:
return
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
for attn_group in attn_groups:
yield self.kv_cache_config.kv_cache_groups[
kv_cache_spec_id].kv_cache_spec, attn_group
def may_reinitialize_input_batch(self,
kv_cache_config: KVCacheConfig) -> None:
"""
@@ -2917,9 +2924,45 @@ class NPUModelRunner(LoRAModelRunnerMixin):
assert len(self.attn_groups) == 0, \
"Attention backends are already initialized"
class AttentionGroupKey(NamedTuple):
attn_backend: type[AttentionBackend]
kv_cache_spec: KVCacheSpec
def get_attn_backends_for_group(
kv_cache_group_spec: KVCacheGroupSpec,
) -> dict[AttentionGroupKey, list[str]]:
layers = get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase,
kv_cache_group_spec.layer_names)
attn_backends = {}
attn_backend_layers = defaultdict(list)
# Dedupe based on full class name; this is a bit safer than
# using the class itself as the key because when we create dynamic
# attention backend subclasses (e.g. ChunkedLocalAttention) unless
# they are cached correctly, there will be different objects per
# layer.
for layer_name in kv_cache_group_spec.layer_names:
attn_backend = layers[layer_name].get_attn_backend()
full_cls_name = attn_backend.full_cls_name()
layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs):
layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[
layer_name]
key = (full_cls_name, layer_kv_cache_spec)
attn_backends[key] = AttentionGroupKey(attn_backend,
layer_kv_cache_spec)
attn_backend_layers[key].append(layer_name)
return {
attn_backends[k]: v
for k, v in attn_backend_layers.items()
}
def get_attn_backends_for_layers(
layer_names: list[str]
) -> dict[type[AttentionBackend], list[str]]:
"""Get attention_backend for all attention layers
TODO: Only used in v0.10.2, drop me when 0.10.2 is dropped
"""
layers = get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase,
layer_names)
@@ -2960,10 +3003,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def create_attn_groups(
attn_backends_map: dict[AttentionBackend, list[str]],
kv_cache_spec: KVCacheSpec,
) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = []
for attn_backend, layer_names in attn_backends_map.items():
for (attn_backend,
kv_cache_spec), layer_names in attn_backends_map.items():
attn_metadata_builders = []
attn_metadata_builders.append(attn_backend.get_builder_cls()(
kv_cache_spec,
@@ -2973,20 +3016,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
))
attn_group = AttentionGroup(attn_backend,
attn_metadata_builders,
layer_names)
layer_names, kv_cache_spec)
attn_groups.append(attn_group)
return attn_groups
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
attn_backends = get_attn_backends_for_layers(
kv_cache_group_spec.layer_names)
if vllm_version_is("0.10.2"):
if vllm_version_is("0.10.2"):
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
attn_backends = get_attn_backends_for_layers(
kv_cache_group_spec.layer_names)
self.attn_groups.append(
create_attn_groups_v0102(attn_backends, kv_cache_spec))
else:
self.attn_groups.append(
create_attn_groups(attn_backends, kv_cache_spec))
else:
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
attn_backends = get_attn_backends_for_group( # type: ignore
kv_cache_group_spec)
self.attn_groups.append(create_attn_groups(attn_backends))
# Calculate reorder batch threshold (if needed)
self.calculate_reorder_batch_threshold()
@@ -2994,6 +3039,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
return itertools.chain.from_iterable(self.attn_groups)
def _kv_cache_spec_attn_group_iterator(self) -> Iterator[AttentionGroup]:
if not self.kv_cache_config.kv_cache_groups:
return
for attn_groups in self.attn_groups:
yield from attn_groups
def _kv_cache_spec_attn_group_iterator_v0102(
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
if not self.kv_cache_config.kv_cache_groups:
return
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
for attn_group in attn_groups:
yield self.kv_cache_config.kv_cache_groups[
kv_cache_spec_id].kv_cache_spec, attn_group
def _kv_cache_spec_attn_group_iterator_dispatcher(self):
if vllm_version_is("0.10.2"):
return self._kv_cache_spec_attn_group_iterator_v0102()
else:
return self._kv_cache_spec_attn_group_iterator()
def calculate_reorder_batch_threshold(self) -> None:
"""
Check that if any backends reorder batches; that the reordering