diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 088d94a78..f38397212 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -136,7 +136,7 @@ class LogitsProcessor(nn.Module): last_logits = torch.matmul(last_hidden, weight.T) if self.tp_size > 1: last_logits = tensor_model_parallel_all_gather(last_logits) - last_logits = last_logits[:, : self.config.vocab_size] + last_logits = last_logits[:, : self.config.vocab_size].float() if hasattr(self.config, "final_logit_softcapping"): last_logits /= self.config.final_logit_softcapping @@ -161,9 +161,9 @@ class LogitsProcessor(nn.Module): all_logits = torch.matmul(hidden_states, weight.T) if self.tp_size > 1: all_logits = tensor_model_parallel_all_gather(all_logits) - all_logits = all_logits[:, : self.config.vocab_size] + all_logits = all_logits[:, : self.config.vocab_size].float() - all_logprobs = all_logits.float() + all_logprobs = all_logits del all_logits all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index e19ec5897..5ef3552ba 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -687,13 +687,21 @@ class Batch: # TODO(lmzheng): apply penalty probs = torch.softmax(logits, dim=-1) - max_top_k_round, batch_size = 32, probs.shape[0] - uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device) - batch_next_token_ids, success = top_k_top_p_sampling_from_probs( - probs, uniform_samples, self.top_ks, self.top_ps - ) + if True: + max_top_k_round, batch_size = 32, probs.shape[0] + uniform_samples = torch.rand( + (max_top_k_round, batch_size), device=probs.device + ) + batch_next_token_ids, success = top_k_top_p_sampling_from_probs( + probs, uniform_samples, self.top_ks, self.top_ps + ) + else: + # Here we provide a slower fallback implementation. + batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch( + probs, self.top_ks, self.top_ps + ) - if torch.any(~success): + if not torch.all(success): warnings.warn("Sampling failed, fallback to top_k=1 strategy") probs = probs.masked_fill(torch.isnan(probs), 0.0) argmax_ids = torch.argmax(probs, dim=-1) @@ -933,3 +941,29 @@ def init_triton_args(forward_mode, seq_lens, prefix_lens): max_extend_len = int(torch.max(extend_seq_lens)) return max_seq_len, max_extend_len, start_loc, prefix_lens + + +def top_k_top_p_sampling_from_probs_torch( + probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor +): + """A top-k and top-k sampling implementation with native pytorch operations.""" + probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 + probs_sort[ + torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) + >= top_ks.view(-1, 1) + ] = 0.0 + probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) + try: + sampled_index = torch.multinomial(probs_sort, num_samples=1) + except RuntimeError: + batch_next_token_ids = torch.zeros( + (probs_sort.shape[0],), dtype=torch.int64, device=probs.device + ) + success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device) + return batch_next_token_ids, success + + batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) + success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device) + return batch_next_token_ids, success