Remove chunked_prefill_for_mla and fix ring_mla bug (#2781)
### What this PR does / why we need it?
Remove chunked prefill for mla branch in mla , and change dtype of
prefill_mask to avoid accuracy problem
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
- vLLM version: v0.10.2
- vLLM main:
ef7eefe17a
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
This commit is contained in:
@@ -148,10 +148,6 @@ msgid ""
|
|||||||
" to be passed in."
|
" to be passed in."
|
||||||
msgstr "在为MOE模型使用专家负载均衡时,需要传入专家映射路径。"
|
msgstr "在为MOE模型使用专家负载均衡时,需要传入专家映射路径。"
|
||||||
|
|
||||||
#: ../../user_guide/configuration/additional_config.md
|
|
||||||
msgid "`chunked_prefill_for_mla`"
|
|
||||||
msgstr "`chunked_prefill_for_mla`"
|
|
||||||
|
|
||||||
#: ../../user_guide/configuration/additional_config.md
|
#: ../../user_guide/configuration/additional_config.md
|
||||||
msgid "`False`"
|
msgid "`False`"
|
||||||
msgstr "`False`"
|
msgstr "`False`"
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ The following table lists the additional configuration options available in vLLM
|
|||||||
| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler |
|
| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler |
|
||||||
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
|
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
|
||||||
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
|
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
|
||||||
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
|
|
||||||
| `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. |
|
| `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. |
|
||||||
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
|
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
|
||||||
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |
|
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |
|
||||||
|
|||||||
@@ -70,9 +70,7 @@ vllm serve /models/deepseek_r1_w8a8 \
|
|||||||
"kv_port": "20001",
|
"kv_port": "20001",
|
||||||
"engine_id": "0",
|
"engine_id": "0",
|
||||||
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
|
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
|
||||||
}' \
|
}'
|
||||||
--additional-config \
|
|
||||||
'{"chunked_prefill_for_mla":true}'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Run prefill server P2 on second node:
|
Run prefill server P2 on second node:
|
||||||
@@ -114,9 +112,7 @@ vllm serve /models/deepseek_r1_w8a8 \
|
|||||||
"kv_port": "20001",
|
"kv_port": "20001",
|
||||||
"engine_id": "0",
|
"engine_id": "0",
|
||||||
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
|
"kv_connector_module_path": "vllm_ascend.distributed.llmdatadist_c_mgr_connector"
|
||||||
}' \
|
}'
|
||||||
--additional-config \
|
|
||||||
'{"chunked_prefill_for_mla":true}'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Run decode server d1 on third node:
|
Run decode server d1 on third node:
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
|||||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
|
||||||
from vllm_ascend.utils import npu_prefetch
|
from vllm_ascend.utils import npu_prefetch
|
||||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||||
|
|
||||||
@@ -491,7 +490,6 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
self.enable_prefetch = ascend_config.enable_prefetch
|
self.enable_prefetch = ascend_config.enable_prefetch
|
||||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||||
self.chunked_prefill_for_mla = ascend_config.chunked_prefill_for_mla
|
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.ring_mla_mask_size = 512
|
self.ring_mla_mask_size = 512
|
||||||
@@ -673,84 +671,47 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.v_head_dim,
|
self.v_head_dim,
|
||||||
dtype=q_nope.dtype,
|
dtype=q_nope.dtype,
|
||||||
device=q_nope.device)
|
device=q_nope.device)
|
||||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
attn_lse = torch.empty(self.num_heads,
|
||||||
query = torch.cat((q_nope, q_pe), dim=-1)
|
num_tokens,
|
||||||
key = torch.cat((k_nope, k_pe), dim=-1)
|
dtype=torch.float32,
|
||||||
torch_npu._npu_flash_attention(
|
device=q_nope.device)
|
||||||
query=query,
|
if self.prefill_mask is None:
|
||||||
key=key,
|
if q_nope.dtype == torch.float16:
|
||||||
value=value,
|
mask_value = torch.finfo(torch.float32).min
|
||||||
mask=attn_metadata.attn_mask,
|
else:
|
||||||
seq_len=attn_metadata.prefill.context_lens,
|
mask_value = 1
|
||||||
scale_value=self.scale,
|
prefill_mask = torch.triu(
|
||||||
num_heads=self.num_heads,
|
torch.ones(self.ring_mla_mask_size,
|
||||||
num_kv_heads=self.num_heads,
|
self.ring_mla_mask_size,
|
||||||
out=attn_output)
|
device=q_nope.device,
|
||||||
elif self.chunked_prefill_for_mla:
|
dtype=q_nope.dtype), 1)
|
||||||
attn_lse = torch.empty(self.num_heads,
|
self.prefill_mask = torch.where(prefill_mask == 1, mask_value,
|
||||||
num_tokens,
|
0).to(q_nope.dtype)
|
||||||
dtype=torch.float32,
|
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
|
||||||
device=q_nope.device)
|
q_rope=q_pe,
|
||||||
if self.prefill_mask is None:
|
k_nope=k_nope,
|
||||||
self.prefill_mask = torch.triu(
|
k_rope=k_pe,
|
||||||
torch.ones(self.ring_mla_mask_size,
|
value=value,
|
||||||
self.ring_mla_mask_size,
|
mask=self.prefill_mask,
|
||||||
device=q_nope.device,
|
seqlen=torch.tensor(
|
||||||
dtype=q_nope.dtype), 1)
|
attn_metadata.prefill.query_lens,
|
||||||
torch_npu.atb.npu_ring_mla(
|
dtype=torch.int32),
|
||||||
q_nope=q_nope,
|
head_num=self.num_heads,
|
||||||
q_rope=q_pe,
|
kv_head_num=self.num_heads,
|
||||||
k_nope=k_nope,
|
pre_out=None,
|
||||||
k_rope=k_pe,
|
prev_lse=None,
|
||||||
value=value,
|
qk_scale=self.scale,
|
||||||
mask=self.prefill_mask,
|
kernel_type="kernel_type_high_precision",
|
||||||
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
|
mask_type="mask_type_triu",
|
||||||
dtype=torch.int32),
|
input_layout="type_bsnd",
|
||||||
head_num=self.num_heads,
|
calc_type="calc_type_first_ring",
|
||||||
kv_head_num=self.num_heads,
|
output=attn_output,
|
||||||
pre_out=None,
|
softmax_lse=attn_lse)
|
||||||
prev_lse=None,
|
attn_output, attn_lse = self._compute_prefill_context( \
|
||||||
qk_scale=self.scale,
|
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
||||||
kernel_type="kernel_type_high_precision",
|
|
||||||
mask_type="mask_type_triu",
|
|
||||||
input_layout="type_bsnd",
|
|
||||||
calc_type="calc_type_first_ring",
|
|
||||||
output=attn_output,
|
|
||||||
softmax_lse=attn_lse)
|
|
||||||
attn_output, attn_lse = self._compute_prefill_context( \
|
|
||||||
q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
|
||||||
else:
|
|
||||||
query = torch.cat((q_nope, q_pe), dim=-1)
|
|
||||||
attn_output_torch = torch.empty(num_tokens,
|
|
||||||
self.num_heads * self.v_head_dim,
|
|
||||||
dtype=query.dtype,
|
|
||||||
device=query.device)
|
|
||||||
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
|
||||||
vanilla_chunked_prefill_mla(
|
|
||||||
output=attn_output_torch,
|
|
||||||
query=query,
|
|
||||||
kv_cache=kv_c_and_k_pe_cache,
|
|
||||||
block_tables=attn_metadata.prefill.block_table,
|
|
||||||
query_lens=attn_metadata.prefill.query_lens,
|
|
||||||
context_lens=attn_metadata.prefill.context_lens,
|
|
||||||
kv_b_proj=self.kv_b_proj,
|
|
||||||
max_query_len=attn_metadata.prefill.max_query_len,
|
|
||||||
max_context_len=attn_metadata.prefill.max_seq_lens,
|
|
||||||
nope_dim=self.qk_nope_head_dim,
|
|
||||||
rope_dim=self.qk_rope_head_dim,
|
|
||||||
v_head_dim=self.v_head_dim,
|
|
||||||
scale=self.scale,
|
|
||||||
alibi_slopes=None,
|
|
||||||
causal=True)
|
|
||||||
|
|
||||||
attn_output = attn_output.reshape(
|
attn_output = attn_output.reshape(
|
||||||
[num_tokens, self.num_heads * self.v_head_dim])
|
[num_tokens, self.num_heads * self.v_head_dim])
|
||||||
if attn_metadata.attn_state in [
|
|
||||||
AscendAttentionState.ChunkedPrefill,
|
|
||||||
AscendAttentionState.SpecDecoding,
|
|
||||||
AscendAttentionState.PrefillCacheHit
|
|
||||||
] and not self.chunked_prefill_for_mla:
|
|
||||||
attn_output = attn_output_torch
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
def exec_kv_decode(
|
def exec_kv_decode(
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
|||||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
|
||||||
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
||||||
npu_stream_switch, npu_wait_tensor)
|
npu_stream_switch, npu_wait_tensor)
|
||||||
from vllm_ascend.utils import npu_prefetch
|
from vllm_ascend.utils import npu_prefetch
|
||||||
@@ -674,6 +673,8 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
|||||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
self.running_in_graph = False
|
self.running_in_graph = False
|
||||||
|
self.prefill_mask = None
|
||||||
|
self.ring_mla_mask_size = 512
|
||||||
|
|
||||||
# Adapt torch air graph mode with spec decoding.
|
# Adapt torch air graph mode with spec decoding.
|
||||||
speculative_config = get_current_vllm_config().speculative_config
|
speculative_config = get_current_vllm_config().speculative_config
|
||||||
@@ -820,16 +821,13 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
|||||||
k_nope, v = kv_nope\
|
k_nope, v = kv_nope\
|
||||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||||
mask = torch.triu(
|
|
||||||
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
|
||||||
1)
|
|
||||||
torch_npu.atb.npu_ring_mla(
|
torch_npu.atb.npu_ring_mla(
|
||||||
q_nope=q_nope,
|
q_nope=q_nope,
|
||||||
q_rope=q_pe,
|
q_rope=q_pe,
|
||||||
k_nope=k_nope,
|
k_nope=k_nope,
|
||||||
k_rope=k_pe,
|
k_rope=k_pe,
|
||||||
value=v,
|
value=v,
|
||||||
mask=mask,
|
mask=self.prefill_mask,
|
||||||
seqlen=seq_len,
|
seqlen=seq_len,
|
||||||
head_num=self.num_heads,
|
head_num=self.num_heads,
|
||||||
kv_head_num=self.num_heads,
|
kv_head_num=self.num_heads,
|
||||||
@@ -861,104 +859,54 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
|||||||
self.v_head_dim,
|
self.v_head_dim,
|
||||||
dtype=query.dtype,
|
dtype=query.dtype,
|
||||||
device=query.device)
|
device=query.device)
|
||||||
|
attn_lse = torch.empty(self.num_heads,
|
||||||
|
num_tokens,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=query.device)
|
||||||
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
|
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
|
||||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
|
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
|
||||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||||
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
|
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
|
||||||
ascend_config = get_ascend_config()
|
q_pe = query[..., self.qk_nope_head_dim:]
|
||||||
|
q_nope = query[..., :self.qk_nope_head_dim]
|
||||||
|
if self.prefill_mask is None:
|
||||||
|
if q_nope.dtype == torch.float16:
|
||||||
|
mask_value = torch.finfo(torch.float32).min
|
||||||
|
else:
|
||||||
|
mask_value = 1
|
||||||
|
prefill_mask = torch.triu(
|
||||||
|
torch.ones(self.ring_mla_mask_size,
|
||||||
|
self.ring_mla_mask_size,
|
||||||
|
device=q_nope.device,
|
||||||
|
dtype=q_nope.dtype), 1)
|
||||||
|
self.prefill_mask = torch.where(prefill_mask == 1, mask_value,
|
||||||
|
0).to(q_nope.dtype)
|
||||||
|
torch_npu.atb.npu_ring_mla(q_nope=q_nope,
|
||||||
|
q_rope=q_pe,
|
||||||
|
k_nope=k_nope,
|
||||||
|
k_rope=k_pe,
|
||||||
|
value=value,
|
||||||
|
mask=self.prefill_mask,
|
||||||
|
seqlen=torch.tensor(
|
||||||
|
attn_metadata.prefill.query_lens,
|
||||||
|
dtype=torch.int32),
|
||||||
|
head_num=self.num_heads,
|
||||||
|
kv_head_num=self.num_heads,
|
||||||
|
pre_out=None,
|
||||||
|
prev_lse=None,
|
||||||
|
qk_scale=self.scale,
|
||||||
|
kernel_type="kernel_type_high_precision",
|
||||||
|
mask_type="mask_type_triu",
|
||||||
|
input_layout="type_bsnd",
|
||||||
|
calc_type="calc_type_first_ring",
|
||||||
|
output=attn_output,
|
||||||
|
softmax_lse=attn_lse)
|
||||||
|
attn_output, attn_lse = self._compute_prefill_context( \
|
||||||
|
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
||||||
|
|
||||||
if attn_metadata.attn_state in [
|
|
||||||
AscendAttentionState.ChunkedPrefill,
|
|
||||||
AscendAttentionState.SpecDecoding,
|
|
||||||
AscendAttentionState.PrefillCacheHit
|
|
||||||
] and not ascend_config.chunked_prefill_for_mla:
|
|
||||||
attn_output_torch = torch.empty(num_tokens,
|
|
||||||
self.num_heads * self.v_head_dim,
|
|
||||||
dtype=query.dtype,
|
|
||||||
device=query.device)
|
|
||||||
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
|
||||||
vanilla_chunked_prefill_mla(
|
|
||||||
output=attn_output_torch,
|
|
||||||
query=query,
|
|
||||||
kv_cache=kv_c_and_k_pe_cache,
|
|
||||||
block_tables=attn_metadata.prefill.block_table,
|
|
||||||
query_lens=attn_metadata.prefill.query_lens,
|
|
||||||
context_lens=attn_metadata.prefill.context_lens,
|
|
||||||
kv_b_proj=self.kv_b_proj,
|
|
||||||
max_query_len=attn_metadata.prefill.max_query_len,
|
|
||||||
max_context_len=attn_metadata.prefill.max_seq_lens,
|
|
||||||
nope_dim=self.qk_nope_head_dim,
|
|
||||||
rope_dim=self.qk_rope_head_dim,
|
|
||||||
v_head_dim=self.v_head_dim,
|
|
||||||
scale=self.scale,
|
|
||||||
alibi_slopes=None,
|
|
||||||
causal=True)
|
|
||||||
elif attn_metadata.attn_state in [
|
|
||||||
AscendAttentionState.ChunkedPrefill,
|
|
||||||
AscendAttentionState.SpecDecoding,
|
|
||||||
AscendAttentionState.PrefillCacheHit
|
|
||||||
]:
|
|
||||||
attn_lse = torch.empty(self.num_heads,
|
|
||||||
num_tokens,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=query.device)
|
|
||||||
q_pe = query[..., self.qk_nope_head_dim:]
|
|
||||||
q_nope = query[..., :self.qk_nope_head_dim]
|
|
||||||
mask = torch.triu(
|
|
||||||
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
|
||||||
1) # 512: mask only support 512
|
|
||||||
if attn_metadata.num_prefills > 1:
|
|
||||||
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
|
|
||||||
1)
|
|
||||||
torch_npu.atb.npu_ring_mla(
|
|
||||||
q_nope=q_nope,
|
|
||||||
q_rope=q_pe,
|
|
||||||
k_nope=k_nope,
|
|
||||||
k_rope=k_pe,
|
|
||||||
value=value,
|
|
||||||
mask=mask,
|
|
||||||
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
|
|
||||||
dtype=torch.int32),
|
|
||||||
head_num=self.num_heads,
|
|
||||||
kv_head_num=self.num_heads,
|
|
||||||
pre_out=None,
|
|
||||||
prev_lse=None,
|
|
||||||
qk_scale=self.scale,
|
|
||||||
kernel_type="kernel_type_high_precision",
|
|
||||||
mask_type="mask_type_triu",
|
|
||||||
input_layout="type_bsnd",
|
|
||||||
calc_type="calc_type_first_ring",
|
|
||||||
output=attn_output,
|
|
||||||
softmax_lse=attn_lse)
|
|
||||||
attn_output, attn_lse = self._compute_prefill_context( \
|
|
||||||
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
|
||||||
|
|
||||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
|
||||||
key = torch.cat((k_nope, k_pe), dim=-1)
|
|
||||||
torch_npu._npu_flash_attention(
|
|
||||||
query=query,
|
|
||||||
key=key,
|
|
||||||
value=value,
|
|
||||||
mask=attn_metadata.attn_mask,
|
|
||||||
seq_len=attn_metadata.prefill.context_lens,
|
|
||||||
scale_value=self.scale,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
num_kv_heads=self.num_heads,
|
|
||||||
out=attn_output)
|
|
||||||
attn_output = attn_output.view(-1, self.num_heads, self.v_head_dim)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Unexpected path reached, AscendMLATorchairImpl should only have PrefillNoCache, PrefillCacheHit, ChunkedPrefill and SpecDecoding scenario in forward prefill, please file a bug to vllm-ascend !"
|
|
||||||
)
|
|
||||||
attn_output = attn_output.reshape(
|
attn_output = attn_output.reshape(
|
||||||
[num_tokens, self.num_heads * self.v_head_dim])
|
[num_tokens, self.num_heads * self.v_head_dim])
|
||||||
if attn_metadata.attn_state in [
|
|
||||||
AscendAttentionState.ChunkedPrefill,
|
|
||||||
AscendAttentionState.SpecDecoding,
|
|
||||||
AscendAttentionState.PrefillCacheHit
|
|
||||||
] and not ascend_config.chunked_prefill_for_mla:
|
|
||||||
attn_output = attn_output_torch
|
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user