From 70068359770b6e8cfcbb9931aa79be50731a274c Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Mon, 10 Feb 2025 19:20:29 +0800 Subject: [PATCH] [attn] fix device of tensors in attention (#25) ### What this PR does / why we need it? Fix device of tensors created in `AscendAttentionBackendImpl`. While specifying device to cards except card-0, there'll cause an **device conflict** because the tensors (such as `attn_mask`) will be put on card-0 by default. This pr creates these tensors on the correct card corresponding to the input. ### Does this PR introduce _any_ user-facing change? User could specify device with local rank by this pr, and a modify on vLLM is also needed, will related to this pr when created. ### How was this patch tested? This is tested by the following code locally. Will add a test case when the modify in vLLM is also completed. ```python from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.0) # Create an LLM. llm = LLM(model="~/.cache/modelscope/hub/Qwen/Qwen2___5-7B-Instruct", device="npu:1") # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` Signed-off-by: MengqingCao --- examples/offline_distributed_inference_npu.py | 3 +-- vllm_ascend/attention.py | 21 ++++++++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/offline_distributed_inference_npu.py b/examples/offline_distributed_inference_npu.py index f8d5489..8e503ad 100644 --- a/examples/offline_distributed_inference_npu.py +++ b/examples/offline_distributed_inference_npu.py @@ -29,11 +29,10 @@ prompts = [ # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.0) # Create an LLM. -# TODO (cmq): ray is not supported currently, need some fixes llm = LLM( model="facebook/opt-125m", tensor_parallel_size=2, - distributed_executor_backend="mp", + distributed_executor_backend="ray", trust_remote_code=True, ) diff --git a/vllm_ascend/attention.py b/vllm_ascend/attention.py index 0a014e6..2f9b5e7 100644 --- a/vllm_ascend/attention.py +++ b/vllm_ascend/attention.py @@ -457,9 +457,7 @@ class AscendAttentionBackendImpl(AttentionImpl): self.kv_cache_dtype = kv_cache_dtype self.sliding_window = sliding_window if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, - dtype=torch.float32, - device="npu") + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes self.attn_type = attn_type @@ -520,7 +518,7 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_metadata.sparse_mode = 2 attention_mask = gen_input_mask( attn_metadata.max_prefill_seq_len, self.sliding_window, - num_tokens) + num_tokens, query.device) attn_metadata.attn_mask = attention_mask if (self.alibi_slopes is not None @@ -531,6 +529,7 @@ class AscendAttentionBackendImpl(AttentionImpl): dtype=query.dtype, seq_len=attn_metadata.max_prefill_seq_len, batch_size=num_tokens, + device=query.device, ) if (len(kv_cache) == 0 or attn_metadata.block_tables is None @@ -571,7 +570,7 @@ class AscendAttentionBackendImpl(AttentionImpl): query = query.view(query.shape[0], -1, self.num_heads * self.head_size) output = torch.zeros(query.shape, - device="npu", + device=query.device, dtype=query.dtype) # TODO (Mengqing Cao): torch_npu.npu_incre_flash_attention # support only when `S == 1`, OPTIMIZE ME when prefix caching @@ -621,7 +620,7 @@ class AscendAttentionBackendImpl(AttentionImpl): return output -def gen_input_mask(seq_len, sliding_window, len): +def gen_input_mask(seq_len, sliding_window, len, device): """ Generating lower triangular matrix """ @@ -630,7 +629,7 @@ def gen_input_mask(seq_len, sliding_window, len): global SHARE_MASK_TRIL_PREFIX_CACHE if SHARE_MASK_TRIL_PREFIX_CACHE is None: SHARE_MASK_TRIL_PREFIX_CACHE = torch.triu( - torch.ones(1, 1, 2048, 2048, dtype=bool, device="npu"), + torch.ones(1, 1, 2048, 2048, dtype=bool, device=device), diagonal=1, ) attention_mask = SHARE_MASK_TRIL_PREFIX_CACHE @@ -638,7 +637,7 @@ def gen_input_mask(seq_len, sliding_window, len): global SHARE_MASK_TRIL if SHARE_MASK_TRIL is None or SHARE_MASK_TRIL.shape[0] < seq_len: SHARE_MASK_TRIL = ~torch.tril( - torch.ones(seq_len, seq_len, dtype=bool, device="npu")) + torch.ones(seq_len, seq_len, dtype=bool, device=device)) attention_mask = SHARE_MASK_TRIL if sliding_window is not None: @@ -656,8 +655,10 @@ def _make_alibi_bias( dtype: torch.dtype, seq_len: int, batch_size: int, + device: torch.device, ): - bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) + alibi_slopes = alibi_slopes.to(device) + bias = torch.arange(seq_len, dtype=dtype, device=device) # NOTE(zhuohan): HF uses # `bias = bias[None, :].repeat(seq_len, 1)` # here. We find that both biases give the same results, but @@ -674,7 +675,7 @@ def _make_alibi_bias( num_heads, seq_len, padded_len, - device=alibi_slopes.device, + device=device, dtype=dtype, )[:, :, :, :seq_len].copy_(bias) bias.mul_(alibi_slopes[:, None, None])