Remove extra copy in deepseek forward absorb (#5578)
Co-authored-by: saienduri <saimanas.enduri@amd.com>
This commit is contained in:
14
.github/workflows/pr-test-amd.yml
vendored
14
.github/workflows/pr-test-amd.yml
vendored
@@ -38,12 +38,12 @@ jobs:
|
|||||||
else
|
else
|
||||||
DEVICE_FLAG="--device /dev/dri"
|
DEVICE_FLAG="--device /dev/dri"
|
||||||
fi
|
fi
|
||||||
docker pull lmsysorg/sglang:v0.4.5-rocm630
|
docker pull lmsysorg/sglang:v0.4.5.post2-rocm630
|
||||||
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
|
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
|
||||||
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
|
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
|
||||||
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
|
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
|
||||||
-w /sglang-checkout --name ci_sglang \
|
-w /sglang-checkout --name ci_sglang \
|
||||||
lmsysorg/sglang:v0.4.5-rocm630
|
lmsysorg/sglang:v0.4.5.post2-rocm630
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -82,12 +82,12 @@ jobs:
|
|||||||
else
|
else
|
||||||
DEVICE_FLAG="--device /dev/dri"
|
DEVICE_FLAG="--device /dev/dri"
|
||||||
fi
|
fi
|
||||||
docker pull lmsysorg/sglang:v0.4.5-rocm630
|
docker pull lmsysorg/sglang:v0.4.5.post2-rocm630
|
||||||
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
|
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
|
||||||
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
|
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
|
||||||
--cap-add=SYS_PTRACE -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} --security-opt seccomp=unconfined \
|
--cap-add=SYS_PTRACE -e HF_TOKEN=${{ secrets.AMD_HF_TOKEN }} --security-opt seccomp=unconfined \
|
||||||
-w /sglang-checkout --name ci_sglang \
|
-w /sglang-checkout --name ci_sglang \
|
||||||
lmsysorg/sglang:v0.4.5-rocm630
|
lmsysorg/sglang:v0.4.5.post2-rocm630
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -120,12 +120,12 @@ jobs:
|
|||||||
else
|
else
|
||||||
DEVICE_FLAG="--device /dev/dri"
|
DEVICE_FLAG="--device /dev/dri"
|
||||||
fi
|
fi
|
||||||
docker pull lmsysorg/sglang:v0.4.5-rocm630
|
docker pull lmsysorg/sglang:v0.4.5.post2-rocm630
|
||||||
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
|
docker run -dt --user root --device=/dev/kfd $DEVICE_FLAG \
|
||||||
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
|
-v ${{ github.workspace }}:/sglang-checkout --ipc=host --group-add video \
|
||||||
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
|
--cap-add=SYS_PTRACE -e HF_TOKEN=${HF_TOKEN} --security-opt seccomp=unconfined \
|
||||||
-w /sglang-checkout --name ci_sglang \
|
-w /sglang-checkout --name ci_sglang \
|
||||||
lmsysorg/sglang:v0.4.5-rocm630
|
lmsysorg/sglang:v0.4.5.post2-rocm630
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -149,7 +149,7 @@ jobs:
|
|||||||
finish:
|
finish:
|
||||||
if: always()
|
if: always()
|
||||||
needs: [
|
needs: [
|
||||||
accuracy-test-1-gpu-amd, mla-test-1-gpu-amd
|
accuracy-test-1-gpu-amd, mla-test-1-gpu-amd, bench-test-2-gpu-amd
|
||||||
]
|
]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
@@ -665,6 +665,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""PyTorch-native implementation equivalent to forward()."""
|
"""PyTorch-native implementation equivalent to forward()."""
|
||||||
|
dtype = query.dtype
|
||||||
query_rot = query[..., : self.rotary_dim]
|
query_rot = query[..., : self.rotary_dim]
|
||||||
key_rot = key[..., : self.rotary_dim]
|
key_rot = key[..., : self.rotary_dim]
|
||||||
if self.rotary_dim < self.head_size:
|
if self.rotary_dim < self.head_size:
|
||||||
@@ -695,7 +696,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
else:
|
else:
|
||||||
query = query_rot
|
query = query_rot
|
||||||
key = key_rot
|
key = key_rot
|
||||||
return query, key
|
return query.to(dtype), key.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
class Llama3RotaryEmbedding(RotaryEmbedding):
|
class Llama3RotaryEmbedding(RotaryEmbedding):
|
||||||
|
|||||||
@@ -682,10 +682,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
zero_allocator: BumpAllocator,
|
zero_allocator: BumpAllocator,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
q_len = hidden_states.shape[0]
|
|
||||||
q_input = hidden_states.new_empty(
|
|
||||||
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
|
|
||||||
)
|
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
q = self.q_a_proj(hidden_states)[0]
|
q = self.q_a_proj(hidden_states)[0]
|
||||||
q = self.q_a_layernorm(q)
|
q = self.q_a_layernorm(q)
|
||||||
@@ -729,20 +725,20 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
||||||
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
|
||||||
|
q_nope_out = q_nope_out.transpose(0, 1)
|
||||||
|
|
||||||
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||||
v_input = latent_cache[..., : self.kv_lora_rank]
|
k_nope = latent_cache[..., : self.kv_lora_rank]
|
||||||
v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
|
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
|
||||||
k_input = latent_cache.unsqueeze(1)
|
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
||||||
k_input[..., : self.kv_lora_rank] = v_input
|
|
||||||
k_pe = k_input[..., self.kv_lora_rank :]
|
|
||||||
|
|
||||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||||
q_input[..., self.kv_lora_rank :] = q_pe
|
|
||||||
k_input[..., self.kv_lora_rank :] = k_pe
|
|
||||||
|
|
||||||
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
||||||
|
k = torch.cat([k_nope, k_pe], dim=-1)
|
||||||
|
|
||||||
|
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
||||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||||
|
|
||||||
if self.use_deep_gemm_bmm:
|
if self.use_deep_gemm_bmm:
|
||||||
|
|||||||
Reference in New Issue
Block a user