mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-05 13:53:23 +02:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
73a48c9790 | ||
|
|
f696428ce8 | ||
|
|
7cce4f8158 | ||
|
|
8d8862829c | ||
|
|
f77c13b91f | ||
|
|
3cfa9c3f12 | ||
|
|
5d195f17bc | ||
|
|
226f295f4d | ||
|
|
f90b4a8efe | ||
|
|
8423d01931 | ||
|
|
5cca2542ac | ||
|
|
55945d2ef5 | ||
|
|
0bcb40b48c | ||
|
|
69e9ff0103 | ||
|
|
5a91109a5d |
@@ -84,6 +84,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
||||
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
|
||||
- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct)
|
||||
- [x] [Jamba](https://huggingface.co/ai21labs)
|
||||
- [X] [Falcon](https://huggingface.co/models?search=tiiuae/falcon)
|
||||
- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)
|
||||
- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)
|
||||
|
||||
@@ -742,6 +742,12 @@ class TextModel(ModelBase):
|
||||
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
|
||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||
logger.info(f"gguf: experts used count = {n_experts_used}")
|
||||
if (n_expert_groups := self.hparams.get("n_group")) is not None:
|
||||
self.gguf_writer.add_expert_group_count(n_expert_groups)
|
||||
logger.info(f"gguf: expert groups count = {n_expert_groups}")
|
||||
if (n_group_used := self.hparams.get("topk_group")) is not None:
|
||||
self.gguf_writer.add_expert_group_used_count(n_group_used)
|
||||
logger.info(f"gguf: expert groups used count = {n_group_used}")
|
||||
|
||||
if (head_dim := self.hparams.get("head_dim")) is not None:
|
||||
self.gguf_writer.add_key_length(head_dim)
|
||||
@@ -1497,6 +1503,17 @@ class MmprojModel(ModelBase):
|
||||
def set_type(self):
|
||||
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
|
||||
|
||||
def prepare_metadata(self, vocab_only: bool):
|
||||
super().prepare_metadata(vocab_only=vocab_only)
|
||||
|
||||
output_type: str = self.ftype.name.partition("_")[2]
|
||||
|
||||
if self.fname_out.is_dir():
|
||||
fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=output_type, model_type=None)
|
||||
self.fname_out = self.fname_out / f"mmproj-{fname_default}.gguf"
|
||||
else:
|
||||
self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
@@ -8222,8 +8239,6 @@ class BailingMoeV2Model(TextModel):
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_count(hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
|
||||
self.gguf_writer.add_expert_group_count(hparams["n_group"])
|
||||
self.gguf_writer.add_expert_group_used_count(hparams["topk_group"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
|
||||
if hparams["score_function"] == "sigmoid":
|
||||
@@ -8943,6 +8958,13 @@ class SmolLM3Model(LlamaModel):
|
||||
class GptOssModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GPT_OSS
|
||||
|
||||
# TODO: remove once MXFP4 is supported more generally
|
||||
def dequant_model(self):
|
||||
quant_config = self.hparams.get("quantization_config")
|
||||
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
|
||||
return
|
||||
return super().dequant_model()
|
||||
|
||||
def transform_nibble_layout(self, tensor):
|
||||
assert tensor.dtype == torch.uint8
|
||||
assert tensor.shape[-1] == 16
|
||||
@@ -9722,10 +9744,6 @@ def main() -> None:
|
||||
|
||||
logger.info(f"Loading model: {dir_model.name}")
|
||||
|
||||
if args.mmproj:
|
||||
if "mmproj" not in fname_out.name:
|
||||
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
|
||||
|
||||
is_mistral_format = args.mistral_format
|
||||
if is_mistral_format and not _mistral_common_installed:
|
||||
raise ImportError(_mistral_import_error_msg)
|
||||
|
||||
@@ -138,7 +138,7 @@ if model_path is None:
|
||||
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
print("Model type: ", config.model_type)
|
||||
print("Vocab size: ", config.vocab_size)
|
||||
@@ -148,8 +148,8 @@ print("BOS token id: ", config.bos_token_id)
|
||||
print("EOS token id: ", config.eos_token_id)
|
||||
|
||||
print("Loading model and tokenizer using AutoTokenizer:", model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
if unreleased_model_name:
|
||||
model_name_lower = unreleased_model_name.lower()
|
||||
@@ -171,7 +171,7 @@ if unreleased_model_name:
|
||||
exit(1)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, device_map="auto", offload_folder="offload"
|
||||
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True
|
||||
)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
|
||||
@@ -1,5 +1,81 @@
|
||||
#include "argsort.cuh"
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
# include <cub/cub.cuh>
|
||||
using namespace cub;
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
|
||||
const int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int row = blockIdx.y;
|
||||
|
||||
if (col < ncols && row < nrows) {
|
||||
indices[row * ncols + col] = col;
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx <= nrows) {
|
||||
offsets[idx] = idx * ncols;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
|
||||
const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream) {
|
||||
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
|
||||
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
|
||||
|
||||
int * temp_indices = temp_indices_alloc.get();
|
||||
float * temp_keys = temp_keys_alloc.get();
|
||||
int * d_offsets = offsets_alloc.get();
|
||||
|
||||
static const int block_size = 256;
|
||||
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
|
||||
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
|
||||
|
||||
const dim3 offset_grid((nrows + block_size - 1) / block_size);
|
||||
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
|
||||
|
||||
cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
|
||||
|
||||
size_t temp_storage_bytes = 0;
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
|
||||
temp_indices, dst, // values (indices)
|
||||
ncols * nrows, nrows, // num items, num segments
|
||||
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
|
||||
stream);
|
||||
} else {
|
||||
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
|
||||
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
|
||||
sizeof(float) * 8, stream);
|
||||
}
|
||||
|
||||
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
|
||||
void * d_temp_storage = temp_storage_alloc.get();
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
|
||||
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
|
||||
stream);
|
||||
} else {
|
||||
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
|
||||
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
|
||||
0, sizeof(float) * 8, stream);
|
||||
}
|
||||
}
|
||||
#endif // GGML_CUDA_USE_CUB
|
||||
|
||||
// Bitonic sort implementation
|
||||
template<typename T>
|
||||
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
|
||||
T tmp = a;
|
||||
@@ -65,7 +141,12 @@ static int next_power_of_2(int x) {
|
||||
return n;
|
||||
}
|
||||
|
||||
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
|
||||
static void argsort_f32_i32_cuda_bitonic(const float * x,
|
||||
int * dst,
|
||||
const int ncols,
|
||||
const int nrows,
|
||||
ggml_sort_order order,
|
||||
cudaStream_t stream) {
|
||||
// bitonic sort requires ncols to be power of 2
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
|
||||
@@ -77,9 +158,11 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
|
||||
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
|
||||
|
||||
if (order == GGML_SORT_ORDER_ASC) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
|
||||
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||
} else if (order == GGML_SORT_ORDER_DESC) {
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
|
||||
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@@ -100,5 +183,18 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||
|
||||
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
|
||||
#ifdef GGML_CUDA_USE_CUB
|
||||
const int ncols_pad = next_power_of_2(ncols);
|
||||
const size_t shared_mem = ncols_pad * sizeof(int);
|
||||
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
|
||||
|
||||
if (shared_mem > max_shared_mem || ncols > 1024) {
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||
} else {
|
||||
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||
}
|
||||
#else
|
||||
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -272,7 +272,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
|
||||
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
|
||||
|
||||
if (block_nums.z > 65535) {
|
||||
if (block_nums.z > 65535 || block_nums.y > 65535) {
|
||||
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
||||
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
|
||||
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
|
||||
|
||||
@@ -1005,3 +1005,16 @@ struct ggml_backend_cuda_context {
|
||||
return pool(device);
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_cuda_mm_fusion_args_host {
|
||||
const ggml_tensor * x_bias = nullptr;
|
||||
const ggml_tensor * gate = nullptr;
|
||||
const ggml_tensor * gate_bias = nullptr;
|
||||
ggml_glu_op glu_op;
|
||||
};
|
||||
struct ggml_cuda_mm_fusion_args_device {
|
||||
const void * x_bias = nullptr;
|
||||
const void * gate = nullptr;
|
||||
const void * gate_bias = nullptr;
|
||||
ggml_glu_op glu_op;
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#pragma once
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||
|
||||
@@ -2007,6 +2007,147 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up,
|
||||
const ggml_tensor * ffn_gate,
|
||||
const ggml_tensor * glu,
|
||||
const ggml_tensor * ffn_up_bias = nullptr,
|
||||
const ggml_tensor * ffn_gate_bias = nullptr) {
|
||||
const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr;
|
||||
|
||||
if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU;
|
||||
const bool is_mul_mat_id = ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU;
|
||||
|
||||
GGML_ASSERT(ffn_up && ffn_gate && glu);
|
||||
|
||||
if (!is_mul_mat && !is_mul_mat_id) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||
|
||||
if (has_bias) {
|
||||
if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (expected_bias_op == GGML_OP_ADD) {
|
||||
const bool up_has_mul = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up;
|
||||
const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate;
|
||||
if (!up_has_mul || !gate_has_mul) {
|
||||
return false;
|
||||
}
|
||||
} else { // GGML_OP_ADD_ID
|
||||
if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) {
|
||||
return false;
|
||||
}
|
||||
if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) ||
|
||||
!ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ffn_up->src[1] != ffn_gate->src[1]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static constexpr std::array<ggml_glu_op, 3> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI };
|
||||
|
||||
if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) ||
|
||||
ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft);
|
||||
|
||||
//TODO: add support for fusion for split buffers
|
||||
if (split) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
|
||||
ggml_tensor * src0 = tensor->src[0];
|
||||
ggml_tensor * src1 = tensor->src[1];
|
||||
const ggml_tensor * dst = tensor;
|
||||
|
||||
const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;
|
||||
|
||||
bool use_mul_mat_vec_f =
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
|
||||
src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
|
||||
|
||||
//we only support fusion for ncols_dst = 1
|
||||
if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
return use_mul_mat_vec_f;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
|
||||
ggml_tensor * src0 = tensor->src[0];
|
||||
ggml_tensor * src1 = tensor->src[1];
|
||||
const ggml_tensor * dst = tensor;
|
||||
|
||||
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&
|
||||
ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
|
||||
src0->view_src;
|
||||
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
|
||||
dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
|
||||
// fusion is not universally faster on Pascal
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
if (cc <= GGML_CUDA_CC_PASCAL) {
|
||||
return false;
|
||||
}
|
||||
//we only support fusion for ncols_dst = 1
|
||||
if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return use_mul_mat_vec_q;
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
|
||||
|
||||
@@ -2745,7 +2886,7 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
||||
}
|
||||
}
|
||||
|
||||
if (node->op == GGML_OP_SCALE &&
|
||||
if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
|
||||
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
|
||||
return false;
|
||||
}
|
||||
@@ -2854,6 +2995,38 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
}
|
||||
}
|
||||
|
||||
std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
|
||||
std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
|
||||
|
||||
std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
|
||||
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
|
||||
|
||||
if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) ||
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) {
|
||||
|
||||
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
|
||||
const ggml_tensor * ffn_up_bias = cgraph->nodes[node_idx + 3];
|
||||
const ggml_tensor * glu = cgraph->nodes[node_idx + 4];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) ||
|
||||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) {
|
||||
|
||||
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
|
||||
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
|
||||
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
||||
return false;
|
||||
}
|
||||
@@ -3004,6 +3177,184 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
}
|
||||
}
|
||||
|
||||
bool fused_mul_mat_vec = false;
|
||||
int fused_node_count = 0;
|
||||
|
||||
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
|
||||
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
|
||||
ggml_tensor * glu = cgraph->nodes[i + 4];
|
||||
ggml_tensor * gate_bias_n = glu->src[0];
|
||||
ggml_tensor * up_bias_n = glu->src[1];
|
||||
|
||||
//we don't assume the order for {gate, up}. Instead infer it from the bias tensor
|
||||
ggml_tensor * gate_n = nullptr;
|
||||
ggml_tensor * up_n = nullptr;
|
||||
|
||||
if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
|
||||
gate_n = cgraph->nodes[i];
|
||||
up_n = cgraph->nodes[i + 2];
|
||||
} else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
|
||||
gate_n = cgraph->nodes[i + 2];
|
||||
up_n = cgraph->nodes[i];
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
|
||||
if (op_bias == GGML_OP_ADD) {
|
||||
if (bias_node->src[0] == mul_node) {
|
||||
return bias_node->src[1];
|
||||
}
|
||||
if (bias_node->src[1] == mul_node) {
|
||||
return bias_node->src[0];
|
||||
}
|
||||
return (ggml_tensor *) nullptr;
|
||||
}
|
||||
GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
|
||||
GGML_ASSERT(bias_node->src[0] == mul_node);
|
||||
return bias_node->src[1];
|
||||
};
|
||||
|
||||
ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
|
||||
ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
|
||||
|
||||
if (!up_bias_tensor || !gate_bias_tensor) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = up_n->src[0];
|
||||
const ggml_tensor * src1 = up_n->src[1];
|
||||
const ggml_tensor * ids = up_n->src[2];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate_n->src[0];
|
||||
fusion_data.x_bias = up_bias_tensor;
|
||||
fusion_data.gate_bias = gate_bias_tensor;
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 5;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate_n->src[0];
|
||||
fusion_data.x_bias = up_bias_tensor;
|
||||
fusion_data.gate_bias = gate_bias_tensor;
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 5;
|
||||
break;
|
||||
}
|
||||
} else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
|
||||
ggml_tensor * glu = cgraph->nodes[i + 2];
|
||||
ggml_tensor * gate = glu->src[0];
|
||||
ggml_tensor * up = glu->src[1];
|
||||
|
||||
bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])
|
||||
|| (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
|
||||
|
||||
if (!ok) continue;
|
||||
|
||||
const ggml_tensor * src0 = up->src[0];
|
||||
const ggml_tensor * src1 = up->src[1];
|
||||
const ggml_tensor * ids = up->src[2];
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate->src[0];
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 3;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.gate = gate->src[0];
|
||||
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||
|
||||
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 3;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (fused_mul_mat_vec) {
|
||||
i += fused_node_count - 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
fused_mul_mat_vec = false;
|
||||
fused_node_count = 0;
|
||||
|
||||
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
|
||||
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||
|
||||
if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_tensor * mm_node = cgraph->nodes[i];
|
||||
ggml_tensor * bias_node = cgraph->nodes[i + 1];
|
||||
|
||||
ggml_tensor * bias_tensor = nullptr;
|
||||
if (bias_op == GGML_OP_ADD) {
|
||||
if (bias_node->src[0] == mm_node) {
|
||||
bias_tensor = bias_node->src[1];
|
||||
} else if (bias_node->src[1] == mm_node) {
|
||||
bias_tensor = bias_node->src[0];
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (bias_node->src[0] != mm_node) {
|
||||
continue;
|
||||
}
|
||||
bias_tensor = bias_node->src[1];
|
||||
}
|
||||
|
||||
const ggml_tensor * src0 = mm_node->src[0];
|
||||
const ggml_tensor * src1 = mm_node->src[1];
|
||||
const ggml_tensor * ids = mm_node->src[2];
|
||||
|
||||
if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||
fusion_data.x_bias = bias_tensor;
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
|
||||
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 2;
|
||||
break;
|
||||
}
|
||||
|
||||
if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
|
||||
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
|
||||
fused_mul_mat_vec = true;
|
||||
fused_node_count = 2;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (fused_mul_mat_vec) {
|
||||
i += fused_node_count - 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
|
||||
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
|
||||
@@ -3642,8 +3993,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_SUM:
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_ARGSORT:
|
||||
// TODO: Support arbitrary column width
|
||||
#ifndef GGML_CUDA_USE_CUB
|
||||
return op->src[0]->ne[0] <= 1024;
|
||||
#else
|
||||
return true;
|
||||
#endif
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
#include "ggml.h"
|
||||
#include "common.cuh"
|
||||
#include "convert.cuh"
|
||||
#include "unary.cuh"
|
||||
#include "mmvf.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
template <typename T, typename type_acc, int ncols_dst, int block_size>
|
||||
template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
|
||||
static __global__ void mul_mat_vec_f(
|
||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
|
||||
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
||||
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||
@@ -24,58 +25,164 @@ static __global__ void mul_mat_vec_f(
|
||||
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
|
||||
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
|
||||
|
||||
bool use_gate = false;
|
||||
bool use_bias = false;
|
||||
bool use_gate_bias = false;
|
||||
ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU;
|
||||
const T * gate_x = nullptr;
|
||||
const float * x_bias = nullptr;
|
||||
const float * gate_bias = nullptr;
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
use_gate = fusion.gate != nullptr;
|
||||
use_bias = fusion.x_bias != nullptr;
|
||||
use_gate_bias = fusion.gate_bias != nullptr;
|
||||
glu_op = fusion.glu_op;
|
||||
|
||||
if (use_gate) {
|
||||
gate_x = static_cast<const T *>(fusion.gate);
|
||||
}
|
||||
if (use_bias) {
|
||||
x_bias = static_cast<const float *>(fusion.x_bias);
|
||||
}
|
||||
if (use_gate_bias) {
|
||||
gate_bias = static_cast<const float *>(fusion.gate_bias);
|
||||
use_gate_bias = use_gate;
|
||||
} else {
|
||||
use_gate_bias = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (use_gate) {
|
||||
gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
||||
}
|
||||
if constexpr (has_fusion) {
|
||||
const int channel_bias = ids ? channel_x : channel_dst;
|
||||
if (use_bias) {
|
||||
x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
|
||||
}
|
||||
if (use_gate_bias) {
|
||||
gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
|
||||
}
|
||||
}
|
||||
|
||||
const float2 * y2 = (const float2 *) y;
|
||||
|
||||
extern __shared__ char data_mmv[];
|
||||
float * buf_iw = (float *) data_mmv;
|
||||
float * buf_iw_gate = nullptr;
|
||||
if constexpr (has_fusion) {
|
||||
buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float));
|
||||
}
|
||||
|
||||
if (block_size > warp_size) {
|
||||
if (tid < warp_size) {
|
||||
buf_iw[tid] = 0.0f;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
buf_iw_gate[tid] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
float sumf[ncols_dst] = {0.0f};
|
||||
float sumf_gate[ncols_dst];
|
||||
if constexpr (has_fusion) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
sumf_gate[j] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same_v<T, float>) {
|
||||
const float2 * x2 = (const float2 *) x;
|
||||
const float2 * gate_x2 = nullptr;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
gate_x2 = (const float2 *) gate_x;
|
||||
}
|
||||
}
|
||||
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const float2 tmpx = x2[col2];
|
||||
float2 tmpx_gate = make_float2(0.0f, 0.0f);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmpx_gate = gate_x2[col2];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if constexpr (std::is_same_v<T, half>) {
|
||||
const half2 * x2 = (const half2 *) x;
|
||||
const half2 * gate_x2 = nullptr;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
gate_x2 = (const half2 *) gate_x;
|
||||
}
|
||||
}
|
||||
|
||||
if (std::is_same_v<type_acc, float>) {
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const float2 tmpx = __half22float2(x2[col2]);
|
||||
|
||||
float2 tmpx_gate = make_float2(0.0f, 0.0f);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmpx_gate = __half22float2(gate_x2[col2]);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#ifdef FP16_AVAILABLE
|
||||
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
|
||||
half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};
|
||||
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const half2 tmpx = x2[col2];
|
||||
|
||||
half2 tmpx_gate = make_half2(0.0f, 0.0f);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmpx_gate = gate_x2[col2];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,6 +190,15 @@ static __global__ void mul_mat_vec_f(
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
|
||||
}
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
NO_DEVICE_CODE;
|
||||
#endif // FP16_AVAILABLE
|
||||
@@ -91,8 +207,20 @@ static __global__ void mul_mat_vec_f(
|
||||
//TODO: add support for ggml_cuda_mad for hip_bfloat162
|
||||
#if defined(GGML_USE_HIP)
|
||||
const int * x2 = (const int *) x;
|
||||
const int * gate_x2 = nullptr;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
gate_x2 = (const int *) gate_x;
|
||||
}
|
||||
}
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const int tmpx = x2[col2];
|
||||
int tmpx_gate = 0;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmpx_gate = gate_x2[col2];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
@@ -100,17 +228,45 @@ static __global__ void mul_mat_vec_f(
|
||||
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
|
||||
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]);
|
||||
const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]);
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x);
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
|
||||
const nv_bfloat162 * gate_x2 = nullptr;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
gate_x2 = (const nv_bfloat162 *) gate_x;
|
||||
}
|
||||
}
|
||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
const nv_bfloat162 tmpx = x2[col2];
|
||||
nv_bfloat162 tmpx_gate;
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmpx_gate = gate_x2[col2];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
|
||||
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -122,13 +278,31 @@ static __global__ void mul_mat_vec_f(
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
|
||||
}
|
||||
}
|
||||
|
||||
if (block_size > warp_size) {
|
||||
buf_iw[tid/warp_size] = sumf[j];
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
buf_iw_gate[tid/warp_size] = sumf_gate[j];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (tid < warp_size) {
|
||||
sumf[j] = buf_iw[tid];
|
||||
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
sumf_gate[j] = buf_iw_gate[tid];
|
||||
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (j < ncols_dst) {
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -139,12 +313,70 @@ static __global__ void mul_mat_vec_f(
|
||||
return;
|
||||
}
|
||||
|
||||
dst[tid*stride_col_dst + row] = sumf[tid];
|
||||
float value = sumf[tid];
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_bias) {
|
||||
value += x_bias[tid*stride_col_dst + row];
|
||||
}
|
||||
|
||||
if (use_gate) {
|
||||
float gate_value = sumf_gate[tid];
|
||||
if (use_gate_bias) {
|
||||
gate_value += gate_bias[tid*stride_col_dst + row];
|
||||
}
|
||||
switch (glu_op) {
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
value *= ggml_cuda_op_silu_single(gate_value);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
value *= ggml_cuda_op_gelu_single(gate_value);
|
||||
break;
|
||||
case GGML_GLU_OP_SWIGLU_OAI: {
|
||||
value = ggml_cuda_op_swiglu_oai_single(gate_value, value);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst[tid*stride_col_dst + row] = value;
|
||||
}
|
||||
|
||||
template<typename T, typename type_acc, int ncols_dst, int block_size>
|
||||
static void mul_mat_vec_f_switch_fusion(
|
||||
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int64_t ncols, const int64_t nrows,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||
const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
|
||||
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
if constexpr (ncols_dst == 1) {
|
||||
if (has_fusion) {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
|
||||
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
|
||||
}
|
||||
|
||||
template <typename T, typename type_acc, int ncols_dst>
|
||||
static void launch_mul_mat_vec_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
void launch_mul_mat_vec_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int64_t ncols, const int64_t nrows,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
@@ -176,57 +408,59 @@ static void launch_mul_mat_vec_f_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
const int nbytes_shared = warp_size*sizeof(float);
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
|
||||
const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
|
||||
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
|
||||
const dim3 block_dims(block_size_best, 1, 1);
|
||||
switch (block_size_best) {
|
||||
case 32: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 64: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 96: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 128: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 160: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 192: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 224: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
case 256: {
|
||||
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
|
||||
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||
} break;
|
||||
default: {
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -236,7 +470,7 @@ static void launch_mul_mat_vec_f_cuda(
|
||||
|
||||
template <typename T, typename type_acc>
|
||||
static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
@@ -246,49 +480,49 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
||||
switch (ncols_dst) {
|
||||
case 1:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 1>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 2:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 2>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 3:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 3>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 4:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 4>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 5:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 5>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 6:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 6>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 7:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 7>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case 8:
|
||||
launch_mul_mat_vec_f_cuda<T, type_acc, 8>
|
||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
@@ -300,29 +534,31 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
||||
|
||||
template<typename T>
|
||||
static void mul_mat_vec_f_cuda(
|
||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
||||
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
|
||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||
enum ggml_prec prec, cudaStream_t stream) {
|
||||
|
||||
if constexpr(std::is_same_v<T, half>) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
|
||||
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
return;
|
||||
}
|
||||
}
|
||||
mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
|
||||
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
|
||||
const ggml_cuda_mm_fusion_args_host * fusion) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
@@ -348,6 +584,30 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
ggml_cuda_mm_fusion_args_device fusion_local{};
|
||||
|
||||
if (fusion) {
|
||||
GGML_ASSERT( !ids || dst->ne[2] == 1);
|
||||
GGML_ASSERT( ids || dst->ne[1] == 1);
|
||||
if (fusion->x_bias) {
|
||||
GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
|
||||
GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
|
||||
fusion_local.x_bias = fusion->x_bias->data;
|
||||
}
|
||||
if (fusion->gate) {
|
||||
GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
|
||||
fusion_local.gate = fusion->gate->data;
|
||||
}
|
||||
if (fusion->gate_bias) {
|
||||
GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
|
||||
GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
|
||||
fusion_local.gate_bias = fusion->gate_bias->data;
|
||||
}
|
||||
fusion_local.glu_op = fusion->glu_op;
|
||||
}
|
||||
|
||||
const int64_t s01 = src0->nb[1] / ts_src0;
|
||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||
const int64_t s1 = dst->nb[1] / ts_dst;
|
||||
@@ -370,19 +630,19 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half * src0_d = (const half *) src0->data;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||
} break;
|
||||
@@ -409,7 +669,6 @@ void ggml_cuda_op_mul_mat_vec_f(
|
||||
const int cc = ggml_cuda_info().devices[id].cc;
|
||||
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
||||
|
||||
|
||||
// ggml_cuda_op provides single, contiguous matrices
|
||||
const int64_t stride_row = ne00;
|
||||
const int64_t stride_col_y = ne10;
|
||||
@@ -426,22 +685,23 @@ void ggml_cuda_op_mul_mat_vec_f(
|
||||
const int64_t stride_sample_y = 0;
|
||||
const int64_t stride_sample_dst = 0;
|
||||
|
||||
ggml_cuda_mm_fusion_args_device empty{};
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0_dd_i;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
case GGML_TYPE_F16: {
|
||||
const half * src0_d = (const half *) src0_dd_i;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||
} break;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
|
||||
const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_f(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "mmvq.cuh"
|
||||
#include "quantize.cuh"
|
||||
#include "unary.cuh"
|
||||
#include "vecdotq.cuh"
|
||||
|
||||
#include <cstdint>
|
||||
@@ -82,7 +83,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
|
||||
return MMVQ_PARAMETERS_GENERIC;
|
||||
}
|
||||
|
||||
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
|
||||
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
|
||||
if (table_id == MMVQ_PARAMETERS_GENERIC) {
|
||||
switch (ncols_dst) {
|
||||
case 1:
|
||||
@@ -136,11 +137,11 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <ggml_type type, int ncols_dst>
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
template <ggml_type type, int ncols_dst, bool has_fusion>
|
||||
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
|
||||
static __global__ void mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
|
||||
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
|
||||
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
|
||||
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
|
||||
@@ -169,8 +170,38 @@ static __global__ void mul_mat_vec_q(
|
||||
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
|
||||
const uint32_t sample_y = sample_dst;
|
||||
|
||||
bool use_gate = false;
|
||||
bool use_bias = false;
|
||||
bool use_gate_bias = false;
|
||||
const void * vgate = nullptr;
|
||||
const float * x_bias = nullptr;
|
||||
const float * gate_bias = nullptr;
|
||||
ggml_glu_op active_glu;
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
use_gate = fusion.gate != nullptr;
|
||||
use_bias = fusion.x_bias != nullptr;
|
||||
use_gate_bias = fusion.gate_bias != nullptr && use_gate;
|
||||
vgate = fusion.gate;
|
||||
x_bias = (const float *) fusion.x_bias;
|
||||
gate_bias = (const float *) fusion.gate_bias;
|
||||
active_glu = fusion.glu_op;
|
||||
}
|
||||
|
||||
const uint32_t channel_bias = ids ? channel_x : channel_dst;
|
||||
|
||||
if constexpr (has_fusion) {
|
||||
if (use_bias) {
|
||||
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
||||
}
|
||||
if (use_gate_bias) {
|
||||
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
||||
}
|
||||
}
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
|
||||
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
|
||||
|
||||
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
|
||||
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
|
||||
@@ -187,17 +218,35 @@ static __global__ void mul_mat_vec_q(
|
||||
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||
tmp[j][i] += vec_dot_q_cuda(
|
||||
vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmp_gate[j][i] += vec_dot_q_cuda(
|
||||
vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
|
||||
__shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
|
||||
if constexpr (!has_fusion) {
|
||||
(void) tmp_shared_gate;
|
||||
} else if (!use_gate) {
|
||||
(void) tmp_shared_gate;
|
||||
}
|
||||
|
||||
if (threadIdx.y > 0) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_dst; ++j) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -216,12 +265,49 @@ static __global__ void mul_mat_vec_q(
|
||||
#pragma unroll
|
||||
for (int l = 0; l < nwarps-1; ++l) {
|
||||
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x];
|
||||
}
|
||||
}
|
||||
}
|
||||
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
|
||||
if constexpr (has_fusion) {
|
||||
if (use_gate) {
|
||||
tmp_gate[j][i] = warp_reduce_sum<warp_size>(tmp_gate[j][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
|
||||
dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x];
|
||||
float result = tmp[j][threadIdx.x];
|
||||
if constexpr (has_fusion) {
|
||||
if (use_bias) {
|
||||
result += x_bias[j*stride_col_dst + threadIdx.x];
|
||||
}
|
||||
if (use_gate) {
|
||||
float gate_value = tmp_gate[j][threadIdx.x];
|
||||
if (use_gate_bias) {
|
||||
gate_value += gate_bias[j*stride_col_dst + threadIdx.x];
|
||||
}
|
||||
switch (active_glu) {
|
||||
case GGML_GLU_OP_SWIGLU:
|
||||
result *= ggml_cuda_op_silu_single(gate_value);
|
||||
break;
|
||||
case GGML_GLU_OP_GEGLU:
|
||||
result *= ggml_cuda_op_gelu_single(gate_value);
|
||||
break;
|
||||
case GGML_GLU_OP_SWIGLU_OAI: {
|
||||
result = ggml_cuda_op_swiglu_oai_single(gate_value, result);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
result = result * gate_value;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
dst[j*stride_col_dst + threadIdx.x] = result;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -235,9 +321,37 @@ static std::pair<dim3, dim3> calc_launch_params(
|
||||
return {block_nums, block_dims};
|
||||
}
|
||||
|
||||
template<ggml_type type, int c_ncols_dst>
|
||||
static void mul_mat_vec_q_switch_fusion(
|
||||
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
|
||||
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
|
||||
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
|
||||
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
|
||||
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) {
|
||||
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
if constexpr (c_ncols_dst == 1) {
|
||||
if (has_fusion) {
|
||||
mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
|
||||
|
||||
mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
}
|
||||
|
||||
template <ggml_type type>
|
||||
static void mul_mat_vec_q_switch_ncols_dst(
|
||||
const void * vx, const void * vy, const int32_t * ids, float * dst,
|
||||
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst,
|
||||
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
|
||||
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
|
||||
@@ -256,80 +370,83 @@ static void mul_mat_vec_q_switch_ncols_dst(
|
||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
|
||||
|
||||
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||
|
||||
GGML_ASSERT(!ids || ncols_dst == 1);
|
||||
switch (ncols_dst) {
|
||||
case 1: {
|
||||
constexpr int c_ncols_dst = 1;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 2: {
|
||||
constexpr int c_ncols_dst = 2;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 3: {
|
||||
constexpr int c_ncols_dst = 3;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 4: {
|
||||
constexpr int c_ncols_dst = 4;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 5: {
|
||||
constexpr int c_ncols_dst = 5;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 6: {
|
||||
constexpr int c_ncols_dst = 6;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 7: {
|
||||
constexpr int c_ncols_dst = 7;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
case 8: {
|
||||
constexpr int c_ncols_dst = 8;
|
||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
dims.first, dims.second, 0, stream);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
GGML_UNUSED(has_fusion);
|
||||
}
|
||||
static void mul_mat_vec_q_switch_type(
|
||||
const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst,
|
||||
const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_dst,
|
||||
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
|
||||
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
|
||||
@@ -339,143 +456,123 @@ static void mul_mat_vec_q_switch_type(
|
||||
switch (type_x) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_MXFP4:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_S:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ1_M:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_S:
|
||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
|
||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||
stream);
|
||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -484,7 +581,8 @@ static void mul_mat_vec_q_switch_type(
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_vec_q(
|
||||
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
|
||||
const ggml_cuda_mm_fusion_args_host * fusion) {
|
||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
|
||||
@@ -508,6 +606,31 @@ void ggml_cuda_mul_mat_vec_q(
|
||||
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
ggml_cuda_mm_fusion_args_device fusion_local{};
|
||||
|
||||
if (fusion) {
|
||||
GGML_ASSERT( !ids || dst->ne[2] == 1);
|
||||
GGML_ASSERT( ids || dst->ne[1] == 1);
|
||||
|
||||
if (fusion->x_bias) {
|
||||
GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
|
||||
GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
|
||||
fusion_local.x_bias = fusion->x_bias->data;
|
||||
}
|
||||
if (fusion->gate) {
|
||||
GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
|
||||
fusion_local.gate = fusion->gate->data;
|
||||
}
|
||||
if (fusion->gate_bias) {
|
||||
GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
|
||||
GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
|
||||
fusion_local.gate_bias = fusion->gate_bias->data;
|
||||
}
|
||||
fusion_local.glu_op = fusion->glu_op;
|
||||
}
|
||||
|
||||
// If src0 is a temporary compute buffer, clear any potential padding.
|
||||
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
|
||||
const size_t size_data = ggml_nbytes(src0);
|
||||
@@ -549,10 +672,10 @@ void ggml_cuda_mul_mat_vec_q(
|
||||
const int64_t stride_channel_y = ids ? s11 : s12;
|
||||
|
||||
mul_mat_vec_q_switch_type(
|
||||
src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00,
|
||||
src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
|
||||
ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
|
||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03, s13, s3, stream);
|
||||
ne03, ne3, s03, s13, s3, stream);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_q(
|
||||
@@ -578,8 +701,9 @@ void ggml_cuda_op_mul_mat_vec_q(
|
||||
const int stride_row_x = ne00 / ggml_blck_size(src0->type);
|
||||
const int stride_col_y = src1_padded_row_size / QK8_1;
|
||||
|
||||
ggml_cuda_mm_fusion_args_device fusion_local{};
|
||||
mul_mat_vec_q_switch_type(
|
||||
src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
|
||||
src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
|
||||
|
||||
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
|
||||
|
||||
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
||||
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
|
||||
|
||||
void ggml_cuda_op_mul_mat_vec_q(
|
||||
ggml_backend_cuda_context & ctx,
|
||||
|
||||
@@ -18,10 +18,7 @@ static __device__ __forceinline__ float op_step(float x) {
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_gelu(float x) {
|
||||
const float GELU_COEF_A = 0.044715f;
|
||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
||||
return ggml_cuda_op_gelu_single(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_gelu_erf(float x) {
|
||||
@@ -37,7 +34,7 @@ static __device__ __forceinline__ float op_gelu_quick(float x) {
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_silu(float x) {
|
||||
return x / (1.0f + expf(-x));
|
||||
return ggml_cuda_op_silu_single(x);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float op_tanh(float x) {
|
||||
@@ -317,13 +314,8 @@ static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, cons
|
||||
|
||||
float xi = x[j0];
|
||||
float gi = g[j1];
|
||||
xi = fminf(xi, limit);
|
||||
gi = fmaxf(fminf(gi, limit), -limit);
|
||||
|
||||
float out_glu = xi / (1.0f + expf(-xi * alpha));
|
||||
out_glu = out_glu * (1.0f + gi);
|
||||
|
||||
dst[i] = out_glu;
|
||||
dst[i] = ggml_cuda_op_swiglu_oai_single(xi, gi, alpha, limit);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#pragma once
|
||||
#include "common.cuh"
|
||||
|
||||
#define CUDA_NEG_BLOCK_SIZE 256
|
||||
@@ -75,3 +76,23 @@ void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
|
||||
return x / (1.0f + expf(-x));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ggml_cuda_op_gelu_single(float x) {
|
||||
const float GELU_COEF_A = 0.044715f;
|
||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
return 0.5f * x * (1.0f + tanhf(SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x)));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float ggml_cuda_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {
|
||||
x = fminf(x, limit);
|
||||
g = fmaxf(fminf(g, limit), -limit);
|
||||
|
||||
float out_glu = x / (1.0f + expf(-x * alpha));
|
||||
out_glu = out_glu * (1.0f + g);
|
||||
return out_glu;
|
||||
}
|
||||
|
||||
@@ -96,8 +96,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
||||
|
||||
#define GGML_VK_MAX_NODES 8192
|
||||
|
||||
#define MAX_VK_BUFFERS 256
|
||||
|
||||
#define VK_CHECK(err, msg) \
|
||||
do { \
|
||||
vk::Result err_ = (err); \
|
||||
@@ -1311,7 +1309,6 @@ struct ggml_vk_garbage_collector {
|
||||
std::vector<vk_semaphore> tl_semaphores;
|
||||
std::vector<vk_semaphore> semaphores;
|
||||
std::vector<vk::Event> events;
|
||||
std::vector<vk_buffer> temp_buffers;
|
||||
std::vector<vk_context> contexts;
|
||||
};
|
||||
|
||||
@@ -1482,8 +1479,6 @@ struct ggml_backend_vk_context {
|
||||
// and set to true after the buffer contents are consumed.
|
||||
bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
|
||||
|
||||
vk_buffer buffer_pool[MAX_VK_BUFFERS];
|
||||
|
||||
vk_context_ref compute_ctx;
|
||||
vk_context_ref transfer_ctx;
|
||||
|
||||
@@ -3623,8 +3618,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
|
||||
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
|
||||
} else {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
|
||||
|
||||
@@ -4733,7 +4733,14 @@ static void ggml_vk_instance_init() {
|
||||
vk::PhysicalDeviceIDProperties old_id;
|
||||
old_props.pNext = &old_id;
|
||||
devices[k].getProperties2(&old_props);
|
||||
return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
|
||||
|
||||
bool equals = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
|
||||
equals = equals || (
|
||||
old_id.deviceLUIDValid && new_id.deviceLUIDValid &&
|
||||
std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID))
|
||||
);
|
||||
|
||||
return equals;
|
||||
}
|
||||
);
|
||||
if (old_device == vk_instance.device_indices.end()) {
|
||||
@@ -4771,6 +4778,7 @@ static void ggml_vk_instance_init() {
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
driver_priorities[vk::DriverId::eMesaDozen] = 100;
|
||||
|
||||
if (driver_priorities.count(old_driver.driverID)) {
|
||||
old_priority = driver_priorities[old_driver.driverID];
|
||||
@@ -5144,71 +5152,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
||||
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
|
||||
}
|
||||
|
||||
static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
|
||||
VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")");
|
||||
VK_LOG_MEMORY("ggml_vk_pool_malloc");
|
||||
|
||||
int best_i = -1;
|
||||
size_t best_size = std::numeric_limits<size_t>::max(); //smallest unused buffer that fits our needs
|
||||
int worst_i = -1;
|
||||
size_t worst_size = 0; //largest unused buffer seen so far
|
||||
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
|
||||
vk_buffer &b = ctx->buffer_pool[i];
|
||||
if (b != nullptr && b->size >= size && b->size < best_size) {
|
||||
best_i = i;
|
||||
best_size = b->size;
|
||||
}
|
||||
if (b != nullptr && b->size > worst_size) {
|
||||
worst_i = i;
|
||||
worst_size = b->size;
|
||||
}
|
||||
}
|
||||
if(best_i != -1) {
|
||||
//found the smallest buffer that fits our needs
|
||||
vk_buffer b = ctx->buffer_pool[best_i];
|
||||
ctx->buffer_pool[best_i].reset();
|
||||
return b;
|
||||
}
|
||||
if(worst_i != -1) {
|
||||
//no buffer that fits our needs, resize largest one to save memory
|
||||
vk_buffer& b = ctx->buffer_pool[worst_i];
|
||||
ggml_vk_destroy_buffer(b);
|
||||
}
|
||||
|
||||
return ggml_vk_create_buffer_device(ctx->device, size);
|
||||
}
|
||||
|
||||
static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) {
|
||||
VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")");
|
||||
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
|
||||
vk_buffer& b = ctx->buffer_pool[i];
|
||||
if (b == nullptr) {
|
||||
b = buffer;
|
||||
return;
|
||||
}
|
||||
}
|
||||
std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl;
|
||||
ggml_vk_destroy_buffer(buffer);
|
||||
}
|
||||
|
||||
// Returns an available temporary buffer that may only be used temporarily, it will be reused
|
||||
static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) {
|
||||
// Try to find existing temp buffer with enough capacity
|
||||
for (auto& buffer : ctx->gc.temp_buffers) {
|
||||
if (buffer->size >= size) {
|
||||
return buffer;
|
||||
}
|
||||
}
|
||||
|
||||
VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")");
|
||||
|
||||
// Otherwise create new buffer
|
||||
vk_buffer buf = ggml_vk_pool_malloc(ctx, size);
|
||||
ctx->gc.temp_buffers.push_back(buf);
|
||||
|
||||
return buf;
|
||||
}
|
||||
|
||||
static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
|
||||
VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
|
||||
vk_buffer buf = ggml_vk_create_buffer(device, size,
|
||||
@@ -11789,10 +11732,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
// Clean up after graph processing is done
|
||||
static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
|
||||
VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
|
||||
for (auto& buffer : ctx->gc.temp_buffers) {
|
||||
ggml_vk_pool_free(ctx, buffer);
|
||||
}
|
||||
ctx->gc.temp_buffers.clear();
|
||||
ctx->prealloc_y_last_pipeline_used = {};
|
||||
|
||||
ctx->unsynced_nodes_written.clear();
|
||||
@@ -11835,10 +11774,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
|
||||
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
|
||||
ctx->prealloc_y_last_pipeline_used = nullptr;
|
||||
|
||||
for (auto& buffer : ctx->buffer_pool) {
|
||||
ggml_vk_destroy_buffer(buffer);
|
||||
}
|
||||
|
||||
ctx->prealloc_size_x = 0;
|
||||
ctx->prealloc_size_y = 0;
|
||||
ctx->prealloc_size_split_k = 0;
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#if USE_SUBGROUP_ADD
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#endif
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
@@ -84,35 +87,47 @@ void main() {
|
||||
}
|
||||
|
||||
barrier();
|
||||
for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) {
|
||||
[[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
|
||||
const uint k = (tid % (w >> 1)) +
|
||||
(D_STATE * (tid / (w >> 1))) +
|
||||
j * D_STATE * (D_STATE / (w >> 1));
|
||||
if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) {
|
||||
stateC[k] += stateC[k + (w >> 1)];
|
||||
[[unroll]]
|
||||
for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
|
||||
[[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
|
||||
const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
|
||||
if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
|
||||
stateC[k] += stateC[k + w];
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
|
||||
[[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
|
||||
const uint idx = (tid % SUBGROUP_SIZE) +
|
||||
D_STATE * (tid / SUBGROUP_SIZE) +
|
||||
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
|
||||
const uint max_idx = SUBGROUP_SIZE - 1 +
|
||||
D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
|
||||
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
|
||||
|
||||
uint lane = tid % SUBGROUP_SIZE;
|
||||
|
||||
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
|
||||
if (idx + offset < SPLIT_H * D_STATE) {
|
||||
stateC[idx] += stateC[idx + offset];
|
||||
if (idx < SPLIT_H * D_STATE ||
|
||||
max_idx < SPLIT_H * D_STATE) {
|
||||
float sc;
|
||||
#if USE_SUBGROUP_ADD
|
||||
sc = stateC[idx];
|
||||
sc = subgroupAdd(sc);
|
||||
#else
|
||||
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
|
||||
if (idx + offset < SPLIT_H * D_STATE) {
|
||||
stateC[idx] += stateC[idx + offset];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid % SUBGROUP_SIZE == 0) {
|
||||
sc = stateC[idx];
|
||||
}
|
||||
#endif
|
||||
|
||||
if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
|
||||
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
|
||||
d[y_base_idx + i * stride_y + k] = stateC[idx];
|
||||
if (tid % SUBGROUP_SIZE == 0) {
|
||||
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
|
||||
d[y_base_idx + i * stride_y + k] = sc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -916,7 +916,8 @@ void process_shaders() {
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
|
||||
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
|
||||
|
||||
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
|
||||
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
|
||||
string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
|
||||
|
||||
string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});
|
||||
|
||||
|
||||
@@ -810,6 +810,9 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
//expand here so that we can fuse ffn gate
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
if (gate && type_gate == LLM_FFN_PAR) {
|
||||
cur = ggml_mul(ctx0, cur, tmp);
|
||||
cb(cur, "ffn_gate_par", il);
|
||||
@@ -1006,10 +1009,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
|
||||
cb(weights_sum, "ffn_moe_weights_sum", il);
|
||||
|
||||
if (arch == LLM_ARCH_BAILINGMOE2) {
|
||||
weights_sum = ggml_scale_bias(ctx0, weights_sum, 1.0, 1e-20);
|
||||
cb(weights_sum, "ffn_moe_weights_sum_biased", il);
|
||||
}
|
||||
// Avoid division by zero, clamp to smallest number representable by F16
|
||||
weights_sum = ggml_clamp(ctx0, weights_sum, 6.103515625e-5, INFINITY);
|
||||
cb(weights_sum, "ffn_moe_weights_sum_clamped", il);
|
||||
|
||||
weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
|
||||
cb(weights, "ffn_moe_weights_norm", il);
|
||||
@@ -1091,6 +1093,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
//expand here so that we can fuse ffn gate
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
||||
cb(experts, "ffn_moe_down", il);
|
||||
|
||||
|
||||
@@ -6369,6 +6369,8 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
|
||||
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
|
||||
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
|
||||
LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used);
|
||||
LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn);
|
||||
LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
|
||||
LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
|
||||
@@ -6469,8 +6471,6 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||
LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
|
||||
LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups);
|
||||
LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used);
|
||||
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
|
||||
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
|
||||
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
|
||||
@@ -17965,6 +17965,8 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
|
||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
@@ -19337,6 +19339,7 @@ struct llm_build_smallthinker : public llm_graph_context{
|
||||
|
||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
@@ -4721,6 +4721,140 @@ struct test_topk_moe: public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
struct test_mul_mat_vec_fusion : public test_case {
|
||||
const ggml_type type;
|
||||
const ggml_glu_op glu_op;
|
||||
const int64_t m;
|
||||
const int64_t n;
|
||||
const int64_t k;
|
||||
const bool use_id;
|
||||
const int n_mats;
|
||||
const int n_used;
|
||||
const bool b; // broadcast b matrix (only for use_id)
|
||||
const bool with_bias;
|
||||
const bool with_gate;
|
||||
|
||||
test_mul_mat_vec_fusion(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k,
|
||||
bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true)
|
||||
: type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate) {
|
||||
if (use_id) {
|
||||
GGML_ASSERT(n_used <= n_mats);
|
||||
}
|
||||
}
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR11(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate);
|
||||
}
|
||||
|
||||
std::string op_desc(ggml_tensor * t) override {
|
||||
GGML_UNUSED(t);
|
||||
return "MUL_MAT_VEC_FUSION";
|
||||
}
|
||||
|
||||
bool run_whole_graph() override { return true; }
|
||||
|
||||
ggml_tensor * build_gate(ggml_context * ctx, ggml_tensor * ffn_gate, ggml_tensor * ffn_up) {
|
||||
ggml_tensor * out = nullptr;
|
||||
if (with_gate) {
|
||||
if (glu_op == GGML_GLU_OP_SWIGLU_OAI) {
|
||||
constexpr float alpha = 1.702f;
|
||||
constexpr float limit = 7.0f;
|
||||
out = ggml_swiglu_oai(ctx, ffn_gate, ffn_up, alpha, limit);
|
||||
} else {
|
||||
out = ggml_glu_split(ctx, ffn_gate, ffn_up, glu_op);
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
if (!use_id) {
|
||||
std::array<int64_t, 4> ne = {k, m, 1, 1};
|
||||
std::array<int64_t, 4> ne0 = {k, n, 1, 1};
|
||||
|
||||
ggml_tensor * cur = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
|
||||
ggml_tensor * gate = with_gate ? ggml_new_tensor(ctx, type, 4, ne0.data()) : nullptr;
|
||||
ggml_tensor * up = ggml_new_tensor(ctx, type, 4, ne0.data());
|
||||
|
||||
ggml_tensor * ffn_up = ggml_mul_mat(ctx, up, cur);
|
||||
if (with_bias) {
|
||||
std::array<int64_t, 4> bias_ne = {ffn_up->ne[0], 1, 1, 1};
|
||||
ggml_tensor * up_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data());
|
||||
ffn_up = ggml_add(ctx, ffn_up, up_bias);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_gate = with_gate ? ggml_mul_mat(ctx, gate, cur) : nullptr;
|
||||
if (with_bias && with_gate) {
|
||||
std::array<int64_t, 4> bias_ne = {ffn_gate->ne[0], 1, 1, 1};
|
||||
ggml_tensor * gate_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data());
|
||||
ffn_gate = ggml_add(ctx, ffn_gate, gate_bias);
|
||||
}
|
||||
|
||||
ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
|
||||
ggml_set_name(out, "out");
|
||||
return out;
|
||||
} else {
|
||||
ggml_tensor * gates = ggml_new_tensor_3d(ctx, type, k, n, n_mats);
|
||||
ggml_tensor * ups = ggml_new_tensor_3d(ctx, type, k, n, n_mats);
|
||||
ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, m);
|
||||
|
||||
if (n_used != n_mats) {
|
||||
ids = ggml_view_2d(ctx, ids, n_used, m, ids->nb[1], 0);
|
||||
}
|
||||
|
||||
ggml_tensor * cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k, this->b ? 1 : n_used, m);
|
||||
ggml_set_name(cur, "cur");
|
||||
|
||||
ggml_tensor * ffn_up = ggml_mul_mat_id(ctx, ups, cur, ids);
|
||||
if (with_bias) {
|
||||
ggml_tensor * up_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_up->ne[0], n_mats);
|
||||
ffn_up = ggml_add_id(ctx, ffn_up, up_bias_param, ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_gate = with_gate? ggml_mul_mat_id(ctx, gates, cur, ids) : nullptr;
|
||||
if (with_bias && with_gate) {
|
||||
ggml_tensor * gate_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_gate->ne[0], n_mats);
|
||||
ffn_gate = ggml_add_id(ctx, ffn_gate, gate_bias_param, ids);
|
||||
}
|
||||
|
||||
ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
|
||||
ggml_set_name(out, "out");
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
if (!use_id) {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
} else {
|
||||
std::random_device rd;
|
||||
std::default_random_engine rng(rd());
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
if (ggml_is_view_op(t->op)) { continue; }
|
||||
// ids
|
||||
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||
std::vector<int32_t> data(t->ne[0]);
|
||||
for (int i = 0; i < t->ne[0]; i++) {
|
||||
data[i] = i % n_mats;
|
||||
}
|
||||
std::shuffle(data.begin(), data.end(), rng);
|
||||
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
|
||||
}
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
return 5e-3;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_SUM
|
||||
struct test_sum : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -6407,6 +6541,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
add_test_bin_bcast(type, {1, 1, 640, 1}, {32, 32, 1, 1});
|
||||
add_test_bin_bcast(type, {5120, 1, 1, 1}, {1, 256, 1, 1});
|
||||
add_test_bin_bcast(type, {640, 1, 1, 1}, {1, 1, 1, 1});
|
||||
add_test_bin_bcast(type, {64, 262144, 1, 1}, {1, 1, 1, 1});
|
||||
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {1, 1, 1, 1});
|
||||
//add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
|
||||
}
|
||||
@@ -6982,6 +7117,33 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
|
||||
test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3}));
|
||||
|
||||
for (ggml_type type : base_types) {
|
||||
for (bool with_gate : {false, true}) {
|
||||
for (bool use_id : {false, true}) {
|
||||
for (bool b : {false, true}) {
|
||||
if (!use_id && b) {
|
||||
continue;
|
||||
}
|
||||
for (bool with_bias : {false, true}) {
|
||||
if (!with_gate && !with_bias) {
|
||||
continue;
|
||||
}
|
||||
for (ggml_glu_op glu_op : {GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU}) {
|
||||
if (!with_bias && glu_op == GGML_GLU_OP_SWIGLU_OAI) {
|
||||
continue;
|
||||
}
|
||||
if (!with_gate && glu_op != GGML_GLU_OP_SWIGLU) {
|
||||
continue;
|
||||
}
|
||||
test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
|
||||
use_id, 16, 8, b, with_bias, with_gate));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (bool with_norm : {false, true}) {
|
||||
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
|
||||
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
|
||||
|
||||
Binary file not shown.
@@ -2,6 +2,9 @@
|
||||
import { ChatScreen } from '$lib/components/app';
|
||||
import { chatStore, isInitialized } from '$lib/stores/chat.svelte';
|
||||
import { onMount } from 'svelte';
|
||||
import { page } from '$app/state';
|
||||
|
||||
let qParam = $derived(page.url.searchParams.get('q'));
|
||||
|
||||
onMount(async () => {
|
||||
if (!isInitialized) {
|
||||
@@ -9,6 +12,11 @@
|
||||
}
|
||||
|
||||
chatStore.clearActiveConversation();
|
||||
|
||||
if (qParam !== null) {
|
||||
await chatStore.createConversation();
|
||||
await chatStore.sendMessage(qParam);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user