From 70b6802982198a739b233a1c72a8fa9871aabec8 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Fri, 13 Sep 2024 20:27:53 -0700 Subject: [PATCH] Optimize conflicts between CUDA graph and vocab mask tensors (#1392) --- python/sglang/bench_latency.py | 8 +-- python/sglang/srt/layers/sampler.py | 23 ------- python/sglang/srt/managers/schedule_batch.py | 19 ------ python/sglang/srt/managers/tp_worker.py | 9 +-- .../srt/model_executor/cuda_graph_runner.py | 21 +----- .../srt/model_executor/forward_batch_info.py | 5 -- .../sglang/srt/model_executor/model_runner.py | 66 ++++++++++++++++--- python/sglang/srt/models/baichuan.py | 7 +- python/sglang/srt/models/chatglm.py | 6 +- python/sglang/srt/models/commandr.py | 6 +- python/sglang/srt/models/dbrx.py | 6 +- python/sglang/srt/models/deepseek.py | 6 +- python/sglang/srt/models/deepseek_v2.py | 6 +- python/sglang/srt/models/exaone.py | 6 +- python/sglang/srt/models/gemma.py | 6 +- python/sglang/srt/models/gemma2.py | 6 +- python/sglang/srt/models/gpt_bigcode.py | 6 +- python/sglang/srt/models/grok.py | 6 +- python/sglang/srt/models/internlm2.py | 6 +- python/sglang/srt/models/llama.py | 6 +- .../sglang/srt/models/llama_classification.py | 21 +----- python/sglang/srt/models/minicpm.py | 6 +- python/sglang/srt/models/minicpm3.py | 6 +- python/sglang/srt/models/mixtral.py | 6 +- python/sglang/srt/models/mixtral_quant.py | 6 +- python/sglang/srt/models/qwen.py | 6 +- python/sglang/srt/models/qwen2.py | 6 +- python/sglang/srt/models/qwen2_moe.py | 6 +- python/sglang/srt/models/stablelm.py | 6 +- python/sglang/srt/models/xverse.py | 7 +- python/sglang/srt/models/xverse_moe.py | 6 +- .../srt/sampling/sampling_batch_info.py | 15 +++-- 32 files changed, 103 insertions(+), 224 deletions(-) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index bfe739432..b1ac43e9d 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -207,15 +207,15 @@ def extend(reqs, model_runner): tree_cache=None, ) batch.prepare_for_extend(model_runner.model_config.vocab_size) - sample_output, logits_output = model_runner.forward(batch) - next_token_ids = sample_output.batch_next_token_ids.tolist() + logits_output = model_runner.forward(batch) + next_token_ids = model_runner.sample(logits_output, batch).tolist() return next_token_ids, logits_output.next_token_logits, batch def decode(input_token_ids, batch, model_runner): batch.prepare_for_decode(input_token_ids) - sample_output, logits_output = model_runner.forward(batch) - next_token_ids = sample_output.batch_next_token_ids.tolist() + logits_output = model_runner.forward(batch) + next_token_ids = model_runner.sample(logits_output, batch).tolist() return next_token_ids, logits_output.next_token_logits diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 16b6b80e9..c0e1d4c7b 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -35,21 +35,6 @@ class Sampler(CustomOp): self.forward_native = self.forward_cuda self.is_torch_compile = False - def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): - # min-token, presence, frequency - if sampling_info.linear_penalties is not None: - logits += sampling_info.linear_penalties - - # repetition - if sampling_info.scaling_penalties is not None: - logits = torch.where( - logits > 0, - logits / sampling_info.scaling_penalties, - logits * sampling_info.scaling_penalties, - ) - - return logits - def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): # Post process logits logits = logits.contiguous() @@ -58,14 +43,6 @@ class Sampler(CustomOp): # FIXME: Temporary workaround for unknown bugs in torch.compile logits.add_(0) - if sampling_info.logit_bias is not None: - logits.add_(sampling_info.logit_bias) - - if sampling_info.vocab_mask is not None: - logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf")) - - logits = self._apply_penalties(logits, sampling_info) - return torch.softmax(logits, dim=-1) def forward_cuda( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 17d13c7a5..13339cddc 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -33,10 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs -if TYPE_CHECKING: - from sglang.srt.layers.sampler import SampleOutput - - INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 # Put some global args for easy access @@ -710,18 +706,3 @@ class ScheduleBatch: self.out_cache_loc = None self.top_logprobs_nums.extend(other.top_logprobs_nums) self.return_logprob = any(req.return_logprob for req in self.reqs) - - def check_sample_results(self, sample_output: SampleOutput): - if not torch.all(sample_output.success): - probs = sample_output.probs - batch_next_token_ids = sample_output.batch_next_token_ids - logging.warning("Sampling failed, fallback to top_k=1 strategy") - probs = probs.masked_fill(torch.isnan(probs), 0.0) - argmax_ids = torch.argmax(probs, dim=-1) - batch_next_token_ids = torch.where( - sample_output.success, batch_next_token_ids, argmax_ids - ) - sample_output.probs = probs - sample_output.batch_next_token_ids = batch_next_token_ids - - return sample_output.batch_next_token_ids diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 05619aae1..3952e081a 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -547,8 +547,9 @@ class ModelTpServer: if self.model_runner.is_generation: # Forward and sample the next tokens if batch.extend_num_tokens != 0: - sample_output, logits_output = self.model_runner.forward(batch) - next_token_ids = batch.check_sample_results(sample_output) + logits_output = self.model_runner.forward(batch) + next_token_ids = self.model_runner.sample(logits_output, batch) + batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids ) @@ -723,8 +724,8 @@ class ModelTpServer: batch.prepare_for_decode() # Forward and sample the next tokens - sample_output, logits_output = self.model_runner.forward(batch) - next_token_ids = batch.check_sample_results(sample_output) + logits_output = self.model_runner.forward(batch) + next_token_ids = self.model_runner.sample(logits_output, batch) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids ) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 27699f65d..d25694329 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -30,10 +30,8 @@ from sglang.srt.layers.logits_processor import ( LogitsProcessor, LogitsProcessorOutput, ) -from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata -from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.utils import monkey_patch_vllm_all_gather if TYPE_CHECKING: @@ -129,10 +127,6 @@ class CudaGraphRunner: self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) - # Sampling info - vocab_size = model_runner.model_config.vocab_size - self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size) - if self.use_torch_compile: set_torch_compile_config() @@ -191,7 +185,6 @@ class CudaGraphRunner: def run_once(): input_metadata = InputMetadata( forward_mode=ForwardMode.DECODE, - sampling_info=self.sampling_info[:bs], batch_size=bs, req_pool_indices=req_pool_indices, seq_lens=seq_lens, @@ -250,14 +243,9 @@ class CudaGraphRunner: bs, self.req_pool_indices, self.seq_lens ) - # Sampling inputs - self.sampling_info.inplace_assign(raw_bs, batch.sampling_info) - # Replay - torch.cuda.synchronize() self.graphs[bs].replay() - torch.cuda.synchronize() - sample_output, logits_output = self.output_buffers[bs] + logits_output = self.output_buffers[bs] # Unpad if bs != raw_bs: @@ -269,11 +257,6 @@ class CudaGraphRunner: input_top_logprobs=None, output_top_logprobs=None, ) - sample_output = SampleOutput( - sample_output.success[:raw_bs], - sample_output.probs[:raw_bs], - sample_output.batch_next_token_ids[:raw_bs], - ) # Extract logprobs if batch.return_logprob: @@ -290,4 +273,4 @@ class CudaGraphRunner: logits_output.next_token_logprobs, logits_metadata )[1] - return sample_output, logits_output + return logits_output diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 8542ced35..0ad568860 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -28,7 +28,6 @@ if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo class ForwardMode(IntEnum): @@ -59,7 +58,6 @@ class InputMetadata: """Store all inforamtion of a forward pass.""" forward_mode: ForwardMode - sampling_info: SamplingBatchInfo batch_size: int req_pool_indices: torch.Tensor seq_lens: torch.Tensor @@ -170,7 +168,6 @@ class InputMetadata: ): ret = cls( forward_mode=batch.forward_mode, - sampling_info=batch.sampling_info, batch_size=batch.batch_size(), req_pool_indices=batch.req_pool_indices, seq_lens=batch.seq_lens, @@ -182,8 +179,6 @@ class InputMetadata: top_logprobs_nums=batch.top_logprobs_nums, ) - ret.sampling_info.update_penalties() - ret.sampling_info.update_regex_vocab_mask(batch) ret.compute_positions(batch) if not batch.forward_mode.is_decode(): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9fcb85454..b754e41c7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.sampler import SampleOutput +from sglang.srt.layers.sampler import SampleOutput, Sampler from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( @@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import ( ReqToTokenPool, ) from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_available_gpu_memory, @@ -107,6 +108,7 @@ class ModelRunner: # Init componnets min_per_gpu_memory = self.init_torch_distributed() + self.sampler = Sampler() self.load_model() if server_args.lora_paths is not None: self.init_lora_manager() @@ -466,11 +468,8 @@ class ModelRunner: def forward_decode(self, batch: ScheduleBatch): if self.server_args.lora_paths is not None: self.lora_manager.prepare_lora_batch(batch) - if ( - self.cuda_graph_runner - and self.cuda_graph_runner.can_run(len(batch.reqs)) - and batch.sampling_info.can_run_in_cuda_graph() - ): + + if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)): return self.cuda_graph_runner.replay(batch) input_metadata = InputMetadata.from_schedule_batch(self, batch) @@ -510,9 +509,7 @@ class ModelRunner: input_metadata.image_offsets, ) - def forward( - self, batch: ScheduleBatch - ) -> Tuple[SampleOutput, LogitsProcessorOutput]: + def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]: assert batch.forward_mode is not None if self.is_multimodal_model and batch.forward_mode.is_extend(): @@ -524,6 +521,57 @@ class ModelRunner: else: raise ValueError(f"Invaid forward mode: {batch.forward_mode}") + def _check_sample_results(self, sample_output: SampleOutput): + if not torch.all(sample_output.success): + probs = sample_output.probs + batch_next_token_ids = sample_output.batch_next_token_ids + logging.warning("Sampling failed, fallback to top_k=1 strategy") + probs = probs.masked_fill(torch.isnan(probs), 0.0) + argmax_ids = torch.argmax(probs, dim=-1) + batch_next_token_ids = torch.where( + sample_output.success, batch_next_token_ids, argmax_ids + ) + sample_output.probs = probs + sample_output.batch_next_token_ids = batch_next_token_ids + + return sample_output.batch_next_token_ids + + def _apply_logits_bias( + self, logits: torch.Tensor, sampling_info: SamplingBatchInfo + ): + # Apply logit_bias + if sampling_info.logit_bias is not None: + logits.add_(sampling_info.logit_bias) + + # min-token, presence, frequency + if sampling_info.linear_penalties is not None: + logits += sampling_info.linear_penalties + + # repetition + if sampling_info.scaling_penalties is not None: + logits = torch.where( + logits > 0, + logits / sampling_info.scaling_penalties, + logits * sampling_info.scaling_penalties, + ) + + # Apply regex vocab_mask + if sampling_info.vocab_mask is not None: + logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf")) + + return logits + + def sample( + self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch + ) -> torch.Tensor: + batch.sampling_info.update_regex_vocab_mask(batch) + batch.sampling_info.update_penalties() + logits = self._apply_logits_bias( + logits_output.next_token_logits, batch.sampling_info + ) + sample_output = self.sampler(logits, batch.sampling_info) + return self._check_sample_results(sample_output) + @lru_cache() def import_model_classes(): diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index d699064b3..19d2a3384 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -346,7 +345,6 @@ class BaiChuanBaseForCausalLM(nn.Module): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() def forward( self, @@ -355,12 +353,9 @@ class BaiChuanBaseForCausalLM(nn.Module): input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 94b405f8e..783dbb4c5 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata LoraConfig = None @@ -371,7 +370,6 @@ class ChatGLMForCausalLM(nn.Module): self.transformer = ChatGLMModel(config, cache_config, quant_config) self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -381,11 +379,9 @@ class ChatGLMForCausalLM(nn.Module): input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, input_metadata) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index c360106f9..f6d6f6e1f 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module): self.config = config self.quant_config = quant_config self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() self.model = CohereModel(config, quant_config) @torch.no_grad() @@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module): positions, input_metadata, ) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index b3a76b56a..39ac4aefa 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module): padding_size=DEFAULT_VOCAB_PADDING_SIZE, ) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module): input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, input_metadata) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [ diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index b939602c1..59fd1ec7e 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module): config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module): input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index bb80e2da2..0ff236f8f 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -649,7 +648,6 @@ class DeepseekV2ForCausalLM(nn.Module): config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() def forward( self, @@ -658,11 +656,9 @@ class DeepseekV2ForCausalLM(nn.Module): input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index bb077f2c8..63d40be7a 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -304,7 +303,6 @@ class ExaoneForCausalLM(nn.Module): self.transformer = ExaoneModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -317,11 +315,9 @@ class ExaoneForCausalLM(nn.Module): hidden_states = self.transformer( input_ids, positions, input_metadata, input_embeds ) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 5a6e5df37..ae3b1b194 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module): self.quant_config = quant_config self.model = GemmaModel(config, quant_config=quant_config) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module): input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return (sample_output, logits_output) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 77ebd8564..3223424d7 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -347,7 +346,6 @@ class Gemma2ForCausalLM(nn.Module): self.quant_config = quant_config self.model = Gemma2Model(config, cache_config, quant_config) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -358,11 +356,9 @@ class Gemma2ForCausalLM(nn.Module): input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def get_attention_sliding_window_size(self): return get_attention_sliding_window_size(self.config) diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index dc828f014..94b7f6153 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -35,7 +35,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module): if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module): input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, input_metadata) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 3c2a2c65e..daf6f25da 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -298,7 +297,6 @@ class Grok1ForCausalLM(nn.Module): self.model = Grok1Model(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() # Monkey patch _prepare_weights to load pre-sharded weights setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) @@ -315,11 +313,9 @@ class Grok1ForCausalLM(nn.Module): input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index c0e4d19e1..f2947e991 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module): self.model = InternLM2Model(config, quant_config) self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module): input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.output.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index c45d9bcd8..b7842f192 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -41,7 +41,6 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import torchao_quantize_param_data from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -305,7 +304,6 @@ class LlamaForCausalLM(nn.Module): self.model = LlamaModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() self.param_dict = dict(self.named_parameters()) @@ -318,11 +316,9 @@ class LlamaForCausalLM(nn.Module): input_embeds: torch.Tensor = None, ) -> LogitsProcessorOutput: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def get_hidden_dim(self, module_name): if module_name in ["q_proj", "o_proj", "qkv_proj"]: diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index db424ff18..de37a00e6 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -23,7 +23,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.sampler import SampleOutput from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel @@ -75,25 +74,7 @@ class LlamaForClassification(nn.Module): output_top_logprobs=None, ) - # A dummy to make this work - sample_output = SampleOutput( - success=torch.full( - size=(scores.shape[0],), - fill_value=True, - dtype=torch.bool, - ), - probs=torch.full( - size=(scores.shape[0], 1), - fill_value=1.0, - dtype=torch.float16, - ), - batch_next_token_ids=torch.full( - size=(scores.shape[0],), - fill_value=0, - dtype=torch.long, - ), - ) - return sample_output, logits_output + return logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = self.param_dict diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 0028ae67a..49ff1926f 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module): self.scale_width = self.config.hidden_size / self.config.dim_model_base self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module): lm_head_weight = self.model.embed_tokens.weight else: lm_head_weight = self.lm_head.weight - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, lm_head_weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index c559d0f31..ce40a94a7 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -572,7 +571,6 @@ class MiniCPM3ForCausalLM(nn.Module): self.scale_width = self.config.hidden_size / self.config.dim_model_base self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -590,11 +588,9 @@ class MiniCPM3ForCausalLM(nn.Module): lm_head_weight = self.model.embed_tokens.weight else: lm_head_weight = self.lm_head.weight - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, lm_head_weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 85f4576c4..87e3bb030 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -300,7 +299,6 @@ class MixtralForCausalLM(nn.Module): self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() def forward( self, @@ -310,11 +308,9 @@ class MixtralForCausalLM(nn.Module): input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index 97ac09ee6..b02e925c5 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -45,7 +45,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -334,7 +333,6 @@ class QuantMixtralForCausalLM(nn.Module): self.model = MixtralModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -345,11 +343,9 @@ class QuantMixtralForCausalLM(nn.Module): input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 4958a8129..8787d0d58 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -252,7 +251,6 @@ class QWenLMHeadModel(nn.Module): vocab_size = ((config.vocab_size + 63) // 64) * 64 self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -262,11 +260,9 @@ class QWenLMHeadModel(nn.Module): input_metadata: InputMetadata, ): hidden_states = self.transformer(input_ids, positions, input_metadata) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 6bb5c0b90..4e7d8e57b 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -40,7 +40,6 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata Qwen2Config = None @@ -277,7 +276,6 @@ class Qwen2ForCausalLM(nn.Module): self.model = Qwen2Model(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @torch.no_grad() @@ -291,11 +289,9 @@ class Qwen2ForCausalLM(nn.Module): ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) if not get_embedding: - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output else: return self.pooler(hidden_states, input_metadata) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 67b5a6ce6..1ff2190ed 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -47,7 +47,6 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -365,7 +364,6 @@ class Qwen2MoeForCausalLM(nn.Module): config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -376,11 +374,9 @@ class Qwen2MoeForCausalLM(nn.Module): input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index a3102baab..9e10f12f2 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -40,7 +40,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -250,7 +249,6 @@ class StableLmForCausalLM(nn.Module): self.model = StableLMEpochModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() @torch.no_grad() def forward( @@ -261,11 +259,9 @@ class StableLmForCausalLM(nn.Module): input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index 460ae820f..0f4e2a89f 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.model_runner import InputMetadata @@ -307,7 +306,6 @@ class XverseForCausalLM(nn.Module): self.model = XverseModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() self.param_dict = dict(self.named_parameters()) @@ -320,12 +318,9 @@ class XverseForCausalLM(nn.Module): input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - # print(f"{hidden_states=}") - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index 3aeb9ea98..b8cf2c8af 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -44,7 +44,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -383,7 +382,6 @@ class XverseMoeForCausalLM(nn.Module): config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) - self.sampler = Sampler() self.param_dict = dict(self.named_parameters()) @@ -395,11 +393,9 @@ class XverseMoeForCausalLM(nn.Module): input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata) - logits_output = self.logits_processor( + return self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - sample_output = self.sampler(logits_output, input_metadata.sampling_info) - return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 6f6bb6126..8eb8e0882 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -41,7 +41,6 @@ class SamplingBatchInfo: # Vocab bias and min_ps are not supported in CUDA graph return ( self.logit_bias is None - and self.vocab_mask is None and self.linear_penalties is None and self.scaling_penalties is None and not self.need_min_p_sampling @@ -50,9 +49,11 @@ class SamplingBatchInfo: @classmethod def dummy_one(cls, max_bs: int, vocab_size: int): ret = cls(vocab_size=vocab_size) - ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda") - ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda") - ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda") + with torch.device("cuda"): + ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float) + ret.top_ps = torch.ones((max_bs,), dtype=torch.float) + ret.top_ks = torch.ones((max_bs,), dtype=torch.int) + ret.vocab_mask = torch.zeros((max_bs, vocab_size), dtype=torch.bool) return ret def __getitem__(self, key): @@ -64,6 +65,7 @@ class SamplingBatchInfo: temperatures=self.temperatures[key], top_ps=self.top_ps[key], top_ks=self.top_ks[key], + vocab_mask=self.vocab_mask[key], ) else: raise NotImplementedError @@ -77,6 +79,11 @@ class SamplingBatchInfo: self.top_ps[:bs] = other.top_ps self.top_ks[:bs] = other.top_ks + if other.vocab_mask is None: + self.vocab_mask[:bs].fill_(False) + else: + self.vocab_mask[:bs] = other.vocab_mask + @classmethod def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): device = "cuda"