Improve tensor parallel performance (#625)
Co-authored-by: Mingyi <wisclmy0611@gmail.com>
This commit is contained in:
@@ -53,7 +53,7 @@ class ModelTpServer:
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
model_port_args: ModelPortArgs,
|
||||
model_overide_args,
|
||||
model_overide_args: dict,
|
||||
):
|
||||
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
|
||||
suppress_other_loggers()
|
||||
@@ -178,7 +178,7 @@ class ModelTpServer:
|
||||
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
||||
|
||||
def exposed_step(self, recv_reqs):
|
||||
if self.tp_size * self.dp_size != 1:
|
||||
if not isinstance(recv_reqs, list):
|
||||
recv_reqs = obtain(recv_reqs)
|
||||
|
||||
try:
|
||||
@@ -206,11 +206,11 @@ class ModelTpServer:
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward_step(self):
|
||||
new_batch = self.get_new_fill_batch()
|
||||
new_batch = self.get_new_prefill_batch()
|
||||
|
||||
if new_batch is not None:
|
||||
# Run a new fill batch
|
||||
self.forward_fill_batch(new_batch)
|
||||
# Run a new prefill batch
|
||||
self.forward_prefill_batch(new_batch)
|
||||
self.cache_filled_batch(new_batch)
|
||||
|
||||
if not new_batch.is_empty():
|
||||
@@ -219,7 +219,7 @@ class ModelTpServer:
|
||||
else:
|
||||
self.running_batch.merge(new_batch)
|
||||
else:
|
||||
# Run decode batch
|
||||
# Run a decode batch
|
||||
if self.running_batch is not None:
|
||||
# Run a few decode batches continuously for reducing overhead
|
||||
for _ in range(global_config.num_continue_decode_steps):
|
||||
@@ -312,7 +312,7 @@ class ModelTpServer:
|
||||
)
|
||||
self.forward_queue.append(req)
|
||||
|
||||
def get_new_fill_batch(self) -> Optional[Batch]:
|
||||
def get_new_prefill_batch(self) -> Optional[Batch]:
|
||||
running_bs = (
|
||||
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||
)
|
||||
@@ -436,7 +436,7 @@ class ModelTpServer:
|
||||
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
|
||||
return new_batch
|
||||
|
||||
def forward_fill_batch(self, batch: Batch):
|
||||
def forward_prefill_batch(self, batch: Batch):
|
||||
# Build batch tensors
|
||||
batch.prepare_for_extend(
|
||||
self.model_config.vocab_size, self.int_token_logit_bias
|
||||
@@ -746,8 +746,8 @@ class ModelTpClient:
|
||||
# Init model
|
||||
assert len(gpu_ids) == 1
|
||||
self.model_server = ModelTpService().exposed_ModelTpServer(
|
||||
0,
|
||||
gpu_ids[0],
|
||||
0,
|
||||
server_args,
|
||||
model_port_args,
|
||||
model_overide_args,
|
||||
|
||||
Reference in New Issue
Block a user