Compare commits

..

2 Commits

Author SHA1 Message Date
Georgi Gerganov
697966680b ggml : sync (ggml_conv_2d, fix mul_mat bug, CUDA GLM rope) 2023-07-14 16:36:41 +03:00
Kawrakow
27ad57a69b Metal: faster Q4_0 and Q4_1 matrix x vector kernels (#2212)
* 3-5% faster Q4_0 on Metal

* 7-25% faster Q4_1 on Metal

* Oops, forgot to delete the original Q4_1 kernel

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2023-07-14 11:46:21 +02:00
4 changed files with 210 additions and 133 deletions

View File

@@ -1667,6 +1667,40 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int half_n_dims = ncols/4;
if (col >= half_n_dims) {
return;
}
const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int i = row*ncols + col;
const float col_theta_scale = powf(theta_scale, col);
const float theta = p*col_theta_scale;
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);
const float x0 = x[i + 0];
const float x1 = x[i + half_n_dims];
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
const float block_theta = block_p*col_theta_scale;
const float sin_block_theta = sinf(block_theta);
const float cos_block_theta = cosf(block_theta);
const float x2 = x[i + half_n_dims * 2];
const float x3 = x[i + half_n_dims * 3];
dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
}
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int row = blockDim.y*blockIdx.y + threadIdx.y;
@@ -2064,6 +2098,14 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
}
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
GGML_ASSERT(nrows % 4 == 0);
const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int num_blocks_x = (ncols + 4*CUDA_ROPE_BLOCK_SIZE - 1) / (4*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(num_blocks_x, nrows, 1);
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
}
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
@@ -2618,13 +2660,21 @@ inline void ggml_cuda_op_rope(
const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
GGML_ASSERT(mode == 0);
const int n_ctx = ((int32_t *) src1->data)[3];
const float theta_scale = powf(10000.0, -2.0f/n_dims);
const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
bool is_glm = mode & 4;
// compute
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
if (is_glm) {
const float id_p = min(p, n_ctx - 2.f);
const float block_p = max(p - (n_ctx - 2.f), 0.f);
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
} else {
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
}
(void) dst;
(void) src0_ddq_i;

View File

@@ -739,12 +739,8 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
if (src0t == GGML_TYPE_Q4_0) {
[encoder dispatchThreadgroups:MTLSizeMake(ne01 / 8+((ne01 % 8) & 0x01), ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q4_1) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_Q3_K ||

View File

@@ -395,9 +395,12 @@ kernel void kernel_mul_mat_q4_0_f32(
// each thread in a SIMD group deals with 1 block.
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
sumy *= (-8.f);
for (int row = 0; row < N_DST; row++) {
// prefetch next x block
@@ -405,39 +408,50 @@ kernel void kernel_mul_mat_q4_0_f32(
// calculate
float d = qb_curr.d;
float2 acc = {0.0f, 0.0f};
float acc = sumy;
for (int i = 0; i < 16; i++) {
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
acc[1] += yl[i] + yl[i+16];
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
}
sumf[row] += d * (acc[0] - 8.f*acc[1]);
sumf[row] += d * acc;
qb_curr = qb_next;
}
}
for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
}
for (int row = 0; row < N_DST; row++) {
// prefetch next x block
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
// calculate
float d = qb_curr.d;
float2 acc = {0.0f, 0.0f};
for (int i = 0; i < 16; i++) {
acc[0] += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
acc[1] += yl[i] + yl[i+16];
if (nb % N_SIMDWIDTH == 0) {
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
}
}
if (tiisg < nb % N_SIMDWIDTH) {
sumf[row] += d * (acc[0] - 8.f*acc[1]);
}
qb_curr = qb_next;
} else {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
sumy *= (-8.f);
for (int row = 0; row < N_DST; row++) {
// prefetch next x block
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
// calculate
float d = qb_curr.d;
float acc = sumy;
for (int i = 0; i < 16; i++) {
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
}
if (tiisg < nb % N_SIMDWIDTH) {
sumf[row] += d * acc;
}
qb_curr = qb_next;
all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
}
}
}
}
@@ -449,65 +463,83 @@ kernel void kernel_mul_mat_q4_1_f32(
constant int64_t & ne00,
constant int64_t & ne10,
constant int64_t & ne0,
threadgroup float * sum [[threadgroup(0)]],
constant int64_t & ne01[[buffer(4)]],
uint2 tgpig[[threadgroup_position_in_grid]],
uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) {
const int nb = ne00/QK4_1;
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK4_0;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
device const float * y = (device const float *) src1 + r1*ne10;
block_q4_1 qb_curr, qb_next;
float4 y_curr[8]; // src1 vector cache
float sumf[N_DST]={0.f}, all_sum;
thread float * yl=(thread float *)y_curr;
const uint nth = tptg.x*tptg.y;
const uint ith = tptg.y*tpitg.x + tpitg.y;
const int ix = tpitg.y/4; // 0 or 1
const int iy = tpitg.y - 4*ix; // 0...3
const int first = 4 * iy;
float sumf = 0;
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
const float d = (float)x[i].d;
const float m = (float)x[i].m;
device const uint8_t * xl = x[i].qs + first;
device const float * yl = y + i * QK4_1 + first;
float2 acc = {0.0f, 0.0f};
for (int j = 0; j < 4; ++j) {
acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
// bootstrap
qb_curr = x[tiisg];
// each thread in a SIMD group deals with 1 block.
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
sumf += acc[0] + acc[1];
for (int row = 0; row < N_DST; row++) {
// prefetch next x block
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
// calculate
const float d = qb_curr.d;
const float m = qb_curr.m;
float acc = 0.f;
for (int i = 0; i < 16; i++) {
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
}
sumf[row] += d * acc + m * sumy;
qb_curr = qb_next;
}
}
sum[ith] = sumf;
if (nb % N_SIMDWIDTH == 0) {
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
}
}
} else {
//
// Accumulate the sum from all threads in the threadgroup
//
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith%4 == 0) {
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith%16 == 0) {
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith == 0) {
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
dst[r1*ne0 + r0] = sum[0];
float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
for (int row = 0; row < N_DST; row++) {
// prefetch next x block
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
// calculate
const float d = qb_curr.d;
const float m = qb_curr.m;
float acc = 0.f;
for (int i = 0; i < 16; i++) {
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
}
if (tiisg < nb % N_SIMDWIDTH) {
sumf[row] += d * acc + m * sumy;
}
qb_curr = qb_next;
all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
}
}
}
}

