support pangumoe w8a8c8 and docs (#1477)
### What this PR does / why we need it? support pangu moe w8a8c8 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with new added test. Signed-off-by: zhuyilin <809721801@qq.com>
This commit is contained in:
@@ -69,6 +69,15 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
16)
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_bsh_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: List[torch.Tensor],
|
||||
@@ -279,6 +288,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
value=value,
|
||||
output=output,
|
||||
layer_name=layer.layer_name)
|
||||
|
||||
elif hasattr(layer, 'quant_method'):
|
||||
output = layer.quant_method.apply(layer, query, key, value,
|
||||
kv_cache, attn_metadata,
|
||||
self.attn_type, self.scale,
|
||||
output)
|
||||
|
||||
else:
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
@@ -308,11 +324,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
|
||||
if hasattr(layer, 'quant_method'):
|
||||
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
|
||||
pass
|
||||
# V0-Style scheduler situation.
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
@@ -414,6 +427,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
out=output)
|
||||
|
||||
# to make in-place change to the output tensor
|
||||
if hasattr(layer, 'quant_method'):
|
||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||
ori_output[:, :, :] = output[:num_tokens, :, :]
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user