[Minor] Fix styles for overlap mode (#2068)
This commit is contained in:
@@ -1002,7 +1002,7 @@ class Scheduler:
|
|||||||
if req.is_retracted:
|
if req.is_retracted:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self.server_args.enable_overlap_schedule and (req.finished()):
|
if self.enable_overlap and req.finished():
|
||||||
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -1319,7 +1319,7 @@ def run_scheduler_process(
|
|||||||
try:
|
try:
|
||||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
||||||
pipe_writer.send("ready")
|
pipe_writer.send("ready")
|
||||||
if server_args.enable_overlap_schedule:
|
if scheduler.enable_overlap:
|
||||||
scheduler.event_loop_overlap()
|
scheduler.event_loop_overlap()
|
||||||
else:
|
else:
|
||||||
scheduler.event_loop_normal()
|
scheduler.event_loop_normal()
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ import torch
|
|||||||
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -176,16 +175,8 @@ class TpModelWorkerClient:
|
|||||||
) % self.future_token_ids_limit
|
) % self.future_token_ids_limit
|
||||||
return None, future_next_token_ids
|
return None, future_next_token_ids
|
||||||
|
|
||||||
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
|
||||||
embeddings = logits_output.embeddings
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||||
success, message = self.model_runner.update_weights(
|
success, message = self.worker.update_weights(recv_req)
|
||||||
recv_req.model_path, recv_req.load_format
|
|
||||||
)
|
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
def __delete__(self):
|
def __delete__(self):
|
||||||
|
|||||||
@@ -276,10 +276,6 @@ class ModelRunner:
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
self.dtype = self.vllm_model_config.dtype
|
self.dtype = self.vllm_model_config.dtype
|
||||||
if self.sliding_window_size:
|
|
||||||
assert (
|
|
||||||
self.server_args.attention_backend == "flashinfer"
|
|
||||||
), "Only flashinfer supports window attention."
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Load weight end. "
|
f"Load weight end. "
|
||||||
|
|||||||
@@ -1,4 +1,8 @@
|
|||||||
import subprocess
|
"""
|
||||||
|
Usage:
|
||||||
|
python3 -m unittest test_triton_attention_backend.TestTritonAttnBackend.test_mmlu
|
||||||
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user