[bugfix] restore pr-7029 and fix patch error (#7294)

### What this PR does / why we need it?
This PR restores #7029, which adds W8A8C8 support for dsv3.2/glm5 using
the `lightning_indexer_quant` ops in the pd-mix stage.

The original PR was reverted by #7288 because the patch did not work
with the recompute scheduler.

This PR also fixes the patching issue so that it works correctly with
the recompute scheduler.

### Does this PR introduce _any_ user-facing change?
Yes. To enable LI C8, users need to set the `enable_sparse_c8` option to
`"true"` in `additional_config`.

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
rjg-lyh
2026-03-16 15:39:42 +08:00
committed by GitHub
parent 9320365dab
commit 4d443b9228
25 changed files with 4309 additions and 78 deletions

View File

@@ -45,6 +45,30 @@ from vllm.v1.utils import ConstantList, record_function_or_nullcontext
logger = init_logger(__name__)
# `spec_manager_map` in single_type_kv_cache_manager is a module-level dict
# whose keys are class objects bound at import time. When the async
# recompute scheduler is enabled, `recompute_scheduler.py` is imported by
# `check_and_update_config()` (via AsyncScheduler → scheduler.py →
# kv_cache_coordinator → single_type_kv_cache_manager) *before*
# this patch file is executed a second time (e.g. triggered by
# unpickling an AscendMLAAttentionSpec in the EngineCoreProc subprocess).
# In that case the dict already contains the original MLAAttentionSpec
# class as a key, so a subsequent lookup with type(AscendMLAAttentionSpec
# instance) raises KeyError.
#
# Fix: whenever this patch is applied, register AscendMLAAttentionSpec as
# an additional key in spec_manager_map (if the module is already loaded).
def register_ascend_mla_spec_in_manager():
import sys as _sys
from vllm.v1.core.single_type_kv_cache_manager import FullAttentionManager
from vllm.v1.kv_cache_interface import MLAAttentionSpec as AscendMLAAttentionSpec
_stm = _sys.modules.get("vllm.v1.core.single_type_kv_cache_manager")
if _stm is not None and AscendMLAAttentionSpec not in _stm.spec_manager_map:
_stm.spec_manager_map[AscendMLAAttentionSpec] = FullAttentionManager
@dataclass
class RecomputeSchedulerConfig(SchedulerConfig):
scheduler_cls: str | type[object] = "vllm_ascend.core.recompute_scheduler.RecomputeScheduler"
@@ -82,6 +106,8 @@ class RecomputeScheduler(Scheduler):
running: list[Request]
def __init__(self, *args, **kwargs):
register_ascend_mla_spec_in_manager()
super().__init__(*args, **kwargs)
# When is_mtp_kv_consumer is true, we will fill request.spec_token_ids
# with placeholder tokens to enable full graph when decode nodes pull
@@ -993,4 +1019,6 @@ class RecomputeScheduler(Scheduler):
class AsyncRecomputeScheduler(AsyncScheduler, RecomputeScheduler):
def __init__(self, *args, **kwargs):
register_ascend_mla_spec_in_manager()
super().__init__(*args, **kwargs)