unify is_cuda and is_hip (#4321)
This commit is contained in:
@@ -30,6 +30,8 @@ from sglang.srt.utils import is_hip, set_weight_attrs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
class GroupedGemmRunner(torch.nn.Module):
|
||||
flashinfer_gemm_warpper = None
|
||||
@@ -703,7 +705,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If rocm, use float8_e4m3fnuz as dtype
|
||||
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
||||
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user