99
ggml.c
View File

@@ -10684,6 +10684,8 @@ static void ggml_compute_forward_mul_mat(
const enum ggml_type type = src0->type;
const bool src1_cont = ggml_is_contiguous(src1);
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
@@ -10747,7 +10749,7 @@ static void ggml_compute_forward_mul_mat(
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
if (type != GGML_TYPE_F32) {
float * const wdata = params->wdata;
float * const wdata = params->wdata;
ggml_to_float_t const to_float = type_traits[type].to_float;
size_t id = 0;
@@ -10805,7 +10807,7 @@ static void ggml_compute_forward_mul_mat(
// src1 rows
const int64_t nr1 = ne11*ne12*ne13;
void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
for (int64_t ir1 = 0; ir1 < nr1; ++ir1) {
@@ -10828,7 +10830,15 @@ static void ggml_compute_forward_mul_mat(
const int64_t i3 = i13;
const char * src0_row = (const char *) src0->data + ( 0 + i02*nb02 + i03*nb03 );
const char * src1_col = (const char *) wdata + (i11 + i12*ne11 + i13*ne12*ne11)*row_size;
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
: (i11*nb11 + i12*nb12 + i13*nb13));
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
@@ -12982,12 +12992,13 @@ static void ggml_compute_forward_conv_1d(
};
}
// ggml_compute_forward_conv_2d_sk_p0
// ggml_compute_forward_conv_2d
static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
static void ggml_compute_forward_conv_2d_f16_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
const struct ggml_tensor * opt0,
struct ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -13007,28 +13018,37 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
// size of the convolution row - the kernel size unrolled across all channels
const int ew0 = nk0*nk1*ne02;
const int32_t s0 = ((const int32_t*)(opt0->data))[0];
const int32_t s1 = ((const int32_t*)(opt0->data))[1];
const int32_t p0 = ((const int32_t*)(opt0->data))[2];
const int32_t p1 = ((const int32_t*)(opt0->data))[3];
const int32_t d0 = ((const int32_t*)(opt0->data))[4];
const int32_t d1 = ((const int32_t*)(opt0->data))[5];
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb10 == sizeof(float));
if (params->type == GGML_TASK_INIT) {
// TODO: fix this memset (wsize is overestimated)
memset(params->wdata, 0, params->wsize);
// prepare source data (src1)
{
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
for (int i13 = 0; i13 < ne13; i13++) {
for (int i12 = 0; i12 < ne12; i12++) {
const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);
for (int i12 = 0; i12 < ne12; i12++) {
const float * const src = (float *)((char *) src1->data + i12*nb12);
ggml_fp16_t * dst_data = wdata;
for (int i1 = 0; i1 < ne1; i1++) {
for (int i0 = 0; i0 < ne0; i0++) {
for (int ik1 = 0; ik1 < nk1; ik1++) {
for (int ik0 = 0; ik0 < nk0; ik0++) {
for (int i1 = 0; i1 < ne1; i1++) {
for (int i0 = 0; i0 < ne0; i0++) {
for (int ik1 = 0; ik1 < nk1; ik1++) {
for (int ik0 = 0; ik0 < nk0; ik0++) {
const int idx0 = i0*s0 + ik0*d0 - p0;
const int idx1 = i1*s1 + ik1*d1 - p1;
if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
}
}
}
@@ -13071,19 +13091,21 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
}
}
static void ggml_compute_forward_conv_2d_sk_p0(
static void ggml_compute_forward_conv_2d(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
const struct ggml_tensor * opt0,
struct ggml_tensor * dst
) {
switch (src0->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_conv_2d_sk_p0_f16_f32(params, src0, src1, dst);
ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, opt0, dst);
} break;
case GGML_TYPE_F32:
{
//ggml_compute_forward_conv_2d_sk_p0_f32(params, src0, src1, dst);
//ggml_compute_forward_conv_2d_f32(params, src0, src1, opt0, dst);
GGML_ASSERT(false);
} break;
default:
@@ -13093,32 +13115,6 @@ static void ggml_compute_forward_conv_2d_sk_p0(
}
}
// ggml_compute_forward_conv_2d
static void ggml_compute_forward_conv_2d(
const struct ggml_compute_params* params,
const struct ggml_tensor* src0,
const struct ggml_tensor* src1,
const struct ggml_tensor* opt0,
struct ggml_tensor* dst) {
const int32_t s0 = ((const int32_t*)(opt0->data))[0];
const int32_t s1 = ((const int32_t*)(opt0->data))[1];
const int32_t p0 = ((const int32_t*)(opt0->data))[2];
const int32_t p1 = ((const int32_t*)(opt0->data))[3];
const int32_t d0 = ((const int32_t*)(opt0->data))[4];
const int32_t d1 = ((const int32_t*)(opt0->data))[5];
GGML_ASSERT(d0 == 1); // dilation not supported
GGML_ASSERT(d1 == 1);
GGML_ASSERT(p0 == 0); // padding not supported
GGML_ASSERT(p1 == 0);
if (s0 == src0->ne[0] && s1 == src0->ne[1]) {
ggml_compute_forward_conv_2d_sk_p0(params, src0, src1, dst);
} else {
GGML_ASSERT(false); // only stride equal to kernel size is supported
}
}
// ggml_compute_forward_pool_1d_sk_p0
static void ggml_compute_forward_pool_1d_sk_p0(
@@ -16575,19 +16571,22 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
const int64_t ne11 = node->src[1]->ne[1]; // H
const int64_t ne12 = node->src[1]->ne[2]; // C
const int64_t ne0 = node->ne[0];
const int64_t ne1 = node->ne[1];
const int64_t ne2 = node->ne[2];
const int64_t nk = ne00*ne01;
const int64_t ew0 = nk * ne02;
UNUSED(ne02);
UNUSED(ne03);
UNUSED(nk);
UNUSED(ne2);
size_t cur = 0;
if (node->src[0]->type == GGML_TYPE_F16 &&
node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
} else if (node->src[0]->type == GGML_TYPE_F32 &&
node->src[1]->type == GGML_TYPE_F32) {
node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)* (ne10*ne11*ne12);
} else {
GGML_ASSERT(false);