mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
fix: revert to _npu_reshape_and_cache (contiguous) and _npu_flash_attention
This commit is contained in:
@@ -322,21 +322,24 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via
|
Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via
|
||||||
slot_mapping indices.
|
slot_mapping indices.
|
||||||
"""
|
"""
|
||||||
|
import torch_npu # noqa: F401
|
||||||
|
|
||||||
if kv_cache.numel() > 0:
|
if kv_cache.numel() > 0:
|
||||||
if self._key_cache is None:
|
if self._key_cache is None:
|
||||||
self._key_cache, self._value_cache = kv_cache[0], kv_cache[1]
|
self._key_cache, self._value_cache = kv_cache[0], kv_cache[1]
|
||||||
|
|
||||||
slots = attn_metadata.slot_mapping
|
# Ensure contiguous tensors for the NPU op
|
||||||
key_to_cache = key[:attn_metadata.num_actual_tokens]
|
key = key.contiguous()
|
||||||
val_to_cache = value[:attn_metadata.num_actual_tokens]
|
value = value.contiguous()
|
||||||
|
slots = attn_metadata.slot_mapping.long() # indices must be long
|
||||||
|
|
||||||
# Use pure-PyTorch indexing (ATB reshape_and_cache may fail
|
torch_npu._npu_reshape_and_cache(
|
||||||
# depending on environment; this is functionally identical)
|
key,
|
||||||
block_size = self._key_cache.shape[1]
|
value,
|
||||||
block_idx = slots // block_size
|
self._key_cache,
|
||||||
block_offset = slots % block_size
|
self._value_cache,
|
||||||
self._key_cache[block_idx, block_offset] = key_to_cache
|
slots,
|
||||||
self._value_cache[block_idx, block_offset] = val_to_cache
|
)
|
||||||
|
|
||||||
return key, value
|
return key, value
|
||||||
|
|
||||||
@@ -450,28 +453,33 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Prefill attention without KV cache (self-attention)."""
|
"""Prefill attention without KV cache (self-attention) using _npu_flash_attention."""
|
||||||
import torch_npu # noqa: F401
|
import torch_npu # noqa: F401
|
||||||
|
|
||||||
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
# Huawei uses _npu_flash_attention for prefill
|
||||||
|
# Ensure contiguous inputs
|
||||||
|
query = query.contiguous()
|
||||||
|
key = key.contiguous()
|
||||||
|
value = value.contiguous()
|
||||||
|
|
||||||
attn_out = torch_npu.npu_fusion_attention(
|
# mask needs to be contiguous and cast to expected format if needed
|
||||||
query[:num_tokens],
|
# but _npu_flash_attention handles generic mask?
|
||||||
key[:num_tokens],
|
# Huawei code: mask = attn_metadata.attn_mask...
|
||||||
value[:num_tokens],
|
# We'll pass it as is, assuming AscendMetadataBuilder created it correctly.
|
||||||
head_num=self.num_heads,
|
|
||||||
input_layout="TND",
|
torch_npu._npu_flash_attention(
|
||||||
scale=self.scale,
|
query=query,
|
||||||
sparse_mode=0,
|
key=key,
|
||||||
atten_mask=attn_metadata.attn_mask,
|
value=value,
|
||||||
pre_tockens=2147483647,
|
mask=attn_metadata.attn_mask,
|
||||||
next_tockens=0,
|
seq_len=attn_metadata.seq_lens,
|
||||||
actual_seq_qlen=cum_seq_len,
|
scale_value=self.scale,
|
||||||
actual_seq_kvlen=cum_seq_len,
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
out=output
|
||||||
)
|
)
|
||||||
|
|
||||||
output[:num_tokens] = attn_out[0]
|
return output[:num_tokens]
|
||||||
return output
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------
|
# -----------------------------------------------------------------
|
||||||
# Chunked prefill — mixed prefill+decode via npu_fusion_attention
|
# Chunked prefill — mixed prefill+decode via npu_fusion_attention
|
||||||
|
|||||||
Reference in New Issue
Block a user