Improve tensor parallel performance (#625)

Co-authored-by: Mingyi <wisclmy0611@gmail.com>
This commit is contained in:
Ying Sheng
2024-07-15 07:10:51 -07:00
committed by GitHub
parent 5ac8b80677
commit 6a2941f4d0
10 changed files with 171 additions and 81 deletions

View File

@@ -53,7 +53,7 @@ class ModelTpServer:
tp_rank: int,
server_args: ServerArgs,
model_port_args: ModelPortArgs,
model_overide_args,
model_overide_args: dict,
):
server_args, model_port_args = obtain(server_args), obtain(model_port_args)
suppress_other_loggers()
@@ -178,7 +178,7 @@ class ModelTpServer:
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
def exposed_step(self, recv_reqs):
if self.tp_size * self.dp_size != 1:
if not isinstance(recv_reqs, list):
recv_reqs = obtain(recv_reqs)
try:
@@ -206,11 +206,11 @@ class ModelTpServer:
@torch.inference_mode()
def forward_step(self):
new_batch = self.get_new_fill_batch()
new_batch = self.get_new_prefill_batch()
if new_batch is not None:
# Run a new fill batch
self.forward_fill_batch(new_batch)
# Run a new prefill batch
self.forward_prefill_batch(new_batch)
self.cache_filled_batch(new_batch)
if not new_batch.is_empty():
@@ -219,7 +219,7 @@ class ModelTpServer:
else:
self.running_batch.merge(new_batch)
else:
# Run decode batch
# Run a decode batch
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(global_config.num_continue_decode_steps):
@@ -312,7 +312,7 @@ class ModelTpServer:
)
self.forward_queue.append(req)
def get_new_fill_batch(self) -> Optional[Batch]:
def get_new_prefill_batch(self) -> Optional[Batch]:
running_bs = (
len(self.running_batch.reqs) if self.running_batch is not None else 0
)
@@ -436,7 +436,7 @@ class ModelTpServer:
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
return new_batch
def forward_fill_batch(self, batch: Batch):
def forward_prefill_batch(self, batch: Batch):
# Build batch tensors
batch.prepare_for_extend(
self.model_config.vocab_size, self.int_token_logit_bias
@@ -746,8 +746,8 @@ class ModelTpClient:
# Init model
assert len(gpu_ids) == 1
self.model_server = ModelTpService().exposed_ModelTpServer(
0,
gpu_ids[0],
0,
server_args,
model_port_args,
model_overide_args,