Format (#593)
This commit is contained in:
@@ -108,7 +108,7 @@ def prepare_inputs(bench_args, tokenizer):
|
||||
for i in range(len(prompts)):
|
||||
assert len(input_ids[i]) > bench_args.cut_len
|
||||
|
||||
tmp_input_ids = input_ids[i][:bench_args.cut_len]
|
||||
tmp_input_ids = input_ids[i][: bench_args.cut_len]
|
||||
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
|
||||
req.prefix_indices = []
|
||||
req.sampling_params = sampling_params
|
||||
@@ -121,9 +121,9 @@ def prepare_inputs(bench_args, tokenizer):
|
||||
def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
|
||||
for i in range(len(reqs)):
|
||||
req = reqs[i]
|
||||
req.input_ids += input_ids[i][bench_args.cut_len:]
|
||||
req.input_ids += input_ids[i][bench_args.cut_len :]
|
||||
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
||||
i, :bench_args.cut_len
|
||||
i, : bench_args.cut_len
|
||||
]
|
||||
return reqs
|
||||
|
||||
@@ -151,7 +151,8 @@ def extend(reqs, model_runner):
|
||||
reqs=reqs,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
tree_cache=None)
|
||||
tree_cache=None,
|
||||
)
|
||||
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
|
||||
output = model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||
@@ -212,7 +213,9 @@ def latency_test(
|
||||
|
||||
# Load the model
|
||||
model_runner, tokenizer = load_model(server_args, tp_rank)
|
||||
print(f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}")
|
||||
print(
|
||||
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
|
||||
)
|
||||
|
||||
# Prepare inputs
|
||||
reqs = prepare_synthetic_inputs(bench_args, tokenizer)
|
||||
@@ -232,7 +235,9 @@ def latency_test(
|
||||
prefill_latency = time.time() - tic
|
||||
tot_latency += prefill_latency
|
||||
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
|
||||
rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s")
|
||||
rank_print(
|
||||
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
|
||||
# Decode
|
||||
for i in range(output_len):
|
||||
@@ -243,13 +248,24 @@ def latency_test(
|
||||
latency = time.time() - tic
|
||||
tot_latency += latency
|
||||
throughput = bench_args.batch_size / latency
|
||||
if i < 5: rank_print(f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s")
|
||||
if i < 5:
|
||||
rank_print(
|
||||
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
||||
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
|
||||
rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s")
|
||||
|
||||
throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency
|
||||
rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s")
|
||||
rank_print(
|
||||
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
|
||||
)
|
||||
|
||||
throughput = (
|
||||
(bench_args.input_len + bench_args.output_len)
|
||||
* bench_args.batch_size
|
||||
/ tot_latency
|
||||
)
|
||||
rank_print(
|
||||
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
|
||||
# Warm up
|
||||
run_once(4)
|
||||
@@ -298,4 +314,4 @@ if __name__ == "__main__":
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
main(server_args, bench_args)
|
||||
main(server_args, bench_args)
|
||||
|
||||
@@ -39,4 +39,5 @@ class GlobalConfig:
|
||||
# This can improve the speed for large batch sizes during prefill.
|
||||
self.layer_sync_threshold = 8192
|
||||
|
||||
|
||||
global_config = GlobalConfig()
|
||||
|
||||
@@ -185,8 +185,10 @@ class SglFunction:
|
||||
batch_kwargs = [
|
||||
{self.arg_names[i]: v for i, v in enumerate(arg_values)}
|
||||
for arg_values in batch_kwargs
|
||||
if isinstance(arg_values, (list, tuple)) and
|
||||
len(self.arg_names) - len(self.arg_defaults) <= len(arg_values) <= len(self.arg_names)
|
||||
if isinstance(arg_values, (list, tuple))
|
||||
and len(self.arg_names) - len(self.arg_defaults)
|
||||
<= len(arg_values)
|
||||
<= len(self.arg_names)
|
||||
]
|
||||
# Ensure to raise an exception if the number of arguments mismatch
|
||||
if len(batch_kwargs) != num_programs:
|
||||
|
||||
@@ -5,13 +5,14 @@ from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
from outlines.caching import cache as disk_cache
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
from outlines.caching import disable_cache
|
||||
from outlines.fsm.guide import RegexGuide
|
||||
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
|
||||
from outlines.models.transformers import TransformerTokenizer
|
||||
except ImportError as e:
|
||||
print(f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n')
|
||||
print(
|
||||
f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n'
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
|
||||
@@ -264,7 +264,9 @@ class TiktokenTokenizer:
|
||||
return self.tokenizer.decode_batch(batch)
|
||||
|
||||
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
||||
ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
|
||||
ret = self.chat_template.render(
|
||||
messages=messages, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
return self.encode(ret) if tokenize else ret
|
||||
|
||||
|
||||
@@ -297,5 +299,7 @@ class SentencePieceTokenizer:
|
||||
return self.tokenizer.decode(batch)
|
||||
|
||||
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
||||
ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
|
||||
return self.encode(ret) if tokenize else ret
|
||||
ret = self.chat_template.render(
|
||||
messages=messages, add_generation_prompt=add_generation_prompt
|
||||
)
|
||||
return self.encode(ret) if tokenize else ret
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@@ -108,12 +107,16 @@ def fused_moe_kernel(
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||
offs_k[None, :] * stride_ak)
|
||||
a_ptrs = a_ptr + (
|
||||
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
|
||||
)
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m)
|
||||
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
|
||||
offs_bn[None, :] * stride_bn)
|
||||
b_ptrs = (
|
||||
b_ptr
|
||||
+ off_experts * stride_be
|
||||
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
)
|
||||
|
||||
if use_fp8:
|
||||
a_scale = tl.load(a_scale_ptr)
|
||||
@@ -129,13 +132,12 @@ def fused_moe_kernel(
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the
|
||||
# K dimension.
|
||||
a = tl.load(a_ptrs,
|
||||
mask=token_mask[:, None] &
|
||||
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
if use_fp8:
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
@@ -146,9 +148,7 @@ def fused_moe_kernel(
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
||||
mask=token_mask,
|
||||
other=0)
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
|
||||
if use_fp8:
|
||||
@@ -158,15 +158,14 @@ def fused_moe_kernel(
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
||||
None, :]
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids: torch.Tensor, block_size: int,
|
||||
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns the token distribution across experts to be compatible with block
|
||||
size for matrix multiplication.
|
||||
@@ -205,32 +204,38 @@ def moe_align_block_size(
|
||||
by block_size for proper block matrix operations.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
sorted_ids = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||
expert_ids = torch.empty((max_num_m_blocks, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
num_tokens_post_pad = torch.empty((1),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
||||
expert_ids, num_tokens_post_pad)
|
||||
expert_ids = torch.empty(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||
ops.moe_align_block_size(
|
||||
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
|
||||
)
|
||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool, top_k: int,
|
||||
config: Dict[str, Any], compute_type: tl.dtype,
|
||||
use_fp8: bool) -> None:
|
||||
def invoke_fused_moe_kernel(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
top_k: int,
|
||||
config: Dict[str, Any],
|
||||
compute_type: tl.dtype,
|
||||
use_fp8: bool,
|
||||
) -> None:
|
||||
assert topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
@@ -241,8 +246,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
assert B_scale is not None
|
||||
|
||||
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
||||
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
|
||||
grid = lambda META: (
|
||||
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
fused_moe_kernel[grid](
|
||||
A,
|
||||
@@ -280,8 +287,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_moe_configs(E: int, N: int,
|
||||
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
||||
def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
|
||||
"""
|
||||
Return optimized configurations for the fused MoE kernel.
|
||||
|
||||
@@ -296,11 +302,11 @@ def get_moe_configs(E: int, N: int,
|
||||
json_file_name = get_config_file_name(E, N, dtype)
|
||||
|
||||
config_file_path = os.path.join(
|
||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
|
||||
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
||||
)
|
||||
if os.path.exists(config_file_path):
|
||||
with open(config_file_path) as f:
|
||||
logger.info("Using configuration from %s for MoE layer.",
|
||||
config_file_path)
|
||||
logger.info("Using configuration from %s for MoE layer.", config_file_path)
|
||||
# If a configuration has been found, return it
|
||||
return {int(key): val for key, val in json.load(f).items()}
|
||||
|
||||
@@ -319,35 +325,35 @@ def get_default_config(
|
||||
) -> Dict[str, int]:
|
||||
if dtype == "float8":
|
||||
config = {
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 32,
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
"num_stages": 4,
|
||||
}
|
||||
if M <= E:
|
||||
config = {
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 128,
|
||||
'GROUP_SIZE_M': 1,
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
"num_stages": 4,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32,
|
||||
'GROUP_SIZE_M': 8
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
}
|
||||
if M <= E:
|
||||
config = {
|
||||
'BLOCK_SIZE_M': 16,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 64,
|
||||
'GROUP_SIZE_M': 1
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
}
|
||||
return config
|
||||
|
||||
@@ -358,23 +364,17 @@ def fused_topk(
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], (
|
||||
"Number of tokens mismatch")
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
M, _ = hidden_states.shape
|
||||
|
||||
topk_weights = torch.empty(M,
|
||||
topk,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
token_expert_indicies = torch.empty(M,
|
||||
topk,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
topk_weights = torch.empty(
|
||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||
)
|
||||
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||
token_expert_indicies = torch.empty(
|
||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
ops.topk_softmax(
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
@@ -388,27 +388,27 @@ def fused_topk(
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
override_config: Optional[Dict[str, Any]] = None,
|
||||
use_fp8: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None):
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
override_config: Optional[Dict[str, Any]] = None,
|
||||
use_fp8: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
M, _ = hidden_states.shape
|
||||
E, N, _ = w1.shape
|
||||
@@ -417,8 +417,7 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
config = override_config
|
||||
else:
|
||||
# First try to load optimal config from the file
|
||||
configs = get_moe_configs(E, w2.shape[2],
|
||||
"float8" if use_fp8 else None)
|
||||
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
|
||||
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
@@ -426,65 +425,76 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||
else:
|
||||
# Else use the default config
|
||||
config = get_default_config(M, E, N, w1.shape[2],
|
||||
topk_ids.shape[1],
|
||||
"float8" if use_fp8 else None)
|
||||
config = get_default_config(
|
||||
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
|
||||
)
|
||||
|
||||
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
intermediate_cache1 = torch.empty(
|
||||
(M, topk_ids.shape[1], N),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache2 = torch.empty(
|
||||
(M * topk_ids.shape[1], N // 2),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
intermediate_cache3 = torch.empty(
|
||||
(M, topk_ids.shape[1], w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
|
||||
topk_ids, config['BLOCK_SIZE_M'], E)
|
||||
compute_type = (tl.bfloat16
|
||||
if hidden_states.dtype == torch.bfloat16 else tl.float16)
|
||||
topk_ids, config["BLOCK_SIZE_M"], E
|
||||
)
|
||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||
|
||||
invoke_fused_moe_kernel(hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1_scale,
|
||||
w1_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
topk_ids.shape[1],
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8)
|
||||
invoke_fused_moe_kernel(
|
||||
hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1_scale,
|
||||
w1_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
False,
|
||||
topk_ids.shape[1],
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8,
|
||||
)
|
||||
|
||||
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
||||
|
||||
invoke_fused_moe_kernel(intermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
True,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8)
|
||||
invoke_fused_moe_kernel(
|
||||
intermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
True,
|
||||
1,
|
||||
config,
|
||||
compute_type=compute_type,
|
||||
use_fp8=use_fp8,
|
||||
)
|
||||
|
||||
if inplace:
|
||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=hidden_states)
|
||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1)
|
||||
return torch.sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
dim=1,
|
||||
out=hidden_states,
|
||||
)
|
||||
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
|
||||
|
||||
|
||||
def fused_moe(
|
||||
@@ -532,25 +542,28 @@ def fused_moe(
|
||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||
|
||||
if hasattr(ops, "topk_softmax"):
|
||||
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
|
||||
renormalize)
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = fused_topk_v0_4_3(hidden_states, gating_output, topk,
|
||||
renormalize)
|
||||
|
||||
return fused_experts(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=inplace,
|
||||
override_config=override_config,
|
||||
use_fp8=use_fp8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale)
|
||||
topk_weights, topk_ids = fused_topk_v0_4_3(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=inplace,
|
||||
override_config=override_config,
|
||||
use_fp8=use_fp8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def fused_topk_v0_4_3(
|
||||
@@ -560,6 +573,7 @@ def fused_topk_v0_4_3(
|
||||
renormalize: bool,
|
||||
):
|
||||
import vllm._moe_C as moe_kernels
|
||||
|
||||
M, _ = hidden_states.shape
|
||||
|
||||
topk_weights = torch.empty(
|
||||
@@ -579,4 +593,4 @@ def fused_topk_v0_4_3(
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
return topk_weights, topk_ids
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Radix attention."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -11,8 +12,13 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
|
||||
|
||||
class RadixAttention(nn.Module):
|
||||
def __init__(
|
||||
self, num_heads: int, head_dim: int, scaling: float, num_kv_heads: int,
|
||||
layer_id: int, logit_cap: int = -1
|
||||
self,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
scaling: float,
|
||||
num_kv_heads: int,
|
||||
layer_id: int,
|
||||
logit_cap: int = -1,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_q_head_num = num_heads
|
||||
@@ -112,6 +118,7 @@ class RadixAttention(nn.Module):
|
||||
)
|
||||
|
||||
from flashinfer.cascade import merge_state
|
||||
|
||||
o, _ = merge_state(o1, s1, o2, s2)
|
||||
|
||||
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
||||
|
||||
@@ -99,4 +99,4 @@ def start_controller_process(
|
||||
except Exception:
|
||||
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
||||
finally:
|
||||
kill_parent_process()
|
||||
kill_parent_process()
|
||||
|
||||
@@ -127,7 +127,7 @@ class InputMetadata:
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
1
|
||||
1,
|
||||
)
|
||||
else:
|
||||
self.flashinfer_decode_wrapper.end_forward()
|
||||
@@ -140,7 +140,7 @@ class InputMetadata:
|
||||
head_dim,
|
||||
1,
|
||||
pos_encoding_mode="NONE",
|
||||
data_type=self.token_to_kv_pool.kv_data[0].dtype
|
||||
data_type=self.token_to_kv_pool.kv_data[0].dtype,
|
||||
)
|
||||
|
||||
def init_extend_args(self):
|
||||
@@ -228,7 +228,7 @@ class InputMetadata:
|
||||
ret.init_flashinfer_args(
|
||||
model_runner.model_config.num_attention_heads // tp_size,
|
||||
model_runner.model_config.get_num_kv_heads(tp_size),
|
||||
model_runner.model_config.head_dim
|
||||
model_runner.model_config.head_dim,
|
||||
)
|
||||
|
||||
return ret
|
||||
@@ -269,7 +269,7 @@ class ModelRunner:
|
||||
world_size=self.tp_size,
|
||||
rank=self.tp_rank,
|
||||
local_rank=self.gpu_id,
|
||||
distributed_init_method=nccl_init_method
|
||||
distributed_init_method=nccl_init_method,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||
total_gpu_memory = get_available_gpu_memory(
|
||||
@@ -341,7 +341,13 @@ class ModelRunner:
|
||||
)
|
||||
head_dim = self.model_config.head_dim
|
||||
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
||||
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * torch._utils._element_size(self.dtype)
|
||||
cell_size = (
|
||||
head_num
|
||||
* head_dim
|
||||
* self.model_config.num_hidden_layers
|
||||
* 2
|
||||
* torch._utils._element_size(self.dtype)
|
||||
)
|
||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||
1 - self.mem_fraction_static
|
||||
)
|
||||
@@ -384,15 +390,16 @@ class ModelRunner:
|
||||
def init_flash_infer(self):
|
||||
if not global_server_args_dict.get("disable_flashinfer", False):
|
||||
from flashinfer import (
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
if not _grouped_size_compiled_for_decode_kernels(
|
||||
self.model_config.num_attention_heads // self.tp_size,
|
||||
self.model_config.get_num_kv_heads(self.tp_size)):
|
||||
self.model_config.get_num_kv_heads(self.tp_size),
|
||||
):
|
||||
use_tensor_cores = True
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
@@ -400,8 +407,8 @@ class ModelRunner:
|
||||
workspace_buffers = torch.empty(
|
||||
3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
workspace_buffers[0], "NHD"
|
||||
self.flashinfer_prefill_wrapper_ragged = (
|
||||
BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD")
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffers[1], "NHD"
|
||||
@@ -410,7 +417,9 @@ class ModelRunner:
|
||||
workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
|
||||
)
|
||||
else:
|
||||
self.flashinfer_prefill_wrapper_ragged = self.flashinfer_prefill_wrapper_paged = None
|
||||
self.flashinfer_prefill_wrapper_ragged = (
|
||||
self.flashinfer_prefill_wrapper_paged
|
||||
) = None
|
||||
self.flashinfer_decode_wrapper = None
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@@ -34,11 +34,11 @@ from sglang.srt.managers.io_struct import (
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
connect_rpyc_service,
|
||||
get_int_token_logit_bias,
|
||||
is_multimodal_model,
|
||||
set_random_seed,
|
||||
start_rpyc_service_process,
|
||||
connect_rpyc_service,
|
||||
suppress_other_loggers,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
@@ -368,9 +368,11 @@ class ModelTpServer:
|
||||
if (
|
||||
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
||||
< available_size
|
||||
and (req.extend_input_len + new_batch_input_tokens
|
||||
<= self.max_prefill_tokens
|
||||
or len(can_run_list) == 0)
|
||||
and (
|
||||
req.extend_input_len + new_batch_input_tokens
|
||||
<= self.max_prefill_tokens
|
||||
or len(can_run_list) == 0
|
||||
)
|
||||
):
|
||||
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
||||
available_size += delta
|
||||
@@ -452,7 +454,9 @@ class ModelTpServer:
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
|
||||
output.normalized_prompt_logprobs = output.normalized_prompt_logprobs.tolist()
|
||||
output.normalized_prompt_logprobs = (
|
||||
output.normalized_prompt_logprobs.tolist()
|
||||
)
|
||||
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
else:
|
||||
@@ -582,7 +586,9 @@ class ModelTpServer:
|
||||
req.check_finished()
|
||||
|
||||
if req.return_logprob:
|
||||
req.decode_token_logprobs.append((next_token_logprobs[i], next_token_id))
|
||||
req.decode_token_logprobs.append(
|
||||
(next_token_logprobs[i], next_token_id)
|
||||
)
|
||||
if req.top_logprobs_num > 0:
|
||||
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
||||
|
||||
@@ -759,16 +765,27 @@ class ModelTpClient:
|
||||
with ThreadPoolExecutor(self.tp_size) as executor:
|
||||
# Launch model processes
|
||||
if server_args.nnodes == 1:
|
||||
self.procs = list(executor.map(
|
||||
lambda args: start_rpyc_service_process(*args),
|
||||
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
|
||||
))
|
||||
self.procs = list(
|
||||
executor.map(
|
||||
lambda args: start_rpyc_service_process(*args),
|
||||
[
|
||||
(ModelTpService, p)
|
||||
for p in model_port_args.model_tp_ports
|
||||
],
|
||||
)
|
||||
)
|
||||
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
|
||||
else:
|
||||
addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)]
|
||||
addrs = [
|
||||
(ip, port)
|
||||
for ip, port in zip(
|
||||
model_port_args.model_tp_ips, model_port_args.model_tp_ports
|
||||
)
|
||||
]
|
||||
|
||||
self.model_services = list(executor.map(
|
||||
lambda args: connect_rpyc_service(*args), addrs))
|
||||
self.model_services = list(
|
||||
executor.map(lambda args: connect_rpyc_service(*args), addrs)
|
||||
)
|
||||
|
||||
# Init model
|
||||
def init_model(i):
|
||||
|
||||
@@ -334,15 +334,15 @@ class TokenizerManager:
|
||||
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
if top_logprobs_num > 0:
|
||||
ret["meta_info"][
|
||||
"prefill_top_logprobs"
|
||||
] = self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
||||
ret["meta_info"]["prefill_top_logprobs"] = (
|
||||
self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
)
|
||||
ret["meta_info"][
|
||||
"decode_top_logprobs"
|
||||
] = self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
||||
ret["meta_info"]["decode_top_logprobs"] = (
|
||||
self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
)
|
||||
return ret
|
||||
|
||||
|
||||
@@ -5,19 +5,23 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Gemma2Config
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
|
||||
# FIXME: temporary solution, remove after next vllm release
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
|
||||
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
@@ -26,8 +30,6 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
|
||||
|
||||
# FIXME: temporary solution, remove after next vllm release
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
class GemmaRMSNorm(CustomOp):
|
||||
"""RMS normalization for Gemma.
|
||||
|
||||
@@ -76,13 +78,19 @@ class GemmaRMSNorm(CustomOp):
|
||||
|
||||
# FIXME: temporary solution, remove after next vllm release
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
|
||||
class GemmaRotaryEmbedding(RotaryEmbedding):
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
|
||||
inv_freq = 1.0 / (base**(
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() /
|
||||
self.rotary_dim))
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float()
|
||||
/ self.rotary_dim
|
||||
)
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
|
||||
@@ -98,18 +106,17 @@ class Gemma2MLP(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config)
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
||||
)
|
||||
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
|
||||
raise ValueError(
|
||||
"Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
|
||||
"function. Please set `hidden_act` and `hidden_activation` to "
|
||||
"`gelu_pytorch_tanh`.")
|
||||
"`gelu_pytorch_tanh`."
|
||||
)
|
||||
self.act_fn = GeluAndMul(approximate="tanh")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -121,17 +128,19 @@ class Gemma2MLP(nn.Module):
|
||||
|
||||
class Gemma2Attention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
layer_idx: int,
|
||||
config: Gemma2Config,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
layer_idx: int,
|
||||
config: Gemma2Config,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.config = config
|
||||
@@ -183,15 +192,16 @@ class Gemma2Attention(nn.Module):
|
||||
# from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every
|
||||
# odd layer, vLLM currently ignores it and uses global attention for
|
||||
# all layers.
|
||||
use_sliding_window = (layer_idx % 2 == 1
|
||||
and config.sliding_window is not None)
|
||||
use_sliding_window = layer_idx % 2 == 1 and config.sliding_window is not None
|
||||
del use_sliding_window # Unused.
|
||||
self.attn = RadixAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_idx,
|
||||
logit_cap=self.config.attn_logit_softcapping)
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_idx,
|
||||
logit_cap=self.config.attn_logit_softcapping,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -238,14 +248,16 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
hidden_activation=config.hidden_activation,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = GemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.pre_feedforward_layernorm = GemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.post_feedforward_layernorm = GemmaRMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -258,8 +270,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
@@ -268,7 +279,8 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
hidden_states, residual = self.pre_feedforward_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states, residual
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
||||
return hidden_states, residual
|
||||
@@ -289,10 +301,12 @@ class Gemma2Model(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
# Normalize the embedding by sqrt(hidden_size)
|
||||
@@ -392,7 +406,7 @@ class Gemma2ForCausalLM(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||
for param_name, shard_name, shard_id in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
name = name.replace(shard_name, param_name)
|
||||
@@ -412,8 +426,7 @@ class Gemma2ForCausalLM(nn.Module):
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
@@ -421,7 +434,8 @@ class Gemma2ForCausalLM(nn.Module):
|
||||
if unloaded_params:
|
||||
raise RuntimeError(
|
||||
"Some weights are not initialized from checkpoints: "
|
||||
f"{unloaded_params}")
|
||||
f"{unloaded_params}"
|
||||
)
|
||||
|
||||
|
||||
EntryClass = Gemma2ForCausalLM
|
||||
EntryClass = Gemma2ForCausalLM
|
||||
|
||||
@@ -5,14 +5,12 @@ import tqdm
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
)
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
||||
from sglang.srt.managers.controller.model_runner import InputMetadata
|
||||
from sglang.srt.models.llama2 import LlamaModel
|
||||
|
||||
|
||||
@@ -28,7 +26,9 @@ class LlamaForClassification(nn.Module):
|
||||
self.quant_config = quant_config
|
||||
self.model = LlamaModel(config, quant_config=quant_config)
|
||||
|
||||
self.classification_head = nn.Linear(config.hidden_size, config.classification_out_size)
|
||||
self.classification_head = nn.Linear(
|
||||
config.hidden_size, config.classification_out_size
|
||||
)
|
||||
self.eos_token_id = config.eos_token_id
|
||||
|
||||
def forward(
|
||||
@@ -45,7 +45,9 @@ class LlamaForClassification(nn.Module):
|
||||
|
||||
if scores.shape[0] != input_metadata.batch_size:
|
||||
print("Warning: the EOS tokens are missing in some sentences.")
|
||||
scores = torch.ones((input_metadata.batch_size, self.config.classification_out_size)).to(input_ids.device)
|
||||
scores = torch.ones(
|
||||
(input_metadata.batch_size, self.config.classification_out_size)
|
||||
).to(input_ids.device)
|
||||
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=scores,
|
||||
@@ -101,4 +103,5 @@ class LlamaForClassification(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
EntryClass = LlamaForClassification
|
||||
|
||||
EntryClass = LlamaForClassification
|
||||
|
||||
@@ -51,13 +51,12 @@ from sglang.srt.utils import (
|
||||
allocate_init_ports,
|
||||
assert_pkg_version,
|
||||
enable_show_time_cost,
|
||||
send_addrs_to_rank_0,
|
||||
receive_addrs,
|
||||
send_addrs_to_rank_0,
|
||||
start_rpyc_service_process,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
@@ -152,9 +151,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
if server_args.disable_disk_cache:
|
||||
disable_cache()
|
||||
if not server_args.disable_flashinfer:
|
||||
assert_pkg_version("flashinfer", "0.0.8", "Please uninstall the old version and "
|
||||
"reinstall the latest version by following the instructions "
|
||||
"at https://docs.flashinfer.ai/installation.html.")
|
||||
assert_pkg_version(
|
||||
"flashinfer",
|
||||
"0.0.8",
|
||||
"Please uninstall the old version and "
|
||||
"reinstall the latest version by following the instructions "
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
)
|
||||
if server_args.chat_template:
|
||||
# TODO: replace this with huggingface transformers template
|
||||
load_chat_template_for_openai_api(server_args.chat_template)
|
||||
@@ -176,7 +179,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
ModelPortArgs(
|
||||
nccl_port=ports[3 + i * (tp_size_local + 1)],
|
||||
model_tp_ips=[None] * tp_size_local,
|
||||
model_tp_ports=ports[3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)],
|
||||
model_tp_ports=ports[
|
||||
3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
|
||||
],
|
||||
)
|
||||
)
|
||||
port_args = PortArgs(
|
||||
@@ -194,9 +199,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
||||
else:
|
||||
receive_addrs(model_port_args[0], server_args)
|
||||
for i in range(tp_size_local):
|
||||
start_rpyc_service_process(ModelTpService, model_port_args[0].model_tp_ports[i])
|
||||
start_rpyc_service_process(
|
||||
ModelTpService, model_port_args[0].model_tp_ports[i]
|
||||
)
|
||||
if server_args.node_rank != 0:
|
||||
logger.info(f"[node_rank={server_args.node_rank}]: Listen for connections...")
|
||||
logger.info(
|
||||
f"[node_rank={server_args.node_rank}]: Listen for connections..."
|
||||
)
|
||||
while True:
|
||||
pass
|
||||
|
||||
|
||||
@@ -137,17 +137,16 @@ class ServerArgs:
|
||||
"--dtype",
|
||||
type=str,
|
||||
default=ServerArgs.dtype,
|
||||
choices=[
|
||||
"auto", "half", "float16", "bfloat16", "float", "float32"
|
||||
],
|
||||
help='Data type for model weights and activations.\n\n'
|
||||
choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
|
||||
help="Data type for model weights and activations.\n\n"
|
||||
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
|
||||
'BF16 precision for BF16 models.\n'
|
||||
"BF16 precision for BF16 models.\n"
|
||||
'* "half" for FP16. Recommended for AWQ quantization.\n'
|
||||
'* "float16" is the same as "half".\n'
|
||||
'* "bfloat16" for a balance between precision and range.\n'
|
||||
'* "float" is shorthand for FP32 precision.\n'
|
||||
'* "float32" for FP32 precision.')
|
||||
'* "float32" for FP32 precision.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
@@ -271,19 +270,12 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--nccl-init-addr",
|
||||
type=str,
|
||||
help="The nccl init address of multi-node server."
|
||||
help="The nccl init address of multi-node server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nnodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of nodes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--node-rank",
|
||||
type=int,
|
||||
help="The node rank."
|
||||
"--nnodes", type=int, default=1, help="The number of nodes."
|
||||
)
|
||||
parser.add_argument("--node-rank", type=int, help="The node rank.")
|
||||
|
||||
# Optimization/debug options
|
||||
parser.add_argument(
|
||||
|
||||
@@ -432,13 +432,12 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
|
||||
if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
|
||||
raise Exception(
|
||||
f"{pkg} is installed with version {installed_version}, which "
|
||||
f"is less than the minimum required version {min_version}. " +
|
||||
message
|
||||
f"is less than the minimum required version {min_version}. " + message
|
||||
)
|
||||
except PackageNotFoundError:
|
||||
raise Exception(
|
||||
f"{pkg} with minimum required version {min_version} is not installed. " +
|
||||
message
|
||||
f"{pkg} with minimum required version {min_version} is not installed. "
|
||||
+ message
|
||||
)
|
||||
|
||||
|
||||
@@ -474,24 +473,40 @@ def monkey_patch_vllm_dummy_weight_loader():
|
||||
"""
|
||||
|
||||
from vllm.model_executor.model_loader.loader import (
|
||||
ModelConfig, DeviceConfig, LoRAConfig, VisionLanguageConfig,
|
||||
ParallelConfig, SchedulerConfig, CacheConfig, nn,
|
||||
set_default_torch_dtype, _initialize_model, initialize_dummy_weights,
|
||||
DummyModelLoader
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
DummyModelLoader,
|
||||
LoRAConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
VisionLanguageConfig,
|
||||
_initialize_model,
|
||||
initialize_dummy_weights,
|
||||
nn,
|
||||
set_default_torch_dtype,
|
||||
)
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
def load_model(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig,
|
||||
) -> nn.Module:
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, vision_language_config,
|
||||
cache_config)
|
||||
model = _initialize_model(
|
||||
model_config,
|
||||
self.load_config,
|
||||
lora_config,
|
||||
vision_language_config,
|
||||
cache_config,
|
||||
)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
@@ -541,7 +556,7 @@ def get_ip_address(ifname):
|
||||
ip_address = fcntl.ioctl(
|
||||
s.fileno(),
|
||||
0x8915, # SIOCGIFADDR
|
||||
struct.pack('256s', bytes(ifname[:15], 'utf-8'))
|
||||
struct.pack("256s", bytes(ifname[:15], "utf-8")),
|
||||
)[20:24]
|
||||
return socket.inet_ntoa(ip_address)
|
||||
|
||||
@@ -550,44 +565,66 @@ def send_addrs_to_rank_0(model_port_args, server_args):
|
||||
assert server_args.node_rank != 0 and server_args.dp_size == 1
|
||||
import torch.distributed as dist
|
||||
|
||||
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
|
||||
ifname = os.environ.get(
|
||||
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
||||
)
|
||||
ip_addr = get_ip_address(ifname)
|
||||
|
||||
num_tp_ports = server_args.tp_size // server_args.nnodes
|
||||
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
||||
ip_addr = [int(x) for x in ip_addr.split(".")]
|
||||
addrs_tensor = torch.tensor(ip_addr + model_port_args.model_tp_ports, dtype=torch.int)
|
||||
addrs_tensor = torch.tensor(
|
||||
ip_addr + model_port_args.model_tp_ports, dtype=torch.int
|
||||
)
|
||||
|
||||
init_method = f"tcp://{server_args.nccl_init_addr}"
|
||||
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
init_method=init_method,
|
||||
rank=server_args.node_rank,
|
||||
world_size=server_args.nnodes,
|
||||
)
|
||||
dist.send(addrs_tensor, dst=0)
|
||||
print(f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}")
|
||||
print(
|
||||
f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}"
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def receive_addrs(model_port_args, server_args):
|
||||
assert server_args.node_rank == 0 and server_args.dp_size == 1
|
||||
import torch.distributed as dist
|
||||
|
||||
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
|
||||
ifname = os.environ.get(
|
||||
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
||||
)
|
||||
ip_addr = get_ip_address(ifname)
|
||||
|
||||
num_tp_ports = server_args.tp_size // server_args.nnodes
|
||||
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
||||
|
||||
init_method = f"tcp://{server_args.nccl_init_addr}"
|
||||
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
init_method=init_method,
|
||||
rank=server_args.node_rank,
|
||||
world_size=server_args.nnodes,
|
||||
)
|
||||
|
||||
for src_rank in range(1, server_args.nnodes):
|
||||
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
|
||||
dist.recv(tensor, src=src_rank)
|
||||
ip = ".".join([str(x) for x in tensor[:4].tolist()])
|
||||
ports = tensor[4:].tolist()
|
||||
model_port_args.model_tp_ips[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = [ip] * num_tp_ports
|
||||
model_port_args.model_tp_ports[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = ports
|
||||
model_port_args.model_tp_ips[
|
||||
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
|
||||
] = [ip] * num_tp_ports
|
||||
model_port_args.model_tp_ports[
|
||||
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
|
||||
] = ports
|
||||
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
dist.destroy_process_group()
|
||||
|
||||
Reference in New Issue
Block a user