[kernel] Use sgl_kernel rope (#3169)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -6,9 +6,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||||
|
from sglang.srt.utils import is_cuda_available
|
||||||
|
|
||||||
|
_is_cuda_available = is_cuda_available()
|
||||||
|
if _is_cuda_available:
|
||||||
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
||||||
|
|
||||||
|
|
||||||
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -75,7 +81,9 @@ class RotaryEmbedding(CustomOp):
|
|||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
cache = self._compute_cos_sin_cache()
|
cache = self._compute_cos_sin_cache()
|
||||||
cache = cache.to(dtype)
|
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
|
||||||
|
if not _is_cuda_available:
|
||||||
|
cache = cache.to(dtype)
|
||||||
self.cos_sin_cache: torch.Tensor
|
self.cos_sin_cache: torch.Tensor
|
||||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
|
||||||
@@ -141,17 +149,25 @@ class RotaryEmbedding(CustomOp):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
from vllm import _custom_ops as ops
|
if _is_cuda_available:
|
||||||
|
apply_rope_with_cos_sin_cache_inplace(
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
positions=positions,
|
||||||
ops.rotary_embedding(
|
query=query,
|
||||||
positions,
|
key=key,
|
||||||
query,
|
head_size=self.head_size,
|
||||||
key,
|
cos_sin_cache=self.cos_sin_cache,
|
||||||
self.head_size,
|
is_neox=self.is_neox_style,
|
||||||
self.cos_sin_cache,
|
)
|
||||||
self.is_neox_style,
|
else:
|
||||||
)
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
||||||
|
ops.rotary_embedding(
|
||||||
|
positions,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
self.head_size,
|
||||||
|
self.cos_sin_cache,
|
||||||
|
self.is_neox_style,
|
||||||
|
)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
def forward_xpu(
|
def forward_xpu(
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
chunks_ids[i] = chunks_ids[i][1:]
|
chunks_ids[i] = chunks_ids[i][1:]
|
||||||
|
|
||||||
# 1. using session control
|
# 1. using session control
|
||||||
|
requests.post(self.base_url + "/flush_cache")
|
||||||
session_id = requests.post(
|
session_id = requests.post(
|
||||||
self.base_url + "/open_session",
|
self.base_url + "/open_session",
|
||||||
json={"capacity_of_str_len": 1000},
|
json={"capacity_of_str_len": 1000},
|
||||||
@@ -215,7 +216,9 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
print(outputs_from_session)
|
print(outputs_from_session)
|
||||||
print("outputs from normal queries:")
|
print("outputs from normal queries:")
|
||||||
print(outputs_normal)
|
print(outputs_normal)
|
||||||
assert outputs_from_session == outputs_normal
|
assert (
|
||||||
|
outputs_from_session == outputs_normal
|
||||||
|
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
|
||||||
|
|
||||||
async def async_generate(self, payload):
|
async def async_generate(self, payload):
|
||||||
url = self.base_url + "/generate"
|
url = self.base_url + "/generate"
|
||||||
@@ -250,6 +253,7 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
chunks_ids[i] = chunks_ids[i][1:]
|
chunks_ids[i] = chunks_ids[i][1:]
|
||||||
|
|
||||||
# 1. using session control
|
# 1. using session control
|
||||||
|
requests.post(self.base_url + "/flush_cache")
|
||||||
session_id = requests.post(
|
session_id = requests.post(
|
||||||
self.base_url + "/open_session",
|
self.base_url + "/open_session",
|
||||||
json={"capacity_of_str_len": 1000},
|
json={"capacity_of_str_len": 1000},
|
||||||
@@ -320,6 +324,7 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||||
else:
|
else:
|
||||||
# 2. not using session control
|
# 2. not using session control
|
||||||
|
requests.post(self.base_url + "/flush_cache")
|
||||||
output_ids = tokenizer.encode(gen_so_far)
|
output_ids = tokenizer.encode(gen_so_far)
|
||||||
if output_ids[0] == tokenizer.bos_token_id:
|
if output_ids[0] == tokenizer.bos_token_id:
|
||||||
output_ids = output_ids[1:]
|
output_ids = output_ids[1:]
|
||||||
@@ -342,7 +347,9 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
output_no_session = response["text"]
|
output_no_session = response["text"]
|
||||||
print("second request output without session:")
|
print("second request output without session:")
|
||||||
print(output_no_session)
|
print(output_no_session)
|
||||||
assert second_output == output_no_session
|
assert (
|
||||||
|
second_output == output_no_session
|
||||||
|
), f"second_output: {second_output}, output_no_session: {output_no_session}"
|
||||||
|
|
||||||
def test_session_control_backtrack_with_abort(self):
|
def test_session_control_backtrack_with_abort(self):
|
||||||
asyncio.run(self.run_session_control_backtrack_with_abort(replace=True))
|
asyncio.run(self.run_session_control_backtrack_with_abort(replace=True))
|
||||||
@@ -355,6 +362,7 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
assert len(x) == len(chunks_per_step[0])
|
assert len(x) == len(chunks_per_step[0])
|
||||||
|
|
||||||
# 1. using session control
|
# 1. using session control
|
||||||
|
requests.post(self.base_url + "/flush_cache")
|
||||||
session_id = requests.post(
|
session_id = requests.post(
|
||||||
self.base_url + "/open_session",
|
self.base_url + "/open_session",
|
||||||
json={"capacity_of_str_len": 1000},
|
json={"capacity_of_str_len": 1000},
|
||||||
@@ -459,7 +467,9 @@ class TestSessionControl(unittest.TestCase):
|
|||||||
print(outputs_from_session)
|
print(outputs_from_session)
|
||||||
print("====== outputs from normal queries: =======")
|
print("====== outputs from normal queries: =======")
|
||||||
print(outputs_normal)
|
print(outputs_normal)
|
||||||
assert outputs_from_session == outputs_normal
|
assert (
|
||||||
|
outputs_from_session == outputs_normal
|
||||||
|
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
|
||||||
|
|
||||||
def test_session_control_with_branching(self):
|
def test_session_control_with_branching(self):
|
||||||
root_prompt = "First, let me explain in one sentence about AI"
|
root_prompt = "First, let me explain in one sentence about AI"
|
||||||
@@ -525,6 +535,7 @@ class TestSessionControlVision(unittest.TestCase):
|
|||||||
gen_len = 32
|
gen_len = 32
|
||||||
|
|
||||||
# 1. using session control
|
# 1. using session control
|
||||||
|
requests.post(self.base_url + "/flush_cache")
|
||||||
session_id = requests.post(
|
session_id = requests.post(
|
||||||
self.base_url + "/open_session",
|
self.base_url + "/open_session",
|
||||||
json={"capacity_of_str_len": 1000},
|
json={"capacity_of_str_len": 1000},
|
||||||
@@ -691,7 +702,9 @@ class TestSessionControlVision(unittest.TestCase):
|
|||||||
print(outputs_from_session)
|
print(outputs_from_session)
|
||||||
print("outputs from normal queries:")
|
print("outputs from normal queries:")
|
||||||
print(outputs_normal)
|
print(outputs_normal)
|
||||||
assert outputs_from_session == outputs_normal
|
assert (
|
||||||
|
outputs_from_session == outputs_normal
|
||||||
|
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user