feat: mtp support dp-attention (#6081)
Co-authored-by: austindeng <austindeng@tencent.com> Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com> Co-authored-by: Qiaolin Yu <liin1211@outlook.com> Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -242,13 +242,13 @@ class CudaGraphRunner:
|
||||
# Attention backend
|
||||
self.max_bs = max(self.capture_bs)
|
||||
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
||||
if global_server_args_dict["attention_backend"] == "flashmla":
|
||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
||||
else:
|
||||
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
|
||||
self.model_runner.attn_backend.init_cuda_graph_state(
|
||||
self.max_bs, self.max_num_token
|
||||
)
|
||||
self.seq_len_fill_value = (
|
||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||
)
|
||||
|
||||
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
||||
self.encoder_len_fill_value = 0
|
||||
self.seq_lens_cpu = torch.full(
|
||||
@@ -323,12 +323,15 @@ class CudaGraphRunner:
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
|
||||
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
else sum(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
||||
total_global_tokens in self.graphs
|
||||
total_batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else total_global_tokens <= self.max_bs
|
||||
else total_batch_size <= self.max_bs
|
||||
)
|
||||
else:
|
||||
is_bs_supported = (
|
||||
@@ -460,7 +463,7 @@ class CudaGraphRunner:
|
||||
self.global_num_tokens_gpu.copy_(
|
||||
torch.tensor(
|
||||
[
|
||||
num_tokens // self.dp_size + (i < bs % self.dp_size)
|
||||
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
||||
for i in range(self.dp_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
@@ -605,9 +608,12 @@ class CudaGraphRunner:
|
||||
|
||||
# Pad
|
||||
if self.enable_dp_attention or self.enable_sp_layernorm:
|
||||
index = bisect.bisect_left(
|
||||
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
|
||||
total_batch_size = (
|
||||
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
|
||||
if self.model_runner.spec_algorithm.is_eagle()
|
||||
else sum(forward_batch.global_num_tokens_cpu)
|
||||
)
|
||||
index = bisect.bisect_left(self.capture_bs, total_batch_size)
|
||||
else:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
@@ -650,13 +656,13 @@ class CudaGraphRunner:
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
self.req_pool_indices,
|
||||
self.seq_lens,
|
||||
self.req_pool_indices[:bs],
|
||||
self.seq_lens[:bs],
|
||||
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
|
||||
self.encoder_lens,
|
||||
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
|
||||
forward_batch.forward_mode,
|
||||
forward_batch.spec_info,
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
seq_lens_cpu=self.seq_lens_cpu[:bs],
|
||||
)
|
||||
|
||||
# Store fields
|
||||
|
||||
@@ -320,17 +320,30 @@ class ForwardBatch:
|
||||
|
||||
# For DP attention
|
||||
if batch.global_num_tokens is not None:
|
||||
ret.global_num_tokens_cpu = batch.global_num_tokens
|
||||
|
||||
spec_num_draft_tokens = (
|
||||
batch.spec_num_draft_tokens
|
||||
if batch.spec_num_draft_tokens is not None
|
||||
else 1
|
||||
)
|
||||
global_num_tokens = [
|
||||
x * spec_num_draft_tokens for x in batch.global_num_tokens
|
||||
]
|
||||
global_num_tokens_for_logprob = [
|
||||
x * spec_num_draft_tokens for x in batch.global_num_tokens_for_logprob
|
||||
]
|
||||
|
||||
ret.global_num_tokens_cpu = global_num_tokens
|
||||
ret.global_num_tokens_gpu = torch.tensor(
|
||||
batch.global_num_tokens, dtype=torch.int64
|
||||
global_num_tokens, dtype=torch.int64
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
|
||||
ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
|
||||
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
|
||||
batch.global_num_tokens_for_logprob, dtype=torch.int64
|
||||
global_num_tokens_for_logprob, dtype=torch.int64
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
sum_len = sum(batch.global_num_tokens)
|
||||
sum_len = sum(global_num_tokens)
|
||||
ret.gathered_buffer = torch.zeros(
|
||||
(sum_len, model_runner.model_config.hidden_size),
|
||||
dtype=model_runner.dtype,
|
||||
|
||||
@@ -163,6 +163,7 @@ class ModelRunner:
|
||||
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = tp_size
|
||||
self.dp_size = server_args.dp_size
|
||||
self.pp_rank = pp_rank
|
||||
self.pp_size = pp_size
|
||||
self.dist_port = nccl_port
|
||||
@@ -196,6 +197,7 @@ class ModelRunner:
|
||||
| {
|
||||
# TODO it is indeed not a "server args"
|
||||
"use_mla_backend": self.use_mla_backend,
|
||||
"speculative_algorithm": self.spec_algorithm,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user