Optimize conflicts between CUDA graph and vocab mask tensors (#1392)
This commit is contained in:
@@ -207,15 +207,15 @@ def extend(reqs, model_runner):
|
||||
tree_cache=None,
|
||||
)
|
||||
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
||||
sample_output, logits_output = model_runner.forward(batch)
|
||||
next_token_ids = sample_output.batch_next_token_ids.tolist()
|
||||
logits_output = model_runner.forward(batch)
|
||||
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
||||
return next_token_ids, logits_output.next_token_logits, batch
|
||||
|
||||
|
||||
def decode(input_token_ids, batch, model_runner):
|
||||
batch.prepare_for_decode(input_token_ids)
|
||||
sample_output, logits_output = model_runner.forward(batch)
|
||||
next_token_ids = sample_output.batch_next_token_ids.tolist()
|
||||
logits_output = model_runner.forward(batch)
|
||||
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
||||
return next_token_ids, logits_output.next_token_logits
|
||||
|
||||
|
||||
|
||||
@@ -35,21 +35,6 @@ class Sampler(CustomOp):
|
||||
self.forward_native = self.forward_cuda
|
||||
self.is_torch_compile = False
|
||||
|
||||
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
||||
# min-token, presence, frequency
|
||||
if sampling_info.linear_penalties is not None:
|
||||
logits += sampling_info.linear_penalties
|
||||
|
||||
# repetition
|
||||
if sampling_info.scaling_penalties is not None:
|
||||
logits = torch.where(
|
||||
logits > 0,
|
||||
logits / sampling_info.scaling_penalties,
|
||||
logits * sampling_info.scaling_penalties,
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
||||
# Post process logits
|
||||
logits = logits.contiguous()
|
||||
@@ -58,14 +43,6 @@ class Sampler(CustomOp):
|
||||
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
||||
logits.add_(0)
|
||||
|
||||
if sampling_info.logit_bias is not None:
|
||||
logits.add_(sampling_info.logit_bias)
|
||||
|
||||
if sampling_info.vocab_mask is not None:
|
||||
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
||||
|
||||
logits = self._apply_penalties(logits, sampling_info)
|
||||
|
||||
return torch.softmax(logits, dim=-1)
|
||||
|
||||
def forward_cuda(
|
||||
|
||||
@@ -33,10 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
# Put some global args for easy access
|
||||
@@ -710,18 +706,3 @@ class ScheduleBatch:
|
||||
self.out_cache_loc = None
|
||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||
|
||||
def check_sample_results(self, sample_output: SampleOutput):
|
||||
if not torch.all(sample_output.success):
|
||||
probs = sample_output.probs
|
||||
batch_next_token_ids = sample_output.batch_next_token_ids
|
||||
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||
argmax_ids = torch.argmax(probs, dim=-1)
|
||||
batch_next_token_ids = torch.where(
|
||||
sample_output.success, batch_next_token_ids, argmax_ids
|
||||
)
|
||||
sample_output.probs = probs
|
||||
sample_output.batch_next_token_ids = batch_next_token_ids
|
||||
|
||||
return sample_output.batch_next_token_ids
|
||||
|
||||
@@ -547,8 +547,9 @@ class ModelTpServer:
|
||||
if self.model_runner.is_generation:
|
||||
# Forward and sample the next tokens
|
||||
if batch.extend_num_tokens != 0:
|
||||
sample_output, logits_output = self.model_runner.forward(batch)
|
||||
next_token_ids = batch.check_sample_results(sample_output)
|
||||
logits_output = self.model_runner.forward(batch)
|
||||
next_token_ids = self.model_runner.sample(logits_output, batch)
|
||||
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
)
|
||||
@@ -723,8 +724,8 @@ class ModelTpServer:
|
||||
batch.prepare_for_decode()
|
||||
|
||||
# Forward and sample the next tokens
|
||||
sample_output, logits_output = self.model_runner.forward(batch)
|
||||
next_token_ids = batch.check_sample_results(sample_output)
|
||||
logits_output = self.model_runner.forward(batch)
|
||||
next_token_ids = self.model_runner.sample(logits_output, batch)
|
||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||
next_token_ids
|
||||
)
|
||||
|
||||
@@ -30,10 +30,8 @@ from sglang.srt.layers.logits_processor import (
|
||||
LogitsProcessor,
|
||||
LogitsProcessorOutput,
|
||||
)
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -129,10 +127,6 @@ class CudaGraphRunner:
|
||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||
)
|
||||
|
||||
# Sampling info
|
||||
vocab_size = model_runner.model_config.vocab_size
|
||||
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
||||
|
||||
if self.use_torch_compile:
|
||||
set_torch_compile_config()
|
||||
|
||||
@@ -191,7 +185,6 @@ class CudaGraphRunner:
|
||||
def run_once():
|
||||
input_metadata = InputMetadata(
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
sampling_info=self.sampling_info[:bs],
|
||||
batch_size=bs,
|
||||
req_pool_indices=req_pool_indices,
|
||||
seq_lens=seq_lens,
|
||||
@@ -250,14 +243,9 @@ class CudaGraphRunner:
|
||||
bs, self.req_pool_indices, self.seq_lens
|
||||
)
|
||||
|
||||
# Sampling inputs
|
||||
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
|
||||
|
||||
# Replay
|
||||
torch.cuda.synchronize()
|
||||
self.graphs[bs].replay()
|
||||
torch.cuda.synchronize()
|
||||
sample_output, logits_output = self.output_buffers[bs]
|
||||
logits_output = self.output_buffers[bs]
|
||||
|
||||
# Unpad
|
||||
if bs != raw_bs:
|
||||
@@ -269,11 +257,6 @@ class CudaGraphRunner:
|
||||
input_top_logprobs=None,
|
||||
output_top_logprobs=None,
|
||||
)
|
||||
sample_output = SampleOutput(
|
||||
sample_output.success[:raw_bs],
|
||||
sample_output.probs[:raw_bs],
|
||||
sample_output.batch_next_token_ids[:raw_bs],
|
||||
)
|
||||
|
||||
# Extract logprobs
|
||||
if batch.return_logprob:
|
||||
@@ -290,4 +273,4 @@ class CudaGraphRunner:
|
||||
logits_output.next_token_logprobs, logits_metadata
|
||||
)[1]
|
||||
|
||||
return sample_output, logits_output
|
||||
return logits_output
|
||||
|
||||
@@ -28,7 +28,6 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
|
||||
|
||||
class ForwardMode(IntEnum):
|
||||
@@ -59,7 +58,6 @@ class InputMetadata:
|
||||
"""Store all inforamtion of a forward pass."""
|
||||
|
||||
forward_mode: ForwardMode
|
||||
sampling_info: SamplingBatchInfo
|
||||
batch_size: int
|
||||
req_pool_indices: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
@@ -170,7 +168,6 @@ class InputMetadata:
|
||||
):
|
||||
ret = cls(
|
||||
forward_mode=batch.forward_mode,
|
||||
sampling_info=batch.sampling_info,
|
||||
batch_size=batch.batch_size(),
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
@@ -182,8 +179,6 @@ class InputMetadata:
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
)
|
||||
|
||||
ret.sampling_info.update_penalties()
|
||||
ret.sampling_info.update_regex_vocab_mask(batch)
|
||||
ret.compute_positions(batch)
|
||||
|
||||
if not batch.forward_mode.is_decode():
|
||||
|
||||
@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput, Sampler
|
||||
from sglang.srt.lora.lora_manager import LoRAManager
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
@@ -107,6 +108,7 @@ class ModelRunner:
|
||||
|
||||
# Init componnets
|
||||
min_per_gpu_memory = self.init_torch_distributed()
|
||||
self.sampler = Sampler()
|
||||
self.load_model()
|
||||
if server_args.lora_paths is not None:
|
||||
self.init_lora_manager()
|
||||
@@ -466,11 +468,8 @@ class ModelRunner:
|
||||
def forward_decode(self, batch: ScheduleBatch):
|
||||
if self.server_args.lora_paths is not None:
|
||||
self.lora_manager.prepare_lora_batch(batch)
|
||||
if (
|
||||
self.cuda_graph_runner
|
||||
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
||||
and batch.sampling_info.can_run_in_cuda_graph()
|
||||
):
|
||||
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
||||
return self.cuda_graph_runner.replay(batch)
|
||||
|
||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||
@@ -510,9 +509,7 @@ class ModelRunner:
|
||||
input_metadata.image_offsets,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, batch: ScheduleBatch
|
||||
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
||||
def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
|
||||
assert batch.forward_mode is not None
|
||||
|
||||
if self.is_multimodal_model and batch.forward_mode.is_extend():
|
||||
@@ -524,6 +521,57 @@ class ModelRunner:
|
||||
else:
|
||||
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
|
||||
|
||||
def _check_sample_results(self, sample_output: SampleOutput):
|
||||
if not torch.all(sample_output.success):
|
||||
probs = sample_output.probs
|
||||
batch_next_token_ids = sample_output.batch_next_token_ids
|
||||
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||
argmax_ids = torch.argmax(probs, dim=-1)
|
||||
batch_next_token_ids = torch.where(
|
||||
sample_output.success, batch_next_token_ids, argmax_ids
|
||||
)
|
||||
sample_output.probs = probs
|
||||
sample_output.batch_next_token_ids = batch_next_token_ids
|
||||
|
||||
return sample_output.batch_next_token_ids
|
||||
|
||||
def _apply_logits_bias(
|
||||
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
||||
):
|
||||
# Apply logit_bias
|
||||
if sampling_info.logit_bias is not None:
|
||||
logits.add_(sampling_info.logit_bias)
|
||||
|
||||
# min-token, presence, frequency
|
||||
if sampling_info.linear_penalties is not None:
|
||||
logits += sampling_info.linear_penalties
|
||||
|
||||
# repetition
|
||||
if sampling_info.scaling_penalties is not None:
|
||||
logits = torch.where(
|
||||
logits > 0,
|
||||
logits / sampling_info.scaling_penalties,
|
||||
logits * sampling_info.scaling_penalties,
|
||||
)
|
||||
|
||||
# Apply regex vocab_mask
|
||||
if sampling_info.vocab_mask is not None:
|
||||
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
|
||||
) -> torch.Tensor:
|
||||
batch.sampling_info.update_regex_vocab_mask(batch)
|
||||
batch.sampling_info.update_penalties()
|
||||
logits = self._apply_logits_bias(
|
||||
logits_output.next_token_logits, batch.sampling_info
|
||||
)
|
||||
sample_output = self.sampler(logits, batch.sampling_info)
|
||||
return self._check_sample_results(sample_output)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def import_model_classes():
|
||||
|
||||
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -346,7 +345,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -355,12 +353,9 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
LoraConfig = None
|
||||
@@ -371,7 +370,6 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
self.transformer = ChatGLMModel(config, cache_config, quant_config)
|
||||
self.lm_head = self.transformer.output_layer
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -381,11 +379,9 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
|
||||
@@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
self.model = CohereModel(config, quant_config)
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module):
|
||||
positions,
|
||||
input_metadata,
|
||||
)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module):
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
expert_params_mapping = [
|
||||
|
||||
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module):
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
@@ -649,7 +648,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -658,11 +656,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -304,7 +303,6 @@ class ExaoneForCausalLM(nn.Module):
|
||||
self.transformer = ExaoneModel(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -317,11 +315,9 @@ class ExaoneForCausalLM(nn.Module):
|
||||
hidden_states = self.transformer(
|
||||
input_ids, positions, input_metadata, input_embeds
|
||||
)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
|
||||
self.quant_config = quant_config
|
||||
self.model = GemmaModel(config, quant_config=quant_config)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return (sample_output, logits_output)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
|
||||
from sglang.srt.layers.layernorm import GemmaRMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -347,7 +346,6 @@ class Gemma2ForCausalLM(nn.Module):
|
||||
self.quant_config = quant_config
|
||||
self.model = Gemma2Model(config, cache_config, quant_config)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -358,11 +356,9 @@ class Gemma2ForCausalLM(nn.Module):
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def get_attention_sliding_window_size(self):
|
||||
return get_attention_sliding_window_size(self.config)
|
||||
|
||||
@@ -35,7 +35,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.layers.activation import get_act_fn
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
|
||||
@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -298,7 +297,6 @@ class Grok1ForCausalLM(nn.Module):
|
||||
self.model = Grok1Model(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Monkey patch _prepare_weights to load pre-sharded weights
|
||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
||||
@@ -315,11 +313,9 @@ class Grok1ForCausalLM(nn.Module):
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
self.model = InternLM2Model(config, quant_config)
|
||||
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.output.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -41,7 +41,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.layers.torchao_utils import torchao_quantize_param_data
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
@@ -305,7 +304,6 @@ class LlamaForCausalLM(nn.Module):
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
self.param_dict = dict(self.named_parameters())
|
||||
|
||||
@@ -318,11 +316,9 @@ class LlamaForCausalLM(nn.Module):
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> LogitsProcessorOutput:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def get_hidden_dim(self, module_name):
|
||||
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
||||
|
||||
@@ -23,7 +23,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import SampleOutput
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
@@ -75,25 +74,7 @@ class LlamaForClassification(nn.Module):
|
||||
output_top_logprobs=None,
|
||||
)
|
||||
|
||||
# A dummy to make this work
|
||||
sample_output = SampleOutput(
|
||||
success=torch.full(
|
||||
size=(scores.shape[0],),
|
||||
fill_value=True,
|
||||
dtype=torch.bool,
|
||||
),
|
||||
probs=torch.full(
|
||||
size=(scores.shape[0], 1),
|
||||
fill_value=1.0,
|
||||
dtype=torch.float16,
|
||||
),
|
||||
batch_next_token_ids=torch.full(
|
||||
size=(scores.shape[0],),
|
||||
fill_value=0,
|
||||
dtype=torch.long,
|
||||
),
|
||||
)
|
||||
return sample_output, logits_output
|
||||
return logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
params_dict = self.param_dict
|
||||
|
||||
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module):
|
||||
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module):
|
||||
lm_head_weight = self.model.embed_tokens.weight
|
||||
else:
|
||||
lm_head_weight = self.lm_head.weight
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, lm_head_weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
@@ -572,7 +571,6 @@ class MiniCPM3ForCausalLM(nn.Module):
|
||||
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
||||
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -590,11 +588,9 @@ class MiniCPM3ForCausalLM(nn.Module):
|
||||
lm_head_weight = self.model.embed_tokens.weight
|
||||
else:
|
||||
lm_head_weight = self.lm_head.weight
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, lm_head_weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -300,7 +299,6 @@ class MixtralForCausalLM(nn.Module):
|
||||
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -310,11 +308,9 @@ class MixtralForCausalLM(nn.Module):
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -45,7 +45,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -334,7 +333,6 @@ class QuantMixtralForCausalLM(nn.Module):
|
||||
self.model = MixtralModel(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -345,11 +343,9 @@ class QuantMixtralForCausalLM(nn.Module):
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -252,7 +251,6 @@ class QWenLMHeadModel(nn.Module):
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -262,11 +260,9 @@ class QWenLMHeadModel(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
):
|
||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -40,7 +40,6 @@ from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
Qwen2Config = None
|
||||
@@ -277,7 +276,6 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
self.model = Qwen2Model(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -291,11 +289,9 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
if not get_embedding:
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
else:
|
||||
return self.pooler(hidden_states, input_metadata)
|
||||
|
||||
|
||||
@@ -47,7 +47,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -365,7 +364,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -376,11 +374,9 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -250,7 +249,6 @@ class StableLmForCausalLM(nn.Module):
|
||||
self.model = StableLMEpochModel(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
@@ -261,11 +259,9 @@ class StableLmForCausalLM(nn.Module):
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||
|
||||
|
||||
@@ -307,7 +306,6 @@ class XverseForCausalLM(nn.Module):
|
||||
self.model = XverseModel(config, quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
self.param_dict = dict(self.named_parameters())
|
||||
|
||||
@@ -320,12 +318,9 @@ class XverseForCausalLM(nn.Module):
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
# print(f"{hidden_states=}")
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
||||
|
||||
@@ -44,7 +44,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||
|
||||
|
||||
@@ -383,7 +382,6 @@ class XverseMoeForCausalLM(nn.Module):
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
self.param_dict = dict(self.named_parameters())
|
||||
|
||||
@@ -395,11 +393,9 @@ class XverseMoeForCausalLM(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata)
|
||||
logits_output = self.logits_processor(
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
||||
return sample_output, logits_output
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -41,7 +41,6 @@ class SamplingBatchInfo:
|
||||
# Vocab bias and min_ps are not supported in CUDA graph
|
||||
return (
|
||||
self.logit_bias is None
|
||||
and self.vocab_mask is None
|
||||
and self.linear_penalties is None
|
||||
and self.scaling_penalties is None
|
||||
and not self.need_min_p_sampling
|
||||
@@ -50,9 +49,11 @@ class SamplingBatchInfo:
|
||||
@classmethod
|
||||
def dummy_one(cls, max_bs: int, vocab_size: int):
|
||||
ret = cls(vocab_size=vocab_size)
|
||||
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda")
|
||||
ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda")
|
||||
ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda")
|
||||
with torch.device("cuda"):
|
||||
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float)
|
||||
ret.top_ps = torch.ones((max_bs,), dtype=torch.float)
|
||||
ret.top_ks = torch.ones((max_bs,), dtype=torch.int)
|
||||
ret.vocab_mask = torch.zeros((max_bs, vocab_size), dtype=torch.bool)
|
||||
return ret
|
||||
|
||||
def __getitem__(self, key):
|
||||
@@ -64,6 +65,7 @@ class SamplingBatchInfo:
|
||||
temperatures=self.temperatures[key],
|
||||
top_ps=self.top_ps[key],
|
||||
top_ks=self.top_ks[key],
|
||||
vocab_mask=self.vocab_mask[key],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -77,6 +79,11 @@ class SamplingBatchInfo:
|
||||
self.top_ps[:bs] = other.top_ps
|
||||
self.top_ks[:bs] = other.top_ks
|
||||
|
||||
if other.vocab_mask is None:
|
||||
self.vocab_mask[:bs].fill_(False)
|
||||
else:
|
||||
self.vocab_mask[:bs] = other.vocab_mask
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||
device = "cuda"
|
||||
|
||||
Reference in New Issue
Block a user