Compare commits

...

2 Commits

Author SHA1 Message Date
Georgi Gerganov
55cf48de1e cuda : fix multi-seq, quantized FA
ggml-ci
2025-07-22 20:48:53 +03:00
Georgi Gerganov
a856a5665d tests : add non-cont K,V FA tests
ggml-ci
2025-07-18 13:36:27 +03:00
2 changed files with 27 additions and 11 deletions

View File

@@ -745,10 +745,14 @@ void launch_fattn(
size_t nb23 = V ? V->nb[3] : nb13;
if (need_f16_K && K->type != GGML_TYPE_F16) {
GGML_ASSERT(ggml_is_contiguously_allocated(K));
K_f16.alloc(ggml_nelements(K));
const int64_t n_seq = K->ne[3];
const int64_t n_eps = (K->nb[3]/ggml_type_size(K->type))*ggml_blck_size(K->type); // elements per sequence
K_f16.alloc(n_seq*n_eps);
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
for (int s = 0; s < n_seq; ++s) {
to_fp16(K_data + s*K->nb[3], K_f16.ptr + s*n_eps, n_eps, main_stream);
}
K_data = (char *) K_f16.ptr;
const size_t bs = ggml_blck_size(K->type);
@@ -760,10 +764,14 @@ void launch_fattn(
}
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
GGML_ASSERT(ggml_is_contiguously_allocated(V));
V_f16.alloc(ggml_nelements(V));
const int64_t n_seq = V->ne[3];
const int64_t n_eps = (V->nb[3]/ggml_type_size(V->type))*ggml_blck_size(V->type);
V_f16.alloc(n_seq*n_eps);
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
for (int s = 0; s < n_seq; ++s) {
to_fp16(V_data + s*V->nb[3], V_f16.ptr + s*n_eps, n_eps, main_stream);
}
V_data = (char *) V_f16.ptr;
const size_t bs = ggml_blck_size(V->type);

View File

@@ -4258,26 +4258,32 @@ struct test_flash_attn_ext : public test_case {
const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV));
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * {
auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view) -> ggml_tensor * {
int64_t ne[4] = {ne0, ne1, ne2, ne3};
int64_t ne_perm[4];
for (int i = 0; i < 4; ++i) {
ne_perm[permute[i]] = ne[i];
}
ggml_tensor * t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
ggml_tensor * t;
if (is_view) {
ggml_tensor * t0 = ggml_new_tensor_4d(ctx, type, ne_perm[0], 2*ne_perm[1], ne_perm[2], ne_perm[3]);
t = ggml_view_4d(ctx, t0, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3], t0->nb[1], t0->nb[2], t0->nb[3], 0);
} else {
t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]);
}
if (permute != std::array<int32_t, 4>{0, 1, 2, 3}) {
t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]);
}
return t;
};
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1]);
ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1], false);
ggml_set_name(q, "q");
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1]);
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache
ggml_set_name(k, "k");
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1]);
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
ggml_set_name(v, "v");
ggml_tensor * m = nullptr;
@@ -5519,6 +5525,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_timestep_embedding());
test_cases.emplace_back(new test_leaky_relu());
test_cases.emplace_back(new test_flash_attn_ext(128, 128, 4, {1, 3}, 512, 128, true, 0.0f, 0.0f, GGML_PREC_DEFAULT, GGML_TYPE_Q8_0));
for (int hsk : { 64, 80, 128, 192, 256, 576 }) {
for (int hsv : { 64, 80, 128, 192, 256, 512 }) {
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;