Optimize mem indices mangement (#619)
This commit is contained in:
@@ -17,7 +17,8 @@ def run_one_batch_size(bs):
|
|||||||
|
|
||||||
if args.input_len:
|
if args.input_len:
|
||||||
input_ids = [
|
input_ids = [
|
||||||
[int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))] for _ in range(bs)
|
[int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))]
|
||||||
|
for _ in range(bs)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
text = [f"{i, }" for i in range(bs)]
|
text = [f"{i, }" for i in range(bs)]
|
||||||
@@ -116,9 +117,11 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--port", type=int, default=None)
|
parser.add_argument("--port", type=int, default=None)
|
||||||
parser.add_argument("--backend", type=str, default="srt")
|
parser.add_argument("--backend", type=str, default="srt")
|
||||||
parser.add_argument("--input-len", type=int, default=None)
|
parser.add_argument("--input-len", type=int, default=None)
|
||||||
parser.add_argument("--batch-size", type=int, nargs='*', default=[1])
|
parser.add_argument("--batch-size", type=int, nargs="*", default=[1])
|
||||||
parser.add_argument("--max-tokens", type=int, default=256)
|
parser.add_argument("--max-tokens", type=int, default=256)
|
||||||
parser.add_argument("--vllm-model-name", type=str, default="meta-llama/Meta-Llama-3-70B")
|
parser.add_argument(
|
||||||
|
"--vllm-model-name", type=str, default="meta-llama/Meta-Llama-3-70B"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.port is None:
|
if args.port is None:
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from sglang.utils import http_request
|
|||||||
|
|
||||||
|
|
||||||
class RuntimeEndpoint(BaseBackend):
|
class RuntimeEndpoint(BaseBackend):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_url: str,
|
base_url: str,
|
||||||
@@ -38,7 +37,8 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
self.model_info = res.json()
|
self.model_info = res.json()
|
||||||
|
|
||||||
self.chat_template = get_chat_template_by_model_path(
|
self.chat_template = get_chat_template_by_model_path(
|
||||||
self.model_info["model_path"])
|
self.model_info["model_path"]
|
||||||
|
)
|
||||||
|
|
||||||
def get_model_name(self):
|
def get_model_name(self):
|
||||||
return self.model_info["model_path"]
|
return self.model_info["model_path"]
|
||||||
@@ -124,7 +124,12 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||||
|
|
||||||
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
|
for item in [
|
||||||
|
"return_logprob",
|
||||||
|
"logprob_start_len",
|
||||||
|
"top_logprobs_num",
|
||||||
|
"return_text_in_logprobs",
|
||||||
|
]:
|
||||||
value = getattr(sampling_params, item, None)
|
value = getattr(sampling_params, item, None)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
data[item] = value
|
data[item] = value
|
||||||
@@ -171,7 +176,12 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||||
|
|
||||||
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
|
for item in [
|
||||||
|
"return_logprob",
|
||||||
|
"logprob_start_len",
|
||||||
|
"top_logprobs_num",
|
||||||
|
"return_text_in_logprobs",
|
||||||
|
]:
|
||||||
value = getattr(sampling_params, item, None)
|
value = getattr(sampling_params, item, None)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
data[item] = value
|
data[item] = value
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ import logging
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|||||||
@@ -44,4 +44,5 @@ class GlobalConfig:
|
|||||||
# adjust_cache: Adjust the position embedding of KV cache.
|
# adjust_cache: Adjust the position embedding of KV cache.
|
||||||
self.concate_and_append_mode = "no_adjust"
|
self.concate_and_append_mode = "no_adjust"
|
||||||
|
|
||||||
|
|
||||||
global_config = GlobalConfig()
|
global_config = GlobalConfig()
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ register_chat_template(
|
|||||||
"system": ("SYSTEM:", "\n"),
|
"system": ("SYSTEM:", "\n"),
|
||||||
"user": ("USER:", "\n"),
|
"user": ("USER:", "\n"),
|
||||||
"assistant": ("ASSISTANT:", "\n"),
|
"assistant": ("ASSISTANT:", "\n"),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -177,7 +177,7 @@ register_chat_template(
|
|||||||
"assistant": ("", "<|im_end|>\n"),
|
"assistant": ("", "<|im_end|>\n"),
|
||||||
},
|
},
|
||||||
style=ChatTemplateStyle.PLAIN,
|
style=ChatTemplateStyle.PLAIN,
|
||||||
stop_str=("<|im_end|>",)
|
stop_str=("<|im_end|>",),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ class SglSamplingParams:
|
|||||||
presence_penalty: float = 0.0
|
presence_penalty: float = 0.0
|
||||||
ignore_eos: bool = False
|
ignore_eos: bool = False
|
||||||
return_logprob: Optional[bool] = None
|
return_logprob: Optional[bool] = None
|
||||||
logprob_start_len: Optional[int] = None,
|
logprob_start_len: Optional[int] = (None,)
|
||||||
top_logprobs_num: Optional[int] = None,
|
top_logprobs_num: Optional[int] = (None,)
|
||||||
return_text_in_logprobs: Optional[bool] = None,
|
return_text_in_logprobs: Optional[bool] = (None,)
|
||||||
|
|
||||||
# for constrained generation, not included in to_xxx_kwargs
|
# for constrained generation, not included in to_xxx_kwargs
|
||||||
dtype: Optional[str] = None
|
dtype: Optional[str] = None
|
||||||
|
|||||||
@@ -8,7 +8,10 @@ from vllm.distributed.parallel_state import graph_capture
|
|||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
||||||
from sglang.srt.managers.controller.infer_batch import (
|
from sglang.srt.managers.controller.infer_batch import (
|
||||||
Batch, ForwardMode, InputMetadata, init_flashinfer_args
|
Batch,
|
||||||
|
ForwardMode,
|
||||||
|
InputMetadata,
|
||||||
|
init_flashinfer_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -24,18 +27,28 @@ class CudaGraphRunner:
|
|||||||
# Common inputs
|
# Common inputs
|
||||||
self.max_bs = max_batch_size_to_capture
|
self.max_bs = max_batch_size_to_capture
|
||||||
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||||
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
self.req_pool_indices = torch.zeros(
|
||||||
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
||||||
self.position_ids_offsets = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
self.position_ids_offsets = torch.zeros(
|
||||||
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
self.out_cache_loc = torch.zeros(
|
||||||
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||||
|
)
|
||||||
|
|
||||||
# FlashInfer inputs
|
# FlashInfer inputs
|
||||||
self.flashinfer_workspace_buffer = self.model_runner.flashinfer_workspace_buffers[0]
|
self.flashinfer_workspace_buffer = (
|
||||||
|
self.model_runner.flashinfer_workspace_buffers[0]
|
||||||
|
)
|
||||||
self.flashinfer_kv_indptr = torch.zeros(
|
self.flashinfer_kv_indptr = torch.zeros(
|
||||||
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
self.flashinfer_kv_indices = torch.zeros(
|
self.flashinfer_kv_indices = torch.zeros(
|
||||||
(self.max_bs * model_runner.model_config.context_len,), dtype=torch.int32, device="cuda"
|
(self.max_bs * model_runner.model_config.context_len,),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
)
|
)
|
||||||
self.flashinfer_kv_last_page_len = torch.ones(
|
self.flashinfer_kv_last_page_len = torch.ones(
|
||||||
(self.max_bs,), dtype=torch.int32, device="cuda"
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
||||||
@@ -49,7 +62,12 @@ class CudaGraphRunner:
|
|||||||
with graph_capture() as graph_capture_context:
|
with graph_capture() as graph_capture_context:
|
||||||
self.stream = graph_capture_context.stream
|
self.stream = graph_capture_context.stream
|
||||||
for bs in batch_size_list:
|
for bs in batch_size_list:
|
||||||
graph, input_buffers, output_buffers, flashinfer_handler = self.capture_one_batch_size(bs)
|
(
|
||||||
|
graph,
|
||||||
|
input_buffers,
|
||||||
|
output_buffers,
|
||||||
|
flashinfer_handler,
|
||||||
|
) = self.capture_one_batch_size(bs)
|
||||||
self.graphs[bs] = graph
|
self.graphs[bs] = graph
|
||||||
self.input_buffers[bs] = input_buffers
|
self.input_buffers[bs] = input_buffers
|
||||||
self.output_buffers[bs] = output_buffers
|
self.output_buffers[bs] = output_buffers
|
||||||
@@ -71,17 +89,19 @@ class CudaGraphRunner:
|
|||||||
|
|
||||||
# FlashInfer inputs
|
# FlashInfer inputs
|
||||||
if not _grouped_size_compiled_for_decode_kernels(
|
if not _grouped_size_compiled_for_decode_kernels(
|
||||||
self.model_runner.model_config.num_attention_heads // self.model_runner.tp_size,
|
self.model_runner.model_config.num_attention_heads
|
||||||
|
// self.model_runner.tp_size,
|
||||||
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
|
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
|
||||||
):
|
):
|
||||||
use_tensor_cores = True
|
use_tensor_cores = True
|
||||||
else:
|
else:
|
||||||
use_tensor_cores = False
|
use_tensor_cores = False
|
||||||
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||||
self.flashinfer_workspace_buffer, "NHD",
|
self.flashinfer_workspace_buffer,
|
||||||
|
"NHD",
|
||||||
use_cuda_graph=True,
|
use_cuda_graph=True,
|
||||||
use_tensor_cores=use_tensor_cores,
|
use_tensor_cores=use_tensor_cores,
|
||||||
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[:bs+1],
|
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
||||||
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
||||||
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
||||||
)
|
)
|
||||||
@@ -163,10 +183,14 @@ class CudaGraphRunner:
|
|||||||
else:
|
else:
|
||||||
output = LogitProcessorOutput(
|
output = LogitProcessorOutput(
|
||||||
next_token_logits=output.next_token_logits[:raw_bs],
|
next_token_logits=output.next_token_logits[:raw_bs],
|
||||||
next_token_logprobs=output.next_token_logprobs[:raw_bs] if output.next_token_logprobs is not None else None,
|
next_token_logprobs=output.next_token_logprobs[:raw_bs]
|
||||||
|
if output.next_token_logprobs is not None
|
||||||
|
else None,
|
||||||
normalized_prompt_logprobs=None,
|
normalized_prompt_logprobs=None,
|
||||||
prefill_token_logprobs=None,
|
prefill_token_logprobs=None,
|
||||||
prefill_top_logprobs=None,
|
prefill_top_logprobs=None,
|
||||||
decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None,
|
decode_top_logprobs=output.decode_top_logprobs[:raw_bs]
|
||||||
|
if output.decode_top_logprobs is not None
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -668,7 +668,9 @@ class Batch:
|
|||||||
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
warnings.warn(f"Ignore errors in sampling: {e}")
|
warnings.warn(f"Ignore errors in sampling: {e}")
|
||||||
sampled_index = torch.ones(probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device)
|
sampled_index = torch.ones(
|
||||||
|
probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
|
||||||
|
)
|
||||||
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
||||||
-1
|
-1
|
||||||
)
|
)
|
||||||
@@ -749,8 +751,14 @@ class InputMetadata:
|
|||||||
skip_flashinfer_init=False,
|
skip_flashinfer_init=False,
|
||||||
):
|
):
|
||||||
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
||||||
init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
|
init_flashinfer_args(
|
||||||
model_runner.flashinfer_decode_wrapper)
|
forward_mode,
|
||||||
|
model_runner,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
prefix_lens,
|
||||||
|
model_runner.flashinfer_decode_wrapper,
|
||||||
|
)
|
||||||
|
|
||||||
batch_size = len(req_pool_indices)
|
batch_size = len(req_pool_indices)
|
||||||
|
|
||||||
@@ -807,16 +815,24 @@ class InputMetadata:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_runner.server_args.disable_flashinfer:
|
if model_runner.server_args.disable_flashinfer:
|
||||||
(ret.triton_max_seq_len,
|
(
|
||||||
|
ret.triton_max_seq_len,
|
||||||
ret.triton_max_extend_len,
|
ret.triton_max_extend_len,
|
||||||
ret.triton_start_loc,
|
ret.triton_start_loc,
|
||||||
ret.triton_prefix_lens) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
ret.triton_prefix_lens,
|
||||||
|
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
|
def init_flashinfer_args(
|
||||||
flashinfer_decode_wrapper):
|
forward_mode,
|
||||||
|
model_runner,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
prefix_lens,
|
||||||
|
flashinfer_decode_wrapper,
|
||||||
|
):
|
||||||
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||||
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
||||||
head_dim = model_runner.model_config.head_dim
|
head_dim = model_runner.model_config.head_dim
|
||||||
@@ -827,9 +843,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
|
|||||||
else:
|
else:
|
||||||
paged_kernel_lens = prefix_lens
|
paged_kernel_lens = prefix_lens
|
||||||
|
|
||||||
kv_indptr = torch.zeros(
|
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||||
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
||||||
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
||||||
@@ -842,9 +856,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
|
|||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
).contiguous()
|
).contiguous()
|
||||||
kv_last_page_len = torch.ones(
|
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||||
(batch_size,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
if forward_mode == ForwardMode.DECODE:
|
if forward_mode == ForwardMode.DECODE:
|
||||||
flashinfer_decode_wrapper.end_forward()
|
flashinfer_decode_wrapper.end_forward()
|
||||||
@@ -859,9 +871,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# extend part
|
# extend part
|
||||||
qo_indptr = torch.zeros(
|
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
||||||
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
|
||||||
)
|
|
||||||
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||||
|
|
||||||
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||||
|
|||||||
@@ -16,7 +16,12 @@ from vllm.model_executor.model_loader import get_model
|
|||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict
|
from sglang.srt.managers.controller.infer_batch import (
|
||||||
|
Batch,
|
||||||
|
ForwardMode,
|
||||||
|
InputMetadata,
|
||||||
|
global_server_args_dict,
|
||||||
|
)
|
||||||
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
@@ -83,7 +88,9 @@ class ModelRunner:
|
|||||||
|
|
||||||
# Set some global args
|
# Set some global args
|
||||||
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
|
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
|
||||||
global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32
|
global_server_args_dict[
|
||||||
|
"attention_reduce_in_fp32"
|
||||||
|
] = server_args.attention_reduce_in_fp32
|
||||||
|
|
||||||
# Load the model and create memory pool
|
# Load the model and create memory pool
|
||||||
self.load_model()
|
self.load_model()
|
||||||
@@ -217,7 +224,9 @@ class ModelRunner:
|
|||||||
self.flashinfer_workspace_buffers[1], "NHD"
|
self.flashinfer_workspace_buffers[1], "NHD"
|
||||||
)
|
)
|
||||||
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||||
self.flashinfer_workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
|
self.flashinfer_workspace_buffers[0],
|
||||||
|
"NHD",
|
||||||
|
use_tensor_cores=use_tensor_cores,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_cuda_graphs(self):
|
def init_cuda_graphs(self):
|
||||||
@@ -229,7 +238,9 @@ class ModelRunner:
|
|||||||
|
|
||||||
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
|
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
|
||||||
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
|
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
|
||||||
self.cuda_graph_runner = CudaGraphRunner(self, max_batch_size_to_capture=max(batch_size_list))
|
self.cuda_graph_runner = CudaGraphRunner(
|
||||||
|
self, max_batch_size_to_capture=max(batch_size_list)
|
||||||
|
)
|
||||||
self.cuda_graph_runner.capture(batch_size_list)
|
self.cuda_graph_runner.capture(batch_size_list)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
@@ -125,7 +125,8 @@ class RadixCache:
|
|||||||
if x.lock_ref > 0:
|
if x.lock_ref > 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
num_evicted += evict_callback(x.value)
|
evict_callback(x.value)
|
||||||
|
num_evicted += len(x.value)
|
||||||
self._delete_leaf(x)
|
self._delete_leaf(x)
|
||||||
|
|
||||||
if len(x.parent.children) == 0:
|
if len(x.parent.children) == 0:
|
||||||
|
|||||||
@@ -314,7 +314,9 @@ class ModelTpServer:
|
|||||||
self.forward_queue.append(req)
|
self.forward_queue.append(req)
|
||||||
|
|
||||||
def get_new_fill_batch(self) -> Optional[Batch]:
|
def get_new_fill_batch(self) -> Optional[Batch]:
|
||||||
running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0
|
running_bs = (
|
||||||
|
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||||
|
)
|
||||||
if running_bs >= self.max_running_requests:
|
if running_bs >= self.max_running_requests:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -39,10 +39,12 @@ class ReqToTokenPool:
|
|||||||
class TokenToKVPool:
|
class TokenToKVPool:
|
||||||
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
||||||
self.size = size
|
self.size = size
|
||||||
# mem_state is the reference counter.
|
# This can be promised:
|
||||||
|
# assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0)
|
||||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||||
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.int16, device="cuda")
|
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.bool, device="cuda")
|
||||||
self.total_ref_ct = 0
|
self.total_size = self.size
|
||||||
|
self.total_alloc = 0
|
||||||
|
|
||||||
# [size, key/value, head_num, head_dim] for each layer
|
# [size, key/value, head_num, head_dim] for each layer
|
||||||
self.kv_data = [
|
self.kv_data = [
|
||||||
@@ -71,7 +73,9 @@ class TokenToKVPool:
|
|||||||
|
|
||||||
addition_size = need_size - buffer_len
|
addition_size = need_size - buffer_len
|
||||||
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
||||||
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
|
select_index = (
|
||||||
|
torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
|
||||||
|
)
|
||||||
|
|
||||||
if select_index.shape[0] < addition_size:
|
if select_index.shape[0] < addition_size:
|
||||||
return None
|
return None
|
||||||
@@ -105,26 +109,22 @@ class TokenToKVPool:
|
|||||||
return select_index.to(torch.int32), start_loc, start_loc + need_size
|
return select_index.to(torch.int32), start_loc, start_loc + need_size
|
||||||
|
|
||||||
def used_size(self):
|
def used_size(self):
|
||||||
return len(torch.nonzero(self.mem_state).squeeze(1))
|
return self.total_alloc
|
||||||
|
|
||||||
def available_size(self):
|
def available_size(self):
|
||||||
return torch.sum(self.mem_state == 0).item() + len(self.prefetch_buffer)
|
return self.total_size - self.total_alloc + len(self.prefetch_buffer)
|
||||||
|
|
||||||
def add_refs(self, token_index: torch.Tensor):
|
def add_refs(self, token_index: torch.Tensor):
|
||||||
self.total_ref_ct += len(token_index)
|
self.total_alloc += len(token_index)
|
||||||
self.mem_state[token_index] += 1
|
self.mem_state[token_index] ^= True
|
||||||
|
|
||||||
def dec_refs(self, token_index: torch.Tensor):
|
def dec_refs(self, token_index: torch.Tensor):
|
||||||
self.total_ref_ct -= len(token_index)
|
self.total_alloc -= len(token_index)
|
||||||
self.mem_state[token_index] -= 1
|
self.mem_state[token_index] ^= True
|
||||||
|
|
||||||
num_freed = torch.sum(self.mem_state[token_index] == 0)
|
|
||||||
|
|
||||||
return num_freed
|
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.mem_state.fill_(0)
|
self.mem_state.fill_(0)
|
||||||
self.total_ref_ct = 0
|
self.total_alloc = 0
|
||||||
|
|
||||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||||
self.add_refs(torch.tensor([0], dtype=torch.int32))
|
self.mem_state[0] = True
|
||||||
|
|||||||
@@ -5,12 +5,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
@@ -31,7 +28,6 @@ from sglang.srt.managers.controller.model_runner import InputMetadata
|
|||||||
|
|
||||||
|
|
||||||
class MiniCPMMLP(nn.Module):
|
class MiniCPMMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
@@ -67,7 +63,6 @@ class MiniCPMMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MiniCPMAttention(nn.Module):
|
class MiniCPMAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
@@ -152,7 +147,6 @@ class MiniCPMAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MiniCPMDecoderLayer(nn.Module):
|
class MiniCPMDecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
@@ -217,7 +211,6 @@ class MiniCPMDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MiniCPMModel(nn.Module):
|
class MiniCPMModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
|
|||||||
@@ -8,24 +8,28 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
from vllm.distributed import (
|
||||||
tensor_model_parallel_all_reduce)
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_reduce,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (
|
||||||
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
QuantizationConfig)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead,
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
@@ -34,8 +38,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||||
|
|
||||||
class Qwen2MoeMLP(nn.Module):
|
|
||||||
|
|
||||||
|
class Qwen2MoeMLP(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
@@ -46,17 +50,20 @@ class Qwen2MoeMLP(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
hidden_size, [intermediate_size] * 2,
|
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||||
bias=False,
|
)
|
||||||
quant_config=quant_config)
|
self.down_proj = RowParallelLinear(
|
||||||
self.down_proj = RowParallelLinear(intermediate_size,
|
intermediate_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
reduce_results=reduce_results)
|
reduce_results=reduce_results,
|
||||||
|
)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
raise ValueError(
|
||||||
"Only silu is supported for now.")
|
f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu is supported for now."
|
||||||
|
)
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@@ -67,7 +74,6 @@ class Qwen2MoeMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Qwen2MoeSparseMoeBlock(nn.Module):
|
class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
@@ -79,20 +85,22 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
if self.tp_size > config.num_experts:
|
if self.tp_size > config.num_experts:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Tensor parallel size {self.tp_size} is greater than "
|
f"Tensor parallel size {self.tp_size} is greater than "
|
||||||
f"the number of experts {config.num_experts}.")
|
f"the number of experts {config.num_experts}."
|
||||||
|
)
|
||||||
|
|
||||||
self.experts = FusedMoE(num_experts=config.num_experts,
|
self.experts = FusedMoE(
|
||||||
|
num_experts=config.num_experts,
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
intermediate_size=config.moe_intermediate_size,
|
intermediate_size=config.moe_intermediate_size,
|
||||||
reduce_results=False,
|
reduce_results=False,
|
||||||
renormalize=config.norm_topk_prob,
|
renormalize=config.norm_topk_prob,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
self.gate = ReplicatedLinear(config.hidden_size,
|
self.gate = ReplicatedLinear(
|
||||||
config.num_experts,
|
config.hidden_size, config.num_experts, bias=False, quant_config=None
|
||||||
bias=False,
|
)
|
||||||
quant_config=None)
|
|
||||||
if config.shared_expert_intermediate_size > 0:
|
if config.shared_expert_intermediate_size > 0:
|
||||||
self.shared_expert = Qwen2MoeMLP(
|
self.shared_expert = Qwen2MoeMLP(
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
@@ -103,9 +111,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.shared_expert = None
|
self.shared_expert = None
|
||||||
self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
|
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
||||||
1,
|
|
||||||
bias=False)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
@@ -114,24 +120,24 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|||||||
if self.shared_expert is not None:
|
if self.shared_expert is not None:
|
||||||
shared_output = self.shared_expert(hidden_states)
|
shared_output = self.shared_expert(hidden_states)
|
||||||
if self.shared_expert_gate is not None:
|
if self.shared_expert_gate is not None:
|
||||||
shared_output = F.sigmoid(
|
shared_output = (
|
||||||
self.shared_expert_gate(hidden_states)) * shared_output
|
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
|
||||||
|
)
|
||||||
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
final_hidden_states = self.experts(
|
||||||
router_logits=router_logits)
|
hidden_states=hidden_states, router_logits=router_logits
|
||||||
|
)
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
final_hidden_states)
|
|
||||||
|
|
||||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeAttention(nn.Module):
|
class Qwen2MoeAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
@@ -190,17 +196,19 @@ class Qwen2MoeAttention(nn.Module):
|
|||||||
base=rope_theta,
|
base=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
)
|
)
|
||||||
self.attn = RadixAttention(self.num_heads,
|
self.attn = RadixAttention(
|
||||||
|
self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
layer_id=layer_id)
|
layer_id=layer_id,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
input_metadata: InputMetadata
|
input_metadata: InputMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
@@ -211,7 +219,6 @@ class Qwen2MoeAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Qwen2MoeDecoderLayer(nn.Module):
|
class Qwen2MoeDecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
@@ -223,8 +230,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||||
8192)
|
|
||||||
self.self_attn = Qwen2MoeAttention(
|
self.self_attn = Qwen2MoeAttention(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
@@ -239,13 +245,13 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
||||||
# `mlp_only_layers` in the config.
|
# `mlp_only_layers` in the config.
|
||||||
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
|
mlp_only_layers = (
|
||||||
config.mlp_only_layers)
|
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
|
||||||
|
)
|
||||||
if (layer_id not in mlp_only_layers) and (
|
if (layer_id not in mlp_only_layers) and (
|
||||||
config.num_experts > 0 and
|
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
|
||||||
(layer_id + 1) % config.decoder_sparse_step == 0):
|
):
|
||||||
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
|
self.mlp = Qwen2MoeSparseMoeBlock(config=config, quant_config=quant_config)
|
||||||
quant_config=quant_config)
|
|
||||||
else:
|
else:
|
||||||
self.mlp = Qwen2MoeMLP(
|
self.mlp = Qwen2MoeMLP(
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
@@ -253,10 +259,10 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
eps=config.rms_norm_eps)
|
self.post_attention_layernorm = RMSNorm(
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
config.hidden_size, eps=config.rms_norm_eps
|
||||||
eps=config.rms_norm_eps)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -270,23 +276,20 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
else:
|
else:
|
||||||
hidden_states, residual = self.input_layernorm(
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
hidden_states, residual)
|
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
input_metadata=input_metadata
|
input_metadata=input_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.post_attention_layernorm(
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
hidden_states, residual)
|
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeModel(nn.Module):
|
class Qwen2MoeModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
@@ -301,13 +304,14 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList(
|
||||||
Qwen2MoeDecoderLayer(config,
|
[
|
||||||
layer_id,
|
Qwen2MoeDecoderLayer(
|
||||||
cache_config,
|
config, layer_id, cache_config, quant_config=quant_config
|
||||||
quant_config=quant_config)
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
])
|
]
|
||||||
|
)
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -315,7 +319,7 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
input_embeds: torch.Tensor = None
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if input_embeds is None:
|
if input_embeds is None:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
@@ -324,10 +328,9 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
residual = None
|
residual = None
|
||||||
for i in range(len(self.layers)):
|
for i in range(len(self.layers)):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(positions,
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
positions, hidden_states, input_metadata, residual
|
||||||
input_metadata,
|
)
|
||||||
residual)
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -346,9 +349,9 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.model = Qwen2MoeModel(config, cache_config, quant_config)
|
self.model = Qwen2MoeModel(config, cache_config, quant_config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
self.lm_head = ParallelLMHead(
|
||||||
config.hidden_size,
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
||||||
quant_config=quant_config)
|
)
|
||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
|
|
||||||
@@ -357,17 +360,22 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
input_metadata: InputMetadata,
|
input_metadata: InputMetadata,
|
||||||
input_embeds: torch.Tensor = None
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, input_metadata,
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
||||||
input_embeds)
|
return self.logits_processor(
|
||||||
return self.logits_processor(input_ids, hidden_states, self.lm_head.weight,
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
input_metadata)
|
)
|
||||||
|
|
||||||
def compute_logits(self, input_ids: torch.Tensor, hidden_states: torch.Tensor,
|
def compute_logits(
|
||||||
input_metadata: InputMetadata) -> torch.Tensor:
|
self,
|
||||||
logits = self.logits_processor(input_ids, hidden_states, self.lm_head.weight,
|
input_ids: torch.Tensor,
|
||||||
input_metadata)
|
hidden_states: torch.Tensor,
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(
|
||||||
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
||||||
|
)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
@@ -391,11 +399,18 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
expert_params_mapping = [
|
expert_params_mapping = [
|
||||||
# These are the weights for the experts
|
# These are the weights for the experts
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"]
|
(
|
||||||
|
"experts.w13_weight"
|
||||||
|
if weight_name in ["gate_proj", "up_proj"]
|
||||||
else "experts.w2_weight",
|
else "experts.w2_weight",
|
||||||
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
|
f"experts.{expert_id}.{weight_name}.weight",
|
||||||
for expert_id in range(self.config.num_experts) for shard_id,
|
expert_id,
|
||||||
weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
|
shard_id,
|
||||||
|
)
|
||||||
|
for expert_id in range(self.config.num_experts)
|
||||||
|
for shard_id, weight_name in enumerate(
|
||||||
|
["gate_proj", "down_proj", "up_proj"]
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
@@ -433,11 +448,13 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param,
|
weight_loader(
|
||||||
|
param,
|
||||||
loaded_weight,
|
loaded_weight,
|
||||||
weight_name,
|
weight_name,
|
||||||
shard_id=shard_id,
|
shard_id=shard_id,
|
||||||
expert_id=expert_id)
|
expert_id=expert_id,
|
||||||
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
@@ -447,8 +464,10 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(
|
||||||
default_weight_loader)
|
param, "weight_loader", default_weight_loader
|
||||||
|
)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
EntryClass = Qwen2MoeForCausalLM
|
EntryClass = Qwen2MoeForCausalLM
|
||||||
|
|||||||
@@ -474,9 +474,9 @@ def monkey_patch_vllm_dummy_weight_loader():
|
|||||||
DummyModelLoader,
|
DummyModelLoader,
|
||||||
LoRAConfig,
|
LoRAConfig,
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
|
MultiModalConfig,
|
||||||
ParallelConfig,
|
ParallelConfig,
|
||||||
SchedulerConfig,
|
SchedulerConfig,
|
||||||
MultiModalConfig,
|
|
||||||
_initialize_model,
|
_initialize_model,
|
||||||
initialize_dummy_weights,
|
initialize_dummy_weights,
|
||||||
nn,
|
nn,
|
||||||
|
|||||||
Reference in New Issue
Block a user