Compare commits

...

2 Commits

Author SHA1 Message Date
arthw
2985be3324 update hw info 2026-03-31 09:24:40 +08:00
arthw
8dc96153c3 enhance FA stable in UT 2026-03-17 15:57:02 +08:00
5 changed files with 174 additions and 176 deletions

View File

@@ -216,7 +216,8 @@ struct sycl_device_info {
// cudaOccupancyMaxActiveBlocksPerMultiprocessor
bool vmm; // virtual memory support
size_t total_vram;
//sycl_hw_info hw_info; \\ device id and aarch, currently not used
int max_work_group_sizes;
sycl_hw_info hw_info;
optimize_feature opt_feature;
};
@@ -692,15 +693,20 @@ static __dpct_inline__ float get_alibi_slope(const float max_bias,
return dpct::pow(base, exph);
}
static const sycl::uint3 init_fastdiv_values(uint32_t d) {
GGML_ASSERT(d != 0);
static const sycl::uint3 init_fastdiv_values(uint64_t d_64) {
GGML_ASSERT(d_64 != 0);
GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
uint32_t d = (uint32_t)d_64;
// compute L = ceil(log2(d));
uint32_t L = 0;
while (L < 32 && (uint32_t{ 1 } << L) < d) {
L++;
}
uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
// pack divisor as well to reduce error surface
return sycl::uint3(mp, L, d);
}
@@ -965,4 +971,23 @@ static T block_reduce(T val, T * shared_vals, int block_size_template) {
return val;
}
/*
* for Debug
*/
static int get_thread_id() {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
int blockId = item_ct1.get_group(2) +
item_ct1.get_group(1) * item_ct1.get_group_range(2) +
item_ct1.get_group(0) * item_ct1.get_group_range(2) *
item_ct1.get_group_range(1);
int threadsPerBlock = item_ct1.get_local_range(2) *
item_ct1.get_local_range(1) * item_ct1.get_local_range(0);
int threadInBlockId = item_ct1.get_local_id(2) +
item_ct1.get_local_id(1) * item_ct1.get_local_range(2) +
item_ct1.get_local_id(0) * item_ct1.get_local_range(2) *
item_ct1.get_local_range(1);
int id = blockId * threadsPerBlock + threadInBlockId;
return id;
}
#endif // GGML_SYCL_COMMON_HPP

View File

@@ -70,6 +70,7 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 64, 64)
return 0;
}
@@ -130,134 +131,6 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, co
return 0;
}
static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
return 0;
}
static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
return 0;
}
static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
if(fast_fp16_available(cc))
return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols);
@@ -455,7 +328,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
(K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
#ifdef SYCL_FAST_FP16
static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
@@ -505,7 +378,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
}
if (k_KQ_0 + nbatch_K < DKQ) {
item_ct1.barrier(); // Sync not needed on last iteration.
item_ct1.barrier(sycl::access::fence_space::local_space); // Sync not needed on last iteration.
}
}
@@ -545,7 +418,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
const int k_VKQ_max,
const int col_Q_0,
float * KQ_max_new_shared) {
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
@@ -620,14 +493,14 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
}
if constexpr (np == 1) {
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
} else {
static_assert(cpw == 1, "bad cpw");
if (item_ct1.get_local_id(2) == 0) {
KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0];
}
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np];
KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
}
@@ -697,7 +570,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
(V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
#ifdef SYCL_FAST_FP16
#pragma unroll
@@ -765,7 +638,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
}
}
#endif // SYCL_FAST_FP16
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
}
}
@@ -825,6 +698,9 @@ static void flash_attn_tile(const char * Q,
nb31, nb32, nb33);
return;
}
// int id =get_thread_id();
// if(id==0) sycl::ext::oneapi::experimental::printf("DKQ=%d, DV=%d, ncols1=%d, ncols2=%d\n",
// DKQ, DV, ncols1, ncols2);
static_assert(ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined");
@@ -932,7 +808,9 @@ static void flash_attn_tile(const char * Q,
constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size;
#pragma unroll
for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
if (i0 + np * warp_size * cpy_ne_D <= DKQ ||
i0 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D) + item_ct1.get_local_id(2) * cpy_ne_D <
@@ -972,7 +850,7 @@ static void flash_attn_tile(const char * Q,
}
}
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
// Main loop over KV cache:
const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
@@ -1051,7 +929,7 @@ static void flash_attn_tile(const char * Q,
return;
}
item_ct1.barrier();
item_ct1.barrier(sycl::access::fence_space::local_space);
#pragma unroll
for (int ip = 1; ip < np; ++ip) {
@@ -1189,30 +1067,58 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm
const int id = ggml_sycl_get_device();
const int cc = ggml_sycl_info().devices[id].cc;
const auto arch = ggml_sycl_info().devices[id].hw_info.arch;
const auto gpu_name = ggml_sycl_info().devices[id].hw_info.name;
const int warp_size = WARP_32_SIZE; //can't support WARP_16_SIZE
constexpr size_t nbytes_shared = 0;
if constexpr (DV <= 256) {
if (Q->ne[1] > 16/ncols2) {
constexpr int cols_per_block = 32;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
// printf("zjy DKQ=%d, DV=%d, ncols2=%d, use_logit_softca=%d Q->ne[1]=%d 4/ncols2=%d\n",
// DKQ, DV, ncols2, use_logit_softcap, Q->ne[1], 4/ncols2);
// printf("zjy name=%s\n", gpu_name.data());
// std::cout << "GPU name "<<gpu_name<< " arch "<< ggml_sycl_info().devices[id].hw_info.arch_name
// <<std::endl;
//same or older iGPU are weak too.
if (arch <= gpu_arch::intel_gpu_adl_s) {
if (Q->ne[1]>=75) {
{
constexpr int cols_per_block = ncols2*2;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
}
}
}
if (Q->ne[1] > 8/ncols2) {
constexpr int cols_per_block = 16;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
if (arch > gpu_arch::intel_gpu_acm_g10 || Q->ne[1]<75) {
if constexpr (ncols2 <= 32) {
if (Q->ne[1] > 16/ncols2) {
constexpr int cols_per_block = 32;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
}
}
if constexpr (ncols2 <= 16) {
if (Q->ne[1] > 8/ncols2) {
constexpr int cols_per_block = 16;
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
launch_fattn<DV, cols_per_block/ncols2, ncols2,
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
}
}
}
if constexpr (ncols2 <= 8) {
@@ -1237,6 +1143,8 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
return;
}
}
if constexpr (ncols2 <= 2) {

View File

@@ -104,6 +104,7 @@ static ggml_sycl_device_info ggml_sycl_init() {
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units();
info.devices[i].hw_info = get_device_hw_info(&device);
}

View File

@@ -1,15 +1,67 @@
#include "sycl_hw.hpp"
// TODO: currently not used
/*
sycl_hw_info get_device_hw_info(sycl::device *device_ptr) {
sycl_hw_info res;
int32_t id = device_ptr->get_info<sycl::ext::intel::info::device::device_id>();
res.device_id = id;
using namespace std;
syclex::architecture arch = device_ptr->get_info<syclex::info::device::architecture>();
res.arch = arch;
return res;
}
/*defined in
* /opt/intel/oneapi/compiler/latest/include/sycl/ext/oneapi/experimental/device_architecture.def
*/
static map<gpu_arch, std::pair<const char*, sycl_intel_gpu_family>> arch2name = {
{gpu_arch::intel_gpu_bdw, {"intel_gpu_bdw", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_skl, {"intel_gpu_skl", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_kbl, {"intel_gpu_kbl", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_cfl, {"intel_gpu_cfl", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_apl, {"intel_gpu_apl", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_glk, {"intel_gpu_glk", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_whl, {"intel_gpu_whl", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_aml, {"intel_gpu_aml", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_cml, {"intel_gpu_cml", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_icllp, {"intel_gpu_icllp", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_ehl, {"intel_gpu_ehl", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_tgllp, {"intel_gpu_tgllp", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_rkl, {"intel_gpu_rkl", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_adl_s, {"intel_gpu_adl_s", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_adl_p, {"intel_gpu_adl_p", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_adl_n, {"intel_gpu_adl_n", GPU_FAMILY_IGPU_NON_XE}},
{gpu_arch::intel_gpu_dg1, {"intel_gpu_dg1", GPU_FAMILY_DGPU_CLIENT_GAME}},
{gpu_arch::intel_gpu_acm_g10, {"intel_gpu_acm_g10", GPU_FAMILY_DGPU_CLIENT_GAME}},
{gpu_arch::intel_gpu_acm_g11, {"intel_gpu_acm_g11", GPU_FAMILY_DGPU_CLIENT_GAME}},
{gpu_arch::intel_gpu_acm_g12, {"intel_gpu_acm_g12", GPU_FAMILY_DGPU_CLIENT_GAME}},
{gpu_arch::intel_gpu_pvc, {"intel_gpu_pvc", GPU_FAMILY_DGPU_CLOUD}},
{gpu_arch::intel_gpu_pvc_vg, {"intel_gpu_pvc_vg", GPU_FAMILY_DGPU_CLOUD}},
{gpu_arch::intel_gpu_mtl_u, {"intel_gpu_mtl_u", GPU_FAMILY_IGPU_XE}},
{gpu_arch::intel_gpu_mtl_h, {"intel_gpu_mtl_h", GPU_FAMILY_IGPU_XE}},
{gpu_arch::intel_gpu_arl_h, {"intel_gpu_arl_h", GPU_FAMILY_IGPU_XE}},
{gpu_arch::intel_gpu_bmg_g21, {"intel_gpu_bmg_g21", GPU_FAMILY_DGPU_CLIENT_GAME}},
{gpu_arch::intel_gpu_bmg_g31, {"intel_gpu_bmg_g31", GPU_FAMILY_DGPU_CLIENT_GAME}},
{gpu_arch::intel_gpu_lnl_m, {"intel_gpu_lnl_m", GPU_FAMILY_IGPU_XE}},
{gpu_arch::intel_gpu_ptl_h, {"intel_gpu_ptl_h", GPU_FAMILY_IGPU_XE}},
{gpu_arch::intel_gpu_ptl_u, {"intel_gpu_ptl_u", GPU_FAMILY_IGPU_XE}},
{gpu_arch::intel_gpu_wcl, {"intel_gpu_wcl", GPU_FAMILY_IGPU_XE}}
};
sycl_hw_info get_device_hw_info(sycl::device* device_ptr) {
sycl_hw_info res;
int32_t id =
device_ptr->get_info<sycl::ext::intel::info::device::device_id>();
res.device_id = id;
res.name = device_ptr->get_info<sycl::info::device::name>();
syclex::architecture arch =
device_ptr->get_info<syclex::info::device::architecture>();
res.arch = arch;
map<syclex::architecture,
std::pair<const char*, sycl_intel_gpu_family>>::iterator it =
arch2name.find(res.arch);
if (it != arch2name.end()) {
res.arch_name = it->second.first;
res.gpu_family = it->second.second;
} else {
res.arch_name = "unknown";
res.gpu_family = GPU_FAMILY_UKNOWN;
}
return res;
}

View File

@@ -9,18 +9,30 @@
#include <sycl/sycl.hpp>
namespace syclex = sycl::ext::oneapi::experimental;
using gpu_arch = sycl::ext::oneapi::experimental::architecture;
// TODO: currently not used
/*
struct sycl_hw_info {
syclex::architecture arch;
int32_t device_id;
// It's used to mark the GPU computing capacity in Flash-attention OP
// The value must flow the order of performance.
enum sycl_intel_gpu_family {
GPU_FAMILY_UKNOWN = -1,
// Normal iGPU in Intel Core CPU, before Meteor Lake iGPU(Xe)
GPU_FAMILY_IGPU_NON_XE = 0,
// iGPU with Xe core, Meteor Lake iGPU or newer.
GPU_FAMILY_IGPU_XE = 1,
// dGPU for gaming in client/data center, DG1/FLex 140 or newer.
GPU_FAMILY_DGPU_CLIENT_GAME = 2,
// dGPU for AI in cloud, PVC or newer.
GPU_FAMILY_DGPU_CLOUD = 3
};
bool is_in_vector(std::vector<int> &vec, int item);
struct sycl_hw_info {
syclex::architecture arch;
const char* arch_name;
int32_t device_id;
std::string name;
sycl_intel_gpu_family gpu_family;
};
sycl_hw_info get_device_hw_info(sycl::device *device_ptr);
*/
#endif // SYCL_HW_HPP