Fix W8A8 fused moe bug (#1529)
### What this PR does / why we need it? 1. drop some useless code for w8a8 fusedmoe 2. Add in8 kv cache check 3. Add more ut. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with new added test. --------- Signed-off-by: zhuyilin <809721801@qq.com> Signed-off-by: tianyitang <tangtianyi4@huawei.com> Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
@@ -274,6 +274,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
shape = [batch_size * seq_len, num_heads, head_size]
|
||||
"""
|
||||
num_tokens = query.shape[0]
|
||||
use_kv_cache_int8 = kv_cache.numel(
|
||||
) > 0 and kv_cache[0].dtype == torch.int8
|
||||
if output is None:
|
||||
output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
@@ -289,7 +291,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
output=output,
|
||||
layer_name=layer.layer_name)
|
||||
|
||||
elif hasattr(layer, 'quant_method'):
|
||||
elif hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
||||
output = layer.quant_method.apply(layer, query, key, value,
|
||||
kv_cache, attn_metadata,
|
||||
self.attn_type, self.scale,
|
||||
@@ -429,7 +431,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
out=output)
|
||||
|
||||
# to make in-place change to the output tensor
|
||||
if hasattr(layer, 'quant_method'):
|
||||
if hasattr(layer, 'quant_method') and use_kv_cache_int8:
|
||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||
ori_output[:, :, :] = output[:num_tokens, :, :]
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
Reference in New Issue
Block a user