Add DeepSeek V3.2 support (#3270)
### What this PR does / why we need it? This PR added the initial DeepSeek V3.2 support with [vLLM v0.11.0](https://github.com/vllm-project/vllm/tree/releases/v0.11.0) (not released yet). We will complete vLLM adaptation as soon as possible. This feature will be ready in recent 1-2 days. Related doc: https://github.com/vllm-project/vllm-ascend/pull/3223 . ### Does this PR introduce _any_ user-facing change? Yes! ### How was this patch tested? CI passed and Run deepseek doc soon. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.0 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com> Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: wxsIcey <1790571317@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -285,8 +285,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
||||
self.runner_only_attn_layers: set[str] = set()
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
if ascend_config.ascend_scheduler_config.enabled:
|
||||
self.ascend_config = get_ascend_config()
|
||||
if self.ascend_config.ascend_scheduler_config.enabled:
|
||||
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
|
||||
else:
|
||||
self.chunked_prefill_enabled = True
|
||||
@@ -298,6 +298,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.cache_config.cache_dtype]
|
||||
# use_hybrid_blocks: if hybrid blocks is used.
|
||||
self.use_hybrid_blocks: bool = False
|
||||
self.need_accepted_tokens: bool = False
|
||||
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
self.is_pooling_model = self.model_config.pooler_config is not None
|
||||
@@ -315,7 +316,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
)
|
||||
use_sfa=self.ascend_config.use_sfa)
|
||||
else:
|
||||
self.attn_backend = get_attn_backend(
|
||||
0,
|
||||
@@ -323,7 +324,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
None,
|
||||
self.block_size,
|
||||
use_mla=self.model_config.use_mla,
|
||||
)
|
||||
use_sfa=self.ascend_config.use_sfa)
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
self.attn_mask_builder = AttentionMaskBuilder(
|
||||
self.scheduler_config.max_num_batched_tokens, self.dtype,
|
||||
@@ -457,7 +458,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb
|
||||
self.dynamic_eplb = self.ascend_config.dynamic_eplb
|
||||
if self.dynamic_eplb:
|
||||
self.is_eplb_warmuped = False
|
||||
self.eplb_loader = D2DExpertWeightLoader()
|
||||
@@ -890,15 +891,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _make_attention_mask(self, seq_lens, position,
|
||||
attn_state) -> torch.Tensor:
|
||||
# Chunk Prefill situation.
|
||||
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
|
||||
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa:
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
else:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens, position, self.dtype, self.device)
|
||||
|
||||
# Prefill without cache situation.
|
||||
elif attn_state == AscendAttentionState.PrefillNoCache:
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
max_seq_len = max(seq_lens.max().item(), 0)
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
max_seq_len, self.dtype, self.device)
|
||||
# Prefill with cache hit.
|
||||
@@ -1252,7 +1254,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
req_ids = self.input_batch.req_ids
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
max_num_scheduled_tokens = num_scheduled_tokens.max()
|
||||
num_valid_tokens = np.array([
|
||||
num_tokens -
|
||||
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
|
||||
@@ -1376,8 +1378,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions_cpu = self.positions_cpu[:num_input_tokens]
|
||||
positions = self.positions[:num_input_tokens]
|
||||
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
||||
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens)
|
||||
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
|
||||
position=positions_cpu,
|
||||
attn_state=attn_state)
|
||||
@@ -1477,7 +1477,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
spec_decode_common_attn_metadata = None
|
||||
if use_spec_decode:
|
||||
if use_spec_decode and self.need_accepted_tokens:
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
@@ -1550,7 +1550,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
model=self.model,
|
||||
**extra_attn_metadata_args)
|
||||
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
|
||||
attn_metadata_i.num_input_tokens = num_input_tokens
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
@@ -2060,7 +2060,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
self._update_states_after_model_execute(output_token_ids)
|
||||
if self.need_accepted_tokens:
|
||||
self._update_states_after_model_execute(output_token_ids)
|
||||
|
||||
discard_sampled_tokens_req_indices: list[int] = []
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
@@ -2683,10 +2684,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
|
||||
# NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`.
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.need_accepted_tokens = any([
|
||||
isinstance(
|
||||
self.kv_cache_config.kv_cache_groups[0].kv_cache_spec,
|
||||
MambaSpec) for attn_group in self.attn_groups
|
||||
])
|
||||
else:
|
||||
self.need_accepted_tokens = any([
|
||||
isinstance(attn_group[0].kv_cache_spec, MambaSpec)
|
||||
for attn_group in self.attn_groups
|
||||
])
|
||||
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
|
||||
if self.model_config.is_deepseek_mla:
|
||||
kv_caches = self.initialize_kv_cache_tensors_deepseek(
|
||||
if self.ascend_config.is_deepseek_sfa:
|
||||
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
|
||||
kv_cache_config)
|
||||
elif self.model_config.is_deepseek_mla:
|
||||
kv_caches = self.initialize_kv_cache_tensors_deepseek_mla(
|
||||
kv_cache_config)
|
||||
else:
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
@@ -2701,7 +2718,116 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||||
return tensor[int(offset):]
|
||||
|
||||
def initialize_kv_cache_tensors_deepseek(
|
||||
def initialize_kv_cache_tensors_deepseek_sfa(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
kv_cache_sizes = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
assert len(kv_cache_tensor.shared_by) == 1, (
|
||||
"KV cache tensor shared by multiple layers is not supported in "
|
||||
"NPU.")
|
||||
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
||||
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
for group in self._kv_cache_spec_attn_group_iterator_dispatcher():
|
||||
if vllm_version_is("0.10.2"):
|
||||
kv_cache_spec, group = group
|
||||
else:
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
attn_backend = group.backend
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
tensor_size = kv_cache_sizes[layer_name]
|
||||
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
||||
if self.vllm_config.additional_config.get(
|
||||
"kv_cache_dtype", None) == 'int8':
|
||||
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
elif hasattr(
|
||||
attn_backend, "get_supported_block_size"
|
||||
) and not self.model_config.is_deepseek_mla and not self.ascend_config.is_deepseek_sfa:
|
||||
block_size = attn_backend.get_supported_block_size()[0]
|
||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks * block_size_chunk, block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
else:
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
alignment = 2 * 1024 * 1024
|
||||
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
|
||||
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
nope_dim = head_size - rope_dim
|
||||
nope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
||||
nope_dim)
|
||||
rope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
||||
rope_dim)
|
||||
#### k cache
|
||||
# TODO(zzzzwwjj): wait transformers add these params
|
||||
k_cache_shape = (num_blocks, block_size, 1, 128)
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
# For no disaggregate pd scenario, allocate kv cache in normal way
|
||||
rope_cache = torch.zeros(rope_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
nope_cache = torch.zeros(nope_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
rope_cache = self._convert_torch_format(rope_cache)
|
||||
nope_cache = self._convert_torch_format(nope_cache)
|
||||
|
||||
#### k cache
|
||||
k_cache = torch.zeros(k_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
k_cache = self._convert_torch_format(k_cache)
|
||||
else:
|
||||
|
||||
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
||||
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
|
||||
# we found there are also some exceptions during test, so we manual align those memory here, this part
|
||||
# of code may consume 2M * 2 * elem_size memory every layer.
|
||||
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
|
||||
nope_allocate_shape_alignment = nope_allocate_shape + alignment
|
||||
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
|
||||
rope_allocate_shape_alignment = rope_allocate_shape + alignment
|
||||
|
||||
nope_cache = torch.zeros(nope_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
rope_cache = torch.zeros(rope_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
#### k cache
|
||||
# TODO(zzzzwwjj): wait transformers add these params
|
||||
k_allocate_shape = num_blocks * block_size * 1 * 128
|
||||
k_allocate_shape_alignment = k_allocate_shape + alignment
|
||||
k_cache = torch.zeros(k_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
|
||||
nope_cache = self._align_memory(
|
||||
nope_cache,
|
||||
alignment)[:nope_allocate_shape].view(nope_cache_shape)
|
||||
rope_cache = self._align_memory(
|
||||
rope_cache,
|
||||
alignment)[:rope_allocate_shape].view(rope_cache_shape)
|
||||
k_cache = self._align_memory(
|
||||
k_cache,
|
||||
alignment)[:k_allocate_shape].view(k_cache_shape)
|
||||
|
||||
kv_caches[layer_name] = (nope_cache, rope_cache, k_cache)
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
return kv_caches
|
||||
|
||||
def initialize_kv_cache_tensors_deepseek_mla(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
kv_cache_sizes = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
@@ -3217,6 +3343,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
use_sfa = self.ascend_config.use_sfa
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
@@ -3243,7 +3370,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
use_mla=use_mla,
|
||||
use_sfa=use_sfa)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
|
||||
@@ -43,7 +43,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
@@ -88,6 +88,17 @@ class NPUWorker(WorkerBase):
|
||||
# init ascend config and soc version
|
||||
init_ascend_config(vllm_config)
|
||||
init_ascend_soc_version()
|
||||
if get_ascend_config().use_sfa:
|
||||
# Direct import instead of using try_register_lib to ensure proper error handling when
|
||||
# custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments)
|
||||
# yapf: disable
|
||||
import custom_ops # type: ignore # noqa
|
||||
|
||||
# yapf: enable
|
||||
logger.info(
|
||||
"custom_ops module loaded successfully. Custom operators like "
|
||||
"torch.ops.custom.npu_sparse_flash_attention are now available."
|
||||
)
|
||||
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
local_rank=local_rank,
|
||||
|
||||
Reference in New Issue
Block a user