[bugfix] Add 'disaggregation_mode' parameter to warmup function when compile deep_gemm manually (#8618)
This commit is contained in:
@@ -17,6 +17,7 @@ import time
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
|
||||||
from sglang.srt.entrypoints.http_server import launch_server
|
from sglang.srt.entrypoints.http_server import launch_server
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
@@ -52,7 +53,9 @@ class CompileArgs:
|
|||||||
|
|
||||||
|
|
||||||
@warmup("compile-deep-gemm")
|
@warmup("compile-deep-gemm")
|
||||||
async def warm_up_compile(tokenizer_manager: TokenizerManager):
|
async def warm_up_compile(
|
||||||
|
disaggregation_mode: str, tokenizer_manager: TokenizerManager
|
||||||
|
):
|
||||||
print("\nGenerate warm up request for compiling DeepGEMM...\n")
|
print("\nGenerate warm up request for compiling DeepGEMM...\n")
|
||||||
generate_req_input = GenerateReqInput(
|
generate_req_input = GenerateReqInput(
|
||||||
input_ids=[0, 1, 2, 3],
|
input_ids=[0, 1, 2, 3],
|
||||||
@@ -62,6 +65,10 @@ async def warm_up_compile(tokenizer_manager: TokenizerManager):
|
|||||||
"ignore_eos": True,
|
"ignore_eos": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
if disaggregation_mode != "null":
|
||||||
|
generate_req_input.bootstrap_room = 0
|
||||||
|
generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
||||||
|
|
||||||
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user