feat: Support DP Attention for step3_vl (#8699)
This commit is contained in:
@@ -11,6 +11,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
||||
from sglang.srt.utils import is_cuda, print_info_once
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -365,19 +366,20 @@ class VisionAttention(nn.Module):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
self.tp_size = world_size
|
||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
attn_tp_rank = get_attention_tp_rank()
|
||||
attn_tp_size = get_attention_tp_size()
|
||||
self.tp_size = attn_tp_size
|
||||
self.tp_rank = attn_tp_rank
|
||||
self.dropout = dropout
|
||||
self.head_size = embed_dim // num_heads
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads
|
||||
)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_dummy_heads + num_heads, world_size
|
||||
num_dummy_heads + num_heads, self.tp_size
|
||||
)
|
||||
self.num_attention_kv_heads_per_partition = dist_utils.divide(
|
||||
num_dummy_heads + num_heads, world_size
|
||||
num_dummy_heads + num_heads, self.tp_size
|
||||
)
|
||||
|
||||
self.q_size = self.num_attention_heads_per_partition * self.head_size
|
||||
@@ -427,6 +429,8 @@ class VisionAttention(nn.Module):
|
||||
total_num_kv_heads=num_dummy_heads + num_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
else:
|
||||
@@ -435,6 +439,8 @@ class VisionAttention(nn.Module):
|
||||
output_size=3 * self.dummy_dim,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
self.proj = RowParallelLinear(
|
||||
@@ -442,6 +448,8 @@ class VisionAttention(nn.Module):
|
||||
output_size=embed_dim,
|
||||
bias=proj_bias,
|
||||
quant_config=quant_config,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
prefix=add_prefix("proj", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -531,11 +531,18 @@ class Step3VisionMLP(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Since this is a dense model,
|
||||
# the MLP component likewise adopts a DP-MLP approach modeled after DP Attention.
|
||||
# This choice may not represent the optimal solution and remains open to further deliberation.
|
||||
attn_tp_rank = get_attention_tp_rank()
|
||||
attn_tp_size = get_attention_tp_size()
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
dim,
|
||||
intermediate_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
prefix=add_prefix("gate_proj", prefix),
|
||||
)
|
||||
self.act = ACT2FN[hidden_act] # quick_gelu
|
||||
@@ -544,6 +551,8 @@ class Step3VisionMLP(nn.Module):
|
||||
dim,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from transformers import BatchFeature, TensorType
|
||||
from transformers import BatchFeature, ProcessorMixin, TensorType
|
||||
|
||||
from sglang.srt.models.step3_vl import Step3VLForConditionalGeneration
|
||||
from sglang.srt.multimodal.processors.base_processor import (
|
||||
@@ -276,6 +276,8 @@ class Step3VLProcessor:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
if isinstance(tokenizer, ProcessorMixin):
|
||||
tokenizer = tokenizer.tokenizer
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.image_size = 728
|
||||
|
||||
Reference in New Issue
Block a user