[CPU] Fix TP padding issue on Phi-4 (#8289)
This commit is contained in:
@@ -49,14 +49,25 @@ def get_num_heads_padding_size(tp_size, weight_block_size):
|
||||
|
||||
|
||||
def update_intermediate_size(model_config, attr_name, intermediate_padding_size):
|
||||
if hasattr(model_config.hf_config, attr_name):
|
||||
attr_value = intermediate_padding_size
|
||||
if hasattr(model_config, "hf_config") and hasattr(
|
||||
model_config.hf_config, attr_name
|
||||
):
|
||||
attr_value = getattr(model_config.hf_config, attr_name)
|
||||
if attr_value % intermediate_padding_size != 0:
|
||||
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
elif hasattr(model_config, attr_name):
|
||||
attr_value = getattr(model_config, attr_name)
|
||||
|
||||
attr_value = pad_vocab_size(attr_value, intermediate_padding_size)
|
||||
if attr_value % intermediate_padding_size != 0:
|
||||
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
|
||||
attr_value = pad_vocab_size(attr_value, intermediate_padding_size)
|
||||
if hasattr(model_config, "hf_config"):
|
||||
setattr(model_config.hf_config, attr_name, attr_value)
|
||||
setattr(model_config.hf_text_config, attr_name, attr_value)
|
||||
if hasattr(model_config, "hf_text_config"):
|
||||
setattr(model_config.hf_text_config, attr_name, attr_value)
|
||||
else:
|
||||
setattr(model_config, attr_name, attr_value)
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
@@ -118,4 +129,28 @@ def adjust_config_with_unaligned_cpu_tp(
|
||||
model_config = update_intermediate_size(
|
||||
model_config, "intermediate_size_mlp", intermediate_padding_size
|
||||
)
|
||||
if (
|
||||
hasattr(model_config.hf_config, "vision_config")
|
||||
and model_config.hf_config.vision_config.model_type == "siglip_vision_model"
|
||||
):
|
||||
model_config.hf_config.vision_config.original_num_attention_heads = (
|
||||
model_config.num_attention_heads
|
||||
)
|
||||
if model_config.hf_config.vision_config.num_attention_heads % tp_size != 0:
|
||||
model_config.hf_config.vision_config.head_dim = (
|
||||
model_config.hf_config.vision_config.hidden_size
|
||||
// model_config.hf_config.vision_config.num_attention_heads
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
|
||||
pad_size = get_num_heads_padding_size(tp_size, weight_block_size)
|
||||
model_config.hf_config.vision_config.num_attention_heads = pad_vocab_size(
|
||||
model_config.hf_config.vision_config.num_attention_heads, pad_size
|
||||
)
|
||||
model_config.hf_config.vision_config = update_intermediate_size(
|
||||
model_config.hf_config.vision_config,
|
||||
"intermediate_size",
|
||||
intermediate_padding_size,
|
||||
)
|
||||
|
||||
return model_config
|
||||
|
||||
@@ -129,6 +129,25 @@ def get_config(
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||
)
|
||||
if (
|
||||
config.architectures is not None
|
||||
and config.architectures[0] == "Phi4MMForCausalLM"
|
||||
):
|
||||
# Phi4MMForCausalLM uses a hard-coded vision_config. See:
|
||||
# https://github.com/vllm-project/vllm/blob/6071e989df1531b59ef35568f83f7351afb0b51e/vllm/model_executor/models/phi4mm.py#L71
|
||||
# We set it here to support cases where num_attention_heads is not divisible by the TP size.
|
||||
from transformers import SiglipVisionConfig
|
||||
|
||||
vision_config = {
|
||||
"hidden_size": 1152,
|
||||
"image_size": 448,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 26, # Model is originally 27-layer, we only need the first 26 layers for feature extraction.
|
||||
"patch_size": 14,
|
||||
}
|
||||
config.vision_config = SiglipVisionConfig(**vision_config)
|
||||
text_config = get_hf_text_config(config=config)
|
||||
|
||||
if isinstance(model, str) and text_config is not None:
|
||||
|
||||
@@ -110,6 +110,20 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
||||
return param[shard_id], loaded_weight
|
||||
|
||||
|
||||
def adjust_shard_offsets(shard_offsets, loaded_weight, dim):
|
||||
actual_weight_size = loaded_weight.size(dim)
|
||||
target_weight_size = shard_offsets[-1][-1] + shard_offsets[-1][-2]
|
||||
if actual_weight_size != target_weight_size:
|
||||
new_shard_offsets = []
|
||||
new_offset = 0
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
actual_shard_size = actual_weight_size * shard_size // target_weight_size
|
||||
new_shard_offsets.append((shard_id, new_offset, actual_shard_size))
|
||||
new_offset += actual_shard_size
|
||||
return new_shard_offsets
|
||||
return shard_offsets
|
||||
|
||||
|
||||
class LinearBase(torch.nn.Module):
|
||||
"""Base linear layer.
|
||||
|
||||
@@ -535,6 +549,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
if _is_cpu:
|
||||
shard_offsets = adjust_shard_offsets(
|
||||
shard_offsets, loaded_weight, output_dim
|
||||
)
|
||||
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
@@ -977,6 +996,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
if _is_cpu:
|
||||
shard_offsets = adjust_shard_offsets(
|
||||
shard_offsets, loaded_weight, output_dim
|
||||
)
|
||||
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
|
||||
@@ -116,9 +116,15 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
) -> torch.Tensor:
|
||||
|
||||
if use_intel_amx_backend(layer):
|
||||
return torch.ops.sgl_kernel.weight_packed_linear(
|
||||
x_shapes = x.shape
|
||||
if len(x_shapes) == 3:
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = torch.ops.sgl_kernel.weight_packed_linear(
|
||||
x, layer.weight, bias, True # is_vnni
|
||||
)
|
||||
if len(x_shapes) == 3:
|
||||
output = output.view(x_shapes[0], x_shapes[1], -1)
|
||||
return output
|
||||
|
||||
return F.linear(x, layer.weight, bias)
|
||||
|
||||
|
||||
@@ -54,25 +54,6 @@ VISION_ENCODER_TO_PROCESSING_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
def get_navit_vision_model():
|
||||
vision_config = {
|
||||
"hidden_size": 1152,
|
||||
"image_size": 448,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 26, # Model is originally 27-layer, we only need the first 26 layers for feature extraction.
|
||||
"patch_size": 14,
|
||||
}
|
||||
model_config = SiglipVisionConfig(**vision_config)
|
||||
|
||||
vision_model = Idefics2VisionTransformer(
|
||||
config=model_config, require_post_norm=False
|
||||
)
|
||||
|
||||
return vision_model
|
||||
|
||||
|
||||
class Phi4MMImageEncoder(nn.Module):
|
||||
"""Image embedding."""
|
||||
|
||||
@@ -88,8 +69,9 @@ class Phi4MMImageEncoder(nn.Module):
|
||||
# n_embed or hidden_size
|
||||
hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
|
||||
self.type_feature = "patch"
|
||||
|
||||
self.img_processor = get_navit_vision_model()
|
||||
self.img_processor = Idefics2VisionTransformer(
|
||||
config=config.vision_config, require_post_norm=False
|
||||
)
|
||||
|
||||
pe_weight = self.img_processor.embeddings.position_embedding.weight
|
||||
L, D = pe_weight.size()
|
||||
|
||||
Reference in New Issue
Block a user