Compare commits

...

2 Commits
b9110 ... b9112

Author SHA1 Message Date
CrispStrobe
8e1f9d0834 CUDA: handle OW > 65535 in im2col (2D and 3D) (#22944)
`im2col_cuda` and `im2col_3d_cuda` both dispatch with
`block_nums.y = OW`. CUDA caps grid Y at 65535. Conv1d encoders on
raw 16 kHz audio with T > 65535 (~ 4 s) trip the limit -- e.g. SEANet
at 11 s lands at OW = 176000 -- and the launch returns
`invalid configuration argument`.

Clamp `block_nums.y` to `MIN(OW, MAX_GRIDDIM_Y)` and loop inside the
kernel with stride `MAX_GRIDDIM_Y`. Same in-kernel stride pattern
already used for the z axis (`MAX_GRIDDIM_Z`). Both 2D `im2col_kernel`
and 3D `im2col_3d_kernel` need the same fix. Bit-identical for
OW <= 65535 (single iteration of the new outer loop).

Tested on T4 / Jetson Orin with a SEANet encoder running on 11 s /
16 kHz audio (im2col reaching OW ~ 176000); pre-fix launch returns
`invalid configuration argument`, post-fix runs to completion.
Existing test-backend-ops im2col cases unchanged.
2026-05-11 19:48:29 +02:00
Pascal
e936660760 Ggml/cuda snake fusion hardening (#22912)
* cuda: tighten snake fusion type checks for all operands (defensive, sync vulkan)

* cuda: reject snake fusion when ne[2] or ne[3] > 1 (mirror vulkan PR review)

* cuda: merge type_ok and types_ok into a single types_ok (address am17an review)

* cuda: filter ADD/SUB/MUL/DIV in supports_op to F32/F16

bin_bcast only dispatches F32/F16 type triplets, mirror the
vulkan filter so unsupported types fall back through cpy
instead of aborting.

* test-backend-ops: extend snake_fuse to rank-4 with ne[2]/ne[3] > 1 cases
2026-05-11 18:42:08 +02:00
3 changed files with 71 additions and 46 deletions

View File

@@ -3929,10 +3929,25 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
// closure check: the trailing add must read the same x as the leading mul
const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];
const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16);
// Kernel iterates over total = T * C, so x and add must be 2D and
// a / inv_b must collapse to [1, C, 1, 1]. Higher dims are not handled.
const bool dim_ok = (x->ne[2] == 1 && x->ne[3] == 1) &&
(add->ne[2] == 1 && add->ne[3] == 1) &&
(a->ne[2] == 1 && a->ne[3] == 1);
const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1];
if (type_ok && shape_ok && x_in_add == x && add->type == x->type) {
// x must be in the supported whitelist and every operand / intermediate
// result must share x's type, since launch_snake casts a / inv_b as
// float and templates the kernel on a single T. Mixed precision chains
// fall back to the naive path.
const ggml_tensor * sin1 = cgraph->nodes[i + 1];
const bool types_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16) &&
(a->type == x->type) && (inv_b->type == x->type) &&
(mul0->type == x->type) && (sin1->type == x->type) &&
(sqr->type == x->type) && (mul1->type == x->type) &&
(add->type == x->type);
if (types_ok && shape_ok && dim_ok && x_in_add == x) {
ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add);
return 4;
}
@@ -5291,12 +5306,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SQRT:
@@ -5305,6 +5316,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CLAMP:
case GGML_OP_LOG:
return true;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
case GGML_OP_SSM_SCAN: {
if (op->src[3]->ne[0] == 1) {
// Mamba2

View File

@@ -1,5 +1,6 @@
#include "im2col.cuh"
#define MAX_GRIDDIM_Y 65535
#define MAX_GRIDDIM_Z 65535
template <typename T>
@@ -18,22 +19,23 @@ static __global__ void im2col_kernel(
const int64_t ikh = rem / KW;
const int64_t ikw = rem - ikh * KW;
const int64_t iow = blockIdx.y;
for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) {
const int64_t in = iz / OH;
const int64_t ioh = iz - in * OH;
for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) {
for (int64_t iz = blockIdx.z; iz < N_OH; iz += MAX_GRIDDIM_Z) {
const int64_t in = iz / OH;
const int64_t ioh = iz - in * OH;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t offset_dst =
((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
const int64_t offset_dst =
((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
}
@@ -51,7 +53,7 @@ static void im2col_cuda(const float * x, T* dst,
const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
const int64_t N_OH = N * OH;
const int64_t KH_KW = KW*KH;
dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z));
dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OH, MAX_GRIDDIM_Z));
im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(x, dst, IC, IW, IH, OH, OW, KW, KH,
IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW,
s0, s1, p0, p1, d0, d1);
@@ -136,23 +138,24 @@ static __global__ void im2col_3d_kernel(
const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
const int64_t ikw = i % KW;
const int64_t iow = blockIdx.y;
for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) {
const int64_t in = iz / OD_OH;
const int64_t iod = (iz - in*OD_OH) / OH;
const int64_t ioh = iz % OH;
for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) {
for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz += MAX_GRIDDIM_Z) {
const int64_t in = iz / OD_OH;
const int64_t iod = (iz - in*OD_OH) / OH;
const int64_t ioh = iz % OH;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t iid = iod * s2 + ikd * d2 - p2;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t iid = iod * s2 + ikd * d2 - p2;
const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
dst[offset_dst] = src[offset_src];
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
dst[offset_dst] = src[offset_src];
}
}
}
}
@@ -178,7 +181,7 @@ static void im2col_3d_cuda(const float * src, T* dst,
const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z));
dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OD_OH, MAX_GRIDDIM_Z));
im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,

