Enable optional FP32 compute for LM Head (#10729)
Thanks to MiniMax Team and Chenyang Zhao's support.
This commit is contained in:
@@ -220,6 +220,7 @@ class LogitsProcessor(nn.Module):
|
||||
self.config = config
|
||||
self.logit_scale = logit_scale
|
||||
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
|
||||
self.use_fp32_lm_head = global_server_args_dict["enable_fp32_lm_head"]
|
||||
if self.use_attn_tp_group:
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.do_tensor_parallel_all_gather = (
|
||||
@@ -461,7 +462,11 @@ class LogitsProcessor(nn.Module):
|
||||
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
||||
|
||||
if hasattr(lm_head, "weight"):
|
||||
if use_intel_amx_backend(lm_head):
|
||||
if self.use_fp32_lm_head:
|
||||
logits = torch.matmul(
|
||||
hidden_states.to(torch.float32), lm_head.weight.to(torch.float32).T
|
||||
)
|
||||
elif use_intel_amx_backend(lm_head):
|
||||
logits = torch.ops.sgl_kernel.weight_packed_linear(
|
||||
hidden_states.to(lm_head.weight.dtype),
|
||||
lm_head.weight,
|
||||
@@ -475,7 +480,15 @@ class LogitsProcessor(nn.Module):
|
||||
else:
|
||||
# GGUF models
|
||||
# TODO: use weight_packed_linear for GGUF models
|
||||
logits = lm_head.quant_method.apply(lm_head, hidden_states, embedding_bias)
|
||||
if self.use_fp32_lm_head:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
logits = lm_head.quant_method.apply(
|
||||
lm_head, hidden_states.to(torch.float32), embedding_bias
|
||||
)
|
||||
else:
|
||||
logits = lm_head.quant_method.apply(
|
||||
lm_head, hidden_states, embedding_bias
|
||||
)
|
||||
|
||||
if self.logit_scale is not None:
|
||||
logits.mul_(self.logit_scale)
|
||||
|
||||
@@ -90,6 +90,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"disable_flashinfer_cutlass_moe_fp4_allgather",
|
||||
"disable_radix_cache",
|
||||
"enable_dp_lm_head",
|
||||
"enable_fp32_lm_head",
|
||||
"flashinfer_mxfp4_moe_precision",
|
||||
"enable_flashinfer_allreduce_fusion",
|
||||
"moe_dense_tp_size",
|
||||
|
||||
@@ -167,6 +167,7 @@ class ServerArgs:
|
||||
quantization: Optional[str] = None
|
||||
quantization_param_path: Optional[str] = None
|
||||
kv_cache_dtype: str = "auto"
|
||||
enable_fp32_lm_head: bool = False
|
||||
|
||||
# Memory and scheduling
|
||||
mem_fraction_static: Optional[float] = None
|
||||
@@ -1392,6 +1393,11 @@ class ServerArgs:
|
||||
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
|
||||
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-fp32-lm-head",
|
||||
action="store_true",
|
||||
help="If set, the LM head outputs (logits) are in FP32.",
|
||||
)
|
||||
|
||||
# Memory and scheduling
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user