diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index ad265830f..7093bb90d 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -6,9 +6,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn +from vllm import _custom_ops as ops from vllm.model_executor.custom_op import CustomOp from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.utils import is_cuda_available + +_is_cuda_available = is_cuda_available() +if _is_cuda_available: + from sgl_kernel import apply_rope_with_cos_sin_cache_inplace def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -75,7 +81,9 @@ class RotaryEmbedding(CustomOp): self.dtype = dtype cache = self._compute_cos_sin_cache() - cache = cache.to(dtype) + # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability + if not _is_cuda_available: + cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) @@ -141,17 +149,25 @@ class RotaryEmbedding(CustomOp): key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - from vllm import _custom_ops as ops - - self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - ops.rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - ) + if _is_cuda_available: + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=query, + key=key, + head_size=self.head_size, + cos_sin_cache=self.cos_sin_cache, + is_neox=self.is_neox_style, + ) + else: + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) return query, key def forward_xpu( diff --git a/test/srt/test_session_control.py b/test/srt/test_session_control.py index 5653e9b69..2915133f4 100644 --- a/test/srt/test_session_control.py +++ b/test/srt/test_session_control.py @@ -54,6 +54,7 @@ class TestSessionControl(unittest.TestCase): chunks_ids[i] = chunks_ids[i][1:] # 1. using session control + requests.post(self.base_url + "/flush_cache") session_id = requests.post( self.base_url + "/open_session", json={"capacity_of_str_len": 1000}, @@ -215,7 +216,9 @@ class TestSessionControl(unittest.TestCase): print(outputs_from_session) print("outputs from normal queries:") print(outputs_normal) - assert outputs_from_session == outputs_normal + assert ( + outputs_from_session == outputs_normal + ), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}" async def async_generate(self, payload): url = self.base_url + "/generate" @@ -250,6 +253,7 @@ class TestSessionControl(unittest.TestCase): chunks_ids[i] = chunks_ids[i][1:] # 1. using session control + requests.post(self.base_url + "/flush_cache") session_id = requests.post( self.base_url + "/open_session", json={"capacity_of_str_len": 1000}, @@ -320,6 +324,7 @@ class TestSessionControl(unittest.TestCase): assert response["meta_info"]["finish_reason"]["type"] == "abort" else: # 2. not using session control + requests.post(self.base_url + "/flush_cache") output_ids = tokenizer.encode(gen_so_far) if output_ids[0] == tokenizer.bos_token_id: output_ids = output_ids[1:] @@ -342,7 +347,9 @@ class TestSessionControl(unittest.TestCase): output_no_session = response["text"] print("second request output without session:") print(output_no_session) - assert second_output == output_no_session + assert ( + second_output == output_no_session + ), f"second_output: {second_output}, output_no_session: {output_no_session}" def test_session_control_backtrack_with_abort(self): asyncio.run(self.run_session_control_backtrack_with_abort(replace=True)) @@ -355,6 +362,7 @@ class TestSessionControl(unittest.TestCase): assert len(x) == len(chunks_per_step[0]) # 1. using session control + requests.post(self.base_url + "/flush_cache") session_id = requests.post( self.base_url + "/open_session", json={"capacity_of_str_len": 1000}, @@ -459,7 +467,9 @@ class TestSessionControl(unittest.TestCase): print(outputs_from_session) print("====== outputs from normal queries: =======") print(outputs_normal) - assert outputs_from_session == outputs_normal + assert ( + outputs_from_session == outputs_normal + ), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}" def test_session_control_with_branching(self): root_prompt = "First, let me explain in one sentence about AI" @@ -525,6 +535,7 @@ class TestSessionControlVision(unittest.TestCase): gen_len = 32 # 1. using session control + requests.post(self.base_url + "/flush_cache") session_id = requests.post( self.base_url + "/open_session", json={"capacity_of_str_len": 1000}, @@ -691,7 +702,9 @@ class TestSessionControlVision(unittest.TestCase): print(outputs_from_session) print("outputs from normal queries:") print(outputs_normal) - assert outputs_from_session == outputs_normal + assert ( + outputs_from_session == outputs_normal + ), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}" if __name__ == "__main__":