Fix oom issues with fp8 for llama (#1454)
This commit is contained in:
12
.github/workflows/pr-test.yml
vendored
12
.github/workflows/pr-test.yml
vendored
@@ -144,18 +144,18 @@ jobs:
|
||||
cd test/srt
|
||||
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_without_radix_cache
|
||||
|
||||
- name: Benchmark Offline Throughput (w/o ChunkedPrefill)
|
||||
timeout-minutes: 10
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_without_chunked_prefill
|
||||
|
||||
- name: Benchmark Offline Throughput (w/ Triton)
|
||||
timeout-minutes: 10
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_with_triton_attention_backend
|
||||
|
||||
- name: Benchmark Offline Throughput (w/ FP8)
|
||||
timeout-minutes: 10
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default_fp8
|
||||
|
||||
performance-test-2-gpu:
|
||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||
runs-on: 2-gpu-runner
|
||||
|
||||
@@ -305,8 +305,6 @@ class LlamaForCausalLM(nn.Module):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
self.param_dict = dict(self.named_parameters())
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
@@ -374,7 +372,7 @@ class LlamaForCausalLM(nn.Module):
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = self.param_dict
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
|
||||
@@ -36,6 +36,7 @@ class LlamaForClassification(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.torchao_config = None
|
||||
self.quant_config = quant_config
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
|
||||
@@ -44,8 +45,6 @@ class LlamaForClassification(nn.Module):
|
||||
)
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
self.param_dict = dict(self.named_parameters())
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
@@ -77,7 +76,7 @@ class LlamaForClassification(nn.Module):
|
||||
return logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = self.param_dict
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "classification_head" in name:
|
||||
|
||||
@@ -307,8 +307,6 @@ class XverseForCausalLM(nn.Module):
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
self.param_dict = dict(self.named_parameters())
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
@@ -333,7 +331,7 @@ class XverseForCausalLM(nn.Module):
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = self.param_dict
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
def load_weights_per_param(name, loaded_weight):
|
||||
if "rotary_emb.inv_freq" in name or "projector" in name:
|
||||
|
||||
@@ -383,8 +383,6 @@ class XverseMoeForCausalLM(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
self.param_dict = dict(self.named_parameters())
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
@@ -406,8 +404,7 @@ class XverseMoeForCausalLM(nn.Module):
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = self.param_dict
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
|
||||
@@ -22,6 +22,7 @@ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8"
|
||||
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_FP8_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
|
||||
is_in_ci,
|
||||
@@ -59,6 +60,17 @@ class TestBenchServing(unittest.TestCase):
|
||||
if is_in_ci():
|
||||
assert res["output_throughput"] > 2600
|
||||
|
||||
def test_offline_throughput_default_fp8(self):
|
||||
res = run_bench_serving(
|
||||
model=DEFAULT_FP8_MODEL_NAME_FOR_TEST,
|
||||
num_prompts=500,
|
||||
request_rate=float("inf"),
|
||||
other_server_args=[],
|
||||
)
|
||||
|
||||
if is_in_ci():
|
||||
assert res["output_throughput"] > 3100
|
||||
|
||||
def test_online_latency_default(self):
|
||||
res = run_bench_serving(
|
||||
model=DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
|
||||
@@ -12,8 +12,10 @@ from sglang.test.test_utils import (
|
||||
|
||||
|
||||
class TestChunkedPrefill(unittest.TestCase):
|
||||
def run_mmlu(self, disable_radix_cache, enable_mixed_chunk):
|
||||
other_args = ["--chunked-prefill-size", "32"]
|
||||
def run_mmlu(
|
||||
self, disable_radix_cache, enable_mixed_chunk, chunked_prefill_size=32
|
||||
):
|
||||
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
|
||||
if disable_radix_cache:
|
||||
other_args += ["--disable-radix-cache"]
|
||||
|
||||
@@ -55,6 +57,11 @@ class TestChunkedPrefill(unittest.TestCase):
|
||||
def test_mixed_chunked_prefill_without_radix_cache(self):
|
||||
self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=True)
|
||||
|
||||
def test_no_chunked_prefill(self):
|
||||
self.run_mmlu(
|
||||
disable_radix_cache=False, enable_mixed_chunk=False, chunked_prefill_size=-1
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user