Simplify sampler and its error handling (#1441)

This commit is contained in:
Lianmin Zheng
2024-09-16 21:23:31 -07:00
committed by GitHub
parent 27b557aea7
commit 2fa5cec775
4 changed files with 32 additions and 159 deletions

View File

@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput, Sampler
from sglang.srt.layers.sampler import Sampler
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
@@ -516,21 +516,6 @@ class ModelRunner:
else:
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
def _check_sample_results(self, sample_output: SampleOutput):
if not torch.all(sample_output.success):
probs = sample_output.probs
batch_next_token_ids = sample_output.batch_next_token_ids
logging.warning("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
sample_output.success, batch_next_token_ids, argmax_ids
)
sample_output.probs = probs
sample_output.batch_next_token_ids = batch_next_token_ids
return sample_output.batch_next_token_ids
def _apply_logits_bias(
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
):
@@ -559,13 +544,16 @@ class ModelRunner:
def sample(
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
) -> torch.Tensor:
# Put CPU-heavy tasks here. They will be overlapped with the forward pass.
batch.sampling_info.update_regex_vocab_mask(batch)
batch.sampling_info.update_penalties()
logits = self._apply_logits_bias(
logits_output.next_token_logits, batch.sampling_info
)
sample_output = self.sampler(logits, batch.sampling_info)
return self._check_sample_results(sample_output)
# Sample the next tokens.
next_token_ids = self.sampler(logits, batch.sampling_info)
return next_token_ids
@lru_cache()