Enable optional FP32 compute for LM Head (#10729)
Thanks to MiniMax Team and Chenyang Zhao's support.
This commit is contained in:
106
test/srt/rl/test_fp32_lm_head.py
Normal file
106
test/srt/rl/test_fp32_lm_head.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
|
||||
|
||||
class LMHeadStub(nn.Module):
|
||||
def __init__(self, vocab, hidden, dtype, device="cuda"):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(
|
||||
torch.randn(vocab, hidden, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
|
||||
class DummyMeta:
|
||||
gathered_buffer = None
|
||||
next_token_logits_buffer = None
|
||||
|
||||
def compute_dp_attention_metadata(self): ...
|
||||
|
||||
|
||||
class TestLMHeadFP32(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not torch.cuda.is_available():
|
||||
raise unittest.SkipTest("needs CUDA GPU")
|
||||
|
||||
def _make_logprocessor(self, vocab_size, enable_fp32):
|
||||
global_server_args_dict["enable_dp_lm_head"] = False
|
||||
global_server_args_dict["enable_fp32_lm_head"] = enable_fp32
|
||||
cfg = SimpleNamespace(vocab_size=vocab_size, final_logit_softcapping=None)
|
||||
return LogitsProcessor(cfg, skip_all_gather=True, logit_scale=None)
|
||||
|
||||
def _run_case(
|
||||
self,
|
||||
hidden_state_dtype,
|
||||
enable_fp32,
|
||||
weights_dtype,
|
||||
expected_a_dtype,
|
||||
expected_b_dtype,
|
||||
):
|
||||
device = "cuda"
|
||||
BATCH_SIZE, HIDDEN_SIZE, VOCAB_SIZE = 2, 64, 128
|
||||
hidden_state = torch.randn(
|
||||
BATCH_SIZE, HIDDEN_SIZE, dtype=hidden_state_dtype, device=device
|
||||
)
|
||||
head = LMHeadStub(VOCAB_SIZE, HIDDEN_SIZE, dtype=weights_dtype, device=device)
|
||||
meta = DummyMeta()
|
||||
logprocessor = self._make_logprocessor(VOCAB_SIZE, enable_fp32)
|
||||
|
||||
original_matmul = torch.matmul
|
||||
original_linear = F.linear
|
||||
|
||||
state = {
|
||||
"called": False, # Whether a matmul/linear call has been intercepted yet
|
||||
"operation": None, # Which operation was captured ("matmul" or "linear")
|
||||
"a": None, # The dtype of the first input tensor to the operation
|
||||
"b": None, # The dtype of the second input tensor to the operation
|
||||
}
|
||||
|
||||
def probe_matmul(a, b, *args, **kw):
|
||||
if not state["called"]:
|
||||
state.update(called=True, operation="matmul", a=a.dtype, b=b.dtype)
|
||||
return original_matmul(a, b, *args, **kw)
|
||||
|
||||
def probe_linear(x, w, bias=None):
|
||||
if not state["called"]:
|
||||
state.update(called=True, ooperationp="linear", a=x.dtype, b=w.dtype)
|
||||
return original_linear(x, w, bias)
|
||||
|
||||
with patch("torch.matmul", new=probe_matmul), patch(
|
||||
"torch.nn.functional.linear", new=probe_linear
|
||||
):
|
||||
logits = logprocessor._get_logits(hidden_state, head, meta)
|
||||
self.assertEqual(hidden_state.dtype, hidden_state_dtype)
|
||||
self.assertTrue(state["called"], "no call lm head matlmul/linear")
|
||||
self.assertEqual(state["a"], expected_a_dtype)
|
||||
self.assertEqual(state["b"], expected_b_dtype)
|
||||
|
||||
def test_flag_true_fp16_activations(self):
|
||||
self._run_case(torch.float16, True, torch.float16, torch.float32, torch.float32)
|
||||
|
||||
def test_flag_true_bf16_activations(self):
|
||||
self._run_case(
|
||||
torch.bfloat16, True, torch.bfloat16, torch.float32, torch.float32
|
||||
)
|
||||
|
||||
def test_flag_false_fp16_path(self):
|
||||
self._run_case(
|
||||
torch.float16, False, torch.float16, torch.float16, torch.float16
|
||||
)
|
||||
|
||||
def test_flag_false_bf16_path(self):
|
||||
self._run_case(
|
||||
torch.bfloat16, False, torch.bfloat16, torch.bfloat16, torch.bfloat16
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
@@ -59,6 +59,7 @@ suites = {
|
||||
TestFile("quant/test_int8_kernel.py", 8),
|
||||
TestFile("quant/test_triton_scaled_mm.py", 8),
|
||||
TestFile("quant/test_w8a8_quantization.py", 46),
|
||||
TestFile("rl/test_fp32_lm_head.py", 30),
|
||||
TestFile("rl/test_update_weights_from_disk.py", 114),
|
||||
TestFile("rl/test_update_weights_from_tensor.py", 48),
|
||||
TestFile("test_abort.py", 51),
|
||||
|
||||
Reference in New Issue
Block a user