Support DP MLA (#1970)

This commit is contained in:
Ke Bao
2024-11-16 17:01:43 +08:00
committed by GitHub
parent 2f2e07439c
commit 976bc302e5
12 changed files with 395 additions and 63 deletions

View File

@@ -56,6 +56,7 @@ global_server_args_dict = {
"disable_mla": ServerArgs.disable_mla,
"torchao_config": ServerArgs.torchao_config,
"disable_nan_detection": ServerArgs.disable_nan_detection,
"enable_dp_attention": ServerArgs.enable_dp_attention,
}
@@ -450,6 +451,9 @@ class ScheduleBatch:
# The sum of all sequence lengths
seq_lens_sum: int = None
# For DP attention
global_num_tokens: Optional[List[int]] = None
# For processing logprobs
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
@@ -858,6 +862,16 @@ class ScheduleBatch:
# Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs)
def prepare_for_idle(self):
self.forward_mode = ForwardMode.IDLE
self.input_ids = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.extend_num_tokens = 0
def prepare_for_decode(self, enable_overlap: bool = False):
self.forward_mode = ForwardMode.DECODE
@@ -969,17 +983,18 @@ class ScheduleBatch:
self.has_grammar = self.has_grammar or other.has_grammar
def get_model_worker_batch(self):
if self.forward_mode.is_decode():
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs]
else:
self.sampling_info.grammars = None
if self.sampling_info is not None:
if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs]
else:
self.sampling_info.grammars = None
global bid
bid += 1
@@ -995,6 +1010,7 @@ class ScheduleBatch:
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
global_num_tokens=self.global_num_tokens,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
@@ -1051,6 +1067,9 @@ class ModelWorkerBatch:
return_logprob: bool
top_logprobs_nums: Optional[List[int]]
# For DP attention
global_num_tokens: Optional[List[int]]
# For extend
extend_num_tokens: Optional[int]
extend_seq_lens: Optional[List[int]]