Update CI threshold & Improve code style (#2159)

This commit is contained in:
Lianmin Zheng
2024-11-24 06:29:38 -08:00
committed by GitHub
parent e3938b2f9c
commit 5652c56535
8 changed files with 126 additions and 41 deletions

View File

@@ -212,6 +212,7 @@ def extend(reqs, model_runner):
token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None,
model_config=model_runner.model_config,
enable_overlap=False,
)
batch.prepare_for_extend()
model_worker_batch = batch.get_model_worker_batch()

View File

@@ -1,3 +1,8 @@
"""
Torch-native implementation for FusedMoE. This is used for torch.compile.
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
"""
from typing import Callable, Optional
import torch

View File

@@ -437,9 +437,12 @@ class ScheduleBatch:
token_to_kv_pool: BaseTokenToKVPool = None
tree_cache: BasePrefixCache = None
# For utility
# Batch configs
model_config: ModelConfig = None
forward_mode: ForwardMode = None
enable_overlap: bool = False
# Sampling info
sampling_info: SamplingBatchInfo = None
next_batch_sampling_info: SamplingBatchInfo = None
@@ -488,10 +491,11 @@ class ScheduleBatch:
def init_new(
cls,
reqs: List[Req],
req_to_token_pool,
token_to_kv_pool,
tree_cache,
model_config,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: ReqToTokenPool,
tree_cache: BasePrefixCache,
model_config: ModelConfig,
enable_overlap: bool,
):
return cls(
reqs=reqs,
@@ -499,6 +503,7 @@ class ScheduleBatch:
token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache,
model_config=model_config,
enable_overlap=enable_overlap,
return_logprob=any(req.return_logprob for req in reqs),
has_stream=any(req.stream for req in reqs),
has_grammar=any(req.grammar for req in reqs),
@@ -612,7 +617,7 @@ class ScheduleBatch:
assert len(self.out_cache_loc) == self.extend_num_tokens
def prepare_for_extend(self, enable_overlap_schedule: bool = False):
def prepare_for_extend(self):
self.forward_mode = ForwardMode.EXTEND
bs = len(self.reqs)
@@ -706,7 +711,7 @@ class ScheduleBatch:
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
enable_overlap_schedule=enable_overlap_schedule,
enable_overlap_schedule=self.enable_overlap,
)
def mix_with_running(self, running_batch: "ScheduleBatch"):
@@ -897,7 +902,7 @@ class ScheduleBatch:
self.seq_lens_sum = 0
self.extend_num_tokens = 0
def prepare_for_decode(self, enable_overlap: bool = False):
def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE
self.input_ids = self.output_ids
@@ -914,7 +919,7 @@ class ScheduleBatch:
else:
locs = self.seq_lens
if enable_overlap:
if self.enable_overlap:
# Do not use in-place operations in the overlap mode
self.req_to_token_pool.write(
(self.req_pool_indices, locs), self.out_cache_loc

View File

@@ -466,6 +466,7 @@ class Scheduler:
self.token_to_kv_pool,
self.tree_cache,
self.model_config,
self.enable_overlap,
)
idle_batch.prepare_for_idle()
return idle_batch
@@ -842,14 +843,15 @@ class Scheduler:
self.token_to_kv_pool,
self.tree_cache,
self.model_config,
self.enable_overlap,
)
new_batch.prepare_for_extend(self.enable_overlap)
new_batch.prepare_for_extend()
# Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.filter_batch()
if not self.running_batch.is_empty():
self.running_batch.prepare_for_decode(self.enable_overlap)
self.running_batch.prepare_for_decode()
new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch = None
@@ -900,7 +902,7 @@ class Scheduler:
self.batch_is_full = False
# Update batch tensors
batch.prepare_for_decode(self.enable_overlap)
batch.prepare_for_decode()
return batch
def run_batch(self, batch: ScheduleBatch):
@@ -1055,6 +1057,7 @@ class Scheduler:
continue
if self.enable_overlap and req.finished():
# Free the one delayed token
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
continue

View File

@@ -23,7 +23,7 @@ import torch
from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.fused_moe.patch import fused_moe_forward_native
from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native
from sglang.srt.layers.logits_processor import (
LogitsMetadata,
LogitsProcessor,