feat: mtp support dp-attention (#6081)

Co-authored-by: austindeng <austindeng@tencent.com>
Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com>
Co-authored-by: Qiaolin Yu <liin1211@outlook.com>
Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
u4lr451
2025-06-17 15:33:28 +08:00
committed by GitHub
parent 8a10c4c3d9
commit 10d60cd41b
22 changed files with 641 additions and 151 deletions

View File

@@ -242,13 +242,13 @@ class CudaGraphRunner:
# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
if global_server_args_dict["attention_backend"] == "flashmla":
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
else:
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
self.model_runner.attn_backend.init_cuda_graph_state(
self.max_bs, self.max_num_token
)
self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0
self.seq_lens_cpu = torch.full(
@@ -323,12 +323,15 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm:
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
total_global_tokens in self.graphs
total_batch_size in self.graphs
if self.disable_padding
else total_global_tokens <= self.max_bs
else total_batch_size <= self.max_bs
)
else:
is_bs_supported = (
@@ -460,7 +463,7 @@ class CudaGraphRunner:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < bs % self.dp_size)
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
for i in range(self.dp_size)
],
dtype=torch.int32,
@@ -605,9 +608,12 @@ class CudaGraphRunner:
# Pad
if self.enable_dp_attention or self.enable_sp_layernorm:
index = bisect.bisect_left(
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
index = bisect.bisect_left(self.capture_bs, total_batch_size)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
@@ -650,13 +656,13 @@ class CudaGraphRunner:
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs,
self.req_pool_indices,
self.seq_lens,
self.req_pool_indices[:bs],
self.seq_lens[:bs],
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
self.encoder_lens,
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
forward_batch.forward_mode,
forward_batch.spec_info,
seq_lens_cpu=self.seq_lens_cpu,
seq_lens_cpu=self.seq_lens_cpu[:bs],
)
# Store fields

View File

@@ -320,17 +320,30 @@ class ForwardBatch:
# For DP attention
if batch.global_num_tokens is not None:
ret.global_num_tokens_cpu = batch.global_num_tokens
spec_num_draft_tokens = (
batch.spec_num_draft_tokens
if batch.spec_num_draft_tokens is not None
else 1
)
global_num_tokens = [
x * spec_num_draft_tokens for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
]
ret.global_num_tokens_cpu = global_num_tokens
ret.global_num_tokens_gpu = torch.tensor(
batch.global_num_tokens, dtype=torch.int64
global_num_tokens, dtype=torch.int64
).to(device, non_blocking=True)
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
batch.global_num_tokens_for_logprob, dtype=torch.int64
global_num_tokens_for_logprob, dtype=torch.int64
).to(device, non_blocking=True)
sum_len = sum(batch.global_num_tokens)
sum_len = sum(global_num_tokens)
ret.gathered_buffer = torch.zeros(
(sum_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,

View File

@@ -163,6 +163,7 @@ class ModelRunner:
logger.addFilter(RankZeroFilter(tp_rank == 0))
self.tp_rank = tp_rank
self.tp_size = tp_size
self.dp_size = server_args.dp_size
self.pp_rank = pp_rank
self.pp_size = pp_size
self.dist_port = nccl_port
@@ -196,6 +197,7 @@ class ModelRunner:
| {
# TODO it is indeed not a "server args"
"use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm,
}
)