[Feat] support basic pcp&dcp for qwen3next (#6091)
### What this PR does / why we need it?
This PR implements Context Parallelism (CP) support for the Qwen3-Next
model, including PCP (Parallel Context Parallelism) and DCP
(Dynamic/Data Context Parallelism).
- vLLM version: v0.15.0
- vLLM main:
f176443446
---------
Signed-off-by: SunnyLee219 <3294305115@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
Signed-off-by: Bai Yongbin <845473182@qq.com>
Co-authored-by: SunnyLee219 <3294305115@qq.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: 白永斌 <baiyongbin3@h-partners.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -892,10 +892,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def reshape_and_cache(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: tuple[torch.Tensor],
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
if len(kv_cache) > 1:
|
||||
if self.is_kv_producer:
|
||||
@@ -915,7 +917,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
)
|
||||
if self.is_kv_producer:
|
||||
attn_metadata.reshape_cache_event.record()
|
||||
return key, value
|
||||
return query, key, value, output
|
||||
|
||||
def forward_impl(
|
||||
self,
|
||||
@@ -970,12 +972,20 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
num_tokens = query.shape[0]
|
||||
if attn_metadata is None:
|
||||
return output.fill_(0)
|
||||
output_padded = None
|
||||
if key is not None and value is not None:
|
||||
key, value = self.reshape_and_cache(key, value, kv_cache, attn_metadata)
|
||||
output_padded = output
|
||||
query, key, value, output_padded = self.reshape_and_cache(
|
||||
query, key, value, kv_cache, attn_metadata, output
|
||||
)
|
||||
# pooling model branch
|
||||
if attn_metadata.model_runner_type == "pooling":
|
||||
attn_output = self._forward_encoder_attention(query, key, value, attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output)
|
||||
if output_padded is not None:
|
||||
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output_padded)
|
||||
else:
|
||||
attn_output = self.forward_impl(query, key, value, kv_cache, attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user