Optimize conflicts between CUDA graph and vocab mask tensors (#1392)

This commit is contained in:
Liangsheng Yin
2024-09-13 20:27:53 -07:00
committed by GitHub
parent f3d32f888a
commit 70b6802982
32 changed files with 103 additions and 224 deletions

View File

@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.layers.sampler import SampleOutput, Sampler
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool,
)
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_available_gpu_memory,
@@ -107,6 +108,7 @@ class ModelRunner:
# Init componnets
min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler()
self.load_model()
if server_args.lora_paths is not None:
self.init_lora_manager()
@@ -466,11 +468,8 @@ class ModelRunner:
def forward_decode(self, batch: ScheduleBatch):
if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(batch)
if (
self.cuda_graph_runner
and self.cuda_graph_runner.can_run(len(batch.reqs))
and batch.sampling_info.can_run_in_cuda_graph()
):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.from_schedule_batch(self, batch)
@@ -510,9 +509,7 @@ class ModelRunner:
input_metadata.image_offsets,
)
def forward(
self, batch: ScheduleBatch
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
assert batch.forward_mode is not None
if self.is_multimodal_model and batch.forward_mode.is_extend():
@@ -524,6 +521,57 @@ class ModelRunner:
else:
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
def _check_sample_results(self, sample_output: SampleOutput):
if not torch.all(sample_output.success):
probs = sample_output.probs
batch_next_token_ids = sample_output.batch_next_token_ids
logging.warning("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
sample_output.success, batch_next_token_ids, argmax_ids
)
sample_output.probs = probs
sample_output.batch_next_token_ids = batch_next_token_ids
return sample_output.batch_next_token_ids
def _apply_logits_bias(
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
):
# Apply logit_bias
if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias)
# min-token, presence, frequency
if sampling_info.linear_penalties is not None:
logits += sampling_info.linear_penalties
# repetition
if sampling_info.scaling_penalties is not None:
logits = torch.where(
logits > 0,
logits / sampling_info.scaling_penalties,
logits * sampling_info.scaling_penalties,
)
# Apply regex vocab_mask
if sampling_info.vocab_mask is not None:
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
return logits
def sample(
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
) -> torch.Tensor:
batch.sampling_info.update_regex_vocab_mask(batch)
batch.sampling_info.update_penalties()
logits = self._apply_logits_bias(
logits_output.next_token_logits, batch.sampling_info
)
sample_output = self.sampler(logits, batch.sampling_info)
return self._check_sample_results(sample_output)
@lru_cache()
def import_model_classes():