Fuse more ops & Simplify token mapping (#1758)
This commit is contained in:
@@ -32,6 +32,15 @@ from sglang.srt.server_args import ServerArgs
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def resolve_future_token_ids(input_ids, future_token_ids_map):
|
||||
input_ids[:] = torch.where(
|
||||
input_ids < 0,
|
||||
future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
||||
input_ids,
|
||||
)
|
||||
|
||||
|
||||
class TpModelWorkerClient:
|
||||
"""A tensor parallel model worker."""
|
||||
|
||||
@@ -99,33 +108,25 @@ class TpModelWorkerClient:
|
||||
|
||||
# Resolve future tokens in the input
|
||||
input_ids = model_worker_batch.input_ids
|
||||
input_ids[:] = torch.where(
|
||||
input_ids < 0,
|
||||
self.future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
||||
input_ids,
|
||||
)
|
||||
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
||||
|
||||
# Run forward
|
||||
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
self.launch_event.set()
|
||||
|
||||
# Update the future token ids map
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
future_next_token_ids = torch.arange(
|
||||
-(future_token_ids_ct + bs),
|
||||
-(future_token_ids_ct),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
|
||||
torch.int32
|
||||
)
|
||||
self.future_token_ids_map[
|
||||
future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
|
||||
] = next_token_ids
|
||||
|
||||
# Copy results to the CPU
|
||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||
copy_event = torch.cuda.Event(blocking=True)
|
||||
copy_event.record()
|
||||
|
||||
self.launch_event.set()
|
||||
self.copy_queue.put((copy_event, next_token_ids))
|
||||
|
||||
def copy_thread_func(self):
|
||||
@@ -149,8 +150,9 @@ class TpModelWorkerClient:
|
||||
# Allocate output future objects
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
future_next_token_ids = torch.arange(
|
||||
-(self.future_token_ids_ct + bs),
|
||||
-(self.future_token_ids_ct),
|
||||
-(self.future_token_ids_ct + 1),
|
||||
-(self.future_token_ids_ct + 1 + bs),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user