View File

@@ -3561,7 +3561,7 @@ struct test_relu_sqr : public test_case {
// and dispatches a single fused kernel.
struct test_snake_fuse : public test_case {
const ggml_type type;
const std::array<int64_t, 2> ne; // [T, C]
const std::array<int64_t, 4> ne; // [T, C, D2, D3]
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
@@ -3586,11 +3586,11 @@ struct test_snake_fuse : public test_case {
}
test_snake_fuse(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 2> ne = {256, 192})
std::array<int64_t, 4> ne = {256, 192, 1, 1})
: type(type), ne(ne) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * x = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]);
ggml_tensor * x = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
ggml_set_name(x, "x");
ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, ne[1]);
@@ -7558,11 +7558,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
// SNAKE activation fusion: x + sin(a*x)^2 * inv_b
for (ggml_type type : { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
test_cases.emplace_back(new test_snake_fuse(type, { 5, 7})); // primes sub-block
test_cases.emplace_back(new test_snake_fuse(type, { 33, 32})); // boundary
test_cases.emplace_back(new test_snake_fuse(type, {1025, 13})); // large prime, grid-stride
test_cases.emplace_back(new test_snake_fuse(type, { 128, 16})); // power-of-two
test_cases.emplace_back(new test_snake_fuse(type, { 256, 192})); // BigVGAN-ish
test_cases.emplace_back(new test_snake_fuse(type, { 5, 7, 1, 1})); // primes sub-block
test_cases.emplace_back(new test_snake_fuse(type, { 33, 32, 1, 1})); // boundary
test_cases.emplace_back(new test_snake_fuse(type, {1025, 13, 1, 1})); // large prime, grid-stride
test_cases.emplace_back(new test_snake_fuse(type, { 128, 16, 1, 1})); // power-of-two
test_cases.emplace_back(new test_snake_fuse(type, { 256, 192, 1, 1})); // BigVGAN-ish
// higher-rank shapes: matcher must reject fusion, fallback to naive chain
test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 1})); // ne[2] > 1
test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 1, 2})); // ne[3] > 1
test_cases.emplace_back(new test_snake_fuse(type, { 64, 32, 2, 3})); // ne[2] > 1 and ne[3] > 1
}
// glu ops
@@ -9093,9 +9097,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
// SNAKE activation fusion at BigVGAN scale (T=7680 = 24 kHz x 320 ms, C=192)
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F32, {7680, 192, 1, 1}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_F16, {7680, 192, 1, 1}));
test_cases.emplace_back(new test_snake_fuse(GGML_TYPE_BF16, {7680, 192, 1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));