support pangumoe w8a8c8 and docs (#1477)
### What this PR does / why we need it? support pangu moe w8a8c8 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with new added test. Signed-off-by: zhuyilin <809721801@qq.com>
This commit is contained in:
@@ -32,6 +32,7 @@ The following table lists the additional configuration options available in vLLM
|
||||
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
|
||||
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
|
||||
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
|
||||
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
|
||||
|
||||
The details of each config option are as follows:
|
||||
|
||||
|
||||
@@ -69,6 +69,15 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
16)
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_bsh_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: List[torch.Tensor],
|
||||
@@ -279,6 +288,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
value=value,
|
||||
output=output,
|
||||
layer_name=layer.layer_name)
|
||||
|
||||
elif hasattr(layer, 'quant_method'):
|
||||
output = layer.quant_method.apply(layer, query, key, value,
|
||||
kv_cache, attn_metadata,
|
||||
self.attn_type, self.scale,
|
||||
output)
|
||||
|
||||
else:
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
@@ -308,11 +324,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
|
||||
if hasattr(layer, 'quant_method'):
|
||||
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
|
||||
pass
|
||||
# V0-Style scheduler situation.
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
@@ -414,6 +427,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
out=output)
|
||||
|
||||
# to make in-place change to the output tensor
|
||||
if hasattr(layer, 'quant_method'):
|
||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||
ori_output[:, :, :] = output[:num_tokens, :, :]
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
|
||||
@@ -505,7 +505,7 @@ class PanguProMoESparseMoeBlock(nn.Module):
|
||||
# native FusedMoE. here we need to design a better FusedMoE
|
||||
# (maybe using AscendFusedMoE) to enable these different
|
||||
# communication schema.
|
||||
final_hidden_states = self.experts.quant_method(
|
||||
final_hidden_states = self.experts.quant_method.apply(
|
||||
layer=self.experts,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
@@ -937,6 +937,8 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP):
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
tp_size = get_tp_group().world_size
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@@ -972,6 +974,51 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP):
|
||||
if "module" in name:
|
||||
continue
|
||||
|
||||
if name.endswith('kv_cache_offset'):
|
||||
continue
|
||||
|
||||
if name.endswith("k_proj.kv_cache_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
"k_proj.kv_cache_scale", "attn.key_antiquant_scale")
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
logger.warning_once(
|
||||
"Found kv scale in the checkpoint "
|
||||
f"(e.g. {name}), but not found the expected "
|
||||
f"name in the model "
|
||||
f"(e.g. {remapped_kv_scale_name}). "
|
||||
"kv-scale is not loaded.")
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
param = params_dict[name]
|
||||
loaded_weight = torch.tensor_split(loaded_weight,
|
||||
tp_size,
|
||||
dim=0)[tp_rank]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if name.endswith("v_proj.kv_cache_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
"v_proj.kv_cache_scale", "attn.value_antiquant_scale")
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
logger.warning_once(
|
||||
"Found kv scale in the checkpoint "
|
||||
f"(e.g. {name}), but not found the expected "
|
||||
f"name in the model "
|
||||
f"(e.g. {remapped_kv_scale_name}). "
|
||||
"kv-scale is not loaded.")
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
param = params_dict[name]
|
||||
loaded_weight = torch.tensor_split(loaded_weight,
|
||||
tp_size,
|
||||
dim=0)[tp_rank]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
|
||||
@@ -124,6 +124,10 @@ class NPUPlatform(Platform):
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
cache_config = vllm_config.cache_config
|
||||
kv_cache_dtype = vllm_config.additional_config.get(
|
||||
"kv_cache_dtype", None)
|
||||
if kv_cache_dtype is not None:
|
||||
vllm_config.cache_config.cache_dtype = kv_cache_dtype
|
||||
|
||||
if parallel_config:
|
||||
# Default value for expert tensor parallel size
|
||||
|
||||
@@ -98,6 +98,9 @@ class AscendQuantConfig(QuantizationConfig):
|
||||
'fa_quant_type' in self.quant_description.keys() and \
|
||||
self.quant_description['fa_quant_type'] is not None:
|
||||
return AscendKVCacheMethod(self, prefix)
|
||||
elif isinstance(layer, Attention) and self.quant_description.get(
|
||||
'kv_quant_type') == 'C8':
|
||||
return AscendKVCacheMethod(self, prefix)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
@@ -235,32 +238,11 @@ class AscendKVCacheMethod(BaseKVCacheMethod):
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
k_cache: List[torch.Tensor],
|
||||
v_cache: List[torch.Tensor],
|
||||
scale: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
isPrefill: bool,
|
||||
attn_metadata,
|
||||
output,
|
||||
seq_lens_tensor_cpu: Optional[int] = None) -> torch.Tensor:
|
||||
return self.quant_method.apply(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
k_cache,
|
||||
v_cache,
|
||||
scale,
|
||||
block_tables,
|
||||
isPrefill,
|
||||
attn_metadata.attn_mask,
|
||||
attn_metadata.slot_mapping,
|
||||
output,
|
||||
seq_lens_tensor_cpu=seq_lens_tensor_cpu)
|
||||
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
|
||||
attn_type, scale, output) -> torch.Tensor:
|
||||
return self.quant_method.apply(layer, query, key, value, kv_cache,
|
||||
attn_metadata, attn_type, scale, output)
|
||||
|
||||
|
||||
class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@@ -24,7 +24,8 @@ from vllm.logger import logger
|
||||
|
||||
from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot,
|
||||
wrapper_rmsnorm_init)
|
||||
from .w8a8 import AscendW8A8LinearMethod
|
||||
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod)
|
||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
|
||||
@@ -250,6 +251,8 @@ class VLLMAscendQuantizer:
|
||||
# Attention
|
||||
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
|
||||
quant_type = quant_description['fa_quant_type']
|
||||
if '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
|
||||
quant_type = quant_description['kv_quant_type']
|
||||
# Linear
|
||||
else:
|
||||
quant_type = cls.get_linear_quant_type(quant_description, prefix,
|
||||
@@ -269,6 +272,14 @@ class W8A8Quantizer(VLLMAscendQuantizer):
|
||||
def build_linear_method():
|
||||
return AscendW8A8LinearMethod()
|
||||
|
||||
@staticmethod
|
||||
def build_moe_method():
|
||||
return AscendW8A8FusedMoEMethod()
|
||||
|
||||
@staticmethod
|
||||
def build_attention_method():
|
||||
return AscendC8KVCacheMethod()
|
||||
|
||||
|
||||
class W8A8DYNAMICQuantizer(VLLMAscendQuantizer):
|
||||
|
||||
@@ -284,4 +295,5 @@ class W8A8DYNAMICQuantizer(VLLMAscendQuantizer):
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE = {
|
||||
"W8A8": W8A8Quantizer,
|
||||
"W8A8_DYNAMIC": W8A8DYNAMICQuantizer,
|
||||
"C8": W8A8Quantizer,
|
||||
}
|
||||
|
||||
@@ -15,16 +15,23 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
|
||||
|
||||
def quant_per_tensor(in_tensor: torch.Tensor, input_scale: torch.Tensor,
|
||||
input_offset: torch.Tensor):
|
||||
def quant_per_tensor(in_tensor: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
input_offset: torch.Tensor,
|
||||
function=False):
|
||||
return torch_npu.npu_quantize(in_tensor, input_scale, input_offset,
|
||||
torch.qint8, -1, False)
|
||||
torch.qint8, -1, function)
|
||||
|
||||
|
||||
class AscendW8A8LinearMethod:
|
||||
@@ -86,19 +93,17 @@ class AscendW8A8LinearMethod:
|
||||
) -> torch.Tensor:
|
||||
original_dtype = x.dtype
|
||||
if original_dtype != torch.int8:
|
||||
x = quant_per_tensor(
|
||||
x,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_offset,
|
||||
)
|
||||
x = quant_per_tensor(x, layer.aclnn_input_scale,
|
||||
layer.aclnn_input_offset)
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
return torch_npu.npu_quant_matmul(
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
layer.weight,
|
||||
layer.deq_scale,
|
||||
bias=quant_bias,
|
||||
output_dtype=original_dtype,
|
||||
)
|
||||
return output
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
expanding_factor = layer.weight.data.shape[1]
|
||||
@@ -113,3 +118,561 @@ class AscendW8A8LinearMethod:
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
|
||||
|
||||
class AscendW8A8FusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W8A8.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
@staticmethod
|
||||
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight"] = torch.empty(num_experts,
|
||||
2 *
|
||||
intermediate_size_per_partition,
|
||||
hidden_sizes,
|
||||
dtype=torch.int8,
|
||||
requires_grad=False)
|
||||
param_dict["w2_weight"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8,
|
||||
requires_grad=False)
|
||||
return param_dict
|
||||
|
||||
@staticmethod
|
||||
def get_dynamic_quant_param(num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_weight_offset"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float16)
|
||||
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float16)
|
||||
param_dict["w2_deq_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_deq_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_input_scale"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_input_scale"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_input_offset"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.int8)
|
||||
param_dict["w13_input_offset"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.int8)
|
||||
param_dict["quant_bias"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
dtype=torch.int32)
|
||||
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k,
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group,
|
||||
group_count=num_expert_group,
|
||||
group_select_mode=1,
|
||||
renorm=0,
|
||||
norm_type=1,
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
else:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
|
||||
if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
|
||||
raise NotImplementedError("W8A8FusedMoe are not "
|
||||
"mplemented for VLLM_ENABLE_MC2")
|
||||
|
||||
else:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w1_input_scale=layer.w13_input_scale,
|
||||
w1_input_offset=layer.w13_input_offset,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
w2_input_offset=layer.w2_input_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
# torch.npu.config.allow_internal_format = True
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
|
||||
2).contiguous()
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1).to(torch.float32)
|
||||
|
||||
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
||||
layer.w13_weight_offset.data.shape[0], -1).to(torch.float16)
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
||||
layer.w2_weight_scale.data.shape[0], -1).to(torch.float32)
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
||||
layer.w2_weight_offset.data.shape[0], -1).to(torch.float16)
|
||||
expanding_factor_w13 = layer.w13_weight.data.shape[1]
|
||||
expanding_factor_w2 = layer.w2_weight.data.shape[1]
|
||||
layer.w13_input_scale.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.repeat(
|
||||
1, expanding_factor_w13)[0:1]).to(torch.float16)
|
||||
|
||||
layer.w2_input_scale.data = torch.nn.Parameter(
|
||||
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]).to(
|
||||
torch.float16)
|
||||
layer.w13_input_offset.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.repeat(
|
||||
1, expanding_factor_w13)[0:1]).to(torch.int8)
|
||||
layer.w2_input_offset.data = torch.nn.Parameter(
|
||||
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]).to(
|
||||
torch.int8)
|
||||
|
||||
# NZ
|
||||
# layer.w13_weight.data = torch_npu.npu_format_cast(layer.w13_weight.data, 29).contiguous()
|
||||
# layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, 29).contiguous()
|
||||
|
||||
|
||||
class AscendC8KVCacheMethod:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.antiquant_scale_comb = None
|
||||
|
||||
@staticmethod
|
||||
def create_weights(layer) -> None:
|
||||
param_dict = {} # num_kv_heads * head_size
|
||||
param_dict["key_antiquant_scale"] = torch.empty(layer.num_kv_heads *
|
||||
layer.head_size,
|
||||
dtype=torch.float16,
|
||||
requires_grad=False)
|
||||
param_dict["value_antiquant_scale"] = torch.empty(layer.num_kv_heads *
|
||||
layer.head_size,
|
||||
dtype=torch.float16,
|
||||
requires_grad=False)
|
||||
for weight_name, weight_param in param_dict.items():
|
||||
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
||||
layer.register_parameter(weight_name, param)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
self.antiquant_scale_comb = torch.cat(
|
||||
(layer.key_antiquant_scale.data.unsqueeze(0),
|
||||
layer.value_antiquant_scale.data.unsqueeze(0)),
|
||||
dim=0).to(torch.float16).contiguous()
|
||||
|
||||
def apply(self, layer, query, key, value, kv_cache, attn_metadata,
|
||||
attn_type, scale, output) -> torch.Tensor:
|
||||
num_tokens = query.shape[0]
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, layer.num_heads * layer.head_size)
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
# C8
|
||||
quant_key = quant_per_tensor(
|
||||
key.view(-1, layer.num_kv_heads * layer.head_size),
|
||||
layer.key_antiquant_scale.data.view(-1), None, True)
|
||||
quant_value = quant_per_tensor(
|
||||
value.view(-1, layer.num_kv_heads * layer.head_size),
|
||||
layer.value_antiquant_scale.data.view(-1), None, True)
|
||||
|
||||
# View q k v to BSH.
|
||||
query = query.view(-1, layer.num_heads, layer.head_size)
|
||||
key = key.view(-1, layer.num_kv_heads, layer.head_size)
|
||||
value = value.view(-1, layer.num_kv_heads, layer.head_size)
|
||||
# TODO: Remove this contiguous in the future.
|
||||
value = value.contiguous()
|
||||
|
||||
if kv_cache[0].numel() > 0:
|
||||
# if key_cache is None:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
|
||||
block_size = key_cache.shape[1]
|
||||
slots_indices = slots.reshape(-1, 1)
|
||||
block_indices = slots_indices // block_size
|
||||
slots_indices = slots_indices % block_size
|
||||
indices = torch.cat((block_indices, slots_indices), dim=1)
|
||||
|
||||
# C8
|
||||
torch_npu.npu_scatter_nd_update_(key_cache, indices, quant_key)
|
||||
torch_npu.npu_scatter_nd_update_(value_cache, indices, quant_value)
|
||||
|
||||
# V0-Style scheduler situation.
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=scale,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
out=output.reshape(query.shape))
|
||||
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
raise NotImplementedError("kv cache int8 are not "
|
||||
"implemented for "
|
||||
"PrefillCacheHit")
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||
# torch_air
|
||||
# decode_meta = attn_metadata.decode
|
||||
# seq_lens = decode_meta.seq_lens_list
|
||||
seq_lens = attn_metadata.seq_lens
|
||||
block_size = key_cache.shape[1]
|
||||
query = query.view(num_tokens, 1, layer.num_heads *
|
||||
layer.head_size).contiguous() # changed
|
||||
|
||||
# [num_blocks, block_size, N, D] --> [num_blocks, N, block_size, D]
|
||||
key = key_cache
|
||||
value = value_cache
|
||||
|
||||
output = torch_npu.npu_incre_flash_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_key_value_heads=layer.num_kv_heads,
|
||||
num_heads=layer.num_heads,
|
||||
actual_seq_lengths=seq_lens,
|
||||
scale_value=scale,
|
||||
input_layout='BSH',
|
||||
block_size=block_size,
|
||||
block_table=attn_metadata.block_tables,
|
||||
antiquant_scale=self.antiquant_scale_comb,
|
||||
)
|
||||
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
raise NotImplementedError("kv cache int8 are not "
|
||||
"implemented for "
|
||||
"other case")
|
||||
return output
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w1_input_scale: torch.Tensor,
|
||||
w1_input_offset: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
w2_input_offset: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fused experts with top-k routing.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
||||
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
top_k: Number of experts to select.
|
||||
expert_map: Expert mapping of shape (num_experts,).
|
||||
|
||||
Returns:
|
||||
hidden_states: Hidden states after routing.
|
||||
"""
|
||||
"""
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
"""
|
||||
|
||||
original_dtype = hidden_states.dtype
|
||||
ep_size = get_ep_group().world_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
w1_input_scale, _ = w1_input_scale.max(0)
|
||||
quant_sorted_hidden_states = quant_per_tensor(
|
||||
hidden_states,
|
||||
w1_input_scale,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
if expert_map is not None:
|
||||
expanded_x, expanded_row_idx, expert_token_count, expanded_scale = torch_npu.npu_moe_init_routing_v2(
|
||||
quant_sorted_hidden_states,
|
||||
topk_ids,
|
||||
scale=None,
|
||||
active_num=topk_ids.numel(),
|
||||
expert_capacity=-1,
|
||||
expert_num=local_num_experts,
|
||||
drop_pad_mode=0,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
quant_mode=-1,
|
||||
active_expert_range=[0, local_num_experts],
|
||||
row_idx_type=0,
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"The quantified version of MOE class models "
|
||||
"currently does not support tensor parallelism")
|
||||
if expanded_x.dtype != w1.dtype:
|
||||
w1_input_scale, _ = w1_input_scale.max(0)
|
||||
quant_sorted_hidden_states = quant_per_tensor(
|
||||
expanded_x,
|
||||
w1_input_scale,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
quant_sorted_hidden_states = expanded_x
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[quant_sorted_hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale * w1_input_scale[0]],
|
||||
split_item=2,
|
||||
group_list_type=1,
|
||||
group_type=0,
|
||||
group_list=expert_token_count,
|
||||
output_dtype=original_dtype,
|
||||
)[0]
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
if gate_up_out.dtype != w2.dtype:
|
||||
w2_input_scale, _ = w2_input_scale.max(0)
|
||||
quant_gate_up_out = quant_per_tensor(
|
||||
gate_up_out,
|
||||
w2_input_scale,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
quant_gate_up_out = gate_up_out
|
||||
|
||||
down_out = torch_npu.npu_grouped_matmul(
|
||||
x=[quant_gate_up_out],
|
||||
weight=[w2],
|
||||
scale=[w2_scale * w2_input_scale[0]],
|
||||
split_item=2,
|
||||
group_list_type=1,
|
||||
group_type=0,
|
||||
group_list=expert_token_count,
|
||||
output_dtype=original_dtype,
|
||||
)[0]
|
||||
|
||||
if expert_map is not None:
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
down_out,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=topk_weights.to(down_out.dtype),
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
drop_pad_mode=2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"The quantified version of MOE class models "
|
||||
"currently does not support tensor parallelism")
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
global_num_experts=-1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Select top-k experts based on router logits.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
router_logits: Router logits of shape (num_tokens, num_experts).
|
||||
top_k: Number of experts to select.
|
||||
use_grouped_topk: Whether to group experts before selecting top-k.
|
||||
renormalize: Whether to renormalize the routing weights.
|
||||
topk_group: Number of expert groups to select from.
|
||||
num_expert_group: Number of experts in each group.
|
||||
custom_routing_function: Custom routing function.
|
||||
scoring_func: Scoring function to use.
|
||||
e_score_correction_bias: Correction bias to apply to expert scores.
|
||||
|
||||
Returns:
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported scoring function is provided.
|
||||
"""
|
||||
|
||||
if scoring_func == "softmax":
|
||||
# NOTE: vLLM use dtype=torch.float here
|
||||
topk_weights = router_logits.softmax(dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
topk_weights = router_logits.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
if use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_weights = topk_weights
|
||||
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
|
||||
|
||||
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
|
||||
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
||||
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
|
||||
topk_group)
|
||||
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_weights.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
||||
topk_weights = topk_weights.to(hidden_states.dtype)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def native_grouped_topk(
|
||||
topk_weights: torch.Tensor,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
):
|
||||
topk_group = 0 if topk_group is None else topk_group
|
||||
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
||||
|
||||
num_token = topk_weights.shape[0]
|
||||
grouped_weights = topk_weights.view(num_token, num_expert_group,
|
||||
-1).max(dim=-1).values
|
||||
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
|
||||
k=topk_group,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
topk_group_mask = torch.zeros_like(grouped_weights)
|
||||
topk_group_mask.scatter_(1, topk_group_indices, 1)
|
||||
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
|
||||
num_token, num_expert_group,
|
||||
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
|
||||
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
|
||||
|
||||
return topk_weights
|
||||
|
||||
@@ -49,7 +49,8 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import DeviceMemoryProfiler, LazyLoader, cdiv
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LazyLoader, cdiv)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
@@ -169,6 +170,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
self.chunked_prefill_enabled = True
|
||||
|
||||
if self.cache_config.cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
else:
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype]
|
||||
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
if self.is_multimodal_model:
|
||||
self.inputs_embeds = torch.zeros(
|
||||
@@ -1924,10 +1931,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# TODO: remove this after the OOM issue is located and fixed, otherwise, some model may
|
||||
# encounter OOM issue
|
||||
if isinstance(kv_cache_spec, FullAttentionSpec):
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
if self.vllm_config.additional_config.get(
|
||||
"kv_cache_dtype", None) == 'int8':
|
||||
kv_cache_shape = self.attn_backend.get_bsh_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size)
|
||||
else:
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size)
|
||||
if self.torchair_graph_enabled:
|
||||
layer_kv_cache_nope = torch.zeros(
|
||||
kv_cache_shape[:-1] +
|
||||
@@ -1951,9 +1965,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
acl_format),
|
||||
)
|
||||
else:
|
||||
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
kv_caches[layer_name] = torch.zeros(
|
||||
kv_cache_shape,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device)
|
||||
kv_caches[layer_name] = \
|
||||
torch_npu.npu_format_cast(kv_caches[layer_name], acl_format)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user