support pangumoe w8a8c8 and docs (#1477)

### What this PR does / why we need it?
support pangu moe w8a8c8

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed with new added test.

Signed-off-by: zhuyilin <809721801@qq.com>
This commit is contained in:
Zhu Yi Lin
2025-06-28 18:51:07 +08:00
committed by GitHub
parent c59d69d9e6
commit b308a7a258
8 changed files with 689 additions and 50 deletions

View File

@@ -49,7 +49,8 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import DeviceMemoryProfiler, LazyLoader, cdiv
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LazyLoader, cdiv)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
@@ -169,6 +170,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else:
self.chunked_prefill_enabled = True
if self.cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype]
self.is_multimodal_model = self.model_config.is_multimodal_model
if self.is_multimodal_model:
self.inputs_embeds = torch.zeros(
@@ -1924,10 +1931,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
# encounter OOM issue
if isinstance(kv_cache_spec, FullAttentionSpec):
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
if self.vllm_config.additional_config.get(
"kv_cache_dtype", None) == 'int8':
kv_cache_shape = self.attn_backend.get_bsh_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size)
else:
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size)
if self.torchair_graph_enabled:
layer_kv_cache_nope = torch.zeros(
kv_cache_shape[:-1] +
@@ -1951,9 +1965,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
acl_format),
)
else:
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
kv_caches[layer_name] = torch.zeros(
kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device)
kv_caches[layer_name] = \
torch_npu.npu_format_cast(kv_caches[layer_name], acl_format)
else: