Optimize conflicts between CUDA graph and vocab mask tensors (#1392)
This commit is contained in:
@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput, Sampler
|
||||
from sglang.srt.lora.lora_manager import LoRAManager
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
@@ -107,6 +108,7 @@ class ModelRunner:
|
||||
|
||||
# Init componnets
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
self.sampler = Sampler()
|
||||
self.load_model()
|
||||
if server_args.lora_paths is not None:
|
||||
self.init_lora_manager()
|
||||
@@ -466,11 +468,8 @@ class ModelRunner:
|
||||
def forward_decode(self, batch: ScheduleBatch):
|
||||
if self.server_args.lora_paths is not None:
|
||||
self.lora_manager.prepare_lora_batch(batch)
|
||||
if (
|
||||
self.cuda_graph_runner
|
||||
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
||||
and batch.sampling_info.can_run_in_cuda_graph()
|
||||
):
|
||||
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
||||
return self.cuda_graph_runner.replay(batch)
|
||||
|
||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||
@@ -510,9 +509,7 @@ class ModelRunner:
|
||||
input_metadata.image_offsets,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, batch: ScheduleBatch
|
||||
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
||||
def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
|
||||
assert batch.forward_mode is not None
|
||||
|
||||
if self.is_multimodal_model and batch.forward_mode.is_extend():
|
||||
@@ -524,6 +521,57 @@ class ModelRunner:
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
|
||||
|
||||
def _check_sample_results(self, sample_output: SampleOutput):
|
||||
if not torch.all(sample_output.success):
|
||||
probs = sample_output.probs
|
||||
batch_next_token_ids = sample_output.batch_next_token_ids
|
||||
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||
argmax_ids = torch.argmax(probs, dim=-1)
|
||||
batch_next_token_ids = torch.where(
|
||||
sample_output.success, batch_next_token_ids, argmax_ids
|
||||
)
|
||||
sample_output.probs = probs
|
||||
sample_output.batch_next_token_ids = batch_next_token_ids
|
||||
|
||||
return sample_output.batch_next_token_ids
|
||||
|
||||
def _apply_logits_bias(
|
||||
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
||||
):
|
||||
# Apply logit_bias
|
||||
if sampling_info.logit_bias is not None:
|
||||
logits.add_(sampling_info.logit_bias)
|
||||
|
||||
# min-token, presence, frequency
|
||||
if sampling_info.linear_penalties is not None:
|
||||
logits += sampling_info.linear_penalties
|
||||
|
||||
# repetition
|
||||
if sampling_info.scaling_penalties is not None:
|
||||
logits = torch.where(
|
||||
logits > 0,
|
||||
logits / sampling_info.scaling_penalties,
|
||||
logits * sampling_info.scaling_penalties,
|
||||
)
|
||||
|
||||
# Apply regex vocab_mask
|
||||
if sampling_info.vocab_mask is not None:
|
||||
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
|
||||
) -> torch.Tensor:
|
||||
batch.sampling_info.update_regex_vocab_mask(batch)
|
||||
batch.sampling_info.update_penalties()
|
||||
logits = self._apply_logits_bias(
|
||||
logits_output.next_token_logits, batch.sampling_info
|
||||
)
|
||||
sample_output = self.sampler(logits, batch.sampling_info)
|
||||
return self._check_sample_results(sample_output)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def import_model_classes():
|
||||
|
||||
Reference in New Issue
Block a user