[attn] fix device of tensors in attention (#25)
### What this PR does / why we need it?
Fix device of tensors created in `AscendAttentionBackendImpl`.
While specifying device to cards except card-0, there'll cause an
**device conflict** because the tensors (such as `attn_mask`) will be
put on card-0 by default.
This pr creates these tensors on the correct card corresponding to the
input.
### Does this PR introduce _any_ user-facing change?
User could specify device with local rank by this pr, and a modify on
vLLM is also needed, will related to this pr when created.
### How was this patch tested?
This is tested by the following code locally. Will add a test case when
the modify in vLLM is also completed.
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
# Create an LLM.
llm = LLM(model="~/.cache/modelscope/hub/Qwen/Qwen2___5-7B-Instruct", device="npu:1")
# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -29,11 +29,10 @@ prompts = [
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
|
||||
# Create an LLM.
|
||||
# TODO (cmq): ray is not supported currently, need some fixes
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
tensor_parallel_size=2,
|
||||
distributed_executor_backend="mp",
|
||||
distributed_executor_backend="ray",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -457,9 +457,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = sliding_window
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes,
|
||||
dtype=torch.float32,
|
||||
device="npu")
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.attn_type = attn_type
|
||||
|
||||
@@ -520,7 +518,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
attn_metadata.sparse_mode = 2
|
||||
attention_mask = gen_input_mask(
|
||||
attn_metadata.max_prefill_seq_len, self.sliding_window,
|
||||
num_tokens)
|
||||
num_tokens, query.device)
|
||||
attn_metadata.attn_mask = attention_mask
|
||||
|
||||
if (self.alibi_slopes is not None
|
||||
@@ -531,6 +529,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
dtype=query.dtype,
|
||||
seq_len=attn_metadata.max_prefill_seq_len,
|
||||
batch_size=num_tokens,
|
||||
device=query.device,
|
||||
)
|
||||
|
||||
if (len(kv_cache) == 0 or attn_metadata.block_tables is None
|
||||
@@ -571,7 +570,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
query = query.view(query.shape[0], -1,
|
||||
self.num_heads * self.head_size)
|
||||
output = torch.zeros(query.shape,
|
||||
device="npu",
|
||||
device=query.device,
|
||||
dtype=query.dtype)
|
||||
# TODO (Mengqing Cao): torch_npu.npu_incre_flash_attention
|
||||
# support only when `S == 1`, OPTIMIZE ME when prefix caching
|
||||
@@ -621,7 +620,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
return output
|
||||
|
||||
|
||||
def gen_input_mask(seq_len, sliding_window, len):
|
||||
def gen_input_mask(seq_len, sliding_window, len, device):
|
||||
"""
|
||||
Generating lower triangular matrix
|
||||
"""
|
||||
@@ -630,7 +629,7 @@ def gen_input_mask(seq_len, sliding_window, len):
|
||||
global SHARE_MASK_TRIL_PREFIX_CACHE
|
||||
if SHARE_MASK_TRIL_PREFIX_CACHE is None:
|
||||
SHARE_MASK_TRIL_PREFIX_CACHE = torch.triu(
|
||||
torch.ones(1, 1, 2048, 2048, dtype=bool, device="npu"),
|
||||
torch.ones(1, 1, 2048, 2048, dtype=bool, device=device),
|
||||
diagonal=1,
|
||||
)
|
||||
attention_mask = SHARE_MASK_TRIL_PREFIX_CACHE
|
||||
@@ -638,7 +637,7 @@ def gen_input_mask(seq_len, sliding_window, len):
|
||||
global SHARE_MASK_TRIL
|
||||
if SHARE_MASK_TRIL is None or SHARE_MASK_TRIL.shape[0] < seq_len:
|
||||
SHARE_MASK_TRIL = ~torch.tril(
|
||||
torch.ones(seq_len, seq_len, dtype=bool, device="npu"))
|
||||
torch.ones(seq_len, seq_len, dtype=bool, device=device))
|
||||
|
||||
attention_mask = SHARE_MASK_TRIL
|
||||
if sliding_window is not None:
|
||||
@@ -656,8 +655,10 @@ def _make_alibi_bias(
|
||||
dtype: torch.dtype,
|
||||
seq_len: int,
|
||||
batch_size: int,
|
||||
device: torch.device,
|
||||
):
|
||||
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
|
||||
alibi_slopes = alibi_slopes.to(device)
|
||||
bias = torch.arange(seq_len, dtype=dtype, device=device)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
@@ -674,7 +675,7 @@ def _make_alibi_bias(
|
||||
num_heads,
|
||||
seq_len,
|
||||
padded_len,
|
||||
device=alibi_slopes.device,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :seq_len].copy_(bias)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
|
||||
Reference in New Issue
Block a user