cleanup deps 2/n (#4464)
This commit is contained in:
@@ -23,7 +23,9 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from sglang.srt.utils import is_cuda_available
|
from sglang.srt.utils import is_cuda_available
|
||||||
|
|
||||||
if is_cuda_available():
|
_is_cuda = is_cuda_available()
|
||||||
|
|
||||||
|
if _is_cuda:
|
||||||
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
||||||
|
|
||||||
from sglang.srt.custom_op import CustomOp
|
from sglang.srt.custom_op import CustomOp
|
||||||
@@ -165,7 +167,7 @@ def get_act_fn(
|
|||||||
return act_fn
|
return act_fn
|
||||||
|
|
||||||
|
|
||||||
if not is_cuda_available():
|
if not _is_cuda:
|
||||||
logger.info(
|
logger.info(
|
||||||
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,7 +21,9 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from sglang.srt.utils import is_cuda_available
|
from sglang.srt.utils import is_cuda_available
|
||||||
|
|
||||||
if is_cuda_available():
|
_is_cuda = is_cuda_available()
|
||||||
|
|
||||||
|
if _is_cuda:
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
fused_add_rmsnorm,
|
fused_add_rmsnorm,
|
||||||
gemma_fused_add_rmsnorm,
|
gemma_fused_add_rmsnorm,
|
||||||
@@ -117,7 +119,7 @@ class GemmaRMSNorm(CustomOp):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
if not is_cuda_available():
|
if not _is_cuda:
|
||||||
logger.info(
|
logger.info(
|
||||||
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -41,9 +41,13 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
||||||
from sglang.srt.utils import add_prefix, is_hip
|
from sglang.srt.utils import add_prefix, is_cuda, is_hip
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
|
if _is_cuda:
|
||||||
|
from sgl_kernel import awq_dequantize
|
||||||
|
|
||||||
|
|
||||||
class DeepseekModelNextN(nn.Module):
|
class DeepseekModelNextN(nn.Module):
|
||||||
@@ -261,14 +265,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
self_attn = self.model.decoder.self_attn
|
self_attn = self.model.decoder.self_attn
|
||||||
if hasattr(self_attn.kv_b_proj, "qweight"):
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
||||||
# AWQ compatible
|
# AWQ compatible
|
||||||
w = ops.awq_dequantize(
|
if _is_cuda:
|
||||||
self_attn.kv_b_proj.qweight,
|
w = awq_dequantize(
|
||||||
self_attn.kv_b_proj.scales,
|
self_attn.kv_b_proj.qweight,
|
||||||
self_attn.kv_b_proj.qzeros,
|
self_attn.kv_b_proj.scales,
|
||||||
0,
|
self_attn.kv_b_proj.qzeros,
|
||||||
0,
|
).T
|
||||||
0,
|
else:
|
||||||
).T
|
w = ops.awq_dequantize(
|
||||||
|
self_attn.kv_b_proj.qweight,
|
||||||
|
self_attn.kv_b_proj.scales,
|
||||||
|
self_attn.kv_b_proj.qzeros,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
).T
|
||||||
else:
|
else:
|
||||||
w = self_attn.kv_b_proj.weight
|
w = self_attn.kv_b_proj.weight
|
||||||
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
||||||
|
|||||||
@@ -68,12 +68,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
|
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
if is_cuda_available():
|
if _is_cuda:
|
||||||
from sgl_kernel import bmm_fp8
|
from sgl_kernel import awq_dequantize, bmm_fp8
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2MLP(nn.Module):
|
class DeepseekV2MLP(nn.Module):
|
||||||
@@ -1174,14 +1175,21 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
self_attn = self.model.layers[layer_id].self_attn
|
self_attn = self.model.layers[layer_id].self_attn
|
||||||
if hasattr(self_attn.kv_b_proj, "qweight"):
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
||||||
# AWQ compatible
|
# AWQ compatible
|
||||||
w = ops.awq_dequantize(
|
if _is_cuda:
|
||||||
self_attn.kv_b_proj.qweight,
|
w = awq_dequantize(
|
||||||
self_attn.kv_b_proj.scales,
|
self_attn.kv_b_proj.qweight,
|
||||||
self_attn.kv_b_proj.qzeros,
|
self_attn.kv_b_proj.scales,
|
||||||
0,
|
self_attn.kv_b_proj.qzeros,
|
||||||
0,
|
).T
|
||||||
0,
|
else:
|
||||||
).T
|
w = ops.awq_dequantize(
|
||||||
|
self_attn.kv_b_proj.qweight,
|
||||||
|
self_attn.kv_b_proj.scales,
|
||||||
|
self_attn.kv_b_proj.qzeros,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
).T
|
||||||
else:
|
else:
|
||||||
w = self_attn.kv_b_proj.weight
|
w = self_attn.kv_b_proj.weight
|
||||||
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
||||||
|
|||||||
Reference in New Issue
Block a user