mirror of
https://github.com/handsomezhuzhu/vllm-npu-plugin.git
synced 2026-02-20 19:50:15 +00:00
fix: revert to _npu_reshape_and_cache (contiguous) and _npu_flash_attention
This commit is contained in:
@@ -322,21 +322,24 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
Matches Huawei vllm-ascend: splits kv_cache[0]/[1] and writes via
|
||||
slot_mapping indices.
|
||||
"""
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
if self._key_cache is None:
|
||||
self._key_cache, self._value_cache = kv_cache[0], kv_cache[1]
|
||||
|
||||
slots = attn_metadata.slot_mapping
|
||||
key_to_cache = key[:attn_metadata.num_actual_tokens]
|
||||
val_to_cache = value[:attn_metadata.num_actual_tokens]
|
||||
# Ensure contiguous tensors for the NPU op
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
slots = attn_metadata.slot_mapping.long() # indices must be long
|
||||
|
||||
# Use pure-PyTorch indexing (ATB reshape_and_cache may fail
|
||||
# depending on environment; this is functionally identical)
|
||||
block_size = self._key_cache.shape[1]
|
||||
block_idx = slots // block_size
|
||||
block_offset = slots % block_size
|
||||
self._key_cache[block_idx, block_offset] = key_to_cache
|
||||
self._value_cache[block_idx, block_offset] = val_to_cache
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
self._key_cache,
|
||||
self._value_cache,
|
||||
slots,
|
||||
)
|
||||
|
||||
return key, value
|
||||
|
||||
@@ -450,28 +453,33 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
output: torch.Tensor,
|
||||
num_tokens: int,
|
||||
) -> torch.Tensor:
|
||||
"""Prefill attention without KV cache (self-attention)."""
|
||||
"""Prefill attention without KV cache (self-attention) using _npu_flash_attention."""
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
||||
# Huawei uses _npu_flash_attention for prefill
|
||||
# Ensure contiguous inputs
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
attn_out = torch_npu.npu_fusion_attention(
|
||||
query[:num_tokens],
|
||||
key[:num_tokens],
|
||||
value[:num_tokens],
|
||||
head_num=self.num_heads,
|
||||
input_layout="TND",
|
||||
scale=self.scale,
|
||||
sparse_mode=0,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
pre_tockens=2147483647,
|
||||
next_tockens=0,
|
||||
actual_seq_qlen=cum_seq_len,
|
||||
actual_seq_kvlen=cum_seq_len,
|
||||
# mask needs to be contiguous and cast to expected format if needed
|
||||
# but _npu_flash_attention handles generic mask?
|
||||
# Huawei code: mask = attn_metadata.attn_mask...
|
||||
# We'll pass it as is, assuming AscendMetadataBuilder created it correctly.
|
||||
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=attn_metadata.attn_mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=self.scale,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=output
|
||||
)
|
||||
|
||||
output[:num_tokens] = attn_out[0]
|
||||
return output
|
||||
return output[:num_tokens]
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Chunked prefill — mixed prefill+decode via npu_fusion_attention
|
||||
|
||||
Reference in New Issue
Block a user