support pangumoe w8a8c8 and docs (#1477)

### What this PR does / why we need it?
support pangu moe w8a8c8

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed with new added test.

Signed-off-by: zhuyilin <809721801@qq.com>
This commit is contained in:
Zhu Yi Lin
2025-06-28 18:51:07 +08:00
committed by GitHub
parent c59d69d9e6
commit b308a7a258
8 changed files with 689 additions and 50 deletions

View File

@@ -69,6 +69,15 @@ class AscendAttentionBackend(AttentionBackend):
16)
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_bsh_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size, num_kv_heads * head_size)
@staticmethod
def swap_blocks(
src_kv_cache: List[torch.Tensor],
@@ -279,6 +288,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
value=value,
output=output,
layer_name=layer.layer_name)
elif hasattr(layer, 'quant_method'):
output = layer.quant_method.apply(layer, query, key, value,
kv_cache, attn_metadata,
self.attn_type, self.scale,
output)
else:
if attn_metadata is None:
return output.view(num_tokens, self.hidden_size)
@@ -308,11 +324,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
value_cache=self.value_cache,
slot_indices=slots)
if hasattr(layer, 'quant_method'):
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
pass
# V0-Style scheduler situation.
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
assert attn_metadata is not None
assert attn_metadata.attn_mask is not None
mask = attn_metadata.attn_mask
@@ -414,6 +427,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
out=output)
# to make in-place change to the output tensor
if hasattr(layer, 'quant_method'):
output = output.view(num_tokens, self.num_heads, self.head_size)
ori_output[:, :, :] = output[:num_tokens, :, :]
return output.view(num_tokens, self.hidden_size)