diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 5d941a489..4e3d39e77 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -288,16 +288,11 @@ class GemmaRMSNorm(CustomOp): x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - orig_dtype = x.dtype if residual is not None: x = x + residual residual = x - x = x.float() - variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True) - x = x * torch_npu.rsqrt(variance + self.variance_epsilon) - x = x * (1.0 + self.weight.float()) - x = x.to(orig_dtype) + x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon) return x if residual is None else (x, residual) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index a4fb29929..f6603907a 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -46,10 +46,12 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, ) -from sglang.srt.utils import dump_to_file, use_intel_amx_backend +from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend logger = logging.getLogger(__name__) +_is_npu = is_npu() + @dataclasses.dataclass class LogitsProcessorOutput: @@ -517,7 +519,12 @@ class LogitsProcessor(nn.Module): logits = logits[:, : self.config.vocab_size].float() if self.final_logit_softcapping: - fused_softcap(logits, self.final_logit_softcapping) + if not _is_npu: + fused_softcap(logits, self.final_logit_softcapping) + else: + logits = self.final_logit_softcapping * torch.tanh( + logits / self.final_logit_softcapping + ) return logits diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index ef0d01e44..0a3723e0b 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -550,7 +550,7 @@ class PrefillAdder: ) else: # Make sure at least one page is available - trunc_len = self.rem_chunk_tokens - self.page_size + 1 + trunc_len = self.rem_chunk_tokens // self.page_size * self.page_size if trunc_len <= 0: return AddReqResult.OTHER