From 2f79f58873b0bf38d8c64f5b6ac6fbb5ab50d7b4 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 27 Jan 2025 21:39:52 +0800 Subject: [PATCH] feat: use sgl-kernel 0.0.3 in sglang (#3179) --- python/pyproject.toml | 2 +- python/sglang/srt/layers/activation.py | 10 +++++----- python/sglang/srt/layers/layernorm.py | 10 +++++----- python/sglang/srt/layers/sampler.py | 10 +++------- python/sglang/srt/models/deepseek_v2.py | 6 +++--- python/sglang/srt/models/minicpm3.py | 6 +++--- 6 files changed, 20 insertions(+), 24 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 97e0771cd..d4063cf01 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -27,7 +27,7 @@ runtime_common = [ ] srt = [ "sglang[runtime_common]", "cuda-python", - "sgl-kernel>=0.0.2.post18", "torch", "vllm==0.6.4.post1", + "sgl-kernel>=0.0.3", "torch", "vllm==0.6.4.post1", "flashinfer==0.1.6" ] diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index ebb0652c5..d69d854ab 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -20,10 +20,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda_available -if is_flashinfer_available(): - from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul +if is_cuda_available(): + from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from vllm.model_executor.custom_op import CustomOp @@ -149,8 +149,8 @@ def get_act_fn( return act_fn -if not is_flashinfer_available(): +if not is_cuda_available(): logger.info( - "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries." + "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." ) from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index bd95b9bcc..207ba8d1b 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -19,10 +19,10 @@ from typing import Optional, Tuple, Union import torch import torch.nn as nn -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda_available -if is_flashinfer_available(): - from flashinfer.norm import ( +if is_cuda_available(): + from sgl_kernel import ( fused_add_rmsnorm, gemma_fused_add_rmsnorm, gemma_rmsnorm, @@ -121,8 +121,8 @@ class GemmaRMSNorm(CustomOp): return out -if not is_flashinfer_available(): +if not is_cuda_available(): logger.info( - "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries." + "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." ) from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 3173d533d..b24bfc8da 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -10,14 +10,10 @@ from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.utils import ( - crash_on_warnings, - get_bool_env_var, - is_flashinfer_available, -) +from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available -if is_flashinfer_available(): - from flashinfer.sampling import ( +if is_cuda_available(): + from sgl_kernel import ( min_p_sampling_from_probs, top_k_renorm_prob, top_k_top_p_sampling_from_probs, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 17d7fcf89..438441047 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -56,12 +56,12 @@ from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import is_flashinfer_available, is_hip +from sglang.srt.utils import is_cuda_available, is_hip is_hip_ = is_hip() -if is_flashinfer_available(): - from flashinfer import bmm_fp8 +if is_cuda_available(): + from sgl_kernel import bmm_fp8 class DeepseekV2MLP(nn.Module): diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index 118be8ff6..31ea7cd9f 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -40,10 +40,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda_available -if is_flashinfer_available(): - from flashinfer import bmm_fp8 +if is_cuda_available(): + from sgl_kernel import bmm_fp8 class MiniCPM3MLP(nn.Module):