feat: Support DP Attention for step3_vl (#8699)
This commit is contained in:
@@ -11,6 +11,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
||||||
from sglang.srt.utils import is_cuda, print_info_once
|
from sglang.srt.utils import is_cuda, print_info_once
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
@@ -365,19 +366,20 @@ class VisionAttention(nn.Module):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
attn_tp_rank = get_attention_tp_rank()
|
||||||
self.tp_size = world_size
|
attn_tp_size = get_attention_tp_size()
|
||||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
self.tp_size = attn_tp_size
|
||||||
|
self.tp_rank = attn_tp_rank
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.head_size = embed_dim // num_heads
|
self.head_size = embed_dim // num_heads
|
||||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||||
projection_size, num_heads
|
projection_size, num_heads
|
||||||
)
|
)
|
||||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||||
num_dummy_heads + num_heads, world_size
|
num_dummy_heads + num_heads, self.tp_size
|
||||||
)
|
)
|
||||||
self.num_attention_kv_heads_per_partition = dist_utils.divide(
|
self.num_attention_kv_heads_per_partition = dist_utils.divide(
|
||||||
num_dummy_heads + num_heads, world_size
|
num_dummy_heads + num_heads, self.tp_size
|
||||||
)
|
)
|
||||||
|
|
||||||
self.q_size = self.num_attention_heads_per_partition * self.head_size
|
self.q_size = self.num_attention_heads_per_partition * self.head_size
|
||||||
@@ -427,6 +429,8 @@ class VisionAttention(nn.Module):
|
|||||||
total_num_kv_heads=num_dummy_heads + num_heads,
|
total_num_kv_heads=num_dummy_heads + num_heads,
|
||||||
bias=qkv_bias,
|
bias=qkv_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
tp_size=self.tp_size,
|
||||||
prefix=add_prefix("qkv_proj", prefix),
|
prefix=add_prefix("qkv_proj", prefix),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -435,6 +439,8 @@ class VisionAttention(nn.Module):
|
|||||||
output_size=3 * self.dummy_dim,
|
output_size=3 * self.dummy_dim,
|
||||||
bias=qkv_bias,
|
bias=qkv_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
tp_size=self.tp_size,
|
||||||
prefix=add_prefix("qkv_proj", prefix),
|
prefix=add_prefix("qkv_proj", prefix),
|
||||||
)
|
)
|
||||||
self.proj = RowParallelLinear(
|
self.proj = RowParallelLinear(
|
||||||
@@ -442,6 +448,8 @@ class VisionAttention(nn.Module):
|
|||||||
output_size=embed_dim,
|
output_size=embed_dim,
|
||||||
bias=proj_bias,
|
bias=proj_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
tp_rank=self.tp_rank,
|
||||||
|
tp_size=self.tp_size,
|
||||||
prefix=add_prefix("proj", prefix),
|
prefix=add_prefix("proj", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -531,11 +531,18 @@ class Step3VisionMLP(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# Since this is a dense model,
|
||||||
|
# the MLP component likewise adopts a DP-MLP approach modeled after DP Attention.
|
||||||
|
# This choice may not represent the optimal solution and remains open to further deliberation.
|
||||||
|
attn_tp_rank = get_attention_tp_rank()
|
||||||
|
attn_tp_size = get_attention_tp_size()
|
||||||
self.fc1 = ColumnParallelLinear(
|
self.fc1 = ColumnParallelLinear(
|
||||||
dim,
|
dim,
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
tp_rank=attn_tp_rank,
|
||||||
|
tp_size=attn_tp_size,
|
||||||
prefix=add_prefix("gate_proj", prefix),
|
prefix=add_prefix("gate_proj", prefix),
|
||||||
)
|
)
|
||||||
self.act = ACT2FN[hidden_act] # quick_gelu
|
self.act = ACT2FN[hidden_act] # quick_gelu
|
||||||
@@ -544,6 +551,8 @@ class Step3VisionMLP(nn.Module):
|
|||||||
dim,
|
dim,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
tp_rank=attn_tp_rank,
|
||||||
|
tp_size=attn_tp_size,
|
||||||
prefix=add_prefix("down_proj", prefix),
|
prefix=add_prefix("down_proj", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms import InterpolationMode
|
from torchvision.transforms import InterpolationMode
|
||||||
from transformers import BatchFeature, TensorType
|
from transformers import BatchFeature, ProcessorMixin, TensorType
|
||||||
|
|
||||||
from sglang.srt.models.step3_vl import Step3VLForConditionalGeneration
|
from sglang.srt.models.step3_vl import Step3VLForConditionalGeneration
|
||||||
from sglang.srt.multimodal.processors.base_processor import (
|
from sglang.srt.multimodal.processors.base_processor import (
|
||||||
@@ -276,6 +276,8 @@ class Step3VLProcessor:
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
if isinstance(tokenizer, ProcessorMixin):
|
||||||
|
tokenizer = tokenizer.tokenizer
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
self.image_size = 728
|
self.image_size = 728
|
||||||
|
|||||||
Reference in New Issue
Block a user