forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
@@ -0,0 +1 @@
|
||||
from .backends import mlu_attn
|
||||
@@ -0,0 +1,58 @@
|
||||
from typing import Optional, Type
|
||||
import torch
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.mlu_attn import MLUFlashAttentionMetadata
|
||||
from vllm_mlu.attention.backends.mlu_attn import MLUFlashAttentionImpl_V2
|
||||
|
||||
from .ring_attn import zigzag_ring_attn
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
|
||||
get_context_model_parallel_world_size)
|
||||
|
||||
|
||||
vllm__attention__backends__flash_attn__MLUFlashAttentionImpl__forward_org = MLUFlashAttentionImpl_V2.forward
|
||||
|
||||
def vllm__attention__backends__flash_attn__MLUFlashAttentionImpl__forward_wraper(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: MLUFlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
use_mla: bool = False,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: use ring attn when context parallel
|
||||
'''
|
||||
if get_context_model_parallel_world_size() > 1 and attn_metadata.prefill_metadata:
|
||||
return zigzag_ring_attn(self,
|
||||
query=query.view(-1, self.num_heads, self.head_size),
|
||||
key=key.view(-1, self.num_kv_heads, self.head_size),
|
||||
value=value.view(-1, self.num_kv_heads, self.head_size),
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
return vllm__attention__backends__flash_attn__MLUFlashAttentionImpl__forward_org(self,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
attn_type=attn_type)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MLUFlashAttentionImpl_V2,
|
||||
MLUFlashAttentionImpl_V2.forward,
|
||||
vllm__attention__backends__flash_attn__MLUFlashAttentionImpl__forward_wraper)
|
||||
@@ -0,0 +1,216 @@
|
||||
from typing import List, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.mlu_attn import MLUFlashAttentionMetadata
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import get_context_model_parallel_group
|
||||
from ...distributed.ring_comm import RingComm
|
||||
|
||||
|
||||
# code references: https://github.com/zhuzilin/ring-flash-attention
|
||||
def _update_out_and_lse(
|
||||
out: torch.Tensor,
|
||||
lse: torch.Tensor,
|
||||
block_out: torch.Tensor,
|
||||
block_lse: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
block_out = block_out.to(torch.float32)
|
||||
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
|
||||
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
|
||||
lse = lse - F.logsigmoid(lse - block_lse)
|
||||
return out, lse
|
||||
|
||||
|
||||
def update_out_and_lse(
|
||||
out: Optional[torch.Tensor],
|
||||
lse: Optional[torch.Tensor],
|
||||
block_out: torch.Tensor,
|
||||
block_lse: torch.Tensor,
|
||||
slice_=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if out is None:
|
||||
if slice_ is not None:
|
||||
raise RuntimeError("first update_out_and_lse should not pass slice_ args")
|
||||
out = block_out.to(torch.float32)
|
||||
lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
|
||||
elif slice_ is not None:
|
||||
slice_out, slice_lse = out[slice_], lse[slice_]
|
||||
slice_out, slice_lse = _update_out_and_lse(
|
||||
slice_out, slice_lse, block_out, block_lse
|
||||
)
|
||||
out[slice_], lse[slice_] = slice_out, slice_lse
|
||||
else:
|
||||
out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
|
||||
return out, lse
|
||||
|
||||
|
||||
def get_half(pack_tensor, cu_seq_lens, first_half):
|
||||
batch_num = cu_seq_lens.shape[0] - 1
|
||||
half_list = []
|
||||
for batch in range(batch_num):
|
||||
if first_half:
|
||||
start = cu_seq_lens[batch]
|
||||
end = (cu_seq_lens[batch] + cu_seq_lens[batch + 1]) // 2
|
||||
else:
|
||||
start = (cu_seq_lens[batch] + cu_seq_lens[batch + 1]) // 2
|
||||
end = cu_seq_lens[batch + 1]
|
||||
half = pack_tensor[start: end]
|
||||
half_list.append(half)
|
||||
half = torch.cat(half_list, dim=0)
|
||||
return half
|
||||
|
||||
|
||||
def update_half(pack_tensor, half_tensor, cu_seq_lens, first_half):
|
||||
half_cu_seq_lens = cu_seq_lens // 2
|
||||
batch_num = cu_seq_lens.shape[0] - 1
|
||||
for batch in range(batch_num):
|
||||
if first_half:
|
||||
start = cu_seq_lens[batch]
|
||||
end = (cu_seq_lens[batch] + cu_seq_lens[batch + 1]) // 2
|
||||
else:
|
||||
start = (cu_seq_lens[batch] + cu_seq_lens[batch + 1]) // 2
|
||||
end = cu_seq_lens[batch + 1]
|
||||
pack_tensor[start: end] = half_tensor[half_cu_seq_lens[batch]: half_cu_seq_lens[batch + 1]]
|
||||
|
||||
|
||||
def zigzag_ring_attn(self,
|
||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads. head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
kv_cache: List[torch.Tensor],
|
||||
attn_metadata: MLUFlashAttentionMetadata) -> torch.Tensor:
|
||||
num_tokens, _, _ = query.shape
|
||||
cu_seq_lens = attn_metadata.prefill_metadata.seq_start_loc
|
||||
batch_num = cu_seq_lens.shape[0] - 1
|
||||
block_seq_len = query.shape[0] // 2
|
||||
process_group = get_context_model_parallel_group().device_group
|
||||
comm = RingComm(process_group) # k
|
||||
comm_ = RingComm(process_group) # v
|
||||
comm__ = RingComm(process_group) # slot_mapping
|
||||
|
||||
q, k, v = query, key, value
|
||||
if batch_num == 1:
|
||||
q1 = q[block_seq_len:]
|
||||
else:
|
||||
q1 = get_half(q, cu_seq_lens, False)
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
out = None
|
||||
lse = None
|
||||
next_k, next_v = None, None
|
||||
next_slot_mapping = None
|
||||
|
||||
def forward(q, k, v, causal):
|
||||
if batch_num == 1:
|
||||
seq = q.shape[0]
|
||||
seq_k = k.shape[0]
|
||||
cu_seq_lens_q = torch.arange(0, seq+1, seq, dtype=torch.int32, device=q.device)
|
||||
cu_seq_lens_kv = torch.arange(0, seq_k+1, seq_k, dtype=torch.int32, device=q.device)
|
||||
max_seq_len_q = seq
|
||||
max_seq_len_kv = seq_k
|
||||
else:
|
||||
max_seq_len_q = attn_metadata.prefill_metadata.max_seq_len
|
||||
max_seq_len_kv = attn_metadata.prefill_metadata.max_seq_len
|
||||
cu_seq_lens_q = cu_seq_lens
|
||||
cu_seq_lens_kv = cu_seq_lens
|
||||
if q.shape[0] != cu_seq_lens[-1]:
|
||||
cu_seq_lens_q = cu_seq_lens // 2
|
||||
max_seq_len_q = max_seq_len_q // 2
|
||||
if k.shape[0] != cu_seq_lens[-1]:
|
||||
cu_seq_lens_kv = cu_seq_lens // 2
|
||||
max_seq_len_kv = max_seq_len_kv // 2
|
||||
alibi_slopes = None if self.alibi_slopes is None else \
|
||||
self.alibi_slopes.repeat(attn_metadata.num_prefills, 1)
|
||||
ouptuts = mlu_ops.flash_attention(q,
|
||||
k,
|
||||
v,
|
||||
None,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_kv,
|
||||
alibi_slopes,
|
||||
None,
|
||||
max_seq_len_q,
|
||||
max_seq_len_kv,
|
||||
self.scale,
|
||||
causal, -1, -1, torch.float, True)
|
||||
block_out, block_lse = ouptuts[0], ouptuts[1]
|
||||
|
||||
if block_lse.shape[0] == 1:
|
||||
block_lse = block_lse[0]
|
||||
else:
|
||||
# block_lse shape is [batch, head_num_q, max_seq_q], the empty part will set 0
|
||||
# we need to modify the shape to [batch, head_num_q, total_seq_q]
|
||||
block_lse_list = []
|
||||
for batch in range(block_lse.shape[0]):
|
||||
block_lse_ = block_lse[batch][:, : cu_seq_lens_q[batch + 1] - cu_seq_lens_q[batch]]
|
||||
block_lse_list.append(block_lse_)
|
||||
block_lse = torch.cat(block_lse_list, dim=-1)
|
||||
|
||||
return block_out, block_lse
|
||||
|
||||
for step in range(comm.world_size):
|
||||
if step + 1 != comm.world_size:
|
||||
next_k: torch.Tensor = comm.send_recv(k.contiguous())
|
||||
next_v: torch.Tensor = comm_.send_recv(v.contiguous())
|
||||
next_slot_mapping: torch.Tensor = comm__.send_recv(slot_mapping)
|
||||
comm.commit()
|
||||
comm_.commit()
|
||||
comm__.commit()
|
||||
|
||||
# call mlu_ops.reshape_paged_cache
|
||||
if kv_cache[0].numel() > 0:
|
||||
kv_cache_, kv_cache_scale_ = kv_cache
|
||||
key_cache, value_cache = kv_cache_[0], kv_cache_[1]
|
||||
if isinstance(kv_cache[0], torch.Tensor) and kv_cache[0].dtype == torch.int8:
|
||||
key_cache_scale, value_cache_scale = kv_cache_scale_[0], kv_cache_scale_[1]
|
||||
mlu_ops.quant_to_paged_cache(k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
key_cache_scale,
|
||||
value_cache_scale,
|
||||
slot_mapping.flatten())
|
||||
else:
|
||||
mlu_ops.reshape_paged_cache(k,
|
||||
v,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping.flatten())
|
||||
|
||||
if step == 0:
|
||||
block_out, block_lse = forward(q, k, v, causal = True)
|
||||
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
|
||||
elif step <= comm.rank:
|
||||
if batch_num == 1:
|
||||
k0 = k[:block_seq_len]
|
||||
v0 = v[:block_seq_len]
|
||||
else:
|
||||
k0 = get_half(k, cu_seq_lens, True)
|
||||
v0 = get_half(v, cu_seq_lens, True)
|
||||
block_out, block_lse = forward(q, k0, v0, causal = False)
|
||||
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
|
||||
else:
|
||||
block_out, block_lse = forward(q1, k, v, causal = False)
|
||||
if batch_num == 1:
|
||||
out, lse = update_out_and_lse(out, lse, block_out, block_lse,
|
||||
slice_=(slice(block_seq_len, None)),)
|
||||
else:
|
||||
slice_out = get_half(out, cu_seq_lens, False)
|
||||
slice_lse = get_half(lse, cu_seq_lens, False)
|
||||
slice_out, slice_lse = update_out_and_lse(
|
||||
slice_out, slice_lse, block_out, block_lse
|
||||
)
|
||||
update_half(out, slice_out, cu_seq_lens, False)
|
||||
update_half(lse, slice_lse, cu_seq_lens, False)
|
||||
|
||||
if step + 1 != comm.world_size:
|
||||
comm.wait()
|
||||
comm_.wait()
|
||||
comm__.wait()
|
||||
k = next_k
|
||||
v = next_v
|
||||
slot_mapping = next_slot_mapping
|
||||
out = out.to(q.dtype)
|
||||
return out.view(num_tokens, self.num_heads * self.head_size)
|
||||
@@ -0,0 +1 @@
|
||||
from . import ring_comm
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# code references: https://github.com/zhuzilin/ring-flash-attention
|
||||
class RingComm:
|
||||
def __init__(self, process_group: dist.ProcessGroup):
|
||||
self._process_group = process_group
|
||||
self._ops = []
|
||||
self.rank = dist.get_rank(self._process_group)
|
||||
self.world_size = dist.get_world_size(self._process_group)
|
||||
self._reqs = None
|
||||
|
||||
self.send_rank = (self.rank + 1) % self.world_size
|
||||
self.recv_rank = (self.rank - 1) % self.world_size
|
||||
|
||||
if process_group is not None:
|
||||
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
|
||||
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
|
||||
|
||||
def send_recv(
|
||||
self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
if recv_tensor is None:
|
||||
res = torch.empty_like(to_send)
|
||||
else:
|
||||
res = recv_tensor
|
||||
|
||||
send_op = dist.P2POp(
|
||||
dist.isend, to_send, self.send_rank, group=self._process_group
|
||||
)
|
||||
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
|
||||
self._ops.append(send_op)
|
||||
self._ops.append(recv_op)
|
||||
return res
|
||||
|
||||
def commit(self):
|
||||
if self._reqs is not None:
|
||||
raise RuntimeError("commit called twice")
|
||||
self._reqs = dist.batch_isend_irecv(self._ops)
|
||||
|
||||
def wait(self):
|
||||
if self._reqs is None:
|
||||
raise RuntimeError("wait called before commit")
|
||||
for req in self._reqs:
|
||||
req.wait()
|
||||
self._reqs = None
|
||||
self._ops = []
|
||||
@@ -0,0 +1,2 @@
|
||||
from . import gpu_executor
|
||||
from . import ray_mlu_executor
|
||||
@@ -0,0 +1,40 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
def vllm__executor__gpu_executor__GPUExecutor___get_worker_kwargs(
|
||||
self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Return worker init args for a given rank."""
|
||||
if distributed_init_method is None:
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: replace self.parallel_config.tensor_parallel_size with self.parallel_config.world_size.
|
||||
'''
|
||||
return dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.world_size == 0),
|
||||
)
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
GPUExecutor,
|
||||
GPUExecutor._get_worker_kwargs,
|
||||
vllm__executor__gpu_executor__GPUExecutor___get_worker_kwargs)
|
||||
@@ -0,0 +1,246 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
get_vllm_instance_id)
|
||||
from vllm_mlu._mlu_utils import VLLM_LATENCY_DEBUG, VLLM_LATENCY_DEBUG_NO_DEVICE
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.common import init_logger
|
||||
from vllm.executor.ray_mlu_executor import RayMLUExecutor
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
if ray is not None:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm__executor__ray_mlu_executor__RayMLUExecutor___init_workers_ray(
|
||||
self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
if (self.parallel_config.tensor_parallel_size == 1
|
||||
and self.parallel_config.pipeline_parallel_size == 1):
|
||||
# For single GPU case, we use a ray worker with constrained memory.
|
||||
num_gpus = self.cache_config.gpu_memory_utilization
|
||||
else:
|
||||
# Otherwise, the ray workers are allocated with a full GPU.
|
||||
num_gpus = 1
|
||||
|
||||
# The driver dummy worker does not actually use any resources.
|
||||
# It holds the resource for the driver worker.
|
||||
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
|
||||
# The remaining workers are the actual ray actors.
|
||||
self.workers: List[RayWorkerWrapper] = []
|
||||
|
||||
# Used in ray compiled DAG: indexed first by PP rank,
|
||||
# and then TP rank. In other words, the inner list is
|
||||
# the TP group of workers for a PP rank.
|
||||
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
|
||||
|
||||
if self.parallel_config.ray_workers_use_nsight:
|
||||
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
|
||||
ray_remote_kwargs)
|
||||
|
||||
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
worker_wrapper_kwargs = self._get_worker_wrapper_args()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
self.workers.append(worker)
|
||||
else:
|
||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
**worker_wrapper_kwargs)
|
||||
else:
|
||||
# Else, added to the list of workers.
|
||||
self.workers.append(worker)
|
||||
|
||||
logger.debug("workers: %s", self.workers)
|
||||
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
|
||||
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
|
||||
raise ValueError(
|
||||
"Ray does not allocate any GPUs on the driver node. Consider "
|
||||
"adjusting the Ray placement group or running the driver on a "
|
||||
"GPU node.")
|
||||
|
||||
worker_ips = [
|
||||
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
|
||||
for worker in self.workers
|
||||
]
|
||||
ip_counts: Dict[str, int] = {}
|
||||
for ip in worker_ips:
|
||||
ip_counts[ip] = ip_counts.get(ip, 0) + 1
|
||||
|
||||
def sort_by_driver_then_worker_ip(worker):
|
||||
"""
|
||||
Sort the workers based on 3 properties:
|
||||
1. If the worker is on the same node as the driver (vllm engine),
|
||||
it should be placed first.
|
||||
2. Then, if the worker is on a node with fewer workers, it should
|
||||
be placed first.
|
||||
3. Finally, if the work is on a node with smaller IP address, it
|
||||
should be placed first.
|
||||
"""
|
||||
ip = ray.get(worker.get_node_ip.remote())
|
||||
return (ip != driver_ip, ip_counts[ip], ip)
|
||||
|
||||
# After sorting, the workers on the same node will be
|
||||
# close to each other, and the workers on the driver
|
||||
# node will be placed first.
|
||||
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
|
||||
|
||||
# Get the set of GPU IDs used on each node.
|
||||
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
|
||||
use_dummy_driver=True)
|
||||
|
||||
node_workers = defaultdict(list) # node id -> list of worker ranks
|
||||
node_gpus = defaultdict(list) # node id -> list of gpu ids
|
||||
|
||||
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
|
||||
node_workers[node_id].append(i)
|
||||
# `gpu_ids` can be a list of strings or integers.
|
||||
# convert them to integers for consistency.
|
||||
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
|
||||
# string sorting is not sufficient.
|
||||
# see https://github.com/vllm-project/vllm/issues/5590
|
||||
gpu_ids = [int(x) for x in gpu_ids]
|
||||
node_gpus[node_id].extend(gpu_ids)
|
||||
for node_id, gpu_ids in node_gpus.items():
|
||||
node_gpus[node_id] = sorted(gpu_ids)
|
||||
|
||||
all_ips = set(worker_ips + [driver_ip])
|
||||
n_ips = len(all_ips)
|
||||
n_nodes = len(node_workers)
|
||||
|
||||
if n_nodes != n_ips:
|
||||
raise RuntimeError(
|
||||
f"Every node should have a unique IP address. Got {n_nodes}"
|
||||
f" nodes with node ids {list(node_workers.keys())} and "
|
||||
f"{n_ips} unique IP addresses {all_ips}. Please check your"
|
||||
" network configuration. If you set `VLLM_HOST_IP` or "
|
||||
"`HOST_IP` environment variable, make sure it is unique for"
|
||||
" each node.")
|
||||
|
||||
VLLM_INSTANCE_ID = get_vllm_instance_id()
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [({
|
||||
"MLU_VISIBLE_DEVICES":
|
||||
",".join(map(str, node_gpus[node_id])),
|
||||
"VLLM_INSTANCE_ID":
|
||||
VLLM_INSTANCE_ID,
|
||||
"VLLM_TRACE_FUNCTION":
|
||||
str(envs.VLLM_TRACE_FUNCTION),
|
||||
**({
|
||||
"VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND
|
||||
} if envs.VLLM_ATTENTION_BACKEND is not None else {}),
|
||||
"VLLM_LATENCY_DEBUG":
|
||||
'1' if VLLM_LATENCY_DEBUG else '0',
|
||||
"VLLM_LATENCY_DEBUG_NO_DEVICE":
|
||||
'1' if VLLM_LATENCY_DEBUG_NO_DEVICE else '0',
|
||||
}, ) for (node_id, _) in worker_node_and_gpu_ids]
|
||||
|
||||
self._env_vars_for_all_workers = (
|
||||
all_args_to_update_environment_variables)
|
||||
|
||||
self._run_workers("update_environment_variables",
|
||||
all_args=self._get_env_vars_to_be_updated())
|
||||
|
||||
if len(node_gpus) == 1:
|
||||
# in single node case, we don't need to get the IP address.
|
||||
# the loopback address is sufficient
|
||||
# NOTE: a node may have several IP addresses, one for each
|
||||
# network interface. `get_ip()` might return any of them,
|
||||
# while they might not work for communication inside the node
|
||||
# if the network setup is complicated. Using the loopback address
|
||||
# solves this issue, as it always works for communication inside
|
||||
# the node.
|
||||
driver_ip = "127.0.0.1"
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
driver_ip, get_open_port())
|
||||
|
||||
# Initialize the actual workers inside worker wrapper.
|
||||
init_worker_all_kwargs = [
|
||||
self._get_worker_kwargs(
|
||||
local_rank=node_workers[node_id].index(rank),
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
|
||||
]
|
||||
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
|
||||
|
||||
self._run_workers("init_device")
|
||||
self._run_workers("load_model",
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
||||
self.pp_tp_workers.append([])
|
||||
for tp_rank in range(
|
||||
self.parallel_config.tensor_parallel_size):
|
||||
# PP=2, TP=4
|
||||
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
|
||||
rank = (pp_rank * self.parallel_config.tensor_parallel_size
|
||||
) + tp_rank
|
||||
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
|
||||
assert pp_rank < len(self.pp_tp_workers)
|
||||
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|
||||
|
||||
# This is the list of workers that are rank 0 of each TP group EXCEPT
|
||||
# global rank 0. These are the workers that will broadcast to the
|
||||
# rest of the workers.
|
||||
self.tp_driver_workers: List[RayWorkerWrapper] = []
|
||||
# This is the list of workers that are not drivers and not the first
|
||||
# worker in a TP group. These are the workers that will be
|
||||
# broadcasted to.
|
||||
self.non_driver_workers: List[RayWorkerWrapper] = []
|
||||
|
||||
# Enforce rank order for correct rank to return final output.
|
||||
for index, worker in enumerate(self.workers):
|
||||
# The driver worker is rank 0 and not in self.workers.
|
||||
rank = index + 1
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: replace tp size with world_size.
|
||||
'''
|
||||
if rank % self.parallel_config.world_size == 0:
|
||||
self.tp_driver_workers.append(worker)
|
||||
else:
|
||||
self.non_driver_workers.append(worker)
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
|
||||
MluHijackObject.apply_hijack(RayMLUExecutor,
|
||||
RayMLUExecutor._init_workers_ray,
|
||||
vllm__executor__ray_mlu_executor__RayMLUExecutor___init_workers_ray)
|
||||
@@ -0,0 +1,6 @@
|
||||
print("Apply Context Parallel Demo!")
|
||||
from . import distributed
|
||||
from . import attention
|
||||
from . import model_executor
|
||||
from . import worker
|
||||
from . import executor
|
||||
@@ -0,0 +1,2 @@
|
||||
from .layers import rotary_embedding
|
||||
from .layers import logits_processor
|
||||
@@ -0,0 +1,110 @@
|
||||
from typing import Optional
|
||||
import torch
|
||||
import vllm
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.distributed import get_world_group
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor, _prune_hidden_states, _apply_logits_processors
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
|
||||
get_context_model_parallel_world_size, get_context_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
|
||||
|
||||
def vllm__module_executor__layers__logits_processor__LogitsProcessor__forward_wraper(
|
||||
self,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.logits_as_input:
|
||||
logits = hidden_states
|
||||
else:
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: context parallel requires special handling of hidden_states and logits
|
||||
'''
|
||||
if self.attn_metadata and get_context_model_parallel_world_size() > 1:
|
||||
hidden_states = _prune_hidden_states_context_parallel(hidden_states, sampling_metadata, self.attn_metadata)
|
||||
else:
|
||||
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
# Get the logits for the next tokens.
|
||||
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
|
||||
if logits is not None:
|
||||
if self.soft_cap is not None:
|
||||
logits = logits / self.soft_cap
|
||||
logits = torch.tanh(logits)
|
||||
logits = logits * self.soft_cap
|
||||
|
||||
if self.scale != 1.0:
|
||||
logits *= self.scale
|
||||
|
||||
# Apply logits processors (if any).
|
||||
if sampling_metadata is not None:
|
||||
logits = _apply_logits_processors(logits, sampling_metadata)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: token num can be divisible by context_parallel_size * 2 after padding,
|
||||
and then split to context parallel groups with zigzag method, now we
|
||||
need to find the last valid tokens, and get the logits for the next tokens.
|
||||
'''
|
||||
def _prune_hidden_states_context_parallel(
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
attn_metadata: AttentionMetadata
|
||||
) -> torch.Tensor:
|
||||
select_hidden_states_list = []
|
||||
seq_start_loc = attn_metadata.prefill_metadata.seq_start_loc
|
||||
batch_num = seq_start_loc.shape[0] - 1
|
||||
for batch in range(batch_num):
|
||||
start = seq_start_loc[batch]
|
||||
end = seq_start_loc[batch + 1]
|
||||
hidden_states_ = hidden_states[start : end]
|
||||
split_seq_len = hidden_states_.shape[0] // 2
|
||||
seq_len = attn_metadata.prefill_metadata.seq_lens[batch]
|
||||
last_id = seq_len - 1
|
||||
idx = last_id // split_seq_len
|
||||
select_hidden_states = torch.zeros((1, hidden_states.shape[-1]), dtype = hidden_states.dtype, device = hidden_states.device)
|
||||
if idx < get_context_model_parallel_world_size():
|
||||
target_cp_id = idx
|
||||
src_rank = get_tensor_model_parallel_world_size() * target_cp_id
|
||||
if get_context_model_parallel_rank() == target_cp_id:
|
||||
selected_token_indices = last_id - idx * split_seq_len
|
||||
select_hidden_states = hidden_states_[selected_token_indices].unsqueeze(0)
|
||||
else:
|
||||
target_cp_id = get_context_model_parallel_world_size() * 2 - 1 - idx
|
||||
src_rank = get_tensor_model_parallel_world_size() * target_cp_id
|
||||
if get_context_model_parallel_rank() == target_cp_id:
|
||||
selected_token_indices = last_id - idx * split_seq_len + split_seq_len
|
||||
select_hidden_states = hidden_states_[selected_token_indices].unsqueeze(0)
|
||||
|
||||
select_hidden_states = get_world_group().broadcast(select_hidden_states, src = src_rank)
|
||||
select_hidden_states_list.append(select_hidden_states)
|
||||
|
||||
select_hidden_states = torch.cat(select_hidden_states_list, dim=0)
|
||||
return select_hidden_states
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(LogitsProcessor,
|
||||
LogitsProcessor.forward,
|
||||
vllm__module_executor__layers__logits_processor__LogitsProcessor__forward_wraper)
|
||||
@@ -0,0 +1,62 @@
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
import vllm
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu.model_executor.layers.rotary_embedding import MLURotaryEmbedding
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
|
||||
get_context_model_parallel_world_size)
|
||||
|
||||
def vllm__module_executor__layers__rotary_embedding__MLURotaryEmbedding__forward_mlu_wraper(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm import _mlu_ops as mlu_ops
|
||||
|
||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||
# are in-place operations that update the query and key tensors.
|
||||
if offsets is not None:
|
||||
raise ValueError(f"tmo.apply_rotary not support offsets yet.")
|
||||
else:
|
||||
if MLURotaryEmbedding.set_cos_sin == False:
|
||||
MLURotaryEmbedding.cos_, MLURotaryEmbedding.sin_ = self._get_cos_sin()
|
||||
MLURotaryEmbedding.set_cos_sin = True
|
||||
interleaved = True
|
||||
if self.is_neox_style:
|
||||
interleaved = False
|
||||
if MLURotaryEmbedding.is_chunked or not MLURotaryEmbedding.is_prompt:
|
||||
position_ids = positions
|
||||
discrete = True
|
||||
else :
|
||||
position_ids = None
|
||||
discrete = False
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: context parallel need discrete = True
|
||||
'''
|
||||
position_ids = None if (MLURotaryEmbedding.is_prompt and get_context_model_parallel_world_size == 1) else positions
|
||||
discrete = False if (MLURotaryEmbedding.is_prompt and get_context_model_parallel_world_size == 1) else True
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
x = mlu_ops.rotary_embedding(x,
|
||||
MLURotaryEmbedding.sin_,
|
||||
MLURotaryEmbedding.cos_,
|
||||
position_ids,
|
||||
MLURotaryEmbedding.cu_seq_lens,
|
||||
interleaved,
|
||||
discrete,
|
||||
False,
|
||||
MLURotaryEmbedding.max_seq_len)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MLURotaryEmbedding,
|
||||
MLURotaryEmbedding.forward_mlu,
|
||||
vllm__module_executor__layers__rotary_embedding__MLURotaryEmbedding__forward_mlu_wraper)
|
||||
@@ -0,0 +1,5 @@
|
||||
from . import mlu_model_runner
|
||||
from . import model_runner
|
||||
from . import model_runner_base
|
||||
from . import worker
|
||||
from . import worker_base
|
||||
@@ -0,0 +1,256 @@
|
||||
import torch
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
|
||||
Tuple, Type, TypeVar, Union)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu._mlu_utils import *
|
||||
from vllm.worker.model_runner import (
|
||||
TModelInputForGPU, ModelInputForGPU,
|
||||
ModelInputForGPUWithSamplingMetadata,
|
||||
ModelInputForGPUBuilder, GPUModelRunnerBase,
|
||||
ModelRunner, CUDAGraphRunner,
|
||||
LORA_WARMUP_RANK, _get_graph_batch_size,
|
||||
_BATCH_SIZES_TO_CAPTURE, _NUM_WARMUP_ITERS
|
||||
)
|
||||
from vllm.worker.mlu_model_runner import MLUModelRunner
|
||||
from vllm.sequence import (IntermediateTensors, SequenceGroupMetadata)
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from ..zigzag_utils import get_context_model_parallel_world_size, zigzag_split
|
||||
import vllm.envs as envs
|
||||
|
||||
try:
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
||||
BatchPrefillWithPagedKVCacheWrapper = None
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
|
||||
@torch.inference_mode()
|
||||
def vllm__worker__mlu_model_runner__MLUModelRunner__execute_model(
|
||||
self,
|
||||
model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
self.attn_state.begin_forward(model_input)
|
||||
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
assert model_input.attn_metadata is not None
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
decode_meta = model_input.attn_metadata.decode_metadata
|
||||
# TODO(andoorve): We can remove this once all
|
||||
# virtual engines share the same kv cache.
|
||||
virtual_engine = model_input.virtual_engine
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[virtual_engine][
|
||||
graph_batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
seqlen_agnostic_kwargs = {
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_inner_state else {}
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_start = torch.mlu.Event(enable_timing=True)
|
||||
model_forward_end = torch.mlu.Event(enable_timing=True)
|
||||
model_forward_start.record()
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add mlu metrics
|
||||
'''
|
||||
# Add time markers for model_executable+compute_logits
|
||||
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
||||
use_cuda_graph = ((prefill_meta is None and decode_meta.use_cuda_graph)
|
||||
or use_context_mlugraph)
|
||||
# if use_cuda_graph, the start timestamp will be inserted inside MLUGraphRunner.forward()
|
||||
if not use_cuda_graph:
|
||||
start = torch.mlu.Event(enable_timing=True)
|
||||
start.record()
|
||||
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: context parallel split input for model with zigzag method
|
||||
'''
|
||||
if get_context_model_parallel_world_size() > 1 and model_input.attn_metadata.prefill_metadata:
|
||||
with set_forward_context(model_input.attn_metadata):
|
||||
zigzag_input_ids, zigzag_positions, zigzag_attn_metadata = zigzag_split(model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata, _PAD_SLOT_ID)
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=zigzag_input_ids,
|
||||
positions=zigzag_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=zigzag_attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**multi_modal_kwargs,
|
||||
**seqlen_agnostic_kwargs)
|
||||
else:
|
||||
with set_forward_context(model_input.attn_metadata):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device),
|
||||
**seqlen_agnostic_kwargs)
|
||||
|
||||
#################################################################################################
|
||||
# DEBUG #
|
||||
#################################################################################################
|
||||
# import os
|
||||
# from vllm.distributed import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
|
||||
# from from examples.cambricon_custom_funcvllm.mlu_hijack.distributed.parallel_state import (
|
||||
# get_context_model_parallel_rank)
|
||||
# from ..zigzag_utils import context_parallel_tensor_all_gather, diff1
|
||||
# if get_context_model_parallel_world_size() > 1 and attn_metadata.prefill_metadata:
|
||||
# hidden_states = context_parallel_tensor_all_gather(hidden_states, zigzag_attn_metadata, dim=0)
|
||||
# if attn_metadata.prefill_metadata and (kv_caches[0] is not None):
|
||||
# file_path = '/workspace/output_base_' + str(hidden_states.shape) + \
|
||||
# '_tp_' + str(get_tensor_model_parallel_world_size()) + '.pth'
|
||||
# if get_context_model_parallel_rank() == 0 and get_tensor_model_parallel_rank() == 0:
|
||||
# if os.path.exists(file_path):
|
||||
# print("##################compare################")
|
||||
# hidden_states_base = torch.load(file_path)
|
||||
# print("########output_diff1: ", diff1(hidden_states, hidden_states_base))
|
||||
# else:
|
||||
# print("##################save base################")
|
||||
# torch.save(hidden_states, file_path)
|
||||
|
||||
'''
|
||||
@brief: logits_processor in context parallel need attn_metadata param
|
||||
'''
|
||||
if get_context_model_parallel_world_size() > 1 and model_input.attn_metadata.prefill_metadata:
|
||||
setattr(self.model.logits_processor, 'attn_metadata', zigzag_attn_metadata)
|
||||
else:
|
||||
setattr(self.model.logits_processor, 'attn_metadata', None)
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.record()
|
||||
|
||||
# Compute the logits in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
if (self.is_driver_worker
|
||||
and hidden_or_intermediate_states is not None
|
||||
and isinstance(hidden_or_intermediate_states,
|
||||
IntermediateTensors)
|
||||
and self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
hidden_or_intermediate_states.tensors["model_forward_time"] = (
|
||||
torch.tensor(model_forward_time + orig_model_forward_time))
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Add time markers for model_executable+compute_logits
|
||||
if VLLM_LATENCY_DEBUG_WITH_DEVICE_EN:
|
||||
end_marker = torch.mlu.Event(enable_timing=True)
|
||||
end_marker.record()
|
||||
if use_cuda_graph:
|
||||
self.time_markers = (model_executable.start, end_marker)
|
||||
else:
|
||||
self.time_markers = (start, end_marker)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
output: SamplerOutput = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time
|
||||
and output is not None):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
# If there are multiple workers, we are still tracking the latency
|
||||
# from the start time of the driver worker to the end time of the
|
||||
# driver worker. The model forward time will then end up covering
|
||||
# the communication time as well.
|
||||
output.model_forward_time = (orig_model_forward_time +
|
||||
model_forward_time)
|
||||
|
||||
|
||||
if self.return_hidden_states:
|
||||
# we only need to pass hidden states of most recent token
|
||||
assert model_input.sampling_metadata is not None
|
||||
indices = model_input.sampling_metadata.selected_token_indices
|
||||
if model_input.is_prompt:
|
||||
hidden_states = hidden_or_intermediate_states.index_select(
|
||||
0, indices)
|
||||
elif decode_meta.use_cuda_graph:
|
||||
hidden_states = hidden_or_intermediate_states[:len(indices)]
|
||||
else:
|
||||
hidden_states = hidden_or_intermediate_states
|
||||
|
||||
output.hidden_states = hidden_states
|
||||
|
||||
return [output]
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(MLUModelRunner,
|
||||
MLUModelRunner.execute_model,
|
||||
vllm__worker__mlu_model_runner__MLUModelRunner__execute_model)
|
||||
@@ -0,0 +1,35 @@
|
||||
from typing import (Any, Dict, Optional)
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
from examples.cambricon_custom_func.context_parallel.mlu_hijack.worker.model_runner_base import vllm__worker__model_runner_base___init_sampling_metadata_from_tensor_dict
|
||||
from vllm.worker.model_runner_base import _init_attn_metadata_from_tensor_dict
|
||||
|
||||
@classmethod
|
||||
def vllm__worker__model_runner__ModelInputForGPUWithSamplingMetadata__from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "ModelInputForGPUWithSamplingMetadata":
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: force apply hijacked function.
|
||||
'''
|
||||
tensor_dict = vllm__worker__model_runner_base___init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
ModelInputForGPUWithSamplingMetadata,
|
||||
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict,
|
||||
vllm__worker__model_runner__ModelInputForGPUWithSamplingMetadata__from_broadcasted_tensor_dict
|
||||
)
|
||||
@@ -0,0 +1,74 @@
|
||||
from typing import (Any, Dict)
|
||||
|
||||
from vllm.model_executor.sampling_metadata import SequenceGroupToSample
|
||||
from vllm.worker import model_runner_base
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
def vllm__worker__model_runner_base___init_sampling_metadata_from_tensor_dict( # type: ignore
|
||||
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper method to initialize SamplingMetadata based on broadcastable
|
||||
SamplingMetadata fields.
|
||||
"""
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
|
||||
selected_token_indices = tensor_dict.pop("selected_token_indices", None)
|
||||
if selected_token_indices is not None:
|
||||
if 'seq_group_metadata' in tensor_dict.keys() and len(tensor_dict['seq_group_metadata']) > 0:
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: construct sampling metadata.
|
||||
'''
|
||||
sequence_group_to_sample_list = []
|
||||
for seq_group_metadata in tensor_dict['seq_group_metadata']:
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
seq_data = seq_group_metadata.seq_data
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
if is_prompt:
|
||||
seq_len = query_len = list(seq_data.values())[0].get_prompt_len()
|
||||
else:
|
||||
seq_len = None
|
||||
query_len = 1
|
||||
prompt_logprob_indices = []
|
||||
sample_indices = seq_ids
|
||||
sequence_group_to_sample = SequenceGroupToSample(seq_ids,
|
||||
sampling_params,
|
||||
seq_data,
|
||||
seq_len,
|
||||
query_len,
|
||||
None, # Generator
|
||||
is_prompt,
|
||||
prompt_logprob_indices,
|
||||
sample_indices)
|
||||
sequence_group_to_sample_list.append(sequence_group_to_sample)
|
||||
tensor_dict["sampling_metadata"] = SamplingMetadata(
|
||||
seq_groups=sequence_group_to_sample_list,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=None,
|
||||
num_prompts=len(sequence_group_to_sample_list),
|
||||
)
|
||||
del tensor_dict['seq_group_metadata']
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
else:
|
||||
# An empty SamplingMetadata to signal that the worker should skip
|
||||
# sampling.
|
||||
tensor_dict["sampling_metadata"] = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=None,
|
||||
num_prompts=0,
|
||||
)
|
||||
return tensor_dict
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
model_runner_base,
|
||||
model_runner_base._init_sampling_metadata_from_tensor_dict,
|
||||
vllm__worker__model_runner_base___init_sampling_metadata_from_tensor_dict
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
@property
|
||||
def vllm__worker__worker__Worker__do_metadata_broadcast(self) -> bool:
|
||||
'''
|
||||
=============================
|
||||
Modify by Context Parallel
|
||||
=============================
|
||||
@brief: do metadata broadcast if cp or tp > 1.
|
||||
'''
|
||||
return self.parallel_config.world_size > 1
|
||||
'''
|
||||
==========================
|
||||
End of Context Parallel
|
||||
==========================
|
||||
'''
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
Worker,
|
||||
Worker.do_metadata_broadcast,
|
||||
vllm__worker__worker__Worker__do_metadata_broadcast)
|
||||
@@ -0,0 +1,121 @@
|
||||
import dataclasses
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import ObservabilityConfig, VllmConfig
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
||||
ModelRunnerInputBase)
|
||||
from vllm.worker.worker_base import (extract_previous_hidden_states,
|
||||
LocalOrDistributedWorkerBase,
|
||||
WorkerInput)
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def broadcast_tensor_dict(
|
||||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
||||
src: int = 0
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
return tensor_dict
|
||||
return get_world_group().broadcast_tensor_dict(tensor_dict, src)
|
||||
|
||||
def vllm__worker__worker_base__LocalOrDistributedWorkerBase___get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
|
||||
""" Get the driver input and broadcast it to other workers. """
|
||||
assert self.is_driver_worker
|
||||
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
model_input: ModelRunnerInputBase = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
|
||||
kwargs = extract_previous_hidden_states(execute_model_req)
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_data.update(kwargs)
|
||||
'''
|
||||
==========================
|
||||
Modify by Context Parallel
|
||||
==========================
|
||||
@brief: add seq_group metadata to broadcast.
|
||||
'''
|
||||
broadcast_data['seq_group_metadata'] = execute_model_req.seq_group_metadata_list
|
||||
'''
|
||||
=======================
|
||||
End of Context Parallel
|
||||
=======================
|
||||
'''
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
if execute_model_req.async_callback:
|
||||
model_input = dataclasses.replace( # type: ignore
|
||||
model_input,
|
||||
async_callback=execute_model_req.async_callback)
|
||||
|
||||
return model_input, worker_input, kwargs
|
||||
|
||||
def vllm__worker__worker_base__LocalOrDistributedWorkerBase___get_worker_input_from_broadcast(
|
||||
self
|
||||
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
|
||||
str, torch.Tensor]]]:
|
||||
""" Get the worker input from the broadcasted tensor dict. """
|
||||
assert self.do_metadata_broadcast
|
||||
assert not self.is_driver_worker
|
||||
broadcast_data = broadcast_tensor_dict(src=0)
|
||||
if not broadcast_data:
|
||||
return None
|
||||
|
||||
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
|
||||
model_input = (
|
||||
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
|
||||
broadcast_data))
|
||||
|
||||
kwargs = extract_previous_hidden_states(broadcast_data)
|
||||
|
||||
return model_input, worker_input, kwargs
|
||||
|
||||
|
||||
def vllm__worker__worker_base__LocalOrDistributedWorkerBase__prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]]:
|
||||
"""
|
||||
Prepare the inputs to ModelRunner and workers.
|
||||
"""
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
# This signals that there's no more requests to process for
|
||||
# now. All workers are running infinite loop with
|
||||
# broadcast_tensor_dict, and it stops the loop when the
|
||||
# driver broadcasts an empty input. Send an empty input to
|
||||
# notify all other workers to stop their execution loop.
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return None
|
||||
return self._get_driver_input_and_broadcast(execute_model_req)
|
||||
else:
|
||||
return self._get_worker_input_from_broadcast()
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
LocalOrDistributedWorkerBase,
|
||||
LocalOrDistributedWorkerBase._get_driver_input_and_broadcast,
|
||||
vllm__worker__worker_base__LocalOrDistributedWorkerBase___get_driver_input_and_broadcast)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
LocalOrDistributedWorkerBase,
|
||||
LocalOrDistributedWorkerBase._get_worker_input_from_broadcast,
|
||||
vllm__worker__worker_base__LocalOrDistributedWorkerBase___get_worker_input_from_broadcast)
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
LocalOrDistributedWorkerBase,
|
||||
LocalOrDistributedWorkerBase.prepare_input,
|
||||
vllm__worker__worker_base__LocalOrDistributedWorkerBase__prepare_input)
|
||||
@@ -0,0 +1,149 @@
|
||||
from typing import Dict, Optional, Sequence, List
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from examples.cambricon_custom_func.vllm.mlu_hijack.distributed.parallel_state import (
|
||||
get_context_model_parallel_rank, get_context_model_parallel_world_size, get_context_model_parallel_group)
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.attention import AttentionMetadata
|
||||
import copy
|
||||
|
||||
|
||||
def diff1(result: torch.Tensor, baseline: torch.Tensor):
|
||||
result = result.flatten().float().to('cpu')
|
||||
baseline = baseline.flatten().float().to('cpu')
|
||||
assert result.shape == baseline.shape
|
||||
error = torch.abs(baseline - result)
|
||||
denominator = torch.sum(torch.abs(baseline)).item()
|
||||
eps = 0.0 if denominator > 0 else 1e-9
|
||||
diff1 = torch.sum(error) / (denominator + eps)
|
||||
return diff1.item()
|
||||
|
||||
|
||||
def get_pad_seq(seq_len: int, pad: int):
|
||||
return (seq_len // pad + (int)((seq_len) % (pad) > 0)) * pad
|
||||
|
||||
|
||||
# Gather the partial results of a batch on context parallel groups
|
||||
# together and place them in the order before zigzag splitting
|
||||
def context_parallel_tensor_all_gather_(input_, dim=-1):
|
||||
world_size = get_context_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
|
||||
assert input_size[dim] % 2 == 0, (f"input tensor split dim % 2 != 0")
|
||||
|
||||
gather_list = [torch.empty(input_.shape, dtype=input_.dtype, device=input_.device) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(
|
||||
gather_list, input_, group=get_context_model_parallel_group())
|
||||
|
||||
first = []
|
||||
second = []
|
||||
for i in range(world_size):
|
||||
first_second = torch.split(gather_list[i], gather_list[i].shape[dim] // 2, dim=dim)
|
||||
first.append(first_second[0])
|
||||
second.insert(0, first_second[1])
|
||||
tensor_list = first + second
|
||||
output_tensor = torch.cat(tensor_list, dim = dim).contiguous()
|
||||
return output_tensor
|
||||
|
||||
|
||||
# Gather the partial results of each batch on the context parallel groups together,
|
||||
# place them in the order before zigzag splitting, and remove the pad part.
|
||||
# This function is used for debugging
|
||||
def context_parallel_tensor_all_gather(input, attn_metadata, dim=-1):
|
||||
if dim < 0:
|
||||
dim += input.dim()
|
||||
slice_ = ()
|
||||
for i in range(dim):
|
||||
slice_ + (slice(None))
|
||||
select_list = []
|
||||
seq_start_loc = attn_metadata.prefill_metadata.seq_start_loc
|
||||
batch_num = seq_start_loc.shape[0] - 1
|
||||
for batch in range(batch_num):
|
||||
start = seq_start_loc[batch].item()
|
||||
end = seq_start_loc[batch + 1].item()
|
||||
slice1 = slice_ + (slice(start, end), )
|
||||
input_ = input[slice1]
|
||||
gather_ = context_parallel_tensor_all_gather_(input_, dim=dim)
|
||||
slice2 = slice_ + (slice(None, attn_metadata.prefill_metadata.seq_lens[batch]), )
|
||||
select = gather_[slice2]
|
||||
select_list.append(select)
|
||||
output = torch.cat(select_list, dim=dim)
|
||||
return output
|
||||
|
||||
|
||||
# Pad one dimension of a tensor so that it is divisible by context_parallel_size * 2,
|
||||
# and then use zigzag method to split it into different context parallel groups
|
||||
def zigzag_split_(tensor: torch.Tensor, dim = -1, pad_value=0):
|
||||
if dim < 0:
|
||||
dim = tensor.dim() + dim
|
||||
split_num = get_context_model_parallel_world_size() * 2
|
||||
pad_num = get_pad_seq(tensor.shape[dim], split_num) - tensor.shape[dim]
|
||||
pad_param = (0, 0) * (tensor.dim() - dim - 1) + (0, pad_num) + (0, 0) * dim
|
||||
tensor_pad = F.pad(tensor, pad_param, value = pad_value)
|
||||
split_size = divide(tensor_pad.size()[dim], split_num)
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor_pad, split_size, dim = dim)
|
||||
first = tensor_list[get_context_model_parallel_rank()]
|
||||
second = tensor_list[split_num - get_context_model_parallel_rank() - 1]
|
||||
output_tensor = torch.cat((first, second), dim=dim).contiguous()
|
||||
return output_tensor
|
||||
|
||||
|
||||
# Split each batch of input_ids, positions, attn_metadata.slot_mapping with zigzag method,
|
||||
# and update prefill_metadata.seq_start_loc and prefill_metadata.max_seq_len
|
||||
def zigzag_split(input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
pad_slot_id: int):
|
||||
zigzag_input_ids: List[int] = []
|
||||
zigzag_positions: List[int] = []
|
||||
zigzag_slot_mapping: List[int] = []
|
||||
zigzag_attn_metadata = copy.deepcopy(attn_metadata)
|
||||
seq_lens: List[int] = []
|
||||
seq_start_loc = attn_metadata.prefill_metadata.seq_start_loc
|
||||
batch_num = seq_start_loc.shape[0] - 1
|
||||
for batch in range(batch_num):
|
||||
start, end = seq_start_loc[batch], seq_start_loc[batch + 1]
|
||||
input_ids_ = input_ids[start : end]
|
||||
positions_ = positions[start : end]
|
||||
zigzag_input_ids_ = zigzag_split_(input_ids_)
|
||||
zigzag_positions_ = zigzag_split_(positions_)
|
||||
zigzag_input_ids.append(zigzag_input_ids_)
|
||||
zigzag_positions.append(zigzag_positions_)
|
||||
seq_lens.append(zigzag_input_ids_.shape[0])
|
||||
slot_mapping_ = attn_metadata.slot_mapping[start : end]
|
||||
zigzag_slot_mapping_ = zigzag_split_(slot_mapping_, pad_value=pad_slot_id)
|
||||
zigzag_slot_mapping.append(zigzag_slot_mapping_)
|
||||
|
||||
zigzag_input_ids = torch.cat(zigzag_input_ids, dim=0)
|
||||
zigzag_positions = torch.cat(zigzag_positions, dim=0)
|
||||
zigzag_slot_mapping = torch.cat(zigzag_slot_mapping, dim=0)
|
||||
|
||||
max_seq_len = max(seq_lens)
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=input_ids.device)
|
||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=input_ids.device)
|
||||
torch.cumsum(seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
|
||||
zigzag_attn_metadata.prefill_metadata.seq_start_loc = seq_start_loc
|
||||
zigzag_attn_metadata.prefill_metadata.query_start_loc = seq_start_loc
|
||||
zigzag_attn_metadata.prefill_metadata.max_seq_len = max_seq_len
|
||||
zigzag_attn_metadata.slot_mapping = zigzag_slot_mapping
|
||||
|
||||
return zigzag_input_ids, zigzag_positions, zigzag_attn_metadata
|
||||
Reference in New Issue
Block a user