[attn] fix device of tensors in attention (#25)

### What this PR does / why we need it?
Fix device of tensors created in `AscendAttentionBackendImpl`.

While specifying device to cards except card-0, there'll cause an
**device conflict** because the tensors (such as `attn_mask`) will be
put on card-0 by default.

This pr creates these tensors on the correct card corresponding to the
input.

### Does this PR introduce _any_ user-facing change?
User could specify device with local rank by this pr, and a modify on
vLLM is also needed, will related to this pr when created.

### How was this patch tested?
This is tested by the following code locally. Will add a test case when
the modify in vLLM is also completed.
```python
from vllm import LLM, SamplingParams

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
# Create an LLM.
llm = LLM(model="~/.cache/modelscope/hub/Qwen/Qwen2___5-7B-Instruct", device="npu:1")

# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-02-10 19:20:29 +08:00
committed by GitHub
parent c59375caff
commit 7006835977
2 changed files with 12 additions and 12 deletions

View File

@@ -29,11 +29,10 @@ prompts = [
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
# Create an LLM.
# TODO (cmq): ray is not supported currently, need some fixes
llm = LLM(
model="facebook/opt-125m",
tensor_parallel_size=2,
distributed_executor_backend="mp",
distributed_executor_backend="ray",
trust_remote_code=True,
)