[CPU] support the case where num_attention_heads or intermediate_size is not divisible by the TP size (#6771)
This commit is contained in:
119
python/sglang/srt/configs/update_config.py
Normal file
119
python/sglang/srt/configs/update_config.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
DEFAULT_MOE_PADDING_SIZE = 32
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
|
||||
|
||||
def may_get_weight_block_size(model_config, load_config):
|
||||
from sglang.srt.model_loader.loader import _get_quantization_config
|
||||
from sglang.srt.model_loader.utils import get_model_architecture
|
||||
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
|
||||
|
||||
quant_config = _get_quantization_config(
|
||||
model_config, load_config, packed_modules_mapping
|
||||
)
|
||||
|
||||
if quant_config is not None and hasattr(quant_config, "weight_block_size"):
|
||||
return getattr(quant_config, "weight_block_size")
|
||||
return None
|
||||
|
||||
|
||||
def get_moe_padding_size(weight_block_size):
|
||||
if weight_block_size is not None:
|
||||
# See NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
||||
assert (
|
||||
len(weight_block_size) == 2
|
||||
), "Only len(weight_block_size) == 2 is supported"
|
||||
assert (
|
||||
weight_block_size[0] == weight_block_size[1]
|
||||
), "Only weight_block_size[0] == weight_block_size[1] is supported"
|
||||
|
||||
return weight_block_size[0]
|
||||
|
||||
return DEFAULT_MOE_PADDING_SIZE
|
||||
|
||||
|
||||
def get_num_heads_padding_size(tp_size, weight_block_size):
|
||||
pad_size = (
|
||||
tp_size * 2 if tp_size % 2 == 1 and weight_block_size is not None else tp_size
|
||||
)
|
||||
return pad_size
|
||||
|
||||
|
||||
def update_intermediate_size(model_config, attr_name, intermediate_padding_size):
|
||||
if hasattr(model_config.hf_config, attr_name):
|
||||
attr_value = getattr(model_config.hf_config, attr_name)
|
||||
if attr_value % intermediate_padding_size != 0:
|
||||
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
|
||||
attr_value = pad_vocab_size(attr_value, intermediate_padding_size)
|
||||
setattr(model_config.hf_config, attr_name, attr_value)
|
||||
setattr(model_config.hf_text_config, attr_name, attr_value)
|
||||
return model_config
|
||||
|
||||
|
||||
def adjust_config_with_unaligned_cpu_tp(
|
||||
model_config: ModelConfig, load_config: LoadConfig, tp_size: int
|
||||
) -> ModelConfig:
|
||||
# Support the case where the num_attention_heads is not divisible by the TP size.
|
||||
weight_block_size = may_get_weight_block_size(model_config, load_config)
|
||||
|
||||
model_config.hf_config.original_num_attention_heads = (
|
||||
model_config.num_attention_heads
|
||||
)
|
||||
model_config.hf_text_config.original_num_attention_heads = (
|
||||
model_config.num_attention_heads
|
||||
)
|
||||
|
||||
model_config.hf_config.original_total_num_kv_heads = (
|
||||
model_config.get_total_num_kv_heads()
|
||||
)
|
||||
model_config.hf_text_config.original_total_num_kv_heads = (
|
||||
model_config.get_total_num_kv_heads()
|
||||
)
|
||||
|
||||
if (
|
||||
model_config.num_attention_heads % tp_size != 0
|
||||
or model_config.get_total_num_kv_heads() % tp_size != 0
|
||||
):
|
||||
# Compute the head_dim using the model_config.num_attention_heads before padding
|
||||
if not hasattr(model_config.hf_config, "head_dim"):
|
||||
model_config.hf_config.head_dim = (
|
||||
model_config.hidden_size // model_config.num_attention_heads
|
||||
)
|
||||
|
||||
query_heads_per_kv = (
|
||||
model_config.num_attention_heads // model_config.get_total_num_kv_heads()
|
||||
)
|
||||
total_kv_heads = model_config.get_total_num_kv_heads()
|
||||
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
|
||||
pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
|
||||
num_key_value_heads = pad_vocab_size(total_kv_heads, pad_size)
|
||||
|
||||
model_config.num_key_value_heads = num_key_value_heads
|
||||
model_config.hf_config.num_key_value_heads = num_key_value_heads
|
||||
model_config.hf_text_config.num_key_value_heads = num_key_value_heads
|
||||
|
||||
num_attention_heads = num_key_value_heads * query_heads_per_kv
|
||||
model_config.num_attention_heads = num_attention_heads
|
||||
model_config.hf_config.num_attention_heads = num_attention_heads
|
||||
model_config.hf_text_config.num_attention_heads = num_attention_heads
|
||||
|
||||
intermediate_padding_size = tp_size * get_moe_padding_size(weight_block_size)
|
||||
model_config = update_intermediate_size(
|
||||
model_config, "moe_intermediate_size", intermediate_padding_size
|
||||
)
|
||||
model_config = update_intermediate_size(
|
||||
model_config, "intermediate_size", intermediate_padding_size
|
||||
)
|
||||
|
||||
return model_config
|
||||
@@ -426,8 +426,26 @@ class ColumnParallelLinear(LinearBase):
|
||||
if output_dim is not None and not use_bitsandbytes_4bit:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
if not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
if _is_cpu:
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
narrow_padded_param_and_loaded_weight,
|
||||
)
|
||||
|
||||
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
||||
param_data,
|
||||
loaded_weight,
|
||||
0, # param_data_start
|
||||
start_idx,
|
||||
output_dim,
|
||||
shard_size,
|
||||
not self.use_presharded_weights,
|
||||
)
|
||||
else:
|
||||
if not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
output_dim, start_idx, shard_size
|
||||
)
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
@@ -644,10 +662,29 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
if _is_cpu:
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
narrow_padded_param_and_loaded_weight,
|
||||
)
|
||||
|
||||
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
||||
param_data,
|
||||
loaded_weight,
|
||||
0, # param_data_start
|
||||
start_idx,
|
||||
output_dim,
|
||||
shard_size,
|
||||
not use_bitsandbytes_4bit and not self.use_presharded_weights,
|
||||
)
|
||||
else:
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
output_dim, start_idx, shard_size
|
||||
)
|
||||
|
||||
# Special case for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
@@ -1112,10 +1149,27 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_id = self.tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_id * shard_size
|
||||
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
if _is_cpu:
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
narrow_padded_param_and_loaded_weight,
|
||||
)
|
||||
|
||||
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
||||
param_data,
|
||||
loaded_weight,
|
||||
0, # param_data_start
|
||||
start_idx,
|
||||
output_dim,
|
||||
shard_size,
|
||||
not use_bitsandbytes_4bit and not self.use_presharded_weights,
|
||||
)
|
||||
else:
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow here
|
||||
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
output_dim, start_idx, shard_size
|
||||
)
|
||||
|
||||
# Special case for for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
@@ -1257,7 +1311,22 @@ class RowParallelLinear(LinearBase):
|
||||
):
|
||||
shard_size = param_data.shape[input_dim]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
||||
|
||||
if _is_cpu:
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
narrow_padded_param_and_loaded_weight,
|
||||
)
|
||||
|
||||
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
||||
param_data,
|
||||
loaded_weight,
|
||||
0, # param_data_start
|
||||
start_idx,
|
||||
input_dim,
|
||||
shard_size,
|
||||
)
|
||||
else:
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
|
||||
@@ -19,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
||||
from sglang.srt.utils import (
|
||||
cpu_has_amx_support,
|
||||
get_bool_env_var,
|
||||
@@ -573,11 +574,6 @@ class FusedMoE(torch.nn.Module):
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
|
||||
if not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
shard_dim, shard_size * tp_rank, shard_size
|
||||
)
|
||||
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
@@ -588,7 +584,24 @@ class FusedMoE(torch.nn.Module):
|
||||
start = shard_size
|
||||
else:
|
||||
start = 0
|
||||
expert_data = expert_data.narrow(shard_dim, start, shard_size)
|
||||
|
||||
if _is_cpu:
|
||||
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
||||
expert_data,
|
||||
loaded_weight,
|
||||
start,
|
||||
shard_size * tp_rank,
|
||||
shard_dim,
|
||||
shard_size,
|
||||
not self.use_presharded_weights,
|
||||
)
|
||||
else:
|
||||
if not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
shard_dim, shard_size * tp_rank, shard_size
|
||||
)
|
||||
|
||||
expert_data = expert_data.narrow(shard_dim, start, shard_size)
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
def _load_w2(
|
||||
@@ -605,10 +618,21 @@ class FusedMoE(torch.nn.Module):
|
||||
# Narrow parameter and load.
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
|
||||
if not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
shard_dim, shard_size * tp_rank, shard_size
|
||||
if _is_cpu:
|
||||
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
||||
expert_data,
|
||||
loaded_weight,
|
||||
0, # param_data_start
|
||||
shard_size * tp_rank,
|
||||
shard_dim,
|
||||
shard_size,
|
||||
not self.use_presharded_weights,
|
||||
)
|
||||
else:
|
||||
if not self.use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
shard_dim, shard_size * tp_rank, shard_size
|
||||
)
|
||||
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight)
|
||||
|
||||
@@ -7,6 +7,8 @@ from typing import Callable, Optional, Union
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from sglang.srt.utils import is_cpu
|
||||
|
||||
__all__ = [
|
||||
"BasevLLMParameter",
|
||||
"PackedvLLMParameter",
|
||||
@@ -21,6 +23,8 @@ __all__ = [
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
|
||||
class BasevLLMParameter(Parameter):
|
||||
"""
|
||||
@@ -93,9 +97,28 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
):
|
||||
if not use_presharded_weights:
|
||||
shard_size = self.data.shape[self.output_dim]
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, tp_rank * shard_size, shard_size
|
||||
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
narrow_padded_param_and_loaded_weight,
|
||||
)
|
||||
|
||||
if _is_cpu:
|
||||
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
||||
self.data,
|
||||
loaded_weight,
|
||||
0, # param_data_start
|
||||
tp_rank * shard_size,
|
||||
self.output_dim,
|
||||
shard_size,
|
||||
)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
else:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, tp_rank * shard_size, shard_size
|
||||
)
|
||||
|
||||
assert self.data.shape == loaded_weight.shape
|
||||
self.data.copy_(loaded_weight)
|
||||
|
||||
@@ -116,10 +139,27 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
param_data = self.data
|
||||
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
||||
if not use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, tp_rank * shard_size, shard_size
|
||||
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
narrow_padded_param_and_loaded_weight,
|
||||
)
|
||||
|
||||
if _is_cpu:
|
||||
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
||||
param_data,
|
||||
loaded_weight,
|
||||
0, # param_data_start
|
||||
tp_rank * shard_size,
|
||||
self.output_dim,
|
||||
shard_size,
|
||||
not use_presharded_weights,
|
||||
)
|
||||
else:
|
||||
if not use_presharded_weights:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.output_dim, tp_rank * shard_size, shard_size
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
@@ -182,10 +222,30 @@ class RowvLLMParameter(BasevLLMParameter):
|
||||
):
|
||||
if not use_presharded_weights:
|
||||
shard_size = self.data.shape[self.input_dim]
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.input_dim, tp_rank * shard_size, shard_size
|
||||
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
narrow_padded_param_and_loaded_weight,
|
||||
)
|
||||
|
||||
if _is_cpu:
|
||||
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
||||
self.data,
|
||||
loaded_weight,
|
||||
0, # param_data_start
|
||||
tp_rank * shard_size,
|
||||
self.input_dim,
|
||||
shard_size,
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
return
|
||||
else:
|
||||
loaded_weight = loaded_weight.narrow(
|
||||
self.input_dim, tp_rank * shard_size, shard_size
|
||||
)
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
|
||||
@@ -246,8 +246,16 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.tp_size = 1
|
||||
|
||||
self.num_embeddings = num_embeddings
|
||||
self.padding_size = padding_size
|
||||
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||
|
||||
# Support the case where the vocab size is not divisible by the TP size.
|
||||
if (
|
||||
_is_cpu
|
||||
and pad_vocab_size(self.org_vocab_size, padding_size) % self.tp_size != 0
|
||||
):
|
||||
padding_size *= self.tp_size
|
||||
self.padding_size = padding_size
|
||||
|
||||
num_added_embeddings = num_embeddings - self.org_vocab_size
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
if use_presharded_weights:
|
||||
|
||||
@@ -149,6 +149,7 @@ from sglang.srt.utils import (
|
||||
get_available_gpu_memory,
|
||||
get_bool_env_var,
|
||||
get_zmq_socket,
|
||||
is_cpu,
|
||||
kill_itself_when_parent_died,
|
||||
point_to_point_pyobj,
|
||||
pyspy_dump_schedulers,
|
||||
@@ -167,6 +168,8 @@ TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
||||
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
||||
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
||||
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationBatchResult:
|
||||
@@ -2115,11 +2118,14 @@ class Scheduler(
|
||||
"kvcache": round(
|
||||
self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
|
||||
),
|
||||
"cuda_graph": round(
|
||||
self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
|
||||
),
|
||||
"token_capacity": int(self.max_total_num_tokens),
|
||||
}
|
||||
|
||||
if not _is_cpu:
|
||||
ret["memory_usage"]["cuda_graph"] = round(
|
||||
self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
|
||||
)
|
||||
|
||||
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
||||
ret["avg_spec_accept_length"] = (
|
||||
self.cum_spec_accept_length / self.cum_spec_accept_count
|
||||
|
||||
@@ -29,6 +29,7 @@ import torch.distributed as dist
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
||||
from sglang.srt.distributed import (
|
||||
get_tp_group,
|
||||
@@ -165,7 +166,6 @@ class ModelRunner:
|
||||
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
||||
):
|
||||
# Parse args
|
||||
self.model_config = model_config
|
||||
self.mem_fraction_static = mem_fraction_static
|
||||
self.device = server_args.device
|
||||
self.gpu_id = gpu_id
|
||||
@@ -178,6 +178,7 @@ class ModelRunner:
|
||||
self.dp_size = server_args.dp_size
|
||||
self.pp_rank = pp_rank
|
||||
self.pp_size = pp_size
|
||||
self.model_config = model_config
|
||||
self.dist_port = nccl_port
|
||||
self.server_args = server_args
|
||||
self.is_draft_worker = is_draft_worker
|
||||
@@ -604,6 +605,10 @@ class ModelRunner:
|
||||
download_dir=self.server_args.download_dir,
|
||||
model_loader_extra_config=self.server_args.model_loader_extra_config,
|
||||
)
|
||||
if self.device == "cpu":
|
||||
self.model_config = adjust_config_with_unaligned_cpu_tp(
|
||||
self.model_config, self.load_config, self.tp_size
|
||||
)
|
||||
if self.server_args.load_format == "gguf":
|
||||
monkey_patch_vllm_gguf_config()
|
||||
|
||||
|
||||
@@ -961,3 +961,57 @@ def kv_cache_scales_loader(
|
||||
tp_rank,
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def get_actual_shard_size(shard_size, weight_start, weight_end):
|
||||
if weight_end < weight_start:
|
||||
return 0
|
||||
|
||||
return min(shard_size, weight_end - weight_start)
|
||||
|
||||
|
||||
def reset_param_data_if_needed(param_data, dim, start, length):
|
||||
if length == 0:
|
||||
return
|
||||
|
||||
assert length > 0, f"Length should be positive, but got {length}"
|
||||
|
||||
param_data.narrow(dim, start, length).zero_()
|
||||
return
|
||||
|
||||
|
||||
def narrow_padded_param_and_loaded_weight(
|
||||
param_data,
|
||||
loaded_weight,
|
||||
param_data_start,
|
||||
weight_start,
|
||||
dim,
|
||||
shard_size,
|
||||
narrow_weight=True,
|
||||
):
|
||||
actual_shard_size = get_actual_shard_size(
|
||||
shard_size, weight_start, loaded_weight.size(dim)
|
||||
)
|
||||
|
||||
if narrow_weight:
|
||||
if actual_shard_size > 0:
|
||||
loaded_weight = loaded_weight.narrow(dim, weight_start, actual_shard_size)
|
||||
else:
|
||||
# No real data to load; create a dummy tensor filled with zeros
|
||||
loaded_weight = torch.zeros_like(
|
||||
param_data.narrow(dim, param_data_start, actual_shard_size)
|
||||
)
|
||||
|
||||
# [Note] Reset padded weights to zero.
|
||||
# If the actual shard size is less than the shard size, we need to reset
|
||||
# the padded param_data to zero and then copy the loaded_weight into it.
|
||||
reset_param_data_if_needed(
|
||||
param_data,
|
||||
dim,
|
||||
param_data_start + actual_shard_size,
|
||||
shard_size - actual_shard_size,
|
||||
)
|
||||
|
||||
param_data = param_data.narrow(dim, param_data_start, actual_shard_size)
|
||||
|
||||
return param_data, loaded_weight
|
||||
|
||||
@@ -16,7 +16,9 @@ from sglang.srt.managers.mm_utils import (
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
from sglang.srt.utils import add_prefix, is_cpu
|
||||
|
||||
_is_cpu = is_cpu()
|
||||
|
||||
|
||||
class Llama4ForConditionalGeneration(nn.Module):
|
||||
@@ -107,13 +109,17 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
|
||||
# rotary embeds should be sliced
|
||||
if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
|
||||
loaded_weight = permute(
|
||||
loaded_weight, self.language_model.config.num_key_value_heads
|
||||
)
|
||||
if _is_cpu:
|
||||
dim = self.language_model.config.original_total_num_kv_heads
|
||||
else:
|
||||
dim = self.language_model.config.num_key_value_heads
|
||||
loaded_weight = permute(loaded_weight, dim)
|
||||
elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
|
||||
loaded_weight = permute(
|
||||
loaded_weight, self.language_model.config.num_attention_heads
|
||||
)
|
||||
if _is_cpu:
|
||||
dim = self.language_model.config.original_num_attention_heads
|
||||
else:
|
||||
dim = self.language_model.config.num_attention_heads
|
||||
loaded_weight = permute(loaded_weight, dim)
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
|
||||
@@ -100,6 +100,7 @@ class Qwen2Attention(nn.Module):
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: Optional[int] = None,
|
||||
layer_id: int = 0,
|
||||
rope_theta: float = 1000000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
@@ -123,7 +124,10 @@ class Qwen2Attention(nn.Module):
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
if head_dim is not None:
|
||||
self.head_dim = head_dim
|
||||
else:
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
@@ -191,10 +195,12 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
|
||||
head_dim = getattr(config, "head_dim", None)
|
||||
self.self_attn = Qwen2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
layer_id=layer_id,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# ==============================================================================
|
||||
"""Common utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import builtins
|
||||
import ctypes
|
||||
|
||||
Reference in New Issue
Block a user