unify is_cuda and is_hip (#4321)
This commit is contained in:
@@ -40,7 +40,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
||||
from sglang.srt.utils import add_prefix, is_hip
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
class DeepseekModelNextN(nn.Module):
|
||||
@@ -277,7 +277,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
if weight_block_size is not None:
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=w,
|
||||
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
||||
@@ -301,7 +301,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
and self_attn.w_scale is None
|
||||
):
|
||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
self_attn.w_scale *= 2.0
|
||||
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
|
||||
|
||||
is_hip_ = is_hip()
|
||||
_is_hip = is_hip()
|
||||
|
||||
if is_cuda_available():
|
||||
from sgl_kernel import bmm_fp8
|
||||
@@ -571,7 +571,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
if no_absorb():
|
||||
return self.forward_normal(positions, hidden_states, forward_batch)
|
||||
else:
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
if (
|
||||
os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
||||
and forward_batch.forward_mode.is_decode()
|
||||
@@ -1190,7 +1190,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
if weight_block_size is not None:
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=w,
|
||||
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
||||
@@ -1230,7 +1230,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
and self_attn.w_scale is None
|
||||
):
|
||||
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
||||
if is_hip_:
|
||||
if _is_hip:
|
||||
self_attn.w_scale *= 2.0
|
||||
|
||||
def get_embed_and_head(self):
|
||||
|
||||
Reference in New Issue
Block a user