Optimize conflicts between CUDA graph and vocab mask tensors (#1392)
This commit is contained in:
@@ -207,15 +207,15 @@ def extend(reqs, model_runner):
|
|||||||
tree_cache=None,
|
tree_cache=None,
|
||||||
)
|
)
|
||||||
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
batch.prepare_for_extend(model_runner.model_config.vocab_size)
|
||||||
sample_output, logits_output = model_runner.forward(batch)
|
logits_output = model_runner.forward(batch)
|
||||||
next_token_ids = sample_output.batch_next_token_ids.tolist()
|
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
||||||
return next_token_ids, logits_output.next_token_logits, batch
|
return next_token_ids, logits_output.next_token_logits, batch
|
||||||
|
|
||||||
|
|
||||||
def decode(input_token_ids, batch, model_runner):
|
def decode(input_token_ids, batch, model_runner):
|
||||||
batch.prepare_for_decode(input_token_ids)
|
batch.prepare_for_decode(input_token_ids)
|
||||||
sample_output, logits_output = model_runner.forward(batch)
|
logits_output = model_runner.forward(batch)
|
||||||
next_token_ids = sample_output.batch_next_token_ids.tolist()
|
next_token_ids = model_runner.sample(logits_output, batch).tolist()
|
||||||
return next_token_ids, logits_output.next_token_logits
|
return next_token_ids, logits_output.next_token_logits
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -35,21 +35,6 @@ class Sampler(CustomOp):
|
|||||||
self.forward_native = self.forward_cuda
|
self.forward_native = self.forward_cuda
|
||||||
self.is_torch_compile = False
|
self.is_torch_compile = False
|
||||||
|
|
||||||
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
|
||||||
# min-token, presence, frequency
|
|
||||||
if sampling_info.linear_penalties is not None:
|
|
||||||
logits += sampling_info.linear_penalties
|
|
||||||
|
|
||||||
# repetition
|
|
||||||
if sampling_info.scaling_penalties is not None:
|
|
||||||
logits = torch.where(
|
|
||||||
logits > 0,
|
|
||||||
logits / sampling_info.scaling_penalties,
|
|
||||||
logits * sampling_info.scaling_penalties,
|
|
||||||
)
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
||||||
# Post process logits
|
# Post process logits
|
||||||
logits = logits.contiguous()
|
logits = logits.contiguous()
|
||||||
@@ -58,14 +43,6 @@ class Sampler(CustomOp):
|
|||||||
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
||||||
logits.add_(0)
|
logits.add_(0)
|
||||||
|
|
||||||
if sampling_info.logit_bias is not None:
|
|
||||||
logits.add_(sampling_info.logit_bias)
|
|
||||||
|
|
||||||
if sampling_info.vocab_mask is not None:
|
|
||||||
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
|
||||||
|
|
||||||
logits = self._apply_penalties(logits, sampling_info)
|
|
||||||
|
|
||||||
return torch.softmax(logits, dim=-1)
|
return torch.softmax(logits, dim=-1)
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
|
|||||||
@@ -33,10 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
|||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from sglang.srt.layers.sampler import SampleOutput
|
|
||||||
|
|
||||||
|
|
||||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
|
|
||||||
# Put some global args for easy access
|
# Put some global args for easy access
|
||||||
@@ -710,18 +706,3 @@ class ScheduleBatch:
|
|||||||
self.out_cache_loc = None
|
self.out_cache_loc = None
|
||||||
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
|
|
||||||
def check_sample_results(self, sample_output: SampleOutput):
|
|
||||||
if not torch.all(sample_output.success):
|
|
||||||
probs = sample_output.probs
|
|
||||||
batch_next_token_ids = sample_output.batch_next_token_ids
|
|
||||||
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
|
||||||
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
|
||||||
argmax_ids = torch.argmax(probs, dim=-1)
|
|
||||||
batch_next_token_ids = torch.where(
|
|
||||||
sample_output.success, batch_next_token_ids, argmax_ids
|
|
||||||
)
|
|
||||||
sample_output.probs = probs
|
|
||||||
sample_output.batch_next_token_ids = batch_next_token_ids
|
|
||||||
|
|
||||||
return sample_output.batch_next_token_ids
|
|
||||||
|
|||||||
@@ -547,8 +547,9 @@ class ModelTpServer:
|
|||||||
if self.model_runner.is_generation:
|
if self.model_runner.is_generation:
|
||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
sample_output, logits_output = self.model_runner.forward(batch)
|
logits_output = self.model_runner.forward(batch)
|
||||||
next_token_ids = batch.check_sample_results(sample_output)
|
next_token_ids = self.model_runner.sample(logits_output, batch)
|
||||||
|
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
next_token_ids
|
next_token_ids
|
||||||
)
|
)
|
||||||
@@ -723,8 +724,8 @@ class ModelTpServer:
|
|||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
|
|
||||||
# Forward and sample the next tokens
|
# Forward and sample the next tokens
|
||||||
sample_output, logits_output = self.model_runner.forward(batch)
|
logits_output = self.model_runner.forward(batch)
|
||||||
next_token_ids = batch.check_sample_results(sample_output)
|
next_token_ids = self.model_runner.sample(logits_output, batch)
|
||||||
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
||||||
next_token_ids
|
next_token_ids
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -30,10 +30,8 @@ from sglang.srt.layers.logits_processor import (
|
|||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
LogitsProcessorOutput,
|
LogitsProcessorOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.sampler import SampleOutput
|
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|
||||||
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -129,10 +127,6 @@ class CudaGraphRunner:
|
|||||||
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sampling info
|
|
||||||
vocab_size = model_runner.model_config.vocab_size
|
|
||||||
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
|
||||||
|
|
||||||
if self.use_torch_compile:
|
if self.use_torch_compile:
|
||||||
set_torch_compile_config()
|
set_torch_compile_config()
|
||||||
|
|
||||||
@@ -191,7 +185,6 @@ class CudaGraphRunner:
|
|||||||
def run_once():
|
def run_once():
|
||||||
input_metadata = InputMetadata(
|
input_metadata = InputMetadata(
|
||||||
forward_mode=ForwardMode.DECODE,
|
forward_mode=ForwardMode.DECODE,
|
||||||
sampling_info=self.sampling_info[:bs],
|
|
||||||
batch_size=bs,
|
batch_size=bs,
|
||||||
req_pool_indices=req_pool_indices,
|
req_pool_indices=req_pool_indices,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
@@ -250,14 +243,9 @@ class CudaGraphRunner:
|
|||||||
bs, self.req_pool_indices, self.seq_lens
|
bs, self.req_pool_indices, self.seq_lens
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sampling inputs
|
|
||||||
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
|
|
||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
torch.cuda.synchronize()
|
|
||||||
self.graphs[bs].replay()
|
self.graphs[bs].replay()
|
||||||
torch.cuda.synchronize()
|
logits_output = self.output_buffers[bs]
|
||||||
sample_output, logits_output = self.output_buffers[bs]
|
|
||||||
|
|
||||||
# Unpad
|
# Unpad
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
@@ -269,11 +257,6 @@ class CudaGraphRunner:
|
|||||||
input_top_logprobs=None,
|
input_top_logprobs=None,
|
||||||
output_top_logprobs=None,
|
output_top_logprobs=None,
|
||||||
)
|
)
|
||||||
sample_output = SampleOutput(
|
|
||||||
sample_output.success[:raw_bs],
|
|
||||||
sample_output.probs[:raw_bs],
|
|
||||||
sample_output.batch_next_token_ids[:raw_bs],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract logprobs
|
# Extract logprobs
|
||||||
if batch.return_logprob:
|
if batch.return_logprob:
|
||||||
@@ -290,4 +273,4 @@ class CudaGraphRunner:
|
|||||||
logits_output.next_token_logprobs, logits_metadata
|
logits_output.next_token_logprobs, logits_metadata
|
||||||
)[1]
|
)[1]
|
||||||
|
|
||||||
return sample_output, logits_output
|
return logits_output
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|
||||||
|
|
||||||
|
|
||||||
class ForwardMode(IntEnum):
|
class ForwardMode(IntEnum):
|
||||||
@@ -59,7 +58,6 @@ class InputMetadata:
|
|||||||
"""Store all inforamtion of a forward pass."""
|
"""Store all inforamtion of a forward pass."""
|
||||||
|
|
||||||
forward_mode: ForwardMode
|
forward_mode: ForwardMode
|
||||||
sampling_info: SamplingBatchInfo
|
|
||||||
batch_size: int
|
batch_size: int
|
||||||
req_pool_indices: torch.Tensor
|
req_pool_indices: torch.Tensor
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
@@ -170,7 +168,6 @@ class InputMetadata:
|
|||||||
):
|
):
|
||||||
ret = cls(
|
ret = cls(
|
||||||
forward_mode=batch.forward_mode,
|
forward_mode=batch.forward_mode,
|
||||||
sampling_info=batch.sampling_info,
|
|
||||||
batch_size=batch.batch_size(),
|
batch_size=batch.batch_size(),
|
||||||
req_pool_indices=batch.req_pool_indices,
|
req_pool_indices=batch.req_pool_indices,
|
||||||
seq_lens=batch.seq_lens,
|
seq_lens=batch.seq_lens,
|
||||||
@@ -182,8 +179,6 @@ class InputMetadata:
|
|||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
)
|
)
|
||||||
|
|
||||||
ret.sampling_info.update_penalties()
|
|
||||||
ret.sampling_info.update_regex_vocab_mask(batch)
|
|
||||||
ret.compute_positions(batch)
|
ret.compute_positions(batch)
|
||||||
|
|
||||||
if not batch.forward_mode.is_decode():
|
if not batch.forward_mode.is_decode():
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
|
|||||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||||
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import SampleOutput
|
from sglang.srt.layers.sampler import SampleOutput, Sampler
|
||||||
from sglang.srt.lora.lora_manager import LoRAManager
|
from sglang.srt.lora.lora_manager import LoRAManager
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
ReqToTokenPool,
|
ReqToTokenPool,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_available_gpu_memory,
|
get_available_gpu_memory,
|
||||||
@@ -107,6 +108,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
# Init componnets
|
# Init componnets
|
||||||
min_per_gpu_memory = self.init_torch_distributed()
|
min_per_gpu_memory = self.init_torch_distributed()
|
||||||
|
self.sampler = Sampler()
|
||||||
self.load_model()
|
self.load_model()
|
||||||
if server_args.lora_paths is not None:
|
if server_args.lora_paths is not None:
|
||||||
self.init_lora_manager()
|
self.init_lora_manager()
|
||||||
@@ -466,11 +468,8 @@ class ModelRunner:
|
|||||||
def forward_decode(self, batch: ScheduleBatch):
|
def forward_decode(self, batch: ScheduleBatch):
|
||||||
if self.server_args.lora_paths is not None:
|
if self.server_args.lora_paths is not None:
|
||||||
self.lora_manager.prepare_lora_batch(batch)
|
self.lora_manager.prepare_lora_batch(batch)
|
||||||
if (
|
|
||||||
self.cuda_graph_runner
|
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
||||||
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
|
||||||
and batch.sampling_info.can_run_in_cuda_graph()
|
|
||||||
):
|
|
||||||
return self.cuda_graph_runner.replay(batch)
|
return self.cuda_graph_runner.replay(batch)
|
||||||
|
|
||||||
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
||||||
@@ -510,9 +509,7 @@ class ModelRunner:
|
|||||||
input_metadata.image_offsets,
|
input_metadata.image_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
|
||||||
self, batch: ScheduleBatch
|
|
||||||
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
|
||||||
assert batch.forward_mode is not None
|
assert batch.forward_mode is not None
|
||||||
|
|
||||||
if self.is_multimodal_model and batch.forward_mode.is_extend():
|
if self.is_multimodal_model and batch.forward_mode.is_extend():
|
||||||
@@ -524,6 +521,57 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
|
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
|
||||||
|
|
||||||
|
def _check_sample_results(self, sample_output: SampleOutput):
|
||||||
|
if not torch.all(sample_output.success):
|
||||||
|
probs = sample_output.probs
|
||||||
|
batch_next_token_ids = sample_output.batch_next_token_ids
|
||||||
|
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
||||||
|
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
||||||
|
argmax_ids = torch.argmax(probs, dim=-1)
|
||||||
|
batch_next_token_ids = torch.where(
|
||||||
|
sample_output.success, batch_next_token_ids, argmax_ids
|
||||||
|
)
|
||||||
|
sample_output.probs = probs
|
||||||
|
sample_output.batch_next_token_ids = batch_next_token_ids
|
||||||
|
|
||||||
|
return sample_output.batch_next_token_ids
|
||||||
|
|
||||||
|
def _apply_logits_bias(
|
||||||
|
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
||||||
|
):
|
||||||
|
# Apply logit_bias
|
||||||
|
if sampling_info.logit_bias is not None:
|
||||||
|
logits.add_(sampling_info.logit_bias)
|
||||||
|
|
||||||
|
# min-token, presence, frequency
|
||||||
|
if sampling_info.linear_penalties is not None:
|
||||||
|
logits += sampling_info.linear_penalties
|
||||||
|
|
||||||
|
# repetition
|
||||||
|
if sampling_info.scaling_penalties is not None:
|
||||||
|
logits = torch.where(
|
||||||
|
logits > 0,
|
||||||
|
logits / sampling_info.scaling_penalties,
|
||||||
|
logits * sampling_info.scaling_penalties,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply regex vocab_mask
|
||||||
|
if sampling_info.vocab_mask is not None:
|
||||||
|
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch.sampling_info.update_regex_vocab_mask(batch)
|
||||||
|
batch.sampling_info.update_penalties()
|
||||||
|
logits = self._apply_logits_bias(
|
||||||
|
logits_output.next_token_logits, batch.sampling_info
|
||||||
|
)
|
||||||
|
sample_output = self.sampler(logits, batch.sampling_info)
|
||||||
|
return self._check_sample_results(sample_output)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def import_model_classes():
|
def import_model_classes():
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -346,7 +345,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
if self.config.tie_word_embeddings:
|
if self.config.tie_word_embeddings:
|
||||||
self.lm_head.weight = self.model.embed_tokens.weight
|
self.lm_head.weight = self.model.embed_tokens.weight
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -355,12 +353,9 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata)
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
LoraConfig = None
|
LoraConfig = None
|
||||||
@@ -371,7 +370,6 @@ class ChatGLMForCausalLM(nn.Module):
|
|||||||
self.transformer = ChatGLMModel(config, cache_config, quant_config)
|
self.transformer = ChatGLMModel(config, cache_config, quant_config)
|
||||||
self.lm_head = self.transformer.output_layer
|
self.lm_head = self.transformer.output_layer
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -381,11 +379,9 @@ class ChatGLMForCausalLM(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
|||||||
@@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
self.model = CohereModel(config, quant_config)
|
self.model = CohereModel(config, quant_config)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module):
|
|||||||
positions,
|
positions,
|
||||||
input_metadata,
|
input_metadata,
|
||||||
)
|
)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module):
|
|||||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
expert_params_mapping = [
|
expert_params_mapping = [
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module):
|
|||||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata)
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
@@ -649,7 +648,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -658,11 +656,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata)
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -304,7 +303,6 @@ class ExaoneForCausalLM(nn.Module):
|
|||||||
self.transformer = ExaoneModel(config, quant_config=quant_config)
|
self.transformer = ExaoneModel(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -317,11 +315,9 @@ class ExaoneForCausalLM(nn.Module):
|
|||||||
hidden_states = self.transformer(
|
hidden_states = self.transformer(
|
||||||
input_ids, positions, input_metadata, input_embeds
|
input_ids, positions, input_metadata, input_embeds
|
||||||
)
|
)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = GemmaModel(config, quant_config=quant_config)
|
self.model = GemmaModel(config, quant_config=quant_config)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return (sample_output, logits_output)
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
|
|||||||
from sglang.srt.layers.layernorm import GemmaRMSNorm
|
from sglang.srt.layers.layernorm import GemmaRMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -347,7 +346,6 @@ class Gemma2ForCausalLM(nn.Module):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = Gemma2Model(config, cache_config, quant_config)
|
self.model = Gemma2Model(config, cache_config, quant_config)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -358,11 +356,9 @@ class Gemma2ForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def get_attention_sliding_window_size(self):
|
def get_attention_sliding_window_size(self):
|
||||||
return get_attention_sliding_window_size(self.config)
|
return get_attention_sliding_window_size(self.config)
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.layers.activation import get_act_fn
|
from sglang.srt.layers.activation import get_act_fn
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
if lora_config:
|
if lora_config:
|
||||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -298,7 +297,6 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
self.model = Grok1Model(config, quant_config=quant_config)
|
self.model = Grok1Model(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
# Monkey patch _prepare_weights to load pre-sharded weights
|
# Monkey patch _prepare_weights to load pre-sharded weights
|
||||||
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
||||||
@@ -315,11 +313,9 @@ class Grok1ForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
|
|||||||
self.model = InternLM2Model(config, quant_config)
|
self.model = InternLM2Model(config, quant_config)
|
||||||
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.output.weight, input_metadata
|
input_ids, hidden_states, self.output.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.layers.torchao_utils import torchao_quantize_param_data
|
from sglang.srt.layers.torchao_utils import torchao_quantize_param_data
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
@@ -305,7 +304,6 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
self.model = LlamaModel(config, quant_config=quant_config)
|
self.model = LlamaModel(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
self.param_dict = dict(self.named_parameters())
|
self.param_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
@@ -318,11 +316,9 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> LogitsProcessorOutput:
|
) -> LogitsProcessorOutput:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def get_hidden_dim(self, module_name):
|
def get_hidden_dim(self, module_name):
|
||||||
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import SampleOutput
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
||||||
|
|
||||||
@@ -75,25 +74,7 @@ class LlamaForClassification(nn.Module):
|
|||||||
output_top_logprobs=None,
|
output_top_logprobs=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# A dummy to make this work
|
return logits_output
|
||||||
sample_output = SampleOutput(
|
|
||||||
success=torch.full(
|
|
||||||
size=(scores.shape[0],),
|
|
||||||
fill_value=True,
|
|
||||||
dtype=torch.bool,
|
|
||||||
),
|
|
||||||
probs=torch.full(
|
|
||||||
size=(scores.shape[0], 1),
|
|
||||||
fill_value=1.0,
|
|
||||||
dtype=torch.float16,
|
|
||||||
),
|
|
||||||
batch_next_token_ids=torch.full(
|
|
||||||
size=(scores.shape[0],),
|
|
||||||
fill_value=0,
|
|
||||||
dtype=torch.long,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
params_dict = self.param_dict
|
params_dict = self.param_dict
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module):
|
|||||||
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module):
|
|||||||
lm_head_weight = self.model.embed_tokens.weight
|
lm_head_weight = self.model.embed_tokens.weight
|
||||||
else:
|
else:
|
||||||
lm_head_weight = self.lm_head.weight
|
lm_head_weight = self.lm_head.weight
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, lm_head_weight, input_metadata
|
input_ids, hidden_states, lm_head_weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
@@ -572,7 +571,6 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|||||||
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -590,11 +588,9 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|||||||
lm_head_weight = self.model.embed_tokens.weight
|
lm_head_weight = self.model.embed_tokens.weight
|
||||||
else:
|
else:
|
||||||
lm_head_weight = self.lm_head.weight
|
lm_head_weight = self.lm_head.weight
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, lm_head_weight, input_metadata
|
input_ids, hidden_states, lm_head_weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -300,7 +299,6 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
|
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -310,11 +308,9 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -45,7 +45,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -334,7 +333,6 @@ class QuantMixtralForCausalLM(nn.Module):
|
|||||||
self.model = MixtralModel(config, quant_config=quant_config)
|
self.model = MixtralModel(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -345,11 +343,9 @@ class QuantMixtralForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -252,7 +251,6 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -262,11 +260,9 @@ class QWenLMHeadModel(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
):
|
):
|
||||||
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ from sglang.srt.layers.layernorm import RMSNorm
|
|||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
Qwen2Config = None
|
Qwen2Config = None
|
||||||
@@ -277,7 +276,6 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
self.model = Qwen2Model(config, quant_config=quant_config)
|
self.model = Qwen2Model(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -291,11 +289,9 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
if not get_embedding:
|
if not get_embedding:
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
else:
|
else:
|
||||||
return self.pooler(hidden_states, input_metadata)
|
return self.pooler(hidden_states, input_metadata)
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -365,7 +364,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -376,11 +374,9 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -250,7 +249,6 @@ class StableLmForCausalLM(nn.Module):
|
|||||||
self.model = StableLMEpochModel(config, quant_config=quant_config)
|
self.model = StableLMEpochModel(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
@@ -261,11 +259,9 @@ class StableLmForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.model_runner import InputMetadata
|
from sglang.srt.model_executor.model_runner import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -307,7 +306,6 @@ class XverseForCausalLM(nn.Module):
|
|||||||
self.model = XverseModel(config, quant_config=quant_config)
|
self.model = XverseModel(config, quant_config=quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
self.param_dict = dict(self.named_parameters())
|
self.param_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
@@ -320,12 +318,9 @@ class XverseForCausalLM(nn.Module):
|
|||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
# print(f"{hidden_states=}")
|
return self.logits_processor(
|
||||||
logits_output = self.logits_processor(
|
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(
|
def load_weights(
|
||||||
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.sampler import Sampler
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -383,7 +382,6 @@ class XverseMoeForCausalLM(nn.Module):
|
|||||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||||
)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
self.param_dict = dict(self.named_parameters())
|
self.param_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
@@ -395,11 +393,9 @@ class XverseMoeForCausalLM(nn.Module):
|
|||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata)
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
||||||
logits_output = self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
)
|
)
|
||||||
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
|
||||||
return sample_output, logits_output
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ class SamplingBatchInfo:
|
|||||||
# Vocab bias and min_ps are not supported in CUDA graph
|
# Vocab bias and min_ps are not supported in CUDA graph
|
||||||
return (
|
return (
|
||||||
self.logit_bias is None
|
self.logit_bias is None
|
||||||
and self.vocab_mask is None
|
|
||||||
and self.linear_penalties is None
|
and self.linear_penalties is None
|
||||||
and self.scaling_penalties is None
|
and self.scaling_penalties is None
|
||||||
and not self.need_min_p_sampling
|
and not self.need_min_p_sampling
|
||||||
@@ -50,9 +49,11 @@ class SamplingBatchInfo:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def dummy_one(cls, max_bs: int, vocab_size: int):
|
def dummy_one(cls, max_bs: int, vocab_size: int):
|
||||||
ret = cls(vocab_size=vocab_size)
|
ret = cls(vocab_size=vocab_size)
|
||||||
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda")
|
with torch.device("cuda"):
|
||||||
ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda")
|
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float)
|
||||||
ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda")
|
ret.top_ps = torch.ones((max_bs,), dtype=torch.float)
|
||||||
|
ret.top_ks = torch.ones((max_bs,), dtype=torch.int)
|
||||||
|
ret.vocab_mask = torch.zeros((max_bs, vocab_size), dtype=torch.bool)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
@@ -64,6 +65,7 @@ class SamplingBatchInfo:
|
|||||||
temperatures=self.temperatures[key],
|
temperatures=self.temperatures[key],
|
||||||
top_ps=self.top_ps[key],
|
top_ps=self.top_ps[key],
|
||||||
top_ks=self.top_ks[key],
|
top_ks=self.top_ks[key],
|
||||||
|
vocab_mask=self.vocab_mask[key],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -77,6 +79,11 @@ class SamplingBatchInfo:
|
|||||||
self.top_ps[:bs] = other.top_ps
|
self.top_ps[:bs] = other.top_ps
|
||||||
self.top_ks[:bs] = other.top_ks
|
self.top_ks[:bs] = other.top_ks
|
||||||
|
|
||||||
|
if other.vocab_mask is None:
|
||||||
|
self.vocab_mask[:bs].fill_(False)
|
||||||
|
else:
|
||||||
|
self.vocab_mask[:bs] = other.vocab_mask
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
|||||||
Reference in New Issue
Block a user