Tiny skip_sample adjust (#11225)
This commit is contained in:
@@ -663,7 +663,11 @@ class Req:
|
|||||||
@property
|
@property
|
||||||
def is_prefill_only(self) -> bool:
|
def is_prefill_only(self) -> bool:
|
||||||
"""Check if this request is prefill-only (no token generation needed)."""
|
"""Check if this request is prefill-only (no token generation needed)."""
|
||||||
return self.sampling_params.max_new_tokens == 0
|
# NOTE: when spec is enabled, prefill_only optimizations are disabled
|
||||||
|
return (
|
||||||
|
self.sampling_params.max_new_tokens == 0
|
||||||
|
and global_server_args_dict["speculative_algorithm"] is None
|
||||||
|
)
|
||||||
|
|
||||||
def add_latency(self, stage: RequestStage):
|
def add_latency(self, stage: RequestStage):
|
||||||
if self.metrics_collector is None:
|
if self.metrics_collector is None:
|
||||||
|
|||||||
@@ -237,7 +237,7 @@ class TpModelWorker:
|
|||||||
self,
|
self,
|
||||||
model_worker_batch: ModelWorkerBatch,
|
model_worker_batch: ModelWorkerBatch,
|
||||||
launch_done: Optional[threading.Event] = None,
|
launch_done: Optional[threading.Event] = None,
|
||||||
skip_sample: bool = False,
|
is_verify: bool = False,
|
||||||
) -> ForwardBatchOutput:
|
) -> ForwardBatchOutput:
|
||||||
# update the consumer index of hicache to the running batch
|
# update the consumer index of hicache to the running batch
|
||||||
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
||||||
@@ -259,19 +259,16 @@ class TpModelWorker:
|
|||||||
if launch_done is not None:
|
if launch_done is not None:
|
||||||
launch_done.set()
|
launch_done.set()
|
||||||
|
|
||||||
if skip_sample:
|
skip_sample = is_verify or model_worker_batch.is_prefill_only
|
||||||
next_token_ids = None
|
next_token_ids = None
|
||||||
# For prefill-only requests, we still need to compute logprobs even when sampling is skipped
|
|
||||||
if (
|
if not skip_sample:
|
||||||
model_worker_batch.is_prefill_only
|
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
|
||||||
and model_worker_batch.return_logprob
|
elif model_worker_batch.return_logprob and not is_verify:
|
||||||
):
|
# NOTE: Compute logprobs without full sampling
|
||||||
# Compute logprobs without full sampling
|
|
||||||
self.model_runner.compute_logprobs_only(
|
self.model_runner.compute_logprobs_only(
|
||||||
logits_output, model_worker_batch
|
logits_output, model_worker_batch
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
next_token_ids = self.model_runner.sample(logits_output, forward_batch)
|
|
||||||
|
|
||||||
return ForwardBatchOutput(
|
return ForwardBatchOutput(
|
||||||
logits_output=logits_output,
|
logits_output=logits_output,
|
||||||
|
|||||||
@@ -164,8 +164,6 @@ class TpModelWorkerClient:
|
|||||||
forward_batch_output = self.worker.forward_batch_generation(
|
forward_batch_output = self.worker.forward_batch_generation(
|
||||||
model_worker_batch,
|
model_worker_batch,
|
||||||
model_worker_batch.launch_done,
|
model_worker_batch.launch_done,
|
||||||
# Skip sampling for prefill-only requests
|
|
||||||
skip_sample=model_worker_batch.is_prefill_only,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||||
|
|||||||
@@ -823,7 +823,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
forward_batch_output = self.target_worker.forward_batch_generation(
|
forward_batch_output = self.target_worker.forward_batch_generation(
|
||||||
model_worker_batch, skip_sample=True
|
model_worker_batch, is_verify=True
|
||||||
)
|
)
|
||||||
logits_output, can_run_cuda_graph = (
|
logits_output, can_run_cuda_graph = (
|
||||||
forward_batch_output.logits_output,
|
forward_batch_output.logits_output,
|
||||||
|
|||||||
@@ -214,7 +214,7 @@ class NGRAMWorker:
|
|||||||
|
|
||||||
if model_worker_batch.forward_mode.is_target_verify():
|
if model_worker_batch.forward_mode.is_target_verify():
|
||||||
forward_batch_output = self.target_worker.forward_batch_generation(
|
forward_batch_output = self.target_worker.forward_batch_generation(
|
||||||
model_worker_batch, skip_sample=True
|
model_worker_batch, is_verify=True
|
||||||
)
|
)
|
||||||
logits_output, can_run_cuda_graph = (
|
logits_output, can_run_cuda_graph = (
|
||||||
forward_batch_output.logits_output,
|
forward_batch_output.logits_output,
|
||||||
|
|||||||
Reference in New Issue
Block a user