Improve weight loading and code style (#3174)
This commit is contained in:
@@ -329,12 +329,14 @@ class ColumnParallelLinear(LinearBase):
|
||||
prefix: str = "",
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
use_presharded_weights: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
|
||||
)
|
||||
|
||||
self.gather_output = gather_output
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
if tp_rank is None:
|
||||
@@ -402,7 +404,8 @@ class ColumnParallelLinear(LinearBase):
|
||||
if output_dim is not None and not use_bitsandbytes_4bit:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
if not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
@@ -418,7 +421,11 @@ class ColumnParallelLinear(LinearBase):
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank)
|
||||
param.load_column_parallel_weight(
|
||||
loaded_weight,
|
||||
tp_rank=self.tp_rank,
|
||||
use_presharded_weights=self.use_presharded_weights,
|
||||
)
|
||||
|
||||
def forward(self, input_):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
@@ -499,7 +506,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
prefix=prefix,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
use_presharded_weights=use_presharded_weights,
|
||||
)
|
||||
self.prefix = prefix
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
@@ -743,6 +752,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
prefix: str = "",
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
load_presharded_attn: bool = False,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
@@ -772,6 +782,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
||||
]
|
||||
self.use_presharded_weights = load_presharded_attn
|
||||
|
||||
super().__init__(
|
||||
input_size=input_size,
|
||||
@@ -784,6 +795,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
prefix=prefix,
|
||||
tp_rank=tp_rank,
|
||||
tp_size=tp_size,
|
||||
use_presharded_weights=self.use_presharded_weights,
|
||||
)
|
||||
|
||||
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
||||
@@ -842,9 +854,10 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_size=shard_size, shard_offset=shard_offset
|
||||
)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
param.output_dim, shard_offset, shard_size
|
||||
)
|
||||
if not self.use_presharded_weights:
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
param.output_dim, shard_offset, shard_size
|
||||
)
|
||||
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
||||
|
||||
def weight_loader_v2(
|
||||
@@ -882,6 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_offset=shard_offset,
|
||||
shard_size=shard_size,
|
||||
tp_rank=self.tp_rank,
|
||||
use_presharded_weights=self.use_presharded_weights,
|
||||
)
|
||||
|
||||
def weight_loader(
|
||||
@@ -987,9 +1001,10 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param, orig_qkv_offsets, shard_id
|
||||
)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size
|
||||
)
|
||||
if not self.use_presharded_weights:
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size
|
||||
)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
return
|
||||
|
||||
@@ -1049,7 +1064,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if not use_bitsandbytes_4bit:
|
||||
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
# Special case for for AQLM codebooks.
|
||||
|
||||
@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
activation: str = "silu",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -141,6 +142,7 @@ class EPMoE(torch.nn.Module):
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.activation = activation
|
||||
|
||||
if quant_config is None:
|
||||
@@ -184,6 +186,7 @@ class EPMoE(torch.nn.Module):
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
)
|
||||
|
||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
||||
@@ -257,16 +260,20 @@ class EPMoE(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
|
||||
if self.activation == "silu":
|
||||
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||
|
||||
# GroupGemm-1
|
||||
down_output = torch.empty(
|
||||
@@ -312,7 +319,6 @@ class EPMoE(torch.nn.Module):
|
||||
ckpt_up_proj_name: str,
|
||||
num_experts: int,
|
||||
) -> List[Tuple[str, str, int, str]]:
|
||||
|
||||
return [
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
(
|
||||
@@ -357,7 +363,6 @@ class EPMoE(torch.nn.Module):
|
||||
)
|
||||
return
|
||||
|
||||
expert_data = param.data[expert_id]
|
||||
if shard_id == "w2":
|
||||
param.data[expert_id] = loaded_weight
|
||||
elif shard_id == "w1":
|
||||
|
||||
@@ -124,7 +124,13 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs):
|
||||
def load_qkv_weight(
|
||||
self,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
use_presharded_weights: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
shard_offset = kwargs.get("shard_offset")
|
||||
shard_size = kwargs.get("shard_size")
|
||||
@@ -142,11 +148,14 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
param_data = self.data
|
||||
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, shard_id * shard_size, shard_size
|
||||
)
|
||||
if not use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, shard_id * shard_size, shard_size
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
assert (
|
||||
param_data.shape == loaded_weight.shape
|
||||
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
@@ -292,7 +301,7 @@ class PackedColumnParameter(_ColumnvLLMParameter):
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
@@ -336,7 +345,7 @@ class PackedvLLMParameter(ModelWeightParameter):
|
||||
packed_factor: Union[int, Fraction],
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
|
||||
@@ -247,6 +247,7 @@ class Req:
|
||||
# Each decode stage's output ids
|
||||
self.output_ids = []
|
||||
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
|
||||
self.fill_ids = None
|
||||
self.session_id = session_id
|
||||
self.input_embeds = input_embeds
|
||||
|
||||
|
||||
@@ -486,7 +486,7 @@ class Scheduler:
|
||||
@torch.no_grad()
|
||||
def event_loop_overlap(self):
|
||||
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
||||
result_queue = deque()
|
||||
self.result_queue = deque()
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
@@ -497,7 +497,7 @@ class Scheduler:
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
result_queue.append((batch.copy(), result))
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch is None:
|
||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||
@@ -511,7 +511,7 @@ class Scheduler:
|
||||
|
||||
if self.last_batch:
|
||||
# Process the results of the last batch
|
||||
tmp_batch, tmp_result = result_queue.popleft()
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
@@ -642,7 +642,7 @@ class Scheduler:
|
||||
self.waiting_queue.append(req)
|
||||
return
|
||||
|
||||
# Handle image inputs
|
||||
# Handle multimodal inputs
|
||||
if recv_req.image_inputs is not None:
|
||||
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
||||
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
||||
@@ -743,7 +743,13 @@ class Scheduler:
|
||||
req.logprob_start_len = len(req.origin_input_ids) - 1
|
||||
self.waiting_queue.append(req)
|
||||
|
||||
def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
|
||||
def log_prefill_stats(
|
||||
self,
|
||||
adder: PrefillAdder,
|
||||
can_run_list: List[Req],
|
||||
running_bs: ScheduleBatch,
|
||||
has_being_chunked: bool,
|
||||
):
|
||||
self.tree_cache_metrics["total"] += (
|
||||
adder.log_input_tokens + adder.log_hit_tokens
|
||||
) / 10**9
|
||||
|
||||
@@ -218,7 +218,7 @@ class ModelRunner:
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
# Init torch distributed
|
||||
|
||||
torch.get_device_module(self.device).set_device(self.gpu_id)
|
||||
if self.device == "cuda":
|
||||
backend = "nccl"
|
||||
|
||||
@@ -404,8 +404,13 @@ def np_cache_weights_iterator(
|
||||
|
||||
def safetensors_weights_iterator(
|
||||
hf_weights_files: List[str],
|
||||
is_all_weights_sharded: bool = False,
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
"""Iterate over the weights in the model safetensor files.
|
||||
|
||||
If is_all_weights_sharded is True, it uses more optimize read by reading an
|
||||
entire file instead of reading each tensor one by one.
|
||||
"""
|
||||
enable_tqdm = (
|
||||
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
||||
)
|
||||
@@ -415,9 +420,14 @@ def safetensors_weights_iterator(
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
if not is_all_weights_sharded:
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
else:
|
||||
result = load_file(st_file, device="cpu")
|
||||
for name, param in result.items():
|
||||
yield name, param
|
||||
|
||||
|
||||
|
||||
@@ -75,6 +75,7 @@ class ServerArgs:
|
||||
# Other runtime options
|
||||
tp_size: int = 1
|
||||
stream_interval: int = 1
|
||||
stream_output: bool = False
|
||||
random_seed: Optional[int] = None
|
||||
constrained_json_whitespace_pattern: Optional[str] = None
|
||||
watchdog_timeout: float = 300
|
||||
@@ -500,6 +501,11 @@ class ServerArgs:
|
||||
default=ServerArgs.stream_interval,
|
||||
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stream-output",
|
||||
action="store_true",
|
||||
help="Whether to output as a sequence of disjoint segments.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-seed",
|
||||
type=int,
|
||||
|
||||
@@ -774,7 +774,7 @@ def get_zmq_socket(
|
||||
|
||||
|
||||
def dump_to_file(dirpath, name, value):
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
if get_tensor_model_parallel_rank() != 0:
|
||||
return
|
||||
|
||||
@@ -34,7 +34,7 @@ DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
|
||||
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
|
||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
|
||||
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 1000
|
||||
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
|
||||
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
|
||||
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
|
||||
@@ -135,10 +135,6 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
|
||||
return pred
|
||||
|
||||
|
||||
def call_generate_gserver(prompt, temperature, max_tokens, stop=None, url=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def call_generate_guidance(
|
||||
prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
|
||||
):
|
||||
@@ -530,31 +526,19 @@ def get_similarities(vec1, vec2):
|
||||
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
|
||||
|
||||
|
||||
def run_bench_serving(
|
||||
model,
|
||||
num_prompts,
|
||||
request_rate,
|
||||
other_server_args,
|
||||
dataset_name="random",
|
||||
def get_benchmark_args(
|
||||
base_url="",
|
||||
dataset_name="",
|
||||
dataset_path="",
|
||||
tokenizer=None,
|
||||
tokenizer="",
|
||||
num_prompts=500,
|
||||
random_input_len=4096,
|
||||
random_output_len=2048,
|
||||
request_rate=float("inf"),
|
||||
disable_stream=False,
|
||||
disable_ignore_eos=False,
|
||||
need_warmup=False,
|
||||
):
|
||||
# Launch the server
|
||||
base_url = DEFAULT_URL_FOR_TEST
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_server_args,
|
||||
)
|
||||
|
||||
# Run benchmark
|
||||
args = SimpleNamespace(
|
||||
return SimpleNamespace(
|
||||
backend="sglang",
|
||||
base_url=base_url,
|
||||
host=None,
|
||||
@@ -583,6 +567,44 @@ def run_bench_serving(
|
||||
lora_name=None,
|
||||
)
|
||||
|
||||
|
||||
def run_bench_serving(
|
||||
model,
|
||||
num_prompts,
|
||||
request_rate,
|
||||
other_server_args,
|
||||
dataset_name="random",
|
||||
dataset_path="",
|
||||
tokenizer=None,
|
||||
random_input_len=4096,
|
||||
random_output_len=2048,
|
||||
disable_stream=False,
|
||||
disable_ignore_eos=False,
|
||||
need_warmup=False,
|
||||
):
|
||||
# Launch the server
|
||||
base_url = DEFAULT_URL_FOR_TEST
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_server_args,
|
||||
)
|
||||
|
||||
# Run benchmark
|
||||
args = get_benchmark_args(
|
||||
base_url=base_url,
|
||||
dataset_name=dataset_name,
|
||||
dataset_path=dataset_path,
|
||||
tokenizer=tokenizer,
|
||||
num_prompts=num_prompts,
|
||||
random_input_len=random_input_len,
|
||||
random_output_len=random_output_len,
|
||||
request_rate=request_rate,
|
||||
disable_stream=disable_stream,
|
||||
disable_ignore_eos=disable_ignore_eos,
|
||||
)
|
||||
|
||||
try:
|
||||
if need_warmup:
|
||||
warmup_args = copy.deepcopy(args)
|
||||
@@ -596,6 +618,38 @@ def run_bench_serving(
|
||||
return res
|
||||
|
||||
|
||||
def run_bench_serving_multi(
|
||||
model,
|
||||
base_url,
|
||||
other_server_args,
|
||||
benchmark_args,
|
||||
need_warmup=False,
|
||||
):
|
||||
# Launch the server
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_server_args,
|
||||
)
|
||||
|
||||
# run benchmark for all
|
||||
res_l = []
|
||||
try:
|
||||
for args in benchmark_args:
|
||||
if need_warmup:
|
||||
warmup_args = copy.deepcopy(args)
|
||||
warmup_args.num_prompts = 16
|
||||
run_benchmark(warmup_args)
|
||||
|
||||
res = run_benchmark(args)
|
||||
res_l.append((args, res))
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
return res_l
|
||||
|
||||
|
||||
def run_bench_one_batch(model, other_args):
|
||||
command = [
|
||||
"python3",
|
||||
|
||||
Reference in New Issue
Block a user