unify is_cuda and is_hip (#4321)

This commit is contained in:
Yineng Zhang
2025-03-11 18:12:56 -07:00
committed by GitHub
parent 1cf63485c1
commit d1da58e275
18 changed files with 104 additions and 92 deletions

View File

@@ -30,6 +30,8 @@ from sglang.srt.utils import is_hip, set_weight_attrs
logger = logging.getLogger(__name__)
_is_hip = is_hip()
class GroupedGemmRunner(torch.nn.Module):
flashinfer_gemm_warpper = None
@@ -703,7 +705,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)