Fuse more ops & Simplify token mapping (#1758)

This commit is contained in:
Lianmin Zheng
2024-10-22 23:20:43 -07:00
committed by GitHub
parent 17536e7e3d
commit ad4125d1a9
9 changed files with 99 additions and 75 deletions

View File

@@ -32,6 +32,15 @@ from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
@torch.compile(dynamic=True)
def resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
class TpModelWorkerClient:
"""A tensor parallel model worker."""
@@ -99,33 +108,25 @@ class TpModelWorkerClient:
# Resolve future tokens in the input
input_ids = model_worker_batch.input_ids
input_ids[:] = torch.where(
input_ids < 0,
self.future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
resolve_future_token_ids(input_ids, self.future_token_ids_map)
# Run forward
logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch
)
self.launch_event.set()
# Update the future token ids map
bs = len(model_worker_batch.seq_lens)
future_next_token_ids = torch.arange(
-(future_token_ids_ct + bs),
-(future_token_ids_ct),
dtype=torch.int32,
device=self.device,
)
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
torch.int32
)
self.future_token_ids_map[
future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
] = next_token_ids
# Copy results to the CPU
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_event = torch.cuda.Event(blocking=True)
copy_event.record()
self.launch_event.set()
self.copy_queue.put((copy_event, next_token_ids))
def copy_thread_func(self):
@@ -149,8 +150,9 @@ class TpModelWorkerClient:
# Allocate output future objects
bs = len(model_worker_batch.seq_lens)
future_next_token_ids = torch.arange(
-(self.future_token_ids_ct + bs),
-(self.future_token_ids_ct),
-(self.future_token_ids_ct + 1),
-(self.future_token_ids_ct + 1 + bs),
-1,
dtype=torch.int32,
device=self.device,
)