Simplify sampler and its error handling (#1441)
This commit is contained in:
@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput, Sampler
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.lora.lora_manager import LoRAManager
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
@@ -516,21 +516,6 @@ class ModelRunner:
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
|
||||
|
||||
def _check_sample_results(self, sample_output: SampleOutput):
|
||||
if not torch.all(sample_output.success):
|
||||
probs = sample_output.probs
|
||||
batch_next_token_ids = sample_output.batch_next_token_ids
|
||||
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||
argmax_ids = torch.argmax(probs, dim=-1)
|
||||
batch_next_token_ids = torch.where(
|
||||
sample_output.success, batch_next_token_ids, argmax_ids
|
||||
)
|
||||
sample_output.probs = probs
|
||||
sample_output.batch_next_token_ids = batch_next_token_ids
|
||||
|
||||
return sample_output.batch_next_token_ids
|
||||
|
||||
def _apply_logits_bias(
|
||||
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
||||
):
|
||||
@@ -559,13 +544,16 @@ class ModelRunner:
|
||||
def sample(
|
||||
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
|
||||
) -> torch.Tensor:
|
||||
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
||||
batch.sampling_info.update_regex_vocab_mask(batch)
|
||||
batch.sampling_info.update_penalties()
|
||||
logits = self._apply_logits_bias(
|
||||
logits_output.next_token_logits, batch.sampling_info
|
||||
)
|
||||
sample_output = self.sampler(logits, batch.sampling_info)
|
||||
return self._check_sample_results(sample_output)
|
||||
|
||||
# Sample the next tokens.
|
||||
next_token_ids = self.sampler(logits, batch.sampling_info)
|
||||
return next_token_ids
|
||||
|
||||
|
||||
@lru_cache()
|
||||
|
||||
Reference in New Issue
Block a user