[Perf] Supports compute-communication overlap in the forward of sfa_v1 in the Sharded-CP feature. (#5701)
### What this PR does / why we need it?
> Extracted from PR #5513
Based on the Sharded-CP feature PR:#4702;
RFC:https://github.com/vllm-project/vllm/issues/30055
### All-gather KV Cache for Communication Overlap:
- This PR adjusts the calculation order in the SFA.
- split `index_select` into `indexer_select_pre_process` and
`indexer_select_post_process`.
- Combine `nope`, `rope` and `index-k` into a tensor to perform
asynchronous all-gather.
### benchmark:
input=40k && num_batch_token=20k
- before:
```
Mean TTFT (ms): 2614.52
Median TTFT (ms): 3148.03
P50 TTFT (ms): 3148.03
P90 TTFT (ms): 3163.48
P99 TTFT (ms): 3170.20
```
- after:
```
Mean TTFT (ms): 2529.92
Median TTFT (ms): 3051.69
P50 TTFT (ms): 3051.69
P90 TTFT (ms): 3067.31
P99 TTFT (ms): 3072.15
```
### Does this PR introduce _any_ user-facing change?
None
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
This commit is contained in:
@@ -24,6 +24,7 @@ from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
|
|||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
trans_rope_weight, transdata,
|
trans_rope_weight, transdata,
|
||||||
wait_for_kv_layer_from_connector)
|
wait_for_kv_layer_from_connector)
|
||||||
|
from vllm_ascend.distributed.utils import all_gather_async
|
||||||
from vllm_ascend.ops.layer_shard_linear import (
|
from vllm_ascend.ops.layer_shard_linear import (
|
||||||
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
|
is_hidden_layer, post_process_after_loading_for_shard_weight_series,
|
||||||
reach_layer_for_shard_weight_series,
|
reach_layer_for_shard_weight_series,
|
||||||
@@ -227,7 +228,10 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]):
|
|||||||
|
|
||||||
cos = cos[local_start:local_end_with_pad]
|
cos = cos[local_start:local_end_with_pad]
|
||||||
sin = sin[local_start:local_end_with_pad]
|
sin = sin[local_start:local_end_with_pad]
|
||||||
slot_mapping_cp = slot_mapping[local_start:local_end_with_pad]
|
slot_mapping_cp = torch.full(size=(num_tokens_per_device, ),
|
||||||
|
fill_value=-1,
|
||||||
|
dtype=slot_mapping.dtype,
|
||||||
|
device=slot_mapping.device)
|
||||||
assert cos.shape[0] == num_tokens_per_device, \
|
assert cos.shape[0] == num_tokens_per_device, \
|
||||||
f"cos.shape[0] must be equal to num_tokens_per_device, \
|
f"cos.shape[0] must be equal to num_tokens_per_device, \
|
||||||
got {cos.shape[0]} and {num_tokens_per_device}"
|
got {cos.shape[0]} and {num_tokens_per_device}"
|
||||||
@@ -505,7 +509,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
kv_cache: Tuple,
|
kv_cache: Tuple,
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
slots_cp: Optional[torch.Tensor],
|
|
||||||
):
|
):
|
||||||
B = kv_no_split.shape[0]
|
B = kv_no_split.shape[0]
|
||||||
N = self.num_kv_heads
|
N = self.num_kv_heads
|
||||||
@@ -516,30 +519,19 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
cache_mode = "PA"
|
cache_mode = "PA"
|
||||||
|
|
||||||
if self.enable_sfa_cp:
|
if self.enable_sfa_cp:
|
||||||
assert slots_cp is not None
|
|
||||||
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
_, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||||
kv_no_split,
|
kv_no_split,
|
||||||
self.kv_a_layernorm.weight,
|
self.kv_a_layernorm.weight,
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
slots_cp.to(torch.int64),
|
slots.to(torch.int64),
|
||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
kv_cache[0],
|
kv_cache[0],
|
||||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||||
cache_mode=cache_mode,
|
cache_mode=cache_mode,
|
||||||
is_output_kv=True,
|
is_output_kv=True,
|
||||||
)
|
)
|
||||||
# TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97
|
return k_pe, k_nope
|
||||||
k_pe = get_tp_group().all_gather(k_pe, 0)
|
|
||||||
k_nope = get_tp_group().all_gather(k_nope, 0)
|
|
||||||
|
|
||||||
if kv_cache is not None:
|
|
||||||
torch_npu.npu_scatter_nd_update_(
|
|
||||||
kv_cache[0].view(-1, k_nope.shape[-1]), slots.view(-1, 1),
|
|
||||||
k_nope.view(-1, k_nope.shape[-1]))
|
|
||||||
torch_npu.npu_scatter_nd_update_(
|
|
||||||
kv_cache[1].view(-1, k_pe.shape[-1]), slots.view(-1, 1),
|
|
||||||
k_pe.view(-1, k_pe.shape[-1]))
|
|
||||||
else:
|
else:
|
||||||
torch_npu.npu_kv_rmsnorm_rope_cache(
|
torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||||
kv_no_split,
|
kv_no_split,
|
||||||
@@ -552,6 +544,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||||
cache_mode=cache_mode,
|
cache_mode=cache_mode,
|
||||||
)
|
)
|
||||||
|
return None, None
|
||||||
|
|
||||||
def rope_single(
|
def rope_single(
|
||||||
self,
|
self,
|
||||||
@@ -744,6 +737,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
if is_hidden_layer(layer):
|
if is_hidden_layer(layer):
|
||||||
reach_layer_for_shard_weight_series(layer)
|
reach_layer_for_shard_weight_series(layer)
|
||||||
return output.fill_(0)
|
return output.fill_(0)
|
||||||
|
|
||||||
has_prefill = attn_metadata.has_prefill
|
has_prefill = attn_metadata.has_prefill
|
||||||
cos = attn_metadata.cos
|
cos = attn_metadata.cos
|
||||||
sin = attn_metadata.sin
|
sin = attn_metadata.sin
|
||||||
@@ -764,6 +758,12 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
need_gather_q_kv=need_gather_q_kv,
|
need_gather_q_kv=need_gather_q_kv,
|
||||||
num_input_tokens=attn_metadata.num_input_tokens,
|
num_input_tokens=attn_metadata.num_input_tokens,
|
||||||
)
|
)
|
||||||
|
q, k = self.indexer_select_pre_process(
|
||||||
|
x=hidden_states,
|
||||||
|
qr=q_c,
|
||||||
|
cos=cos,
|
||||||
|
sin=sin,
|
||||||
|
need_gather_q_kv=need_gather_q_kv)
|
||||||
else:
|
else:
|
||||||
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
|
assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
|
||||||
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
|
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
|
||||||
@@ -782,31 +782,67 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
kv_no_split.contiguous(), need_gather_q_kv)
|
kv_no_split.contiguous(), need_gather_q_kv)
|
||||||
|
|
||||||
|
q, k = self.indexer_select_pre_process(
|
||||||
|
x=hidden_states,
|
||||||
|
qr=q_c,
|
||||||
|
cos=cos,
|
||||||
|
sin=sin,
|
||||||
|
need_gather_q_kv=need_gather_q_kv)
|
||||||
|
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
wait_for_kv_layer_from_connector(layer_name)
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
|
|
||||||
slot_mapping = attn_metadata.slot_mapping
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
slot_mapping_cp = None
|
|
||||||
if self.enable_sfa_cp:
|
if self.enable_sfa_cp:
|
||||||
assert attn_metadata.sfa_cp_context is not None
|
assert attn_metadata.sfa_cp_context is not None
|
||||||
slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp
|
slot_mapping = attn_metadata.sfa_cp_context.slot_mapping_cp
|
||||||
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
|
actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query
|
||||||
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key
|
actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key
|
||||||
|
|
||||||
self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping,
|
k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache,
|
||||||
slot_mapping_cp)
|
slot_mapping)
|
||||||
|
|
||||||
if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None:
|
if self.enable_sfa_cp:
|
||||||
for layer in (self.layer_sharding_kwargs or []):
|
assert k_pe is not None
|
||||||
if is_hidden_layer(layer):
|
assert k_nope is not None
|
||||||
reach_layer_for_shard_weight_series(layer)
|
# support all_gather kv async for communication calculation overlap
|
||||||
|
fused_kv_no_split, kv_ag_handle = all_gather_async(
|
||||||
|
torch.cat([
|
||||||
|
k_pe.view(-1, k_pe.shape[-1]),
|
||||||
|
k_nope.view(-1, k_nope.shape[-1]),
|
||||||
|
k.view(-1, k.shape[-1])
|
||||||
|
],
|
||||||
|
dim=1), get_tp_group())
|
||||||
|
|
||||||
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c)
|
||||||
q_pe = self.rope_single(q_pe, cos, sin)
|
q_pe = self.rope_single(q_pe, cos, sin)
|
||||||
|
|
||||||
topk_indices = self.indexer_select(
|
if self.enable_sfa_cp:
|
||||||
|
if kv_ag_handle is not None:
|
||||||
|
kv_ag_handle.wait()
|
||||||
|
for layer in (self.layer_sharding_kwargs or []):
|
||||||
|
if is_hidden_layer(layer):
|
||||||
|
reach_layer_for_shard_weight_series(layer)
|
||||||
|
|
||||||
|
if kv_cache is not None:
|
||||||
|
assert fused_kv_no_split is not None
|
||||||
|
k_pe, k_nope, k = fused_kv_no_split.split([
|
||||||
|
self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim
|
||||||
|
],
|
||||||
|
dim=-1)
|
||||||
|
slot_mapping = attn_metadata.slot_mapping.view(-1, 1)
|
||||||
|
torch_npu.npu_scatter_nd_update_(
|
||||||
|
kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping,
|
||||||
|
k_nope)
|
||||||
|
torch_npu.npu_scatter_nd_update_(
|
||||||
|
kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping,
|
||||||
|
k_pe)
|
||||||
|
|
||||||
|
topk_indices = self.indexer_select_post_process(
|
||||||
x=hidden_states,
|
x=hidden_states,
|
||||||
qr=q_c,
|
qr=q_c,
|
||||||
|
q=q,
|
||||||
|
k=k,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
cos=cos,
|
cos=cos,
|
||||||
@@ -814,6 +850,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||||
need_gather_q_kv=need_gather_q_kv)
|
need_gather_q_kv=need_gather_q_kv)
|
||||||
|
|
||||||
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
|
attn_output = torch.ops._C_ascend.npu_sparse_flash_attention(
|
||||||
query=ql_nope,
|
query=ql_nope,
|
||||||
key=kv_cache[0],
|
key=kv_cache[0],
|
||||||
@@ -830,6 +867,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
layout_kv="PA_BSND",
|
layout_kv="PA_BSND",
|
||||||
sparse_mode=3,
|
sparse_mode=3,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = self._v_up_proj(attn_output, has_prefill)
|
attn_output = self._v_up_proj(attn_output, has_prefill)
|
||||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||||
dependency=attn_output,
|
dependency=attn_output,
|
||||||
@@ -838,22 +876,14 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
output[...] = self.o_proj(attn_output)[0]
|
output[...] = self.o_proj(attn_output)[0]
|
||||||
return output_padded
|
return output_padded
|
||||||
|
|
||||||
def indexer_select(
|
def indexer_select_pre_process(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
qr: torch.Tensor,
|
qr: torch.Tensor,
|
||||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
||||||
attn_metadata: M,
|
|
||||||
cos: torch.Tensor,
|
cos: torch.Tensor,
|
||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
actual_seq_lengths_query: torch.Tensor,
|
|
||||||
actual_seq_lengths_key: torch.Tensor,
|
|
||||||
need_gather_q_kv: bool = False,
|
need_gather_q_kv: bool = False,
|
||||||
):
|
):
|
||||||
# q process in new stream
|
|
||||||
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
|
||||||
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
|
|
||||||
|
|
||||||
k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
|
k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
|
||||||
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
k_proj, need_gather_q_kv)
|
k_proj, need_gather_q_kv)
|
||||||
@@ -861,6 +891,9 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
k = k.view(-1, 1, self.head_dim)
|
k = k.view(-1, 1, self.head_dim)
|
||||||
|
|
||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
|
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
||||||
|
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
|
||||||
|
|
||||||
cos = cos.view(-1, self.qk_rope_head_dim)
|
cos = cos.view(-1, self.qk_rope_head_dim)
|
||||||
sin = sin.view(-1, self.qk_rope_head_dim)
|
sin = sin.view(-1, self.qk_rope_head_dim)
|
||||||
q, k = rope_forward_triton(q,
|
q, k = rope_forward_triton(q,
|
||||||
@@ -870,6 +903,38 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
rope_dim=self.qk_rope_head_dim,
|
rope_dim=self.qk_rope_head_dim,
|
||||||
is_neox_style=True)
|
is_neox_style=True)
|
||||||
else:
|
else:
|
||||||
|
k_pe, k_nope = torch.split(
|
||||||
|
k,
|
||||||
|
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||||
|
dim=-1) # [b,s,64+64]
|
||||||
|
|
||||||
|
k_pe = k_pe.unsqueeze(2)
|
||||||
|
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
|
||||||
|
k_pe = k_pe.squeeze(2)
|
||||||
|
|
||||||
|
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
|
||||||
|
q = None
|
||||||
|
|
||||||
|
return q, k
|
||||||
|
|
||||||
|
def indexer_select_post_process(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
qr: torch.Tensor,
|
||||||
|
q: Optional[torch.Tensor],
|
||||||
|
k: torch.Tensor,
|
||||||
|
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
|
attn_metadata: M,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
actual_seq_lengths_query: torch.Tensor,
|
||||||
|
actual_seq_lengths_key: torch.Tensor,
|
||||||
|
need_gather_q_kv: bool = False,
|
||||||
|
):
|
||||||
|
if q is None:
|
||||||
|
q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
||||||
|
q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128]
|
||||||
|
|
||||||
cos_q, sin_q = cos, sin
|
cos_q, sin_q = cos, sin
|
||||||
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||||
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||||
@@ -884,20 +949,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
q_pe = q_pe.squeeze(2)
|
q_pe = q_pe.squeeze(2)
|
||||||
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
|
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
|
||||||
|
|
||||||
k_pe, k_nope = torch.split(
|
|
||||||
k,
|
|
||||||
[self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
|
||||||
dim=-1) # [b,s,64+64]
|
|
||||||
|
|
||||||
k_pe = k_pe.unsqueeze(2)
|
|
||||||
k_pe = torch_npu.npu_rotary_mul(k_pe, cos, sin)
|
|
||||||
k_pe = k_pe.squeeze(2)
|
|
||||||
|
|
||||||
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
|
|
||||||
|
|
||||||
if self.enable_sfa_cp:
|
|
||||||
k = get_tp_group().all_gather(k, 0)
|
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
if self.is_kv_producer:
|
if self.is_kv_producer:
|
||||||
attn_metadata.reshape_cache_event = torch.npu.Event()
|
attn_metadata.reshape_cache_event = torch.npu.Event()
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from vllm.distributed.parallel_state import get_dp_group
|
from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
from vllm_ascend.distributed.parallel_state import (get_fc3_quant_x_group,
|
from vllm_ascend.distributed.parallel_state import (get_fc3_quant_x_group,
|
||||||
@@ -90,3 +91,21 @@ def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor:
|
|||||||
offset += num_tokens_dp
|
offset += num_tokens_dp
|
||||||
x = result
|
x = result
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_async(input: torch.Tensor,
|
||||||
|
group: GroupCoordinator,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
async_op: bool = True):
|
||||||
|
if group.world_size == 1:
|
||||||
|
return input, None
|
||||||
|
if output is None:
|
||||||
|
input_size = input.size()
|
||||||
|
output_size = (input_size[0] * group.world_size, ) + input_size[1:]
|
||||||
|
output = torch.empty(output_size,
|
||||||
|
dtype=input.dtype,
|
||||||
|
device=input.device)
|
||||||
|
return output, dist.all_gather_into_tensor(output,
|
||||||
|
input,
|
||||||
|
group=group.device_group,
|
||||||
|
async_op=async_op)
|
||||||
@@ -1172,21 +1172,14 @@ def singleton(cls):
|
|||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def enable_dsa_cp() -> bool:
|
def enable_dsa_cp() -> bool:
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
if vllm_config is None:
|
is_ds_v32 = hasattr(
|
||||||
|
vllm_config.model_config, "hf_text_config") and hasattr(
|
||||||
|
vllm_config.model_config.hf_text_config, "index_topk")
|
||||||
|
if is_ds_v32 and enable_sp():
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
model_config = getattr(vllm_config, "model_config", None)
|
|
||||||
if model_config is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
hf_text_config = getattr(model_config, "hf_text_config", None)
|
|
||||||
if hf_text_config is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return hasattr(hf_text_config, "index_topk")
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
def enable_dsa_cp_with_layer_shard() -> bool:
|
def enable_dsa_cp_with_layer_shard() -> bool:
|
||||||
|
|||||||
Reference in New Issue
Block a user