From 65b7c9b78f8e12366de0709468f4d8bb60738ab4 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 15 Mar 2025 23:06:17 -0700 Subject: [PATCH] cleanup deps 2/n (#4464) --- python/sglang/srt/layers/activation.py | 6 +++-- python/sglang/srt/layers/layernorm.py | 6 +++-- python/sglang/srt/models/deepseek_nextn.py | 29 ++++++++++++++------- python/sglang/srt/models/deepseek_v2.py | 30 ++++++++++++++-------- 4 files changed, 47 insertions(+), 24 deletions(-) diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 1b8da93fe..1ee10c0aa 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -23,7 +23,9 @@ import torch.nn.functional as F from sglang.srt.utils import is_cuda_available -if is_cuda_available(): +_is_cuda = is_cuda_available() + +if _is_cuda: from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from sglang.srt.custom_op import CustomOp @@ -165,7 +167,7 @@ def get_act_fn( return act_fn -if not is_cuda_available(): +if not _is_cuda: logger.info( "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." ) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 289d75b36..7f6ac5f69 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -21,7 +21,9 @@ import torch.nn as nn from sglang.srt.utils import is_cuda_available -if is_cuda_available(): +_is_cuda = is_cuda_available() + +if _is_cuda: from sgl_kernel import ( fused_add_rmsnorm, gemma_fused_add_rmsnorm, @@ -117,7 +119,7 @@ class GemmaRMSNorm(CustomOp): return out -if not is_cuda_available(): +if not _is_cuda: logger.info( "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." ) diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index 9d3159326..721c41ff1 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -41,9 +41,13 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM -from sglang.srt.utils import add_prefix, is_hip +from sglang.srt.utils import add_prefix, is_cuda, is_hip _is_hip = is_hip() +_is_cuda = is_cuda() + +if _is_cuda: + from sgl_kernel import awq_dequantize class DeepseekModelNextN(nn.Module): @@ -261,14 +265,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): self_attn = self.model.decoder.self_attn if hasattr(self_attn.kv_b_proj, "qweight"): # AWQ compatible - w = ops.awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - 0, - 0, - 0, - ).T + if _is_cuda: + w = awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + ).T + else: + w = ops.awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + 0, + 0, + 0, + ).T else: w = self_attn.kv_b_proj.weight # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index f51107b82..ed5fb4e84 100755 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -68,12 +68,13 @@ from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import add_prefix, is_cuda_available, is_hip +from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip _is_hip = is_hip() +_is_cuda = is_cuda() -if is_cuda_available(): - from sgl_kernel import bmm_fp8 +if _is_cuda: + from sgl_kernel import awq_dequantize, bmm_fp8 class DeepseekV2MLP(nn.Module): @@ -1174,14 +1175,21 @@ class DeepseekV2ForCausalLM(nn.Module): self_attn = self.model.layers[layer_id].self_attn if hasattr(self_attn.kv_b_proj, "qweight"): # AWQ compatible - w = ops.awq_dequantize( - self_attn.kv_b_proj.qweight, - self_attn.kv_b_proj.scales, - self_attn.kv_b_proj.qzeros, - 0, - 0, - 0, - ).T + if _is_cuda: + w = awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + ).T + else: + w = ops.awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + 0, + 0, + 0, + ).T else: w = self_attn.kv_b_proj.weight # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.