Enable optional FP32 compute for LM Head (#10729)

Thanks to MiniMax Team and Chenyang Zhao's support.
This commit is contained in:
narutolhy
2025-09-29 20:45:17 -07:00
committed by GitHub
parent 8831c55c3d
commit d17986f8c6
6 changed files with 130 additions and 2 deletions

View File

@@ -220,6 +220,7 @@ class LogitsProcessor(nn.Module):
self.config = config
self.logit_scale = logit_scale
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
if self.use_attn_tp_group:
self.attn_tp_size = get_attention_tp_size()
self.do_tensor_parallel_all_gather = (
@@ -461,7 +462,11 @@ class LogitsProcessor(nn.Module):
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
if hasattr(lm_head, "weight"):
if use_intel_amx_backend(lm_head):
if self.use_fp32_lm_head:
logits = torch.matmul(
hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
)
elif use_intel_amx_backend(lm_head):
logits = torch.ops.sgl_kernel.weight_packed_linear(
hidden_states.to(lm_head.weight.dtype),
lm_head.weight,
@@ -475,7 +480,15 @@ class LogitsProcessor(nn.Module):
else:
# GGUF models
# TODO: use weight_packed_linear for GGUF models
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
if self.use_fp32_lm_head:
with torch.cuda.amp.autocast(enabled=False):
logits = lm_head.quant_method.apply(
lm_head, hidden_states.to(torch.float32), embedding_bias
)
else:
logits = lm_head.quant_method.apply(
lm_head, hidden_states, embedding_bias
)
if self.logit_scale is not None:
logits.mul_(self.logit_scale)

View File

@@ -90,6 +90,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"disable_flashinfer_cutlass_moe_fp4_allgather",
"disable_radix_cache",
"enable_dp_lm_head",
"enable_fp32_lm_head",
"flashinfer_mxfp4_moe_precision",
"enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size",

View File

@@ -167,6 +167,7 @@ class ServerArgs:
quantization: Optional[str] = None
quantization_param_path: Optional[str] = None
kv_cache_dtype: str = "auto"
enable_fp32_lm_head: bool = False
# Memory and scheduling
mem_fraction_static: Optional[float] = None
@@ -1392,6 +1393,11 @@ class ServerArgs:
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
)
parser.add_argument(
"--enable-fp32-lm-head",
action="store_true",
help="If set, the LM head outputs (logits) are in FP32.",
)
# Memory and scheduling
parser.add_argument(