Use public model for FA3 speculative decode testing (#5152)
This commit is contained in:
@@ -7,8 +7,6 @@ import torch
|
|||||||
from sglang.srt.utils import get_device_sm, kill_process_tree
|
from sglang.srt.utils import get_device_sm, kill_process_tree
|
||||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
|
||||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
|
||||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
@@ -125,7 +123,7 @@ class TestFlashAttention3MLA(BaseFlashAttentionTest):
|
|||||||
class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
|
class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
|
||||||
"""Test FlashAttention3 with speculative decode enabled."""
|
"""Test FlashAttention3 with speculative decode enabled."""
|
||||||
|
|
||||||
model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_server_args(cls):
|
def get_server_args(cls):
|
||||||
@@ -137,7 +135,7 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
|
|||||||
"--speculative-algorithm",
|
"--speculative-algorithm",
|
||||||
"EAGLE3",
|
"EAGLE3",
|
||||||
"--speculative-draft",
|
"--speculative-draft",
|
||||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
"jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
|
||||||
"--speculative-num-steps",
|
"--speculative-num-steps",
|
||||||
"3",
|
"3",
|
||||||
"--speculative-eagle-topk",
|
"--speculative-eagle-topk",
|
||||||
|
|||||||
Reference in New Issue
Block a user