Optimize mem indices mangement (#619)
This commit is contained in:
@@ -12,7 +12,6 @@ from sglang.utils import http_request
|
||||
|
||||
|
||||
class RuntimeEndpoint(BaseBackend):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
@@ -38,7 +37,8 @@ class RuntimeEndpoint(BaseBackend):
|
||||
self.model_info = res.json()
|
||||
|
||||
self.chat_template = get_chat_template_by_model_path(
|
||||
self.model_info["model_path"])
|
||||
self.model_info["model_path"]
|
||||
)
|
||||
|
||||
def get_model_name(self):
|
||||
return self.model_info["model_path"]
|
||||
@@ -124,7 +124,12 @@ class RuntimeEndpoint(BaseBackend):
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
|
||||
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
|
||||
for item in [
|
||||
"return_logprob",
|
||||
"logprob_start_len",
|
||||
"top_logprobs_num",
|
||||
"return_text_in_logprobs",
|
||||
]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
data[item] = value
|
||||
@@ -171,7 +176,12 @@ class RuntimeEndpoint(BaseBackend):
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
|
||||
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
|
||||
for item in [
|
||||
"return_logprob",
|
||||
"logprob_start_len",
|
||||
"top_logprobs_num",
|
||||
"return_text_in_logprobs",
|
||||
]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
data[item] = value
|
||||
|
||||
@@ -32,7 +32,6 @@ import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -44,4 +44,5 @@ class GlobalConfig:
|
||||
# adjust_cache: Adjust the position embedding of KV cache.
|
||||
self.concate_and_append_mode = "no_adjust"
|
||||
|
||||
|
||||
global_config = GlobalConfig()
|
||||
|
||||
@@ -84,7 +84,7 @@ register_chat_template(
|
||||
"system": ("SYSTEM:", "\n"),
|
||||
"user": ("USER:", "\n"),
|
||||
"assistant": ("ASSISTANT:", "\n"),
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -177,7 +177,7 @@ register_chat_template(
|
||||
"assistant": ("", "<|im_end|>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
stop_str=("<|im_end|>",)
|
||||
stop_str=("<|im_end|>",),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -24,9 +24,9 @@ class SglSamplingParams:
|
||||
presence_penalty: float = 0.0
|
||||
ignore_eos: bool = False
|
||||
return_logprob: Optional[bool] = None
|
||||
logprob_start_len: Optional[int] = None,
|
||||
top_logprobs_num: Optional[int] = None,
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
logprob_start_len: Optional[int] = (None,)
|
||||
top_logprobs_num: Optional[int] = (None,)
|
||||
return_text_in_logprobs: Optional[bool] = (None,)
|
||||
|
||||
# for constrained generation, not included in to_xxx_kwargs
|
||||
dtype: Optional[str] = None
|
||||
|
||||
@@ -8,7 +8,10 @@ from vllm.distributed.parallel_state import graph_capture
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
||||
from sglang.srt.managers.controller.infer_batch import (
|
||||
Batch, ForwardMode, InputMetadata, init_flashinfer_args
|
||||
Batch,
|
||||
ForwardMode,
|
||||
InputMetadata,
|
||||
init_flashinfer_args,
|
||||
)
|
||||
|
||||
|
||||
@@ -24,18 +27,28 @@ class CudaGraphRunner:
|
||||
# Common inputs
|
||||
self.max_bs = max_batch_size_to_capture
|
||||
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||
self.req_pool_indices = torch.zeros(
|
||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||
self.position_ids_offsets = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||
self.position_ids_offsets = torch.zeros(
|
||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.out_cache_loc = torch.zeros(
|
||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
# FlashInfer inputs
|
||||
self.flashinfer_workspace_buffer = self.model_runner.flashinfer_workspace_buffers[0]
|
||||
self.flashinfer_workspace_buffer = (
|
||||
self.model_runner.flashinfer_workspace_buffers[0]
|
||||
)
|
||||
self.flashinfer_kv_indptr = torch.zeros(
|
||||
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.flashinfer_kv_indices = torch.zeros(
|
||||
(self.max_bs * model_runner.model_config.context_len,), dtype=torch.int32, device="cuda"
|
||||
(self.max_bs * model_runner.model_config.context_len,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
self.flashinfer_kv_last_page_len = torch.ones(
|
||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||
@@ -49,7 +62,12 @@ class CudaGraphRunner:
|
||||
with graph_capture() as graph_capture_context:
|
||||
self.stream = graph_capture_context.stream
|
||||
for bs in batch_size_list:
|
||||
graph, input_buffers, output_buffers, flashinfer_handler = self.capture_one_batch_size(bs)
|
||||
(
|
||||
graph,
|
||||
input_buffers,
|
||||
output_buffers,
|
||||
flashinfer_handler,
|
||||
) = self.capture_one_batch_size(bs)
|
||||
self.graphs[bs] = graph
|
||||
self.input_buffers[bs] = input_buffers
|
||||
self.output_buffers[bs] = output_buffers
|
||||
@@ -71,17 +89,19 @@ class CudaGraphRunner:
|
||||
|
||||
# FlashInfer inputs
|
||||
if not _grouped_size_compiled_for_decode_kernels(
|
||||
self.model_runner.model_config.num_attention_heads // self.model_runner.tp_size,
|
||||
self.model_runner.model_config.num_attention_heads
|
||||
// self.model_runner.tp_size,
|
||||
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
|
||||
):
|
||||
use_tensor_cores = True
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer, "NHD",
|
||||
self.flashinfer_workspace_buffer,
|
||||
"NHD",
|
||||
use_cuda_graph=True,
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[:bs+1],
|
||||
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
||||
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
||||
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
||||
)
|
||||
@@ -163,10 +183,14 @@ class CudaGraphRunner:
|
||||
else:
|
||||
output = LogitProcessorOutput(
|
||||
next_token_logits=output.next_token_logits[:raw_bs],
|
||||
next_token_logprobs=output.next_token_logprobs[:raw_bs] if output.next_token_logprobs is not None else None,
|
||||
next_token_logprobs=output.next_token_logprobs[:raw_bs]
|
||||
if output.next_token_logprobs is not None
|
||||
else None,
|
||||
normalized_prompt_logprobs=None,
|
||||
prefill_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None,
|
||||
decode_top_logprobs=output.decode_top_logprobs[:raw_bs]
|
||||
if output.decode_top_logprobs is not None
|
||||
else None,
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -668,7 +668,9 @@ class Batch:
|
||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||
except RuntimeError as e:
|
||||
warnings.warn(f"Ignore errors in sampling: {e}")
|
||||
sampled_index = torch.ones(probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device)
|
||||
sampled_index = torch.ones(
|
||||
probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
|
||||
)
|
||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
||||
-1
|
||||
)
|
||||
@@ -749,8 +751,14 @@ class InputMetadata:
|
||||
skip_flashinfer_init=False,
|
||||
):
|
||||
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
||||
init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
|
||||
model_runner.flashinfer_decode_wrapper)
|
||||
init_flashinfer_args(
|
||||
forward_mode,
|
||||
model_runner,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
model_runner.flashinfer_decode_wrapper,
|
||||
)
|
||||
|
||||
batch_size = len(req_pool_indices)
|
||||
|
||||
@@ -807,16 +815,24 @@ class InputMetadata:
|
||||
)
|
||||
|
||||
if model_runner.server_args.disable_flashinfer:
|
||||
(ret.triton_max_seq_len,
|
||||
ret.triton_max_extend_len,
|
||||
ret.triton_start_loc,
|
||||
ret.triton_prefix_lens) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
||||
(
|
||||
ret.triton_max_seq_len,
|
||||
ret.triton_max_extend_len,
|
||||
ret.triton_start_loc,
|
||||
ret.triton_prefix_lens,
|
||||
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
|
||||
flashinfer_decode_wrapper):
|
||||
def init_flashinfer_args(
|
||||
forward_mode,
|
||||
model_runner,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
flashinfer_decode_wrapper,
|
||||
):
|
||||
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
||||
head_dim = model_runner.model_config.head_dim
|
||||
@@ -827,9 +843,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
|
||||
else:
|
||||
paged_kernel_lens = prefix_lens
|
||||
|
||||
kv_indptr = torch.zeros(
|
||||
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||
@@ -842,9 +856,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
|
||||
],
|
||||
dim=0,
|
||||
).contiguous()
|
||||
kv_last_page_len = torch.ones(
|
||||
(batch_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||
|
||||
if forward_mode == ForwardMode.DECODE:
|
||||
flashinfer_decode_wrapper.end_forward()
|
||||
@@ -859,9 +871,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
|
||||
)
|
||||
else:
|
||||
# extend part
|
||||
qo_indptr = torch.zeros(
|
||||
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
|
||||
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||
|
||||
@@ -16,7 +16,12 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict
|
||||
from sglang.srt.managers.controller.infer_batch import (
|
||||
Batch,
|
||||
ForwardMode,
|
||||
InputMetadata,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
@@ -83,7 +88,9 @@ class ModelRunner:
|
||||
|
||||
# Set some global args
|
||||
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
|
||||
global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32
|
||||
global_server_args_dict[
|
||||
"attention_reduce_in_fp32"
|
||||
] = server_args.attention_reduce_in_fp32
|
||||
|
||||
# Load the model and create memory pool
|
||||
self.load_model()
|
||||
@@ -217,7 +224,9 @@ class ModelRunner:
|
||||
self.flashinfer_workspace_buffers[1], "NHD"
|
||||
)
|
||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
|
||||
self.flashinfer_workspace_buffers[0],
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores,
|
||||
)
|
||||
|
||||
def init_cuda_graphs(self):
|
||||
@@ -229,7 +238,9 @@ class ModelRunner:
|
||||
|
||||
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
|
||||
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
|
||||
self.cuda_graph_runner = CudaGraphRunner(self, max_batch_size_to_capture=max(batch_size_list))
|
||||
self.cuda_graph_runner = CudaGraphRunner(
|
||||
self, max_batch_size_to_capture=max(batch_size_list)
|
||||
)
|
||||
self.cuda_graph_runner.capture(batch_size_list)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@@ -125,7 +125,8 @@ class RadixCache:
|
||||
if x.lock_ref > 0:
|
||||
continue
|
||||
|
||||
num_evicted += evict_callback(x.value)
|
||||
evict_callback(x.value)
|
||||
num_evicted += len(x.value)
|
||||
self._delete_leaf(x)
|
||||
|
||||
if len(x.parent.children) == 0:
|
||||
|
||||
@@ -314,7 +314,9 @@ class ModelTpServer:
|
||||
self.forward_queue.append(req)
|
||||
|
||||
def get_new_fill_batch(self) -> Optional[Batch]:
|
||||
running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||
running_bs = (
|
||||
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||
)
|
||||
if running_bs >= self.max_running_requests:
|
||||
return
|
||||
|
||||
|
||||
@@ -39,10 +39,12 @@ class ReqToTokenPool:
|
||||
class TokenToKVPool:
|
||||
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
||||
self.size = size
|
||||
# mem_state is the reference counter.
|
||||
# This can be promised:
|
||||
# assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0)
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.int16, device="cuda")
|
||||
self.total_ref_ct = 0
|
||||
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.bool, device="cuda")
|
||||
self.total_size = self.size
|
||||
self.total_alloc = 0
|
||||
|
||||
# [size, key/value, head_num, head_dim] for each layer
|
||||
self.kv_data = [
|
||||
@@ -71,7 +73,9 @@ class TokenToKVPool:
|
||||
|
||||
addition_size = need_size - buffer_len
|
||||
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
||||
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
|
||||
select_index = (
|
||||
torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
|
||||
)
|
||||
|
||||
if select_index.shape[0] < addition_size:
|
||||
return None
|
||||
@@ -105,26 +109,22 @@ class TokenToKVPool:
|
||||
return select_index.to(torch.int32), start_loc, start_loc + need_size
|
||||
|
||||
def used_size(self):
|
||||
return len(torch.nonzero(self.mem_state).squeeze(1))
|
||||
return self.total_alloc
|
||||
|
||||
def available_size(self):
|
||||
return torch.sum(self.mem_state == 0).item() + len(self.prefetch_buffer)
|
||||
return self.total_size - self.total_alloc + len(self.prefetch_buffer)
|
||||
|
||||
def add_refs(self, token_index: torch.Tensor):
|
||||
self.total_ref_ct += len(token_index)
|
||||
self.mem_state[token_index] += 1
|
||||
self.total_alloc += len(token_index)
|
||||
self.mem_state[token_index] ^= True
|
||||
|
||||
def dec_refs(self, token_index: torch.Tensor):
|
||||
self.total_ref_ct -= len(token_index)
|
||||
self.mem_state[token_index] -= 1
|
||||
|
||||
num_freed = torch.sum(self.mem_state[token_index] == 0)
|
||||
|
||||
return num_freed
|
||||
self.total_alloc -= len(token_index)
|
||||
self.mem_state[token_index] ^= True
|
||||
|
||||
def clear(self):
|
||||
self.mem_state.fill_(0)
|
||||
self.total_ref_ct = 0
|
||||
self.total_alloc = 0
|
||||
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.add_refs(torch.tensor([0], dtype=torch.int32))
|
||||
self.mem_state[0] = True
|
||||
|
||||
@@ -5,12 +5,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -31,7 +28,6 @@ from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
|
||||
class MiniCPMMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -67,7 +63,6 @@ class MiniCPMMLP(nn.Module):
|
||||
|
||||
|
||||
class MiniCPMAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -152,7 +147,6 @@ class MiniCPMAttention(nn.Module):
|
||||
|
||||
|
||||
class MiniCPMDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
@@ -217,7 +211,6 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
class MiniCPMModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
@@ -274,7 +267,7 @@ class MiniCPMForCausalLM(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
|
||||
self.num_experts = getattr(self.config, "num_experts", 0)
|
||||
self.quant_config = quant_config
|
||||
self.model = MiniCPMModel(config, quant_config=quant_config)
|
||||
|
||||
@@ -8,24 +8,28 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
@@ -34,8 +38,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
class Qwen2MoeMLP(nn.Module):
|
||||
|
||||
class Qwen2MoeMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -46,17 +50,20 @@ class Qwen2MoeMLP(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results)
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
@@ -67,7 +74,6 @@ class Qwen2MoeMLP(nn.Module):
|
||||
|
||||
|
||||
class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@@ -79,20 +85,22 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}.")
|
||||
f"the number of experts {config.num_experts}."
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config)
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=None)
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size, config.num_experts, bias=False, quant_config=None
|
||||
)
|
||||
if config.shared_expert_intermediate_size > 0:
|
||||
self.shared_expert = Qwen2MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -103,9 +111,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
)
|
||||
else:
|
||||
self.shared_expert = None
|
||||
self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
|
||||
1,
|
||||
bias=False)
|
||||
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
@@ -114,24 +120,24 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
if self.shared_expert is not None:
|
||||
shared_output = self.shared_expert(hidden_states)
|
||||
if self.shared_expert_gate is not None:
|
||||
shared_output = F.sigmoid(
|
||||
self.shared_expert_gate(hidden_states)) * shared_output
|
||||
shared_output = (
|
||||
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
|
||||
)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
class Qwen2MoeAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -190,17 +196,19 @@ class Qwen2MoeAttention(nn.Module):
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = RadixAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
@@ -211,7 +219,6 @@ class Qwen2MoeAttention(nn.Module):
|
||||
|
||||
|
||||
class Qwen2MoeDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@@ -223,8 +230,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.self_attn = Qwen2MoeAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@@ -239,13 +245,13 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
|
||||
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
||||
# `mlp_only_layers` in the config.
|
||||
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
||||
config.mlp_only_layers)
|
||||
mlp_only_layers = (
|
||||
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
|
||||
)
|
||||
if (layer_id not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and
|
||||
(layer_id + 1) % config.decoder_sparse_step == 0):
|
||||
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
|
||||
quant_config=quant_config)
|
||||
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
|
||||
):
|
||||
self.mlp = Qwen2MoeSparseMoeBlock(config=config, quant_config=quant_config)
|
||||
else:
|
||||
self.mlp = Qwen2MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -253,10 +259,10 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -270,23 +276,20 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
input_metadata=input_metadata
|
||||
input_metadata=input_metadata,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class Qwen2MoeModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@@ -301,13 +304,14 @@ class Qwen2MoeModel(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen2MoeDecoderLayer(config,
|
||||
layer_id,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
])
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Qwen2MoeDecoderLayer(
|
||||
config, layer_id, cache_config, quant_config=quant_config
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
@@ -315,7 +319,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
input_embeds: torch.Tensor = None
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@@ -324,10 +328,9 @@ class Qwen2MoeModel(nn.Module):
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions,
|
||||
hidden_states,
|
||||
input_metadata,
|
||||
residual)
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, input_metadata, residual
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
@@ -346,9 +349,9 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2MoeModel(config, cache_config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@@ -357,17 +360,22 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
input_embeds: torch.Tensor = None
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, input_metadata,
|
||||
input_embeds)
|
||||
return self.logits_processor(input_ids, hidden_states, self.lm_head.weight,
|
||||
input_metadata)
|
||||
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
|
||||
def compute_logits(self, input_ids: torch.Tensor, hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(input_ids, hidden_states, self.lm_head.weight,
|
||||
input_metadata)
|
||||
def compute_logits(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
logits = self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||
)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
@@ -391,11 +399,18 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
expert_params_mapping = [
|
||||
# These are the weights for the experts
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"]
|
||||
else "experts.w2_weight",
|
||||
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
|
||||
for expert_id in range(self.config.num_experts) for shard_id,
|
||||
weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
|
||||
(
|
||||
"experts.w13_weight"
|
||||
if weight_name in ["gate_proj", "up_proj"]
|
||||
else "experts.w2_weight",
|
||||
f"experts.{expert_id}.{weight_name}.weight",
|
||||
expert_id,
|
||||
shard_id,
|
||||
)
|
||||
for expert_id in range(self.config.num_experts)
|
||||
for shard_id, weight_name in enumerate(
|
||||
["gate_proj", "down_proj", "up_proj"]
|
||||
)
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
@@ -433,11 +448,13 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
@@ -447,8 +464,10 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = Qwen2MoeForCausalLM
|
||||
|
||||
@@ -474,9 +474,9 @@ def monkey_patch_vllm_dummy_weight_loader():
|
||||
DummyModelLoader,
|
||||
LoRAConfig,
|
||||
ModelConfig,
|
||||
MultiModalConfig,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
MultiModalConfig,
|
||||
_initialize_model,
|
||||
initialize_dummy_weights,
|
||||
nn,
|
||||
|
||||
Reference in New Issue
Block a user