support pangumoe w8a8c8 and docs (#1477)
### What this PR does / why we need it? support pangu moe w8a8c8 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with new added test. Signed-off-by: zhuyilin <809721801@qq.com>
This commit is contained in:
@@ -49,7 +49,8 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import DeviceMemoryProfiler, LazyLoader, cdiv
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LazyLoader, cdiv)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
@@ -169,6 +170,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
self.chunked_prefill_enabled = True
|
||||
|
||||
if self.cache_config.cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
else:
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype]
|
||||
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
if self.is_multimodal_model:
|
||||
self.inputs_embeds = torch.zeros(
|
||||
@@ -1924,10 +1931,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
||||
# encounter OOM issue
|
||||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
if self.vllm_config.additional_config.get(
|
||||
"kv_cache_dtype", None) == 'int8':
|
||||
kv_cache_shape = self.attn_backend.get_bsh_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size)
|
||||
else:
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size)
|
||||
if self.torchair_graph_enabled:
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
@@ -1951,9 +1965,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
acl_format),
|
||||
)
|
||||
else:
|
||||
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
kv_caches[layer_name] = torch.zeros(
|
||||
kv_cache_shape,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device)
|
||||
kv_caches[layer_name] = \
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name], acl_format)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user