diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index 68b6432fc..30d6c7f39 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -7,8 +7,6 @@ import torch from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( - DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, - DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_MLA_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -125,7 +123,7 @@ class TestFlashAttention3MLA(BaseFlashAttentionTest): class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest): """Test FlashAttention3 with speculative decode enabled.""" - model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST + model = "meta-llama/Llama-3.1-8B-Instruct" @classmethod def get_server_args(cls): @@ -137,7 +135,7 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest): "--speculative-algorithm", "EAGLE3", "--speculative-draft", - DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B", "--speculative-num-steps", "3", "--speculative-eagle-topk",