feat: use gelu_tanh_and_mul (#1193)
This commit is contained in:
@@ -15,7 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from flashinfer.activation import silu_and_mul
|
from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
|
|
||||||
@@ -37,3 +37,19 @@ class SiluAndMul(CustomOp):
|
|||||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
silu_and_mul(x, out)
|
silu_and_mul(x, out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GeluAndMul(CustomOp):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
return F.gelu(x[..., :d], approximate="tanh") * x[..., d:]
|
||||||
|
|
||||||
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
output_shape = x.shape[:-1] + (d,)
|
||||||
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||||
|
gelu_tanh_and_mul(x, out)
|
||||||
|
return out
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
|||||||
|
|
||||||
# FIXME: temporary solution, remove after next vllm release
|
# FIXME: temporary solution, remove after next vllm release
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.activation import GeluAndMul
|
|
||||||
|
|
||||||
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
@@ -39,6 +38,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
|
from sglang.srt.layers.activation import GeluAndMul
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
@@ -135,7 +135,7 @@ class Gemma2MLP(nn.Module):
|
|||||||
"function. Please set `hidden_act` and `hidden_activation` to "
|
"function. Please set `hidden_act` and `hidden_activation` to "
|
||||||
"`gelu_pytorch_tanh`."
|
"`gelu_pytorch_tanh`."
|
||||||
)
|
)
|
||||||
self.act_fn = GeluAndMul(approximate="tanh")
|
self.act_fn = GeluAndMul()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
|
|||||||
55
python/sglang/test/test_activation.py
Normal file
55
python/sglang/test/test_activation.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import itertools
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.layers.activation import GeluAndMul
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeluAndMul(unittest.TestCase):
|
||||||
|
DTYPES = [torch.half, torch.bfloat16]
|
||||||
|
NUM_TOKENS = [7, 83, 2048]
|
||||||
|
D = [512, 4096, 5120, 13824]
|
||||||
|
SEEDS = [0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise unittest.SkipTest("CUDA is not available")
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
def _run_gelu_and_mul_test(self, num_tokens, d, dtype, seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
layer = GeluAndMul().to(dtype=dtype)
|
||||||
|
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
ref_out = layer.forward_native(x)
|
||||||
|
out = layer.forward_cuda(x)
|
||||||
|
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
atol = rtol = 1e-2
|
||||||
|
else:
|
||||||
|
atol = rtol = 1e-3
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(out, ref_out, atol=atol, rtol=rtol))
|
||||||
|
|
||||||
|
def test_gelu_and_mul(self):
|
||||||
|
for params in itertools.product(
|
||||||
|
self.NUM_TOKENS,
|
||||||
|
self.D,
|
||||||
|
self.DTYPES,
|
||||||
|
self.SEEDS,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
num_tokens=params[0],
|
||||||
|
d=params[1],
|
||||||
|
dtype=params[2],
|
||||||
|
seed=params[3],
|
||||||
|
):
|
||||||
|
self._run_gelu_and_mul_test(*params)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(verbosity=2)
|
||||||
Reference in New Issue
Block a user