Support DP MLA (#1970)
This commit is contained in:
@@ -110,7 +110,7 @@ class Scheduler:
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
|
||||
if self.tp_rank == 0:
|
||||
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||
self.recv_from_tokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
||||
)
|
||||
@@ -347,6 +347,10 @@ class Scheduler:
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
batch = self.get_next_batch_to_run()
|
||||
|
||||
if self.server_args.enable_dp_attention:
|
||||
batch = self.prepare_dp_attn_batch(batch)
|
||||
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
@@ -361,6 +365,8 @@ class Scheduler:
|
||||
self.update_running_batch()
|
||||
if not self.running_batch:
|
||||
break
|
||||
if self.server_args.enable_dp_attention:
|
||||
batch = self.prepare_dp_attn_batch(batch)
|
||||
result = self.run_batch(batch)
|
||||
self.process_batch_result(batch, result)
|
||||
else:
|
||||
@@ -396,8 +402,48 @@ class Scheduler:
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
||||
# Check if other DP workers have running batches
|
||||
if local_batch is None:
|
||||
num_tokens = 0
|
||||
elif local_batch.forward_mode.is_decode():
|
||||
num_tokens = local_batch.batch_size()
|
||||
else:
|
||||
num_tokens = local_batch.extend_num_tokens
|
||||
|
||||
local_num_tokens = torch.tensor(
|
||||
num_tokens, dtype=torch.int64, device=self.device
|
||||
)
|
||||
global_num_tokens = torch.empty(
|
||||
self.tp_size, dtype=torch.int64, device=self.device
|
||||
)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
global_num_tokens,
|
||||
local_num_tokens,
|
||||
group=self.tp_worker.get_tp_device_group(),
|
||||
)
|
||||
|
||||
if local_batch is None and global_num_tokens.max().item() > 0:
|
||||
local_batch = self.get_idle_batch()
|
||||
|
||||
if local_batch is not None:
|
||||
local_batch.global_num_tokens = global_num_tokens.tolist()
|
||||
|
||||
return local_batch
|
||||
|
||||
def get_idle_batch(self):
|
||||
idle_batch = ScheduleBatch.init_new(
|
||||
[],
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
)
|
||||
idle_batch.prepare_for_idle()
|
||||
return idle_batch
|
||||
|
||||
def recv_requests(self):
|
||||
if self.tp_rank == 0:
|
||||
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||
recv_reqs = []
|
||||
|
||||
while True:
|
||||
@@ -409,7 +455,7 @@ class Scheduler:
|
||||
else:
|
||||
recv_reqs = None
|
||||
|
||||
if self.tp_size != 1:
|
||||
if self.tp_size != 1 and not self.server_args.enable_dp_attention:
|
||||
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
||||
return recv_reqs
|
||||
|
||||
@@ -812,6 +858,10 @@ class Scheduler:
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
elif batch.forward_mode.is_idle():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
self.tp_worker.forward_batch_idle(model_worker_batch)
|
||||
return
|
||||
else:
|
||||
logits_output = None
|
||||
if self.skip_tokenizer_init:
|
||||
@@ -830,6 +880,8 @@ class Scheduler:
|
||||
return ret
|
||||
|
||||
def process_batch_result(self, batch: ScheduleBatch, result):
|
||||
if batch.forward_mode.is_idle():
|
||||
return
|
||||
if batch.forward_mode.is_decode():
|
||||
self.process_batch_result_decode(batch, result)
|
||||
if batch.is_empty():
|
||||
|
||||
Reference in New Issue
Block a user