unify is_cuda and is_hip (#4321)

This commit is contained in:
Yineng Zhang
2025-03-11 18:12:56 -07:00
committed by GitHub
parent 1cf63485c1
commit d1da58e275
18 changed files with 104 additions and 92 deletions

View File

@@ -40,7 +40,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.utils import add_prefix, is_hip
is_hip_ = is_hip()
_is_hip = is_hip()
class DeepseekModelNextN(nn.Module):
@@ -277,7 +277,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
if is_hip_:
if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=w,
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
@@ -301,7 +301,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
and self_attn.w_scale is None
):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if is_hip_:
if _is_hip:
self_attn.w_scale *= 2.0

View File

@@ -65,7 +65,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
is_hip_ = is_hip()
_is_hip = is_hip()
if is_cuda_available():
from sgl_kernel import bmm_fp8
@@ -571,7 +571,7 @@ class DeepseekV2AttentionMLA(nn.Module):
if no_absorb():
return self.forward_normal(positions, hidden_states, forward_batch)
else:
if is_hip_:
if _is_hip:
if (
os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
and forward_batch.forward_mode.is_decode()
@@ -1190,7 +1190,7 @@ class DeepseekV2ForCausalLM(nn.Module):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
if is_hip_:
if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=w,
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
@@ -1230,7 +1230,7 @@ class DeepseekV2ForCausalLM(nn.Module):
and self_attn.w_scale is None
):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if is_hip_:
if _is_hip:
self_attn.w_scale *= 2.0
def get_embed_and_head(self):