Support DP MLA (#1970)
This commit is contained in:
@@ -56,6 +56,7 @@ global_server_args_dict = {
|
||||
"disable_mla": ServerArgs.disable_mla,
|
||||
"torchao_config": ServerArgs.torchao_config,
|
||||
"disable_nan_detection": ServerArgs.disable_nan_detection,
|
||||
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
||||
}
|
||||
|
||||
|
||||
@@ -450,6 +451,9 @@ class ScheduleBatch:
|
||||
# The sum of all sequence lengths
|
||||
seq_lens_sum: int = None
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
|
||||
# For processing logprobs
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
@@ -858,6 +862,16 @@ class ScheduleBatch:
|
||||
# Reset the encoder cached status
|
||||
self.encoder_cached = [True] * len(self.reqs)
|
||||
|
||||
def prepare_for_idle(self):
|
||||
self.forward_mode = ForwardMode.IDLE
|
||||
self.input_ids = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.extend_num_tokens = 0
|
||||
|
||||
def prepare_for_decode(self, enable_overlap: bool = False):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
|
||||
@@ -969,17 +983,18 @@ class ScheduleBatch:
|
||||
self.has_grammar = self.has_grammar or other.has_grammar
|
||||
|
||||
def get_model_worker_batch(self):
|
||||
if self.forward_mode.is_decode():
|
||||
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||
else:
|
||||
extend_seq_lens = self.extend_lens
|
||||
extend_prefix_lens = self.prefix_lens
|
||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||
|
||||
if self.has_grammar:
|
||||
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
||||
else:
|
||||
self.sampling_info.grammars = None
|
||||
if self.sampling_info is not None:
|
||||
if self.has_grammar:
|
||||
self.sampling_info.grammars = [req.grammar for req in self.reqs]
|
||||
else:
|
||||
self.sampling_info.grammars = None
|
||||
|
||||
global bid
|
||||
bid += 1
|
||||
@@ -995,6 +1010,7 @@ class ScheduleBatch:
|
||||
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
extend_num_tokens=self.extend_num_tokens,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
@@ -1051,6 +1067,9 @@ class ModelWorkerBatch:
|
||||
return_logprob: bool
|
||||
top_logprobs_nums: Optional[List[int]]
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]]
|
||||
|
||||
# For extend
|
||||
extend_num_tokens: Optional[int]
|
||||
extend_seq_lens: Optional[List[int]]
|
||||
|
||||
Reference in New Issue
Block a user