cleanup deps 2/n (#4464)
This commit is contained in:
@@ -23,7 +23,9 @@ import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
|
||||
if is_cuda_available():
|
||||
_is_cuda = is_cuda_available()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
@@ -165,7 +167,7 @@ def get_act_fn(
|
||||
return act_fn
|
||||
|
||||
|
||||
if not is_cuda_available():
|
||||
if not _is_cuda:
|
||||
logger.info(
|
||||
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||
)
|
||||
|
||||
@@ -21,7 +21,9 @@ import torch.nn as nn
|
||||
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
|
||||
if is_cuda_available():
|
||||
_is_cuda = is_cuda_available()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import (
|
||||
fused_add_rmsnorm,
|
||||
gemma_fused_add_rmsnorm,
|
||||
@@ -117,7 +119,7 @@ class GemmaRMSNorm(CustomOp):
|
||||
return out
|
||||
|
||||
|
||||
if not is_cuda_available():
|
||||
if not _is_cuda:
|
||||
logger.info(
|
||||
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||
)
|
||||
|
||||
@@ -41,9 +41,13 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
||||
from sglang.srt.utils import add_prefix, is_hip
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import awq_dequantize
|
||||
|
||||
|
||||
class DeepseekModelNextN(nn.Module):
|
||||
@@ -261,14 +265,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
self_attn = self.model.decoder.self_attn
|
||||
if hasattr(self_attn.kv_b_proj, "qweight"):
|
||||
# AWQ compatible
|
||||
w = ops.awq_dequantize(
|
||||
self_attn.kv_b_proj.qweight,
|
||||
self_attn.kv_b_proj.scales,
|
||||
self_attn.kv_b_proj.qzeros,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
).T
|
||||
if _is_cuda:
|
||||
w = awq_dequantize(
|
||||
self_attn.kv_b_proj.qweight,
|
||||
self_attn.kv_b_proj.scales,
|
||||
self_attn.kv_b_proj.qzeros,
|
||||
).T
|
||||
else:
|
||||
w = ops.awq_dequantize(
|
||||
self_attn.kv_b_proj.qweight,
|
||||
self_attn.kv_b_proj.scales,
|
||||
self_attn.kv_b_proj.qzeros,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
).T
|
||||
else:
|
||||
w = self_attn.kv_b_proj.weight
|
||||
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
||||
|
||||
@@ -68,12 +68,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if is_cuda_available():
|
||||
from sgl_kernel import bmm_fp8
|
||||
if _is_cuda:
|
||||
from sgl_kernel import awq_dequantize, bmm_fp8
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
@@ -1174,14 +1175,21 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self_attn = self.model.layers[layer_id].self_attn
|
||||
if hasattr(self_attn.kv_b_proj, "qweight"):
|
||||
# AWQ compatible
|
||||
w = ops.awq_dequantize(
|
||||
self_attn.kv_b_proj.qweight,
|
||||
self_attn.kv_b_proj.scales,
|
||||
self_attn.kv_b_proj.qzeros,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
).T
|
||||
if _is_cuda:
|
||||
w = awq_dequantize(
|
||||
self_attn.kv_b_proj.qweight,
|
||||
self_attn.kv_b_proj.scales,
|
||||
self_attn.kv_b_proj.qzeros,
|
||||
).T
|
||||
else:
|
||||
w = ops.awq_dequantize(
|
||||
self_attn.kv_b_proj.qweight,
|
||||
self_attn.kv_b_proj.scales,
|
||||
self_attn.kv_b_proj.qzeros,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
).T
|
||||
else:
|
||||
w = self_attn.kv_b_proj.weight
|
||||
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
||||
|
||||
Reference in New Issue
Block a user