Fix W8A8 fused moe bug (#1529)
### What this PR does / why we need it? 1. drop some useless code for w8a8 fusedmoe 2. Add in8 kv cache check 3. Add more ut. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with new added test. --------- Signed-off-by: zhuyilin <809721801@qq.com> Signed-off-by: tianyitang <tangtianyi4@huawei.com> Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
@@ -274,6 +274,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
shape = [batch_size * seq_len, num_heads, head_size]
|
||||
"""
|
||||
num_tokens = query.shape[0]
|
||||
use_kv_cache_int8 = kv_cache.numel(
|
||||
) > 0 and kv_cache[0].dtype == torch.int8
|
||||
if output is None:
|
||||
output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
@@ -289,7 +291,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
output=output,
|
||||
layer_name=layer.layer_name)
|
||||
|
||||
elif hasattr(layer, 'quant_method'):
|
||||
elif hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
||||
output = layer.quant_method.apply(layer, query, key, value,
|
||||
kv_cache, attn_metadata,
|
||||
self.attn_type, self.scale,
|
||||
@@ -429,7 +431,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
out=output)
|
||||
|
||||
# to make in-place change to the output tensor
|
||||
if hasattr(layer, 'quant_method'):
|
||||
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||
ori_output[:, :, :] = output[:num_tokens, :, :]
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
@@ -251,7 +251,8 @@ class VLLMAscendQuantizer:
|
||||
# Attention
|
||||
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
|
||||
quant_type = quant_description['fa_quant_type']
|
||||
if '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
|
||||
# Use KVCache int8
|
||||
elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
|
||||
quant_type = quant_description['kv_quant_type']
|
||||
# Linear
|
||||
else:
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
@@ -219,53 +218,34 @@ class AscendW8A8FusedMoEMethod:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k,
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group,
|
||||
group_count=num_expert_group,
|
||||
group_select_mode=1,
|
||||
renorm=0,
|
||||
norm_type=1,
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
else:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
|
||||
if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
|
||||
raise NotImplementedError("W8A8FusedMoe are not "
|
||||
"mplemented for VLLM_ENABLE_MC2")
|
||||
|
||||
else:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w1_input_scale=layer.w13_input_scale,
|
||||
w1_input_offset=layer.w13_input_offset,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
w2_input_offset=layer.w2_input_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w1_input_scale=layer.w13_input_scale,
|
||||
w1_input_offset=layer.w13_input_offset,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
w2_input_offset=layer.w2_input_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
# torch.npu.config.allow_internal_format = True
|
||||
@@ -299,8 +279,10 @@ class AscendW8A8FusedMoEMethod:
|
||||
torch.int8)
|
||||
|
||||
# NZ
|
||||
# layer.w13_weight.data = torch_npu.npu_format_cast(layer.w13_weight.data, 29).contiguous()
|
||||
# layer.w2_weight.data = torch_npu.npu_format_cast(layer.w2_weight.data, 29).contiguous()
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, 29).contiguous()
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, 29).contiguous()
|
||||
|
||||
|
||||
class AscendC8KVCacheMethod:
|
||||
|
||||
@@ -2194,7 +2194,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=attn_module.dtype,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
|
||||
Reference in New Issue
Block a user