Add Tensor Parallel to torch_native_llama (#1876)

This commit is contained in:
Ke Wen
2024-11-15 21:26:00 -08:00
committed by GitHub
parent e5c6715003
commit cf2489762b
5 changed files with 246 additions and 82 deletions

View File

@@ -17,6 +17,31 @@ limitations under the License.
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights."""
# PyTorch Tensor Parallel Available for This Model
"""
This model supports tensor parallelism (TP) using the PyTorch tensor parallel package.
Reference: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
Here is a quick example to enable TP:
```python
from sglang.srt.model_parallel import tensor_parallel
device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
tensor_parallel(model, device_mesh)
```
An end-to-end example can be found in `python/sglang/bench_latency.py`.
You can run it with the following command:
```bash
$ python3 -m sglang.bench_latency --correct \
--model meta-llama/Meta-Llama-3-8B \
--json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}' \
--tensor-parallel-size 2 \
--disable-cuda-graph
```
We will eanble CUDA Graph support soon.
"""
import types
from typing import Any, Dict, Iterable, Optional, Tuple
@@ -24,7 +49,10 @@ import torch
from torch import nn
from torch.nn.parameter import Parameter
from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -41,35 +69,45 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
def gate_up_proj_weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None,
loaded_shard_id: int,
):
if loaded_shard_id is None:
shard_offsets: List[Tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
assert loaded_shard_id < len(self.output_sizes)
param_data = param.data
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
# shard_id: (shard_offset, shard_size)
gate_up_offsets = {}
current_shard_offset = 0
for i, output_size in enumerate(self.output_sizes):
# Everything shrinks by tp_size if TP enabled
output_size = output_size // tp_size
gate_up_offsets[i] = (current_shard_offset, output_size)
current_shard_offset += output_size
# Re-size the param to the size after TP
if current_shard_offset != param.shape[0]:
# The clone will free the original, full tensor
param.data = param.data.narrow(0, 0, current_shard_offset).clone()
# Now load gate or up
assert loaded_shard_id < len(self.output_sizes)
param_data = param.data
shard_offset, shard_size = gate_up_offsets[loaded_shard_id]
param_data = param_data.narrow(0, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class LlamaMLP(nn.Module):
_tp_plan = {
"gate_up_proj": "Colwise_Sharded",
"down_proj": "Rowwise",
}
def __init__(
self,
hidden_size: int,
@@ -104,62 +142,44 @@ class LlamaMLP(nn.Module):
return x
def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
"q": 0,
"k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
}
return shard_offset_mapping.get(loaded_shard_id)
def _get_shard_size_mapping(self, loaded_shard_id: str):
shard_size_mapping = {
"q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.head_size,
}
return shard_size_mapping.get(loaded_shard_id)
def qkv_proj_weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None,
loaded_shard_id: str,
):
if loaded_shard_id is None:
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
(
"k",
self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size,
),
(
"v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size,
),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)
param_data = param.data
param_data = param_data.narrow(0, shard_offset, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
num_heads = self.num_heads // tp_size
num_kv_heads = self.num_kv_heads // tp_size
# shard_id: (shard_offset, shard_size)
qkv_offsets = {
"q": (0, num_heads * self.head_size),
"k": (num_heads * self.head_size, num_kv_heads * self.head_size),
"v": (
(num_heads + num_kv_heads) * self.head_size,
num_kv_heads * self.head_size,
),
}
total_size = qkv_offsets["v"][0] + qkv_offsets["v"][1]
# Re-size the param to the size after TP
if total_size != param.shape[0]:
# The clone will free the original, full tensor
param.data = param.data.narrow(0, 0, total_size).clone()
# Now load q, k or v
shard_offset, shard_size = qkv_offsets[loaded_shard_id]
param_data = param.data
param_data = param_data.narrow(0, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class LlamaAttention(nn.Module):
_tp_plan = {
"qkv_proj": "Colwise_Sharded",
"o_proj": "Rowwise",
}
def __init__(
self,
config: LlamaConfig,
@@ -176,7 +196,6 @@ class LlamaAttention(nn.Module):
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
@@ -205,20 +224,12 @@ class LlamaAttention(nn.Module):
(self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
bias=False,
)
self.qkv_proj.total_num_heads = self.total_num_heads
self.qkv_proj.head_size = self.head_dim
self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads
self.qkv_proj.num_heads = self.total_num_heads
self.qkv_proj.num_kv_heads = self.total_num_kv_heads
self.qkv_proj.weight_loader = types.MethodType(
qkv_proj_weight_loader, self.qkv_proj
)
self.qkv_proj._get_shard_offset_mapping = types.MethodType(
_get_shard_offset_mapping, self.qkv_proj
)
self.qkv_proj._get_shard_size_mapping = types.MethodType(
_get_shard_size_mapping, self.qkv_proj
)
self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader
self.qkv_proj.weight.output_dim = 0
self.o_proj = torch.nn.Linear(
@@ -385,6 +396,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.supports_torch_tp = True
self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)