This commit is contained in:
2026-02-10 23:08:39 +08:00
parent 1baa36026c
commit 6680585975
172 changed files with 52867 additions and 892 deletions

View File

@@ -0,0 +1,539 @@
# adapted from vllm/model_executor/layers/mamba/ops/casual_conv1d.py
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
# mypy: ignore-errors
from typing import Optional, Union
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
PAD_SLOT_ID = -1
def causal_conv1d_ref(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
x = x.to(weight.dtype)
seqlen = x.shape[-1]
dim, width = weight.shape
if initial_states is None:
out = F.conv1d(x,
weight.unsqueeze(1),
bias,
padding=width - 1,
groups=dim)
else:
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
):
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
sequences are concatenated from left to right for varlen
weight: (dim, width)
bias: (dim,)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
conv_states: (...,dim,width - 1) itype
updated inplace if provided
activation: either None or "silu" or "swish"
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
out_ref = []
out_ref_b = []
seqlens = query_start_loc[1:] - query_start_loc[:-1]
seqlens = seqlens.tolist()
splits = torch.split(x, seqlens, dim=-1)
for i in range(len(seqlens)):
x_s = splits[i]
if cache_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight,
bias,
activation=activation,
return_final_states=True,
final_states_out=conv_states[cache_indices[i]].unsqueeze(0),
initial_states=conv_states[cache_indices[i]]
if has_initial_state[i] else None))
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
out_ref_tensor = torch.cat(out_ref, dim=0)
return out_ref_tensor
@triton.jit()
def _causal_conv1d_update_kernel(
# Pointers to matrices
x_ptr, # (batch, dim, seqlen)
w_ptr, # (dim, width)
bias_ptr,
conv_state_ptr,
cache_seqlens_ptr, # circular buffer
conv_state_indices_ptr,
num_accepted_tokens_ptr,
intermediate_conv_window_ptr,
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
batch: int,
dim: tl.constexpr,
seqlen: tl.constexpr,
state_len: tl.constexpr,
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
# Strides
stride_x_seq: tl.constexpr,
stride_x_dim: tl.constexpr,
stride_x_token: tl.constexpr,
stride_w_dim: tl.constexpr,
stride_w_width: tl.constexpr,
stride_conv_state_seq: tl.constexpr,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok: tl.constexpr,
stride_state_indices: tl.constexpr,
stride_inter_seq: tl.constexpr,
stride_inter_step: tl.constexpr,
stride_inter_dim: tl.constexpr,
stride_inter_win: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
# others
pad_slot_id: tl.constexpr,
# Meta-parameters
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
BLOCK_N: tl.constexpr,
SAVE_INTERMEDIATE: tl.constexpr,
):
# ruff: noqa: E501
idx_seq = tl.program_id(0)
if idx_seq >= batch:
return
# [BLOCK_N,] elements along the feature-dimension (channel)
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
if IS_CONTINUOUS_BATCHING:
# mask = idx_seq < batch
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_state_indices).to(
tl.int64)
else:
conv_state_batch_coord = idx_seq
if USE_PAD_SLOT: # noqa
if conv_state_batch_coord == pad_slot_id:
# not processing as this is not the actual sequence
return
if IS_SPEC_DECODING:
# The rolling of conv state:
#
# Before forward, the conv_state is:
# [history1, history2, ..., historyM].
#
# After forward, the conv_state becomes:
# [history2, ..., historyM, draft1, draft2, ..., draftN].
#
# After acceptance, it becomes:
#
# - accept 1 tokens: [history2, ..., historyM, draft1]
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
# - and so on.
conv_state_token_offset = tl.load(num_accepted_tokens_ptr +
idx_seq) - 1
else:
conv_state_token_offset = 0
# STEP 1: READ init_state data
conv_states_base = (conv_state_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim))
mask_w = idx_feats < dim
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
if KERNEL_WIDTH >= 2:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH >= 3:
conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH >= 4:
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH == 5:
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
#col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
# STEP 2: assume state_len > seqlen
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
# The conv_state updates works in a sliding window manner,
# at each forward pass, the tokens are shift by 1, so we
# load since idx_tokens + 1.
conv_state_ptrs_source = (
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
conv_state_token_offset * stride_conv_state_tok +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens + 1) * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = ((conv_state_batch_coord < num_cache_lines)
& ((idx_tokens + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
VAL = state_len - seqlen
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim
) # [BLOCK_N]
x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None]
) # [BLOCK_M, BLOCK_N]
mask_x = ((idx_tokens - VAL >= 0)[:, None]
& (idx_tokens - VAL < seqlen)[:, None]
& (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
tl.debug_barrier()
new_conv_state = tl.where(mask, conv_state, loaded_x)
conv_state_base = (conv_state_ptr +
(conv_state_batch_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
conv_state_ptrs_target = (conv_state_base +
(idx_tokens * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
tl.store(conv_state_ptrs_target, new_conv_state, mask)
# STEP 3: init accumulator
if HAS_BIAS:
bias = bias_ptr + idx_feats
mask_bias = idx_feats < dim
acc_preload = tl.load(bias, mask=mask_bias,
other=0.0).to(tl.float32) # [BLOCK_N]
else:
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
# STEP 4:
# PRE-LOAD WEIGHTS
# first kernel column, configured for weights to handle BLOCK_N features in range
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
mask_w = idx_feats < dim
if KERNEL_WIDTH >= 2:
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 3:
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 4:
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
x_base_1d = x_base # starting of chunk [BLOCK_N]
mask_x_1d = idx_feats < dim
# STEP 5: compute each token
for idx_token in tl.static_range(seqlen):
acc = acc_preload
matrix_w = w_col0
matrix_x = col0
for j in tl.static_range(KERNEL_WIDTH):
if KERNEL_WIDTH == 2:
if j == 1: # KERNEL_WIDTH-1:
matrix_w = w_col1
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 3:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 4:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
acc += matrix_x * matrix_w # [BLOCK_N]
if KERNEL_WIDTH == 2:
col0 = matrix_x
elif KERNEL_WIDTH == 3:
col0 = col1
col1 = matrix_x
elif KERNEL_WIDTH == 4:
col0 = col1
col1 = col2
col2 = matrix_x
if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
# mask_1d = (idx_token < seqlen) & (
# idx_feats < dim
# ) # token-index # feature-index
maskL = idx_feats < dim
maskR = tl.full(maskL.shape, False, tl.int1)
mask_1d = tl.where(idx_token < seqlen, maskL, maskR)
o_ptrs = (o_ptr + (idx_seq) * stride_o_seq +
idx_token * stride_o_token + (idx_feats * stride_o_dim))
tl.store(o_ptrs, acc, mask=mask_1d)
if SAVE_INTERMEDIATE:
# Save the window state after consuming this token
# Layout: [seq(cache line), step, dim, win(K-1)]
base_ptr = (intermediate_conv_window_ptr +
conv_state_batch_coord * stride_inter_seq +
idx_token * stride_inter_step +
idx_feats * stride_inter_dim)
if KERNEL_WIDTH >= 2:
tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w)
if KERNEL_WIDTH >= 3:
tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w)
if KERNEL_WIDTH >= 4:
tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w)
def causal_conv1d_update_npu(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Union[bool, str, None] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
intermediate_conv_window: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False,
):
"""
x: (batch, dim) or (batch, dim, seqlen)
[shape=2: single token prediction]
[shape=3: single or multiple tokens prediction]
conv_state: (..., dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state
starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
"""
if validate_data:
assert cache_seqlens is None # not implemented yet - ok for vLLM
assert pad_slot_id is not None
assert x.stride(1) == 1
if isinstance(activation, bool):
activation = "silu" if activation is True else None
elif activation is not None:
assert activation in ["silu", "swish"]
unsqueeze = x.dim() == 2
if unsqueeze:
# make it (batch, dim, seqlen) with seqlen == 1
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
_, width = weight.shape
# conv_state: (..., dim, state_len), where state_len >= width - 1
num_cache_lines, _, state_len = conv_state.size()
if validate_data:
assert dim == weight.size(0)
assert (
conv_state.stride(-2) == 1
), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
assert state_len >= width - 1
# when above happens, we don't shift-left to keep any records in conv_state
assert dim == conv_state.size(1)
if conv_state_indices is None:
assert conv_state.size(0) >= batch
else:
assert (batch, ) == conv_state_indices.shape
assert num_cache_lines >= batch
assert weight.stride(1) == 1 # Need this
assert cache_seqlens is None # not needed for vLLM - circular buffer
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
out = x
stride_w_dim, stride_w_width = weight.stride()
stride_x_seq, stride_x_dim, stride_x_token = x.stride(
) # X (batch, dim, seqlen)
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
)
stride_state_indices = (conv_state_indices.stride(0)
if conv_state_indices is not None else 0)
state_len = width - 1 + (seqlen - 1) # effective state_len needed
np2_statelen = triton.next_power_of_2(state_len)
def grid(META):
return (
batch,
triton.cdiv(dim, META["BLOCK_N"]),
)
# prepare intermediate buffer strides if provided
if intermediate_conv_window is not None:
stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = (
intermediate_conv_window.stride(0),
intermediate_conv_window.stride(1),
intermediate_conv_window.stride(2),
intermediate_conv_window.stride(3),
)
else:
stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0
_causal_conv1d_update_kernel[grid](
# Pointers to matrices
x,
weight,
bias,
conv_state,
cache_seqlens,
conv_state_indices,
num_accepted_tokens,
intermediate_conv_window
if intermediate_conv_window is not None else x,
out,
# Matrix dimensions
batch,
dim,
seqlen,
state_len,
num_cache_lines,
# stride
stride_x_seq,
stride_x_dim,
stride_x_token,
stride_w_dim,
stride_w_width,
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_state_indices,
stride_inter_seq,
stride_inter_step,
stride_inter_dim,
stride_inter_win,
stride_o_seq,
stride_o_dim,
stride_o_token,
# others
pad_slot_id,
# META
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen,
USE_PAD_SLOT=pad_slot_id is not None,
BLOCK_N=128,
SAVE_INTERMEDIATE=intermediate_conv_window is not None,
)
if unsqueeze:
out = out.squeeze(-1)
return out