mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-26 14:23:22 +02:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da348c9dfb | ||
|
|
e6267a9359 | ||
|
|
2bf318fd2f | ||
|
|
c78e682245 | ||
|
|
c5897995a7 | ||
|
|
03fd9d3bb4 | ||
|
|
8004f3a8d1 | ||
|
|
eacb4b67a2 | ||
|
|
c0d0430340 | ||
|
|
3bb2fcc856 | ||
|
|
27326bfce1 | ||
|
|
ad9f692f8f | ||
|
|
8a70973557 | ||
|
|
e7f2f95c9a |
@@ -65,14 +65,25 @@ json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
|
||||
} else if (!content_parts.empty()) {
|
||||
if (concat_typed_text) {
|
||||
std::string text;
|
||||
bool last_was_media_marker = false;
|
||||
// join parts with newline, do not add newline before or after media markers
|
||||
for (const auto & part : content_parts) {
|
||||
if (part.type != "text") {
|
||||
bool add_new_line = true;
|
||||
if (part.type == "text") {
|
||||
add_new_line = !last_was_media_marker && !text.empty();
|
||||
last_was_media_marker = false;
|
||||
} else if (part.type == "media_marker") {
|
||||
add_new_line = false;
|
||||
last_was_media_marker = true;
|
||||
} else {
|
||||
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
if (!text.empty()) {
|
||||
|
||||
if (add_new_line) {
|
||||
text += '\n';
|
||||
}
|
||||
|
||||
text += part.text;
|
||||
}
|
||||
jmsg["content"] = text;
|
||||
@@ -319,7 +330,7 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
|
||||
throw std::invalid_argument("Missing content part type: " + part.dump());
|
||||
}
|
||||
const auto & type = part.at("type");
|
||||
if (type != "text") {
|
||||
if (type != "text" && type != "media_marker") {
|
||||
throw std::invalid_argument("Unsupported content part type: " + type.dump());
|
||||
}
|
||||
common_chat_msg_content_part msg_part;
|
||||
@@ -3307,7 +3318,7 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||
for (const auto & msg : inputs.messages) {
|
||||
auto content = msg.content;
|
||||
for (const auto & part : msg.content_parts) {
|
||||
if (part.type != "text") {
|
||||
if (part.type != "text" && part.type != "media_marker") {
|
||||
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
// for converting from JSON to jinja values
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <cctype>
|
||||
#include <vector>
|
||||
@@ -715,8 +716,46 @@ const func_builtins & value_string_t::get_builtins() const {
|
||||
return args.get_pos(0);
|
||||
}},
|
||||
{"tojson", tojson},
|
||||
{"indent", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("String indent builtin not implemented");
|
||||
{"indent", [](const func_args &args) -> value {
|
||||
args.ensure_count(1, 4);
|
||||
value val_input = args.get_pos(0);
|
||||
value val_width = args.get_kwarg_or_pos("width", 1);
|
||||
const bool first = args.get_kwarg_or_pos("first", 2)->as_bool(); // undefined == false
|
||||
const bool blank = args.get_kwarg_or_pos("blank", 3)->as_bool(); // undefined == false
|
||||
if (!is_val<value_string>(val_input)) {
|
||||
throw raised_exception("indent() first argument must be a string");
|
||||
}
|
||||
std::string indent;
|
||||
if (is_val<value_int>(val_width)) {
|
||||
indent.assign(val_width->as_int(), ' ');
|
||||
} else if (is_val<value_string>(val_width)) {
|
||||
indent = val_width->as_string().str();
|
||||
} else {
|
||||
indent = " ";
|
||||
}
|
||||
std::string indented;
|
||||
std::string input = val_input->as_string().str();
|
||||
std::istringstream iss = std::istringstream(input);
|
||||
std::string line;
|
||||
while (std::getline(iss, line)) {
|
||||
if (!indented.empty()) {
|
||||
indented.push_back('\n');
|
||||
}
|
||||
if ((indented.empty() ? first : (!line.empty() || blank))) {
|
||||
indented += indent;
|
||||
}
|
||||
indented += line;
|
||||
}
|
||||
if (!input.empty() && input.back() == '\n') {
|
||||
indented.push_back('\n');
|
||||
if (blank) {
|
||||
indented += indent;
|
||||
}
|
||||
}
|
||||
|
||||
auto res = mk_val<value_string>(indented);
|
||||
res->val_str.mark_input_based_on(val_input->as_string());
|
||||
return res;
|
||||
}},
|
||||
{"join", [](const func_args &) -> value {
|
||||
throw not_implemented_exception("String join builtin not implemented");
|
||||
|
||||
@@ -1163,6 +1163,9 @@ class TextModel(ModelBase):
|
||||
if chkhsh == "b53802fb28e26d645c3a310b34bfe07da813026ec7c7716883404d5e0f8b1901":
|
||||
# ref: https://huggingface.co/core42/jais-13b
|
||||
res = "jais"
|
||||
if chkhsh == "bc5108ee1eb6a3d600cadd065f63190fbd0554dbc9e4bbd6a0d977970afc8d2a":
|
||||
# ref: https://huggingface.co/inceptionai/Jais-2-8B-Chat
|
||||
res = "jais-2"
|
||||
if chkhsh == "7b3e7548e4308f52a76e8229e4e6cc831195d0d1df43aed21ac6c93da05fec5f":
|
||||
# ref: https://huggingface.co/WisdomShell/CodeShell-7B
|
||||
res = "codeshell"
|
||||
@@ -8633,6 +8636,17 @@ class T5EncoderModel(TextModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Jais2ForCausalLM")
|
||||
class Jais2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.JAIS2
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
head_dim = hparams.get("head_dim", hparams["hidden_size"] // hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_rope_dimension_count(head_dim)
|
||||
|
||||
|
||||
@ModelBase.register("JAISLMHeadModel")
|
||||
class JaisModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.JAIS
|
||||
@@ -10726,7 +10740,7 @@ class LFM2Model(TextModel):
|
||||
def set_gguf_parameters(self):
|
||||
# set num_key_value_heads only for attention layers
|
||||
self.hparams["num_key_value_heads"] = [
|
||||
self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0
|
||||
self.hparams["num_key_value_heads"] if layer_type != "conv" else 0
|
||||
for layer_type in self.hparams["layer_types"]
|
||||
]
|
||||
|
||||
@@ -10912,6 +10926,28 @@ class LFM2AudioModel(ConformerAudioModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Lfm25AudioTokenizer")
|
||||
class LFM25AudioTokenizer(LFM2Model):
|
||||
model_arch = gguf.MODEL_ARCH.LFM2
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_none()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
|
||||
self.gguf_writer.add_embedding_length_out(self.hparams["output_size"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name == "istft.window" or name.startswith("emb.emb"):
|
||||
return
|
||||
|
||||
if name.startswith("lin"):
|
||||
name = name.replace("lin", "dense_2_out")
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("SmallThinkerForCausalLM")
|
||||
class SmallThinkerModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.SMALLTHINKER
|
||||
@@ -11003,13 +11039,17 @@ class ModernBertModel(BertModel):
|
||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# these layers act as MLM head, so we don't need them
|
||||
if name.startswith("decoder."):
|
||||
return
|
||||
|
||||
if name.startswith("model."):
|
||||
name = name[6:]
|
||||
|
||||
if self.cls_out_labels:
|
||||
# For BertForSequenceClassification (direct projection layer)
|
||||
if name == "classifier.weight":
|
||||
name = "classifier.out_proj.weight"
|
||||
|
||||
if name == "classifier.bias":
|
||||
name = "classifier.out_proj.bias"
|
||||
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
|
||||
@@ -114,6 +114,7 @@ models = [
|
||||
{"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", },
|
||||
{"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", },
|
||||
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
|
||||
{"name": "jais-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inceptionai/Jais-2-8B-Chat", },
|
||||
{"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", },
|
||||
{"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", },
|
||||
{"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", },
|
||||
|
||||
@@ -1,333 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
typedef vector unsigned char vec_t;
|
||||
typedef __vector_quad acc_t;
|
||||
|
||||
template <typename TA>
|
||||
class tinyBLAS_Q0_PPC {
|
||||
public:
|
||||
tinyBLAS_Q0_PPC(int64_t k,
|
||||
const TA *A, int64_t lda,
|
||||
const block_q8_0 *B, int64_t ldb,
|
||||
float *C, int64_t ldc,
|
||||
int ith, int nth);
|
||||
|
||||
void matmul(int64_t m, int64_t n);
|
||||
void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
|
||||
vec_t A_pack[mc*kc*2];
|
||||
vec_t B_pack[nc*kc*2];
|
||||
int comparray[mc*kc];
|
||||
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
int64_t ytiles = m / mc;
|
||||
int64_t xtiles = n / nc;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
int64_t duty = (tiles + nth - 1) / nth;
|
||||
int64_t start = duty * ith;
|
||||
int64_t end = start + duty;
|
||||
if (end > tiles) {
|
||||
end = tiles;
|
||||
}
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = (job / xtiles) * mc;
|
||||
int64_t jj = (job % xtiles) * nc;
|
||||
for (int64_t kk = 0; kk < k; kk += kc) {
|
||||
if constexpr(is_Ablock_q4) {
|
||||
packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray);
|
||||
} else {
|
||||
packNormal_large<int8_t, vector signed char>(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray);
|
||||
}
|
||||
packNormal_large<uint8_t, vector unsigned char>(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true);
|
||||
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) {
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I);
|
||||
*c_ptr += *((float*)&fin_res[idx+I]+J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ArrayType>
|
||||
inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) {
|
||||
vector signed int vec_C[4];
|
||||
vector float CA[4] = {0};
|
||||
vector float res[4] = {0};
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
|
||||
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
||||
fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void process_q4_elements(vector signed char (&c)[2], int* ca) {
|
||||
const vector signed char lowMask = vec_splats((signed char)0xF);
|
||||
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
|
||||
const vector signed char v8 = vec_splats((signed char)0x8);
|
||||
vector signed int vsum = {0};
|
||||
vector signed int vsum2 = {0};
|
||||
c[0] = vec_and(c[1], lowMask);
|
||||
c[1] = vec_sr(c[1], v4);
|
||||
c[0] = vec_sub(c[0], v8);
|
||||
c[1] = vec_sub(c[1], v8);
|
||||
vsum = vec_sum4s(c[0], vsum);
|
||||
vsum2 = vec_sum4s(c[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
}
|
||||
|
||||
template <typename V1, typename V2>
|
||||
inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) {
|
||||
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
||||
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
||||
V2 t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
vector unsigned char xor_vector;
|
||||
uint8_t flip_vec = 0x80;
|
||||
xor_vector = vec_splats(flip_vec);
|
||||
t1 = vec_perm(s1, s2, swiz1);
|
||||
t2 = vec_perm(s1, s2, swiz2);
|
||||
t3 = vec_perm(s3, s4, swiz1);
|
||||
t4 = vec_perm(s3, s4, swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
if (flip == true) {
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
}
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset+16);
|
||||
vec_xst(t7, 0, vecOffset+32);
|
||||
vec_xst(t8, 0, vecOffset+48);
|
||||
}
|
||||
|
||||
template<int RM, int RN>
|
||||
inline void kernel(int64_t ii, int64_t jj) {
|
||||
if constexpr(RM == 4 && RN == 8) {
|
||||
KERNEL_4x8(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 4) {
|
||||
KERNEL_8x4(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 8) {
|
||||
KERNEL_8x8(ii,jj);
|
||||
} else {
|
||||
assert(false && "RN/RM values not supported");
|
||||
}
|
||||
}
|
||||
template<int size>
|
||||
void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray);
|
||||
template<typename VA, typename VB>
|
||||
void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip);
|
||||
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n);
|
||||
void KERNEL_4x8(int64_t ii, int64_t jj);
|
||||
void KERNEL_8x4(int64_t ii, int64_t jj);
|
||||
void KERNEL_8x8(int64_t ii, int64_t jj);
|
||||
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN);
|
||||
template <int RM, int RN>
|
||||
void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n);
|
||||
|
||||
void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){
|
||||
for (int I = 0; I<8; I++) {
|
||||
float a_scale = unhalf((A+((ii+I)*lda)+blk)->d);
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d));
|
||||
*((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void process_q8_elements(const int8_t *qs, int *ca) {
|
||||
vector signed char c1 = vec_xl(0, qs);
|
||||
vector signed char c2 = vec_xl(16, qs);
|
||||
vector signed int vsum1 = {0};
|
||||
vector signed int vsum2 = {0};
|
||||
vsum1 = vec_sum4s(c1, vsum1);
|
||||
vsum2 = vec_sum4s(c2, vsum2);
|
||||
vector signed int vsum = vec_add(vsum1, vsum2);
|
||||
*ca = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
}
|
||||
|
||||
template<typename VA, typename VB>
|
||||
void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) {
|
||||
int64_t i, j;
|
||||
block_q8_0 *aoffset = NULL;
|
||||
VA *vecOffset = NULL;
|
||||
block_q8_0* aoffsets[8];
|
||||
__vector_pair arr[8];
|
||||
VB c[8][2] = {0};
|
||||
VB c1[8] = {0}; VB c2[8] = {0};
|
||||
aoffset = const_cast<block_q8_0*>(a);
|
||||
vecOffset = vec;
|
||||
j = (rows >> 3);
|
||||
int index = 0;
|
||||
if (j > 0) {
|
||||
do {
|
||||
for (int it = 0; it < 8; it++)
|
||||
aoffsets[it] = aoffset + it*lda;
|
||||
aoffset += 8 * lda;
|
||||
for (int blk = 0; blk < kc; blk++) {
|
||||
for (int it = 0; it < 8; it++) {
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
||||
c1[it] = c[it][0];
|
||||
c2[it] = c[it][1];
|
||||
if (comparray){
|
||||
process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]);
|
||||
}
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
|
||||
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
|
||||
vecOffset += 256;
|
||||
}
|
||||
j--;
|
||||
index += 8*kc;
|
||||
} while(j > 0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) {
|
||||
int64_t i, j;
|
||||
TA *aoffset = NULL;
|
||||
int8_t *vecOffset = NULL;
|
||||
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
||||
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
||||
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
|
||||
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
|
||||
aoffset = const_cast<TA*>(a);
|
||||
vecOffset = vec;
|
||||
int index = 0;
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
do {
|
||||
aoffset1 = aoffset;
|
||||
aoffset2 = aoffset1 + lda;
|
||||
aoffset3 = aoffset2 + lda;
|
||||
aoffset4 = aoffset3 + lda;
|
||||
aoffset5 = aoffset4 + lda;
|
||||
aoffset6 = aoffset5 + lda;
|
||||
aoffset7 = aoffset6 + lda;
|
||||
aoffset8 = aoffset7 + lda;
|
||||
aoffset += 8 * lda;
|
||||
for (int blk = 0; blk < kc; blk++) {
|
||||
c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset1+blk)->qs));
|
||||
c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset2+blk)->qs));
|
||||
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset3+blk)->qs));
|
||||
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset4+blk)->qs));
|
||||
c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset5+blk)->qs));
|
||||
c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset6+blk)->qs));
|
||||
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset7+blk)->qs));
|
||||
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset8+blk)->qs));
|
||||
|
||||
process_q4_elements(c1, &comparray[index + 8*blk+0]);
|
||||
process_q4_elements(c2, &comparray[index + 8*blk+1]);
|
||||
process_q4_elements(c3, &comparray[index + 8*blk+2]);
|
||||
process_q4_elements(c4, &comparray[index + 8*blk+3]);
|
||||
process_q4_elements(c5, &comparray[index + 8*blk+4]);
|
||||
process_q4_elements(c6, &comparray[index + 8*blk+5]);
|
||||
process_q4_elements(c7, &comparray[index + 8*blk+6]);
|
||||
process_q4_elements(c8, &comparray[index + 8*blk+7]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
|
||||
vecOffset += 256;
|
||||
}
|
||||
j--;
|
||||
index += 8*kc;
|
||||
} while (j > 0);
|
||||
}
|
||||
}
|
||||
|
||||
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) {
|
||||
acc_t acc[8];
|
||||
for (int i = 0; i < mc ; i += 8) {
|
||||
for (int j = 0; j < nc; j += 8) {
|
||||
vector float fin_res[16] = {0};
|
||||
vector float vs[16] = {0};
|
||||
for (int64_t kk = 0; kk < kc; kk+=2) {
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xxsetaccz(&acc[x]);
|
||||
}
|
||||
int A_block_idx = (i/8)*(16*kc) + kk*16;
|
||||
int B_block_idx = (j/8)*(16*kc)+ kk*16;
|
||||
vec_t *A_block = &vec_A[A_block_idx];
|
||||
vec_t *B_block = &vec_B[B_block_idx];
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]);
|
||||
}
|
||||
compute_scale(ii+i, jj+j, l+kk, vs);
|
||||
int c_index = (i/8)*(8*kc)+ kk*8;
|
||||
int* c_block = &comparray[c_index];
|
||||
compute(&acc[0], 0, 0, c_block, vs, fin_res);
|
||||
compute(&acc[1], 4, 4, c_block, vs, fin_res);
|
||||
compute(&acc[2], 0, 8, c_block, vs, fin_res);
|
||||
compute(&acc[3], 4, 12, c_block, vs, fin_res);
|
||||
|
||||
A_block_idx = (i/8)*(16*kc) + (kk+1)*16;
|
||||
B_block_idx = (j/8)*(16*kc)+ (kk+1)*16;
|
||||
A_block = &vec_A[A_block_idx];
|
||||
B_block = &vec_B[B_block_idx];
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc[4], A_block[x], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[6], A_block[x], B_block[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8], B_block[x+8]);
|
||||
}
|
||||
compute_scale(ii+i, jj+j, l+kk+1, vs);
|
||||
c_index = (i/8)*(8*kc)+ (kk+1)*8;
|
||||
c_block = &comparray[c_index];
|
||||
compute(&acc[4], 0, 0, c_block, vs, fin_res);
|
||||
compute(&acc[5], 4, 4, c_block, vs, fin_res);
|
||||
compute(&acc[6], 0, 8, c_block, vs, fin_res);
|
||||
compute(&acc[7], 4, 12, c_block, vs, fin_res);
|
||||
|
||||
}
|
||||
if (l == 0) {
|
||||
save_res(ii+i, jj+j, 0, fin_res);
|
||||
save_res(ii+i+4, jj+j, 4, fin_res);
|
||||
save_res(ii+i, jj+j+4, 8, fin_res);
|
||||
save_res(ii+i+4, jj+j+4, 12, fin_res);
|
||||
} else {
|
||||
add_save_res(ii+i, jj+j, 0, fin_res);
|
||||
add_save_res(ii+i+4, jj+j, 4, fin_res);
|
||||
add_save_res(ii+i, jj+j+4, 8, fin_res);
|
||||
add_save_res(ii+i+4, jj+j+4, 12, fin_res);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const TA *const A;
|
||||
const block_q8_0 *const B;
|
||||
float *C;
|
||||
const int64_t k;
|
||||
int64_t kc;
|
||||
const int64_t lda;
|
||||
const int64_t ldb;
|
||||
const int64_t ldc;
|
||||
const int ith;
|
||||
const int nth;
|
||||
};
|
||||
@@ -121,7 +121,8 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); }
|
||||
#endif
|
||||
|
||||
#if defined(__MMA__)
|
||||
#include "sgemm-ppc.h"
|
||||
typedef vector unsigned char vec_t;
|
||||
typedef __vector_quad acc_t;
|
||||
#endif
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// VECTORIZED FUSED MULTIPLY ADD
|
||||
@@ -2153,7 +2154,7 @@ class tinyBLAS_HP16_PPC {
|
||||
packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B);
|
||||
for (int x = 0; x < 4; x++) {
|
||||
mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]);
|
||||
mma_instr<TA>::outer_product(&acc_1, vec_A[x+4], vec_B[x]);
|
||||
}
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
@@ -2301,43 +2302,299 @@ class tinyBLAS_HP16_PPC {
|
||||
const int nth;
|
||||
};
|
||||
|
||||
template <typename TA>
|
||||
tinyBLAS_Q0_PPC<TA>::tinyBLAS_Q0_PPC(int64_t k,
|
||||
const TA *A, int64_t lda,
|
||||
const block_q8_0 *B, int64_t ldb,
|
||||
float *C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
template <typename TA>
|
||||
class tinyBLAS_Q0_PPC {
|
||||
public:
|
||||
tinyBLAS_Q0_PPC(int64_t k,
|
||||
const TA * A, int64_t lda,
|
||||
const block_q8_0 * B, int64_t ldb,
|
||||
float * C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||
kc = 64;
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::matmul(int64_t m, int64_t n) {
|
||||
int mc = 64; int nc = 64;
|
||||
if (n % 8 == 0 && n < nc) {
|
||||
nc = n;
|
||||
mc = 32 ;
|
||||
kc = 32;
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
const int64_t mc = 64;
|
||||
const int64_t kc = 64;
|
||||
int64_t nc = 64;
|
||||
int64_t n_aligned = 0;
|
||||
if (n % 64 == 0) {
|
||||
n_aligned = n;
|
||||
} else if (n == 4) {
|
||||
n_aligned = 4;
|
||||
} else if (n < 64) {
|
||||
n_aligned = (n / 8) * 8;
|
||||
} else {
|
||||
n_aligned = (n / 64) * 64;
|
||||
}
|
||||
const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0);
|
||||
if (is_aligned) {
|
||||
this->matmul_tiled_q0(m, n, mc, nc, kc);
|
||||
|
||||
if (n_aligned > 0) {
|
||||
if (n_aligned % 64 == 0) nc = 64;
|
||||
else if (n_aligned == n) nc = n;
|
||||
else if (n_aligned % 32 == 0) nc = 32;
|
||||
else if (n_aligned % 24 == 0) nc = 24;
|
||||
else if (n_aligned % 16 == 0) nc = 16;
|
||||
else nc = 8;
|
||||
}
|
||||
bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0);
|
||||
if (can_use_tiled) {
|
||||
matmul_tiled(m, n_aligned, mc, nc, kc);
|
||||
if (n > n_aligned) {
|
||||
mnpack(0, m, n_aligned, n);
|
||||
}
|
||||
} else {
|
||||
mnpack(0, m, 0, n);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
template<int size>
|
||||
void tinyBLAS_Q0_PPC<TA>::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) {
|
||||
private:
|
||||
inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) {
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int I = 0; I < 4; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
*((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) {
|
||||
vec_t vec_C[4];
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int I = 0; I < 4; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I);
|
||||
*c_ptr += *((float *)&vec_C[I] + J);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ArrayType>
|
||||
inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) {
|
||||
vector signed int vec_C[4];
|
||||
vector float CA[4] = {0};
|
||||
vector float res[4] = {0};
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0));
|
||||
res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
|
||||
fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void process_q4_elements(vector signed char (&c)[2], int * ca) {
|
||||
const vector signed char lowMask = vec_splats((signed char)0xF);
|
||||
const vector unsigned char v4 = vec_splats((unsigned char)0x4);
|
||||
const vector signed char v8 = vec_splats((signed char)0x8);
|
||||
vector signed int vsum = {0};
|
||||
vector signed int vsum2 = {0};
|
||||
c[0] = vec_and(c[1], lowMask);
|
||||
c[1] = vec_sr(c[1], v4);
|
||||
c[0] = vec_sub(c[0], v8);
|
||||
c[1] = vec_sub(c[1], v8);
|
||||
vsum = vec_sum4s(c[0], vsum);
|
||||
vsum2 = vec_sum4s(c[1], vsum2);
|
||||
vsum = vec_add(vsum, vsum2);
|
||||
*(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3];
|
||||
}
|
||||
|
||||
template <typename V1, typename V2>
|
||||
inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) {
|
||||
vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
|
||||
vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
|
||||
V2 t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
vector unsigned char xor_vector;
|
||||
uint8_t flip_vec = 0x80;
|
||||
xor_vector = vec_splats(flip_vec);
|
||||
t1 = vec_perm(s1, s2, swiz1);
|
||||
t2 = vec_perm(s1, s2, swiz2);
|
||||
t3 = vec_perm(s3, s4, swiz1);
|
||||
t4 = vec_perm(s3, s4, swiz2);
|
||||
t5 = vec_perm(t1, t3, swiz3);
|
||||
t6 = vec_perm(t1, t3, swiz4);
|
||||
t7 = vec_perm(t2, t4, swiz3);
|
||||
t8 = vec_perm(t2, t4, swiz4);
|
||||
if (flip == true) {
|
||||
t5 = vec_xor(t5, xor_vector);
|
||||
t6 = vec_xor(t6, xor_vector);
|
||||
t7 = vec_xor(t7, xor_vector);
|
||||
t8 = vec_xor(t8, xor_vector);
|
||||
}
|
||||
vec_xst(t5, 0, vecOffset);
|
||||
vec_xst(t6, 0, vecOffset + 16);
|
||||
vec_xst(t7, 0, vecOffset + 32);
|
||||
vec_xst(t8, 0, vecOffset + 48);
|
||||
}
|
||||
|
||||
inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) {
|
||||
const vector signed char lowMask = vec_splats((signed char)0x0F);
|
||||
const vector signed char v8 = vec_splats((signed char)0x08);
|
||||
const vector unsigned char v4 = vec_splats((unsigned char)4);
|
||||
lo = vec_and(packed, lowMask);
|
||||
hi = vec_sr(packed, v4);
|
||||
lo = vec_sub(lo, v8);
|
||||
hi = vec_sub(hi, v8);
|
||||
}
|
||||
|
||||
inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) {
|
||||
vec_t t[8], s[8];
|
||||
vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
||||
vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
||||
vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
for (int i = 0; i < 4; i += 2) {
|
||||
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
|
||||
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
|
||||
}
|
||||
for (int i = 4; i < 8; i += 2) {
|
||||
t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1);
|
||||
t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2);
|
||||
}
|
||||
s[0] = vec_perm(t[0], t[2], swiz3);
|
||||
s[1] = vec_perm(t[0], t[2], swiz4);
|
||||
s[2] = vec_perm(t[1], t[3], swiz3);
|
||||
s[3] = vec_perm(t[1], t[3], swiz4);
|
||||
s[4] = vec_perm(t[4], t[6], swiz3);
|
||||
s[5] = vec_perm(t[4], t[6], swiz4);
|
||||
s[6] = vec_perm(t[5], t[7], swiz3);
|
||||
s[7] = vec_perm(t[5], t[7], swiz4);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16));
|
||||
}
|
||||
}
|
||||
|
||||
static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) {
|
||||
vector signed short i16_hi = vec_unpackh(raw);
|
||||
vector signed short i16_lo = vec_unpackl(raw);
|
||||
|
||||
vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0);
|
||||
vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0);
|
||||
vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0);
|
||||
vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0);
|
||||
out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale));
|
||||
out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale));
|
||||
}
|
||||
|
||||
void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
||||
unsigned char * vecOffset = vec;
|
||||
for (int i = 0; i < rows; i += 8) {
|
||||
const block_q4_0 * rows_base[8];
|
||||
for (int r = 0; r < 8; r++) {
|
||||
rows_base[r] = a + (i + r) * lda;
|
||||
}
|
||||
for (int blk = 0; blk < blocks; blk++) {
|
||||
vector unsigned short hp_res[8][4];
|
||||
for (int r = 0; r < 8; r++) {
|
||||
const block_q4_0 * current_blk = rows_base[r] + blk;
|
||||
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d));
|
||||
vector signed char v_qs = reinterpret_cast<vector signed char>(vec_xl(0, current_blk->qs));
|
||||
vector signed char c1, c2;
|
||||
unpack_q4_to_q8(v_qs, c1, c2);
|
||||
convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]);
|
||||
convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]);
|
||||
}
|
||||
for (int c = 0; c < 4; c++) {
|
||||
vector unsigned char c_arr[8];
|
||||
for (int r = 0; r < 8; r++) {
|
||||
c_arr[r] = (vector unsigned char)hp_res[r][c];
|
||||
}
|
||||
vector_permute_store_fp16((vec_t *)c_arr, vecOffset);
|
||||
vecOffset += 128;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int chunk_size>
|
||||
static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
||||
unsigned char * vecOffset = vec;
|
||||
const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23};
|
||||
const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
|
||||
const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
|
||||
const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
|
||||
|
||||
for (int i = 0; i < rows; i += chunk_size) {
|
||||
const block_q8_0 * rows_base[chunk_size];
|
||||
for (int r = 0; r < chunk_size; r++) {
|
||||
rows_base[r] = a + (i + r) * lda;
|
||||
}
|
||||
for (int blk = 0; blk < blocks; blk++) {
|
||||
vector unsigned short hp_res[chunk_size][4];
|
||||
for (int r = 0; r < chunk_size; r++) {
|
||||
const block_q8_0 * b = rows_base[r] + blk;
|
||||
vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d));
|
||||
vector signed char c[2];
|
||||
__vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs);
|
||||
__builtin_vsx_disassemble_pair(c, & pair);
|
||||
convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]);
|
||||
convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]);
|
||||
}
|
||||
for (int col = 0; col < 4; col++) {
|
||||
if constexpr (chunk_size == 8) {
|
||||
vec_t t[8];
|
||||
t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
|
||||
t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
|
||||
t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
|
||||
t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
|
||||
t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1);
|
||||
t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2);
|
||||
t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1);
|
||||
t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2);
|
||||
|
||||
vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0));
|
||||
vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16));
|
||||
vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32));
|
||||
vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48));
|
||||
vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64));
|
||||
vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80));
|
||||
vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96));
|
||||
vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112));
|
||||
vecOffset += 128;
|
||||
} else {
|
||||
vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1);
|
||||
vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2);
|
||||
vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1);
|
||||
vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2);
|
||||
|
||||
vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0));
|
||||
vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16));
|
||||
vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32));
|
||||
vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48));
|
||||
vecOffset += 64;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) {
|
||||
if (rows == 4) {
|
||||
pack_q8_block<4>(a, lda, rows, blocks, vec);
|
||||
} else {
|
||||
pack_q8_block<8>(a, lda, rows, blocks, vec);
|
||||
}
|
||||
}
|
||||
|
||||
template<int size>
|
||||
void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array<int, size> & comparray) {
|
||||
int64_t i, j;
|
||||
TA *aoffset = NULL;
|
||||
int8_t *vecOffset = NULL;
|
||||
TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
||||
TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
||||
TA * aoffset = NULL;
|
||||
int8_t * vecOffset = NULL;
|
||||
TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL;
|
||||
TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL;
|
||||
vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
|
||||
vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
|
||||
aoffset = const_cast<TA*>(a);
|
||||
aoffset = const_cast<TA *>(a);
|
||||
vecOffset = vec;
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
@@ -2363,18 +2620,18 @@ class tinyBLAS_HP16_PPC {
|
||||
c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs));
|
||||
c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs));
|
||||
|
||||
process_q4_elements(c1, &comparray[0]);
|
||||
process_q4_elements(c2, &comparray[1]);
|
||||
process_q4_elements(c3, &comparray[2]);
|
||||
process_q4_elements(c4, &comparray[3]);
|
||||
process_q4_elements(c5, &comparray[4]);
|
||||
process_q4_elements(c6, &comparray[5]);
|
||||
process_q4_elements(c7, &comparray[6]);
|
||||
process_q4_elements(c8, &comparray[7]);
|
||||
process_q4_elements(c1, & comparray[0]);
|
||||
process_q4_elements(c2, & comparray[1]);
|
||||
process_q4_elements(c3, & comparray[2]);
|
||||
process_q4_elements(c4, & comparray[3]);
|
||||
process_q4_elements(c5, & comparray[4]);
|
||||
process_q4_elements(c6, & comparray[5]);
|
||||
process_q4_elements(c7, & comparray[6]);
|
||||
process_q4_elements(c8, & comparray[7]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false);
|
||||
aoffset1 += lda;
|
||||
aoffset2 += lda;
|
||||
aoffset3 += lda;
|
||||
@@ -2405,12 +2662,12 @@ class tinyBLAS_HP16_PPC {
|
||||
c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs));
|
||||
c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs));
|
||||
|
||||
process_q4_elements(c1, &comparray[0]);
|
||||
process_q4_elements(c2, &comparray[1]);
|
||||
process_q4_elements(c3, &comparray[2]);
|
||||
process_q4_elements(c4, &comparray[3]);
|
||||
process_q4_elements(c1, & comparray[0]);
|
||||
process_q4_elements(c2, & comparray[1]);
|
||||
process_q4_elements(c3, & comparray[2]);
|
||||
process_q4_elements(c4, & comparray[3]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
||||
aoffset1 += lda;
|
||||
aoffset2 += lda;
|
||||
aoffset3 += lda;
|
||||
@@ -2434,12 +2691,12 @@ class tinyBLAS_HP16_PPC {
|
||||
case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs));
|
||||
break;
|
||||
}
|
||||
process_q4_elements(c1, &comparray[0]);
|
||||
process_q4_elements(c2, &comparray[1]);
|
||||
process_q4_elements(c3, &comparray[2]);
|
||||
process_q4_elements(c4, &comparray[3]);
|
||||
process_q4_elements(c1, & comparray[0]);
|
||||
process_q4_elements(c2, & comparray[1]);
|
||||
process_q4_elements(c3, & comparray[2]);
|
||||
process_q4_elements(c4, & comparray[3]);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false);
|
||||
vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false);
|
||||
aoffset1 += lda;
|
||||
aoffset2 += lda;
|
||||
aoffset3 += lda;
|
||||
@@ -2450,39 +2707,38 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
template<typename VA, typename VB>
|
||||
void tinyBLAS_Q0_PPC<TA>::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
|
||||
void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) {
|
||||
int64_t i, j;
|
||||
block_q8_0 *aoffset = NULL;
|
||||
VA *vecOffset = NULL;
|
||||
block_q8_0* aoffsets[8];
|
||||
block_q8_0 * aoffset = NULL;
|
||||
VA * vecOffset = NULL;
|
||||
block_q8_0 * aoffsets[8];
|
||||
__vector_pair arr[8];
|
||||
VB c[8][2] = {0};
|
||||
VB c1[8] = {0}; VB c2[8] = {0};
|
||||
aoffset = const_cast<block_q8_0*>(a);
|
||||
aoffset = const_cast<block_q8_0 *>(a);
|
||||
vecOffset = vec;
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
do {
|
||||
aoffsets[0] = aoffset;
|
||||
for (int it = 1; it < 8; it++)
|
||||
aoffsets[it] = aoffsets[it-1] + lda;
|
||||
aoffsets[it] = aoffsets[it - 1] + lda;
|
||||
aoffset += 8 * lda;
|
||||
|
||||
i = (cols >> 3);
|
||||
if (i > 0) {
|
||||
do {
|
||||
for (int it = 0; it < 8; it++) {
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
|
||||
c1[it] = c[it][0];
|
||||
c2[it] = c[it][1];
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip);
|
||||
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
||||
vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip);
|
||||
vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip);
|
||||
for (int it = 0; it < 8; it++)
|
||||
aoffsets[it] += lda;
|
||||
vecOffset += 256;
|
||||
@@ -2501,13 +2757,13 @@ class tinyBLAS_HP16_PPC {
|
||||
if (i > 0) {
|
||||
do {
|
||||
for (int it = 0; it < 4; it++) {
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], &arr[it]);
|
||||
arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[it], & arr[it]);
|
||||
c1[it] = c[it][0];
|
||||
c2[it] = c[it][1];
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
||||
for (int it = 0; it < 4; it++) {
|
||||
aoffsets[it] += lda;
|
||||
}
|
||||
@@ -2520,24 +2776,24 @@ class tinyBLAS_HP16_PPC {
|
||||
if (rows & 3) {
|
||||
aoffsets[0] = aoffset;
|
||||
for (int it = 1; it < 3; it++ )
|
||||
aoffsets[it] = aoffsets[it-1] + lda;
|
||||
aoffsets[it] = aoffsets[it - 1] + lda;
|
||||
i = (cols >> 3);
|
||||
if (i > 0) {
|
||||
do {
|
||||
switch(rows) {
|
||||
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[2], &arr[2]);
|
||||
case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[2], & arr[2]);
|
||||
c1[2] = c[2][0]; c2[2] = c[2][1];
|
||||
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[1], &arr[1]);
|
||||
case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[1], & arr[1]);
|
||||
c1[1] = c[1][0]; c2[1] = c[1][1];
|
||||
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[0], &arr[0]);
|
||||
case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs);
|
||||
__builtin_vsx_disassemble_pair(c[0], & arr[0]);
|
||||
c1[0] = c[0][0]; c2[0] = c[0][1];
|
||||
break;
|
||||
}
|
||||
vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip);
|
||||
vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip);
|
||||
for (int it = 0; it < 3; it++)
|
||||
aoffsets[it] += lda;
|
||||
vecOffset += 128;
|
||||
@@ -2547,8 +2803,7 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
int m_rem = MIN(m - m0, 16);
|
||||
int n_rem = MIN(n - n0, 16);
|
||||
|
||||
@@ -2585,8 +2840,7 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) {
|
||||
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[8], vec_B[16] = {0};
|
||||
acc_t acc_0, acc_1;
|
||||
std::array<int, 4> comparray {};
|
||||
@@ -2594,26 +2848,26 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float vs[8] = {0};
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
__builtin_mma_xxsetaccz(& acc_1);
|
||||
if (std::is_same_v<TA, block_q4_0>) {
|
||||
packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
|
||||
for(int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]);
|
||||
}
|
||||
for (int I = 0; I<4; I++) {
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
*((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
||||
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
*((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2624,15 +2878,14 @@ class tinyBLAS_HP16_PPC {
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(&acc_1, 0, 4, comparray, vs, fin_res);
|
||||
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(& acc_1, 0, 4, comparray, vs, fin_res);
|
||||
}
|
||||
save_res(ii, jj, 0, fin_res);
|
||||
save_res(ii, jj+4, 4, fin_res);
|
||||
save_res(ii, jj + 4, 4, fin_res);
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) {
|
||||
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[16], vec_B[8] = {0};
|
||||
acc_t acc_0, acc_1;
|
||||
std::array<int, 8> comparray {};
|
||||
@@ -2640,25 +2893,25 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float vs[8] = {0};
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
__builtin_mma_xxsetaccz(& acc_1);
|
||||
if (std::is_same_v<TA, block_q4_0>) {
|
||||
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true);
|
||||
for(int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
|
||||
}
|
||||
for (int I = 0; I<8; I++) {
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
for (int I = 0; I < 8; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
*((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2669,15 +2922,14 @@ class tinyBLAS_HP16_PPC {
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(&acc_1, 4, 4, comparray, vs, fin_res);
|
||||
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(& acc_1, 4, 4, comparray, vs, fin_res);
|
||||
}
|
||||
save_res(ii, jj, 0, fin_res);
|
||||
save_res(ii+4, jj, 4, fin_res);
|
||||
save_res(ii + 4, jj, 4, fin_res);
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) {
|
||||
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[16], vec_B[16] = {0};
|
||||
acc_t acc_0, acc_1, acc_2, acc_3;
|
||||
acc_t acc_4, acc_5, acc_6, acc_7;
|
||||
@@ -2686,30 +2938,30 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float vs[16] = {0};
|
||||
bool isAblock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
__builtin_mma_xxsetaccz(&acc_2);
|
||||
__builtin_mma_xxsetaccz(&acc_3);
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
__builtin_mma_xxsetaccz(& acc_1);
|
||||
__builtin_mma_xxsetaccz(& acc_2);
|
||||
__builtin_mma_xxsetaccz(& acc_3);
|
||||
if (std::is_same_v<TA, block_q4_0>) {
|
||||
packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true);
|
||||
for(int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]);
|
||||
}
|
||||
for (int I = 0; I<8; I++) {
|
||||
for (int J = 0; J<4; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
*((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
|
||||
for (int I = 0; I < 8 ; I++) {
|
||||
for (int J = 0; J < 4; J++) {
|
||||
*((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
*((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2720,19 +2972,99 @@ class tinyBLAS_HP16_PPC {
|
||||
aoffset += lda;
|
||||
}
|
||||
}
|
||||
compute(&acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(&acc_1, 4, 4, comparray, vs, fin_res);
|
||||
compute(&acc_2, 0, 8, comparray, vs, fin_res);
|
||||
compute(&acc_3, 4, 12, comparray, vs, fin_res);
|
||||
compute(& acc_0, 0, 0, comparray, vs, fin_res);
|
||||
compute(& acc_1, 4, 4, comparray, vs, fin_res);
|
||||
compute(& acc_2, 0, 8, comparray, vs, fin_res);
|
||||
compute(& acc_3, 4, 12, comparray, vs, fin_res);
|
||||
}
|
||||
save_res(ii, jj, 0, fin_res);
|
||||
save_res(ii+4, jj, 4, fin_res);
|
||||
save_res(ii, jj+4, 8, fin_res);
|
||||
save_res(ii+4, jj+4, 12, fin_res);
|
||||
save_res(ii + 4, jj, 4, fin_res);
|
||||
save_res(ii, jj + 4, 8, fin_res);
|
||||
save_res(ii + 4, jj + 4, 12, fin_res);
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
||||
void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) {
|
||||
acc_t acc[8];
|
||||
for (int i = 0; i < mc ; i += 16) {
|
||||
for (int j = 0; j < nc; j += 8) {
|
||||
int A0_base = (i / 16) * (2 * 32 * kc);
|
||||
int B0_base = (j / 8) * (32 * kc);
|
||||
for (int x = 0; x < 8; x++) {
|
||||
__builtin_mma_xxsetaccz(&acc[x]);
|
||||
}
|
||||
for (int64_t kk = 0; kk < kc; kk++) {
|
||||
int A0_block_idx = A0_base + kk * 32;
|
||||
int B0_block_idx = B0_base + kk * 32;
|
||||
int A1_block_idx = A0_block_idx + 32 * kc;
|
||||
int B1_block_idx = B0_block_idx + 32 * kc;
|
||||
vec_t * A0_block = & vec_A[A0_block_idx];
|
||||
vec_t * B0_block = & vec_B[B0_block_idx];
|
||||
vec_t * A1_block = & vec_A[A1_block_idx];
|
||||
for (int it = 0; it < 4; it++) {
|
||||
for (int x = 0; x < 4; x++) {
|
||||
__builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]);
|
||||
__builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (l == 0) {
|
||||
save_acc(& acc[0], ii + i, jj + j);
|
||||
save_acc(& acc[1], ii + i, jj + j + 4);
|
||||
save_acc(& acc[2], ii + i + 4, jj + j);
|
||||
save_acc(& acc[3], ii + i + 4, jj + j + 4);
|
||||
save_acc(& acc[4], ii + i + 8, jj + j);
|
||||
save_acc(& acc[5], ii + i + 8, jj + j + 4);
|
||||
save_acc(& acc[6], ii + i + 12, jj + j);
|
||||
save_acc(& acc[7], ii + i + 12, jj + j + 4);
|
||||
} else {
|
||||
add_save_acc(& acc[0], ii + i, jj + j);
|
||||
add_save_acc(& acc[1], ii + i, jj + j + 4);
|
||||
add_save_acc(& acc[2], ii + i + 4, jj + j);
|
||||
add_save_acc(& acc[3], ii + i + 4, jj + j + 4);
|
||||
add_save_acc(& acc[4], ii + i + 8, jj + j);
|
||||
add_save_acc(& acc[5], ii + i + 8, jj + j + 4);
|
||||
add_save_acc(& acc[6], ii + i + 12, jj + j);
|
||||
add_save_acc(& acc[7], ii + i + 12, jj + j + 4);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) {
|
||||
vec_t A_pack[mc * kc * 4];
|
||||
vec_t B_pack[nc * kc * 4];
|
||||
constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>;
|
||||
int64_t ytiles = m / mc;
|
||||
int64_t xtiles = n / nc;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
int64_t duty = (tiles + nth - 1) / nth;
|
||||
int64_t start = duty * ith;
|
||||
int64_t end = start + duty;
|
||||
if (end > tiles) {
|
||||
end = tiles;
|
||||
}
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = (job / xtiles) * mc;
|
||||
int64_t jj = (job % xtiles) * nc;
|
||||
for (int64_t kk = 0; kk < k; kk += kc) {
|
||||
if constexpr(is_Ablock_q4) {
|
||||
packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
||||
} else {
|
||||
packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack);
|
||||
}
|
||||
packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack);
|
||||
KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
||||
int64_t ytiles = (m - m0) / RM;
|
||||
int64_t xtiles = (n - n0) / RN;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
@@ -2754,32 +3086,32 @@ class tinyBLAS_HP16_PPC {
|
||||
vector float fin_res[4] = {0};
|
||||
vector float vs[4] = {0};
|
||||
vector float CA[4] = {0};
|
||||
__builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
|
||||
__builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
|
||||
__builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value
|
||||
__builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value
|
||||
for (int l = 0; l < k; l++) {
|
||||
__builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead
|
||||
__builtin_mma_xxsetaccz(& acc_0);
|
||||
if (isAblock_q4) {
|
||||
packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray);
|
||||
packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray);
|
||||
} else {
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
|
||||
packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false);
|
||||
}
|
||||
packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
|
||||
for(int x = 0; x < 8; x+=4) {
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
|
||||
__builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
|
||||
packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true);
|
||||
for (int x = 0; x < 8; x += 4) {
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]);
|
||||
__builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]);
|
||||
}
|
||||
for (int I = 0; I<RM; I++) {
|
||||
for (int J = 0; J<RN; J++) {
|
||||
*((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
*((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d));
|
||||
}
|
||||
}
|
||||
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
||||
__builtin_mma_disassemble_acc(vec_C, & acc_0);
|
||||
if (!isAblock_q4) {
|
||||
auto aoffset = A+(ii*lda)+l;
|
||||
auto aoffset = A + (ii * lda) + l;
|
||||
for (int i = 0; i < RM; i++) {
|
||||
comparray[i] = 0;
|
||||
int ca = 0;
|
||||
@@ -2800,9 +3132,21 @@ class tinyBLAS_HP16_PPC {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TA>
|
||||
template<int RM, int RN>
|
||||
inline void kernel(int64_t ii, int64_t jj) {
|
||||
if constexpr(RM == 4 && RN == 8) {
|
||||
KERNEL_4x8(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 4) {
|
||||
KERNEL_8x4(ii,jj);
|
||||
} else if constexpr(RM == 8 && RN == 8) {
|
||||
KERNEL_8x8(ii,jj);
|
||||
} else {
|
||||
assert(false && "RN/RM values not supported");
|
||||
}
|
||||
}
|
||||
|
||||
template <int RM, int RN>
|
||||
NOINLINE void tinyBLAS_Q0_PPC<TA>::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
int64_t ytiles = (m - m0) / RM;
|
||||
int64_t xtiles = (n - n0) / RN;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
@@ -2814,12 +3158,20 @@ class tinyBLAS_HP16_PPC {
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = m0 + job / xtiles * RM;
|
||||
int64_t jj = n0 + job % xtiles * RN;
|
||||
this->kernel<RM, RN>(ii, jj);
|
||||
kernel<RM, RN>(ii, jj);
|
||||
}
|
||||
}
|
||||
|
||||
template class tinyBLAS_Q0_PPC<block_q4_0>;
|
||||
template class tinyBLAS_Q0_PPC<block_q8_0>;
|
||||
const TA * const A;
|
||||
const block_q8_0 * const B;
|
||||
float * C;
|
||||
const int64_t k;
|
||||
int64_t kc;
|
||||
const int64_t lda;
|
||||
const int64_t ldb;
|
||||
const int64_t ldc;
|
||||
const int ith;
|
||||
const int nth;
|
||||
};
|
||||
|
||||
class tinyBLAS_PPC {
|
||||
public:
|
||||
|
||||
@@ -1186,8 +1186,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
||||
// On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases.
|
||||
// However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented.
|
||||
const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc);
|
||||
const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX;
|
||||
const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
|
||||
if constexpr (DV == 512) {
|
||||
|
||||
@@ -1121,7 +1121,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
uint32_t batches = dst->ne[2] * dst->ne[3];
|
||||
uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg);
|
||||
uint32_t total_wg = output_groups * batches;
|
||||
wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
|
||||
// TODO: split large sizes into multiple batches to avoid way over-provisioning workgroups
|
||||
wg_x = std::min(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
|
||||
wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension);
|
||||
} else if (use_fast) {
|
||||
auto decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get());
|
||||
|
||||
@@ -435,6 +435,7 @@ class MODEL_ARCH(IntEnum):
|
||||
T5 = auto()
|
||||
T5ENCODER = auto()
|
||||
JAIS = auto()
|
||||
JAIS2 = auto()
|
||||
NEMOTRON = auto()
|
||||
NEMOTRON_H = auto()
|
||||
NEMOTRON_H_MOE = auto()
|
||||
@@ -652,6 +653,7 @@ class MODEL_TENSOR(IntEnum):
|
||||
ENC_OUTPUT_NORM = auto()
|
||||
CLS = auto() # classifier
|
||||
CLS_OUT = auto() # classifier output projection
|
||||
CLS_NORM = auto()
|
||||
CONV1D = auto()
|
||||
CONVNEXT_DW = auto()
|
||||
CONVNEXT_NORM = auto()
|
||||
@@ -873,6 +875,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.T5: "t5",
|
||||
MODEL_ARCH.T5ENCODER: "t5encoder",
|
||||
MODEL_ARCH.JAIS: "jais",
|
||||
MODEL_ARCH.JAIS2: "jais2",
|
||||
MODEL_ARCH.NEMOTRON: "nemotron",
|
||||
MODEL_ARCH.NEMOTRON_H: "nemotron_h",
|
||||
MODEL_ARCH.NEMOTRON_H_MOE: "nemotron_h_moe",
|
||||
@@ -1088,6 +1091,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
|
||||
MODEL_TENSOR.CLS: "cls",
|
||||
MODEL_TENSOR.CLS_OUT: "cls.output",
|
||||
MODEL_TENSOR.CLS_NORM: "cls.norm",
|
||||
MODEL_TENSOR.CONV1D: "conv1d",
|
||||
MODEL_TENSOR.CONVNEXT_DW: "convnext.{bid}.dw",
|
||||
MODEL_TENSOR.CONVNEXT_NORM: "convnext.{bid}.norm",
|
||||
@@ -1507,6 +1511,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.CLS,
|
||||
MODEL_TENSOR.CLS_OUT,
|
||||
MODEL_TENSOR.CLS_NORM,
|
||||
],
|
||||
MODEL_ARCH.NOMIC_BERT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
@@ -2814,6 +2819,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.JAIS2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.NEMOTRON: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
||||
@@ -1240,6 +1240,10 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.CLS_OUT: (
|
||||
"classifier.out_proj", # roberta
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CLS_NORM: (
|
||||
"head.norm", # modern-bert
|
||||
),
|
||||
#############################################################################
|
||||
|
||||
MODEL_TENSOR.CONVNEXT_DW: (
|
||||
|
||||
@@ -84,6 +84,7 @@ add_library(llama
|
||||
models/hunyuan-moe.cpp
|
||||
models/internlm2.cpp
|
||||
models/jais.cpp
|
||||
models/jais2.cpp
|
||||
models/jamba.cpp
|
||||
models/kimi-linear.cpp
|
||||
models/lfm2.cpp
|
||||
|
||||
@@ -79,6 +79,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_T5, "t5" },
|
||||
{ LLM_ARCH_T5ENCODER, "t5encoder" },
|
||||
{ LLM_ARCH_JAIS, "jais" },
|
||||
{ LLM_ARCH_JAIS2, "jais2" },
|
||||
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
||||
{ LLM_ARCH_NEMOTRON_H, "nemotron_h" },
|
||||
{ LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" },
|
||||
@@ -367,6 +368,7 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
|
||||
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
||||
{ LLM_TENSOR_CLS, "cls" },
|
||||
{ LLM_TENSOR_CLS_OUT, "cls.output" },
|
||||
{ LLM_TENSOR_CLS_NORM, "cls.norm" },
|
||||
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
||||
{ LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" },
|
||||
@@ -828,6 +830,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_CLS,
|
||||
LLM_TENSOR_CLS_OUT,
|
||||
LLM_TENSOR_CLS_NORM,
|
||||
};
|
||||
case LLM_ARCH_JINA_BERT_V2:
|
||||
return {
|
||||
@@ -1789,6 +1792,20 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
||||
LLM_TENSOR_FFN_GATE,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
};
|
||||
case LLM_ARCH_JAIS2:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
LLM_TENSOR_OUTPUT,
|
||||
LLM_TENSOR_ATTN_NORM,
|
||||
LLM_TENSOR_ATTN_Q,
|
||||
LLM_TENSOR_ATTN_K,
|
||||
LLM_TENSOR_ATTN_V,
|
||||
LLM_TENSOR_ATTN_OUT,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
LLM_TENSOR_FFN_UP,
|
||||
LLM_TENSOR_FFN_DOWN,
|
||||
};
|
||||
case LLM_ARCH_NEMOTRON_H:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
@@ -2518,6 +2535,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||
{LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
|
||||
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
||||
|
||||
@@ -83,6 +83,7 @@ enum llm_arch {
|
||||
LLM_ARCH_T5,
|
||||
LLM_ARCH_T5ENCODER,
|
||||
LLM_ARCH_JAIS,
|
||||
LLM_ARCH_JAIS2,
|
||||
LLM_ARCH_NEMOTRON,
|
||||
LLM_ARCH_NEMOTRON_H,
|
||||
LLM_ARCH_NEMOTRON_H_MOE,
|
||||
@@ -497,6 +498,7 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ENC_OUTPUT_NORM,
|
||||
LLM_TENSOR_CLS,
|
||||
LLM_TENSOR_CLS_OUT,
|
||||
LLM_TENSOR_CLS_NORM,
|
||||
LLM_TENSOR_CONV1D,
|
||||
LLM_TENSOR_CONVNEXT_DW,
|
||||
LLM_TENSOR_CONVNEXT_NORM,
|
||||
|
||||
@@ -712,8 +712,6 @@ int64_t llama_context::output_resolve_row(int32_t i) const {
|
||||
}
|
||||
|
||||
float * llama_context::get_logits_ith(int32_t i) {
|
||||
int64_t j = -1;
|
||||
|
||||
output_reorder();
|
||||
|
||||
try {
|
||||
@@ -721,26 +719,7 @@ float * llama_context::get_logits_ith(int32_t i) {
|
||||
throw std::runtime_error("no logits");
|
||||
}
|
||||
|
||||
// TODO: use output_resolve_row()
|
||||
if (i < 0) {
|
||||
j = n_outputs + i;
|
||||
if (j < 0) {
|
||||
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
||||
}
|
||||
} else if ((size_t) i >= output_ids.size()) {
|
||||
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
||||
} else {
|
||||
j = output_ids[i];
|
||||
}
|
||||
|
||||
if (j < 0) {
|
||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
||||
}
|
||||
if (j >= n_outputs) {
|
||||
// This should not happen
|
||||
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
||||
}
|
||||
|
||||
const int64_t j = output_resolve_row(i);
|
||||
return logits.data + j*model.vocab.n_tokens();
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
||||
@@ -763,8 +742,6 @@ llama_token * llama_context::get_sampled_tokens() const{
|
||||
}
|
||||
|
||||
float * llama_context::get_embeddings_ith(int32_t i) {
|
||||
int64_t j = -1;
|
||||
|
||||
output_reorder();
|
||||
|
||||
try {
|
||||
@@ -772,26 +749,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
|
||||
throw std::runtime_error("no embeddings");
|
||||
}
|
||||
|
||||
// TODO: use output_resolve_row()
|
||||
if (i < 0) {
|
||||
j = n_outputs + i;
|
||||
if (j < 0) {
|
||||
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
||||
}
|
||||
} else if ((size_t) i >= output_ids.size()) {
|
||||
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
||||
} else {
|
||||
j = output_ids[i];
|
||||
}
|
||||
|
||||
if (j < 0) {
|
||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
||||
}
|
||||
if (j >= n_outputs) {
|
||||
// This should not happen
|
||||
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
||||
}
|
||||
|
||||
const int64_t j = output_resolve_row(i);
|
||||
const uint32_t n_embd_out = model.hparams.n_embd_out();
|
||||
return embd.data + j*n_embd_out;
|
||||
} catch (const std::exception & err) {
|
||||
@@ -2761,6 +2719,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
|
||||
llama_set_param(model->cls_b, param_filter, param_filter_ud);
|
||||
llama_set_param(model->cls_out, param_filter, param_filter_ud);
|
||||
llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
|
||||
llama_set_param(model->cls_norm, param_filter, param_filter_ud);
|
||||
|
||||
for (struct llama_layer & layer : model->layers) {
|
||||
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
|
||||
|
||||
@@ -185,7 +185,10 @@ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
|
||||
}
|
||||
|
||||
void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
||||
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
||||
if (cparams.embeddings &&
|
||||
(cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN ||
|
||||
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK )) {
|
||||
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
||||
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
||||
@@ -1125,8 +1128,8 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||
|
||||
if (down) {
|
||||
cur = build_lora_mm(down, cur);
|
||||
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
||||
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
|
||||
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
|
||||
// GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
}
|
||||
@@ -1721,7 +1724,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
if (cparams.flash_attn && kq_b == nullptr) {
|
||||
const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr;
|
||||
if (use_flash_attn) {
|
||||
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
|
||||
|
||||
if (v_trans) {
|
||||
@@ -1981,8 +1985,8 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
||||
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
|
||||
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
|
||||
// GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
}
|
||||
@@ -2414,8 +2418,9 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
|
||||
|
||||
void llm_graph_context::build_dense_out(
|
||||
ggml_tensor * dense_2,
|
||||
ggml_tensor * dense_2_b,
|
||||
ggml_tensor * dense_3) const {
|
||||
if (!cparams.embeddings || !(dense_2 || dense_3)) {
|
||||
if (!cparams.embeddings || !(dense_2 || dense_2_b || dense_3)) {
|
||||
return;
|
||||
}
|
||||
ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
|
||||
@@ -2424,6 +2429,9 @@ void llm_graph_context::build_dense_out(
|
||||
if (dense_2) {
|
||||
cur = ggml_mul_mat(ctx0, dense_2, cur);
|
||||
}
|
||||
if (dense_2_b) {
|
||||
cur = ggml_add(ctx0, cur, dense_2_b);
|
||||
}
|
||||
if (dense_3) {
|
||||
cur = ggml_mul_mat(ctx0, dense_3, cur);
|
||||
}
|
||||
@@ -2437,7 +2445,8 @@ void llm_graph_context::build_pooling(
|
||||
ggml_tensor * cls,
|
||||
ggml_tensor * cls_b,
|
||||
ggml_tensor * cls_out,
|
||||
ggml_tensor * cls_out_b) const {
|
||||
ggml_tensor * cls_out_b,
|
||||
ggml_tensor * cls_norm) const {
|
||||
if (!cparams.embeddings) {
|
||||
return;
|
||||
}
|
||||
@@ -2476,8 +2485,15 @@ void llm_graph_context::build_pooling(
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
ggml_tensor * inp_cls = build_inp_cls();
|
||||
cur = ggml_get_rows(ctx0, inp, inp_cls);
|
||||
if (arch == LLM_ARCH_MODERN_BERT) {
|
||||
// modern bert gte reranker builds mean first then applies prediction head and classifier
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modular_modernbert.py#L1404-1411
|
||||
ggml_tensor * inp_mean = build_inp_mean();
|
||||
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean);
|
||||
} else {
|
||||
ggml_tensor * inp_cls = build_inp_cls();
|
||||
cur = ggml_get_rows(ctx0, inp, inp_cls);
|
||||
}
|
||||
|
||||
// classification head
|
||||
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
||||
@@ -2486,7 +2502,15 @@ void llm_graph_context::build_pooling(
|
||||
if (cls_b) {
|
||||
cur = ggml_add(ctx0, cur, cls_b);
|
||||
}
|
||||
cur = ggml_tanh(ctx0, cur);
|
||||
if (arch == LLM_ARCH_MODERN_BERT) {
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
} else {
|
||||
cur = ggml_tanh(ctx0, cur);
|
||||
}
|
||||
if (cls_norm) {
|
||||
// head norm
|
||||
cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1);
|
||||
}
|
||||
}
|
||||
|
||||
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
||||
|
||||
@@ -1000,7 +1000,8 @@ struct llm_graph_context {
|
||||
ggml_tensor * cls,
|
||||
ggml_tensor * cls_b,
|
||||
ggml_tensor * cls_out,
|
||||
ggml_tensor * cls_out_b) const;
|
||||
ggml_tensor * cls_out_b,
|
||||
ggml_tensor * cls_norm) const;
|
||||
|
||||
//
|
||||
// sampling (backend sampling)
|
||||
@@ -1014,6 +1015,7 @@ struct llm_graph_context {
|
||||
|
||||
void build_dense_out(
|
||||
ggml_tensor * dense_2,
|
||||
ggml_tensor * dense_2_b,
|
||||
ggml_tensor * dense_3) const;
|
||||
};
|
||||
|
||||
|
||||
@@ -271,6 +271,7 @@ void llama_model_saver::add_tensors_from_model() {
|
||||
add_tensor(model.cls_b);
|
||||
add_tensor(model.cls_out);
|
||||
add_tensor(model.cls_out_b);
|
||||
add_tensor(model.cls_norm);
|
||||
|
||||
for (const struct llama_layer & layer : model.layers) {
|
||||
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
|
||||
|
||||
@@ -908,7 +908,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
|
||||
ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa);
|
||||
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false);
|
||||
hparams.set_swa_pattern(swa_period);
|
||||
hparams.set_swa_pattern(swa_period, true);
|
||||
} else {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
}
|
||||
@@ -1937,6 +1937,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_JAIS2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 32: type = LLM_TYPE_8B; break;
|
||||
case 68: type = LLM_TYPE_70B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_NEMOTRON:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
@@ -2348,6 +2358,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
case 10752: type = LLM_TYPE_2_6B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
if (const auto is_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); is_swa && hparams.n_swa > 0) {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
|
||||
hparams.swa_layers[il] = !hparams.recurrent_layer_arr[il];
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_LFM2MOE:
|
||||
{
|
||||
@@ -3513,9 +3529,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
|
||||
cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
|
||||
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
|
||||
cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
|
||||
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
|
||||
cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
|
||||
cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
|
||||
cls_norm = create_tensor(tn(LLM_TENSOR_CLS_NORM, "weight"), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
} break;
|
||||
case LLM_ARCH_NEO_BERT:
|
||||
@@ -5368,6 +5385,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_JAIS2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
if (!output) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
// attention biases - all have shape n_embd (output dimension of projections)
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0);
|
||||
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
|
||||
|
||||
// Jais-2 uses simple MLP (no gate) with biases
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
||||
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_CHATGLM:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@@ -6895,7 +6951,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
}
|
||||
|
||||
// for LFM2-ColBert-350M
|
||||
dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED);
|
||||
dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED);
|
||||
dense_2_out_layers_b = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "bias"), {hparams.n_embd_out() }, TENSOR_NOT_REQUIRED);
|
||||
} break;
|
||||
case LLM_ARCH_SMALLTHINKER:
|
||||
{
|
||||
@@ -8553,6 +8610,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
{
|
||||
llm = std::make_unique<llm_build_jais>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_JAIS2:
|
||||
{
|
||||
llm = std::make_unique<llm_build_jais2>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_NEMOTRON:
|
||||
{
|
||||
llm = std::make_unique<llm_build_nemotron>(*this, params);
|
||||
@@ -8671,7 +8732,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
case LLM_ARCH_LFM2:
|
||||
case LLM_ARCH_LFM2MOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_lfm2>(*this, params);
|
||||
if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
|
||||
llm = std::make_unique<llm_build_lfm2<true>>(*this, params);
|
||||
} else {
|
||||
llm = std::make_unique<llm_build_lfm2<false>>(*this, params);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_SMALLTHINKER:
|
||||
{
|
||||
@@ -8734,7 +8799,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
}
|
||||
|
||||
// add on pooling layer
|
||||
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
|
||||
llm->build_pooling(cls, cls_b, cls_out, cls_out_b, cls_norm);
|
||||
|
||||
// add backend sampling layers (if any)
|
||||
llm->build_sampling();
|
||||
@@ -8743,7 +8808,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
// there will be two additional dense projection layers
|
||||
// dense linear projections are applied after pooling
|
||||
// TODO: move reranking logic here and generalize
|
||||
llm->build_dense_out(dense_2_out_layers, dense_3_out_layers);
|
||||
llm->build_dense_out(dense_2_out_layers, dense_2_out_layers_b, dense_3_out_layers);
|
||||
|
||||
llm->res->set_outputs();
|
||||
|
||||
@@ -8961,6 +9026,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_BAILINGMOE2:
|
||||
case LLM_ARCH_DOTS1:
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
case LLM_ARCH_JAIS2:
|
||||
case LLM_ARCH_OPENAI_MOE:
|
||||
case LLM_ARCH_HUNYUAN_DENSE:
|
||||
case LLM_ARCH_LFM2:
|
||||
|
||||
@@ -475,6 +475,7 @@ struct llama_model {
|
||||
struct ggml_tensor * cls_b = nullptr;
|
||||
struct ggml_tensor * cls_out = nullptr;
|
||||
struct ggml_tensor * cls_out_b = nullptr;
|
||||
struct ggml_tensor * cls_norm = nullptr;
|
||||
|
||||
struct ggml_tensor * conv1d = nullptr;
|
||||
struct ggml_tensor * conv1d_b = nullptr;
|
||||
@@ -491,8 +492,9 @@ struct llama_model {
|
||||
//Dense linear projections for SentenceTransformers models like embeddinggemma
|
||||
// For Sentence Transformers models structure see
|
||||
// https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models
|
||||
struct ggml_tensor * dense_2_out_layers = nullptr;
|
||||
struct ggml_tensor * dense_3_out_layers = nullptr;
|
||||
struct ggml_tensor * dense_2_out_layers = nullptr;
|
||||
struct ggml_tensor * dense_2_out_layers_b = nullptr;
|
||||
struct ggml_tensor * dense_3_out_layers = nullptr;
|
||||
|
||||
// gguf metadata
|
||||
std::unordered_map<std::string, std::string> gguf_kv;
|
||||
|
||||
@@ -289,6 +289,15 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_JAIS2:
|
||||
regex_exprs = {
|
||||
// original regex from tokenizer.json
|
||||
//"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s{512}(?!\\S)|\\s{256}(?!\\S)|\\s{128}(?!\\S)|\\s{64}(?!\\S)|\\s{32}(?!\\S)|\\s{16}(?!\\S)|\\s{8}(?!\\S)|\\s{4}(?!\\S)|\\s{1,2}(?!\\S)|\\s{1}",
|
||||
|
||||
// adapted: same as llama3 but with cascading whitespace pattern
|
||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s{512}(?!\\S)|\\s{256}(?!\\S)|\\s{128}(?!\\S)|\\s{64}(?!\\S)|\\s{32}(?!\\S)|\\s{16}(?!\\S)|\\s{8}(?!\\S)|\\s{4}(?!\\S)|\\s{1,2}(?!\\S)|\\s{1}",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_DBRX:
|
||||
case LLAMA_VOCAB_PRE_TYPE_SMAUG:
|
||||
regex_exprs = {
|
||||
@@ -1921,8 +1930,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
tokenizer_pre == "jina-v2-de" ||
|
||||
tokenizer_pre == "a.x-4.0" ||
|
||||
tokenizer_pre == "mellum" ||
|
||||
tokenizer_pre == "modern-bert" ) {
|
||||
tokenizer_pre == "modern-bert") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
|
||||
} else if (
|
||||
tokenizer_pre == "jais-2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2;
|
||||
} else if (
|
||||
tokenizer_pre == "jina-v1-en" ||
|
||||
tokenizer_pre == "jina-v2-code" ||
|
||||
|
||||
@@ -57,6 +57,7 @@ enum llama_vocab_pre_type {
|
||||
LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46,
|
||||
LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47,
|
||||
LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48,
|
||||
LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49,
|
||||
};
|
||||
|
||||
struct LLM_KV;
|
||||
|
||||
@@ -27,6 +27,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
const bool kda = (g->ne[0] == S_k && g->ne[1] == H_k);
|
||||
|
||||
GGML_ASSERT(S_k == S_v);
|
||||
GGML_ASSERT(H_v % H_k == 0);
|
||||
@@ -35,9 +36,10 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
||||
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
||||
GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v);
|
||||
GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
|
||||
GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
||||
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_k);
|
||||
|
||||
@@ -52,8 +54,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
|
||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs]
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs]
|
||||
g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [ 1, n_tokens, H_v, n_seqs]
|
||||
b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [ 1, n_tokens, H_v, n_seqs]
|
||||
g = ggml_permute(ctx0, g, 0, 2, 1, 3); // [g_0, n_tokens, H_v, n_seqs]
|
||||
b = ggml_permute(ctx0, b, 0, 2, 1, 3); // [ 1, n_tokens, H_v, n_seqs]
|
||||
|
||||
const int CS = CHUNK_SIZE;
|
||||
|
||||
@@ -78,33 +80,76 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs);
|
||||
v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs);
|
||||
|
||||
g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs);
|
||||
b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs);
|
||||
g = ggml_reshape_4d(ctx0, g, g->ne[0], CS, n_chunks, H_v * n_seqs);
|
||||
b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs);
|
||||
|
||||
// [CS, 1, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_cs = ggml_cumsum(ctx0, g);
|
||||
// [CS, g_0, n_chunks, H_v * n_seqs]
|
||||
// TODO: extend ggml_cumsum with axis parameter to avoid transpose
|
||||
ggml_tensor * g_cs = ggml_cumsum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, g)));
|
||||
cb(g_cs, "g_cs", il);
|
||||
|
||||
ggml_tensor * g_cs_i = g_cs;
|
||||
ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
|
||||
ggml_tensor * kb = nullptr;
|
||||
ggml_tensor * kq = nullptr;
|
||||
if (kda) {
|
||||
const int64_t CHB = n_chunks * H_k * n_seqs;
|
||||
|
||||
g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
|
||||
ggml_tensor * g_cs_i = ggml_reshape_4d(ctx0, g_cs, CS, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB]
|
||||
ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, S_k, CHB); // [1, chunk_size, S_k, CHB]
|
||||
|
||||
// [CS, CS, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * decay_mask;
|
||||
decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
|
||||
decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
|
||||
decay_mask = ggml_exp(ctx0, decay_mask);
|
||||
cb(decay_mask, "decay_mask", il);
|
||||
g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
|
||||
|
||||
// [CS, CS, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * kb;
|
||||
kb = ggml_mul_mat(ctx0, k, k_b);
|
||||
kb = ggml_mul (ctx0, kb, decay_mask);
|
||||
// decay_mask [chunk_size,chunk_size,S_k,CHB]
|
||||
ggml_tensor * decay_mask;
|
||||
decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
|
||||
decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
|
||||
decay_mask = ggml_exp(ctx0, decay_mask);
|
||||
cb(decay_mask, "decay_mask", il);
|
||||
|
||||
// decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
|
||||
decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, CS, CS, CHB);
|
||||
|
||||
ggml_tensor * k_b_i = ggml_reshape_4d(ctx0, k_b, S_k, CS, 1, CHB);
|
||||
ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, CS, CHB);
|
||||
ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, CS, 1, CHB);
|
||||
|
||||
ggml_tensor * decay_k_b_i = ggml_mul(ctx0, decay_mask, k_b_i);
|
||||
ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i);
|
||||
|
||||
// decay_k_b_i [S,BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
|
||||
kb = ggml_mul_mat(ctx0, decay_k_b_i, k_j);
|
||||
kq = ggml_mul_mat(ctx0, decay_q_i, k_j);
|
||||
|
||||
kb = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kb, CS, CS, n_chunks, H_v * n_seqs)));
|
||||
kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kq, CS, CS, n_chunks, H_v * n_seqs)));
|
||||
} else {
|
||||
ggml_tensor * g_cs_i = g_cs;
|
||||
ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs);
|
||||
|
||||
g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs);
|
||||
|
||||
// [CS, CS, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * decay_mask;
|
||||
decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i);
|
||||
decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG);
|
||||
decay_mask = ggml_exp(ctx0, decay_mask);
|
||||
cb(decay_mask, "decay_mask", il);
|
||||
|
||||
// [CS, CS, n_chunks, H_k * n_seqs]
|
||||
kb = ggml_mul_mat(ctx0, k, k_b);
|
||||
kb = ggml_mul (ctx0, kb, decay_mask);
|
||||
|
||||
// [CS, CS, n_chunks, H_k * n_seqs]
|
||||
kq = ggml_mul_mat(ctx0, k, q);
|
||||
kq = ggml_mul(ctx0, kq, decay_mask);
|
||||
}
|
||||
|
||||
kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
// [CS, CS, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * attn;
|
||||
attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER);
|
||||
cb(attn, "attn", il);
|
||||
|
||||
ggml_tensor * identity;
|
||||
identity = ggml_view_1d(ctx0, attn, CS, 0);
|
||||
@@ -115,6 +160,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
cb(lhs, "dnet_add_ch_lhs", il);
|
||||
|
||||
attn = ggml_neg(ctx0, attn);
|
||||
cb(attn, "attn_pre_solve", il);
|
||||
|
||||
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
|
||||
attn = ggml_add(ctx0, lin_solve, identity);
|
||||
@@ -123,7 +169,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
// [S_v, CS, n_chunks, H_v * n_seqs]
|
||||
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn);
|
||||
|
||||
// [CS, 1, n_chunks, H_v * n_seqs]
|
||||
// [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_exp = ggml_exp(ctx0, g_cs);
|
||||
|
||||
k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b));
|
||||
@@ -136,16 +182,10 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn);
|
||||
cb(k_cd, "k_cumdecay", il);
|
||||
|
||||
// [S_k, CS, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp);
|
||||
// [1, CS, n_chunks, H_k * n_seqs] KDA: [S_k, CS, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * g_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_exp));
|
||||
ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t);
|
||||
|
||||
// [CS, CS, n_chunks, H_k * n_seqs]
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
kq = ggml_mul(ctx0, kq, decay_mask);
|
||||
kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
// vectorized calculation of key_gdiff
|
||||
// improved from the chunked version:
|
||||
// g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
|
||||
@@ -156,8 +196,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
|
||||
// get last element in g_cumsum along CS dimension (ne0)
|
||||
// example: [[x, y, z, ..., last], ...] -> [[last], ...]
|
||||
// [1, 1, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3],
|
||||
// [1, 1, n_chunks, H_v * n_seqs] KDA: [1, S_k, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, g_cs->ne[1], g_cs->ne[2], g_cs->ne[3],
|
||||
g_cs->nb[1],
|
||||
g_cs->nb[2],
|
||||
g_cs->nb[3],
|
||||
@@ -167,16 +207,15 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
// TODO: remove this cont when CUDA supports non-cont unary ops
|
||||
g_last = ggml_cont(ctx0, g_last);
|
||||
|
||||
// [1, 1, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
|
||||
cb(g_last_exp, "g_last_exp", il);
|
||||
// [1, 1, n_chunks, H_v * n_seqs] KDA: [S_k, 1, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_last_exp_t = ggml_transpose(ctx0, ggml_exp(ctx0, g_last));
|
||||
cb(g_last_exp_t, "g_last_exp_t", il);
|
||||
|
||||
// [CS, 1, n_chunks, H_v * n_seqs]
|
||||
// [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last));
|
||||
cb(g_diff, "g_diff", il);
|
||||
|
||||
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
|
||||
ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp);
|
||||
ggml_tensor * g_diff_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_exp(ctx0, g_diff)));
|
||||
|
||||
// [S_k, CS, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t);
|
||||
@@ -227,8 +266,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs]
|
||||
|
||||
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
|
||||
ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk);
|
||||
s_t = ggml_mul(ctx0, s_t, ch_g_last_exp);
|
||||
ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk);
|
||||
|
||||
s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t);
|
||||
s_t = ggml_add(ctx0, s_t, kgv);
|
||||
cb(s_t, "dnet_add_ch_state", il);
|
||||
}
|
||||
@@ -241,9 +281,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
ggml_row_size(v->type, S_v),
|
||||
ggml_row_size(v->type, S_v * CS * n_chunks),
|
||||
ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
|
||||
|
||||
o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
|
||||
s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs]
|
||||
s = ggml_transpose(ctx0, s_t);
|
||||
cb(s, "output_state", il);
|
||||
|
||||
return {o, s};
|
||||
}
|
||||
@@ -273,9 +313,10 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
||||
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
||||
GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v);
|
||||
GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs);
|
||||
GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs);
|
||||
GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs);
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_k);
|
||||
|
||||
@@ -291,8 +332,10 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
||||
cb(b, "b_in", il);
|
||||
cb(g, "g_in", il);
|
||||
|
||||
g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs);
|
||||
b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs);
|
||||
// GDA: [1, 1, H_v, n_seqs]
|
||||
// KDA: [1, S_k, H_v, n_seqs]
|
||||
g = ggml_reshape_4d(ctx0, g, 1, g->ne[0], H_v, n_seqs);
|
||||
b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs);
|
||||
|
||||
// [S_v, S_v, H_v, n_seqs]
|
||||
g = ggml_exp(ctx0, g);
|
||||
|
||||
123
src/models/jais2.cpp
Normal file
123
src/models/jais2.cpp
Normal file
@@ -0,0 +1,123 @@
|
||||
#include "models.h"
|
||||
|
||||
// JAIS-2 model graph builder
|
||||
// Uses: LayerNorm (not RMSNorm), relu2 activation, separate Q/K/V, RoPE embeddings
|
||||
llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KV input for attention
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// Pre-attention LayerNorm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm,
|
||||
model.layers[il].attn_norm_b,
|
||||
LLM_NORM, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// Self-attention with separate Q, K, V projections
|
||||
{
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur_bias", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur_bias", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur_bias", il);
|
||||
|
||||
// Reshape for attention
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// Apply RoPE
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur_rope", il);
|
||||
cb(Kcur, "Kcur_rope", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// Pre-FFN LayerNorm
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm,
|
||||
model.layers[il].ffn_norm_b,
|
||||
LLM_NORM, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// FFN with relu2 activation (ReLU squared) - no gate projection
|
||||
// up -> relu2 -> down
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||
NULL, NULL, NULL, // no gate
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||
NULL,
|
||||
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
// Residual connection
|
||||
inpL = ggml_add(ctx0, cur, ffn_inp);
|
||||
inpL = build_cvec(inpL, il);
|
||||
cb(inpL, "l_out", il);
|
||||
}
|
||||
|
||||
// Final LayerNorm
|
||||
cur = build_norm(inpL,
|
||||
model.output_norm,
|
||||
model.output_norm_b,
|
||||
LLM_NORM, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
res->t_embd = cur;
|
||||
|
||||
// Output projection
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
@@ -3,8 +3,6 @@
|
||||
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#define CHUNK_SIZE 64
|
||||
|
||||
// Causal Conv1d function for Q,K,V
|
||||
// When qkv is 0, it is Q, 1 is K, 2 is V
|
||||
static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) {
|
||||
@@ -67,7 +65,7 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t
|
||||
}
|
||||
|
||||
llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_build_mamba_base(params), model(model) {
|
||||
llm_build_delta_net_base(params), model(model) {
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -86,17 +84,6 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
||||
// Output ids for selecting which tokens to output
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
ggml_tensor * chunked_causal_mask =
|
||||
ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
|
||||
GGML_TRI_TYPE_LOWER);
|
||||
|
||||
ggml_tensor * chunked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
|
||||
ggml_tensor * chunked_diag_mask = ggml_add(ctx0, chunked_causal_mask, chunked_identity);
|
||||
|
||||
ggml_build_forward_expand(gf, chunked_causal_mask);
|
||||
ggml_build_forward_expand(gf, chunked_identity);
|
||||
ggml_build_forward_expand(gf, chunked_diag_mask);
|
||||
|
||||
// Kimi dimension constants
|
||||
const int64_t n_head = hparams.n_head();
|
||||
const int64_t head_dim = hparams.n_embd_head_kda;
|
||||
@@ -162,27 +149,35 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
||||
g1 = ggml_mul(ctx0, g1, A);
|
||||
cb(g1, "kda_g1", il);
|
||||
|
||||
g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs);
|
||||
|
||||
// Compute beta (mixing coefficient)
|
||||
ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur);
|
||||
beta = ggml_reshape_4d(ctx0, beta, n_head, 1, n_seq_tokens, n_seqs);
|
||||
beta = ggml_reshape_4d(ctx0, beta, 1, n_head, n_seq_tokens, n_seqs);
|
||||
cb(beta, "kda_beta", il);
|
||||
|
||||
beta = ggml_sigmoid(ctx0, beta);
|
||||
|
||||
// Reshape for KDA recurrence
|
||||
// {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs}
|
||||
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
|
||||
|
||||
g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs);
|
||||
|
||||
// Get SSM state and compute KDA recurrence using ggml_kda_scan
|
||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||
ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs);
|
||||
state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs);
|
||||
// Choose between build_kda_chunking and build_kda_recurrent based on n_tokens
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out = n_seq_tokens == 1 ?
|
||||
build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
|
||||
build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il);
|
||||
|
||||
ggml_tensor * output = attn_out.first;
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
|
||||
Qcur = ggml_l2_norm(ctx0, Qcur, eps_norm);
|
||||
Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm);
|
||||
|
||||
// Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens
|
||||
std::pair<ggml_tensor *, ggml_tensor *> attn_out = n_seq_tokens == 1 ?
|
||||
build_delta_net_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
|
||||
build_delta_net_chunking(Qcur, Kcur, Vcur, g1, beta, state, il);
|
||||
|
||||
ggml_tensor * output = ggml_cont(ctx0, attn_out.first);
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
cb(output, "attn_output", il);
|
||||
cb(new_state, "new_state", il);
|
||||
@@ -393,385 +388,3 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
/*
|
||||
This is a ggml implementation of the naive_chunk_kda function of
|
||||
https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py
|
||||
*/
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_chunking(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * gk,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il) {
|
||||
GGML_ASSERT(ggml_is_contiguous(state));
|
||||
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
const int64_t n_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(gk->ne[0] == S_v && gk->ne[1] == H_v && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
|
||||
|
||||
// TODO: can this ever be false?
|
||||
const bool use_qk_l2norm = true;
|
||||
|
||||
if (use_qk_l2norm) {
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
|
||||
q = ggml_l2_norm(ctx0, q, eps_norm);
|
||||
k = ggml_l2_norm(ctx0, k, eps_norm);
|
||||
}
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_v);
|
||||
|
||||
beta = ggml_sigmoid(ctx0, beta);
|
||||
|
||||
cb(q, "q_in", il);
|
||||
cb(k, "k_in", il);
|
||||
cb(v, "v_in", il);
|
||||
cb(beta, "beta_in", il);
|
||||
cb(gk, "gk_in", il);
|
||||
|
||||
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
|
||||
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
|
||||
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
|
||||
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
|
||||
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
|
||||
|
||||
cb(q, "q_perm", il);
|
||||
cb(k, "k_perm", il);
|
||||
cb(v, "v_perm", il);
|
||||
cb(beta, "beta_perm", il);
|
||||
cb(gk, "gk_perm", il);
|
||||
cb(state, "state_in", il);
|
||||
|
||||
GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
|
||||
|
||||
// Do padding
|
||||
const int64_t chunk_size = CHUNK_SIZE;
|
||||
|
||||
const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
|
||||
const int64_t n_chunks = (n_tokens + pad) / chunk_size;
|
||||
|
||||
q = ggml_pad(ctx0, q, 0, pad, 0, 0);
|
||||
k = ggml_pad(ctx0, k, 0, pad, 0, 0);
|
||||
v = ggml_pad(ctx0, v, 0, pad, 0, 0);
|
||||
gk = ggml_pad(ctx0, gk, 0, pad, 0, 0);
|
||||
beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
|
||||
|
||||
cb(q, "q_pad", il);
|
||||
cb(k, "k_pad", il);
|
||||
cb(v, "v_pad", il);
|
||||
cb(beta, "beta_pad", il);
|
||||
cb(gk, "gk_pad", il);
|
||||
|
||||
ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
|
||||
ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
|
||||
|
||||
cb(v_beta, "v_beta", il);
|
||||
cb(k_beta, "k_beta", il);
|
||||
|
||||
const int64_t HB = H_k * n_seqs;
|
||||
|
||||
q = ggml_cont_4d(ctx0, q, S_k, chunk_size, n_chunks, HB);
|
||||
k = ggml_cont_4d(ctx0, k, S_k, chunk_size, n_chunks, HB);
|
||||
k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, HB);
|
||||
v = ggml_cont_4d(ctx0, v, S_v, chunk_size, n_chunks, HB);
|
||||
v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, HB);
|
||||
|
||||
gk = ggml_cont_4d(ctx0, gk, S_k, chunk_size, n_chunks, HB);
|
||||
beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, HB);
|
||||
|
||||
// switch for cumsum
|
||||
gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 1, 0, 2, 3), chunk_size, S_k, n_chunks, HB);
|
||||
cb(gk, "gk", il);
|
||||
ggml_tensor * gk_cumsum = ggml_cumsum(ctx0, gk);
|
||||
cb(gk_cumsum, "gk_cumsum", il);
|
||||
|
||||
/*
|
||||
Compute Akk and Aqk loop together
|
||||
Akk loop:
|
||||
for i in range(BT):
|
||||
k_i = k[..., i, :] # k_i [B,H,NT,S]
|
||||
g_i = g[..., i:i+1, :] # g_i [B,H,NT,1,S]
|
||||
A[..., i] = torch.einsum('... c d, ... d -> ... c', k * (g - g_i).exp(), k_i)
|
||||
Aqk loop:
|
||||
for j in range(BT):
|
||||
k_j = k[:, :, i, j]
|
||||
g_j = g[:, :, i, j:j+1, :]
|
||||
A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j)
|
||||
*/
|
||||
const int64_t CHB = n_chunks * H_k * n_seqs;
|
||||
ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB]
|
||||
ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB]
|
||||
|
||||
ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
|
||||
// decay_mask [chunk_size,chunk_size,S_k,CHB]
|
||||
ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i);
|
||||
cb(decay_mask, "decay_mask", il);
|
||||
|
||||
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
|
||||
cb(decay_mask, "decay_masked", il);
|
||||
decay_mask = ggml_exp(ctx0, decay_mask);
|
||||
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
|
||||
|
||||
// decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
|
||||
decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
|
||||
|
||||
ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB);
|
||||
ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
|
||||
ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB);
|
||||
|
||||
ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i);
|
||||
ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i);
|
||||
|
||||
// decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
|
||||
ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j);
|
||||
ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j);
|
||||
Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB)));
|
||||
Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB)));
|
||||
cb(Akk, "Akk", il);
|
||||
cb(Aqk, "Aqk", il);
|
||||
|
||||
Akk = ggml_mul(ctx0, Akk, beta);
|
||||
Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask));
|
||||
cb(Akk, "attn_pre_solve", il);
|
||||
|
||||
Aqk = ggml_mul(ctx0, Aqk, diag_mask);
|
||||
Aqk = ggml_scale(ctx0, Aqk, scale); // scale q
|
||||
cb(Aqk, "Aqk_masked", il);
|
||||
|
||||
// for i in range(1, chunk_size):
|
||||
// row = attn[..., i, :i].clone()
|
||||
// sub = attn[..., :i, :i].clone()
|
||||
// attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
// attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
||||
//
|
||||
// We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
|
||||
ggml_tensor * attn_lower = ggml_mul(ctx0, Akk, causal_mask);
|
||||
ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
|
||||
|
||||
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, Akk, true, true, false);
|
||||
Akk = ggml_mul(ctx0, lin_solve, causal_mask);
|
||||
Akk = ggml_add(ctx0, Akk, identity);
|
||||
|
||||
cb(Akk, "attn_solved", il);
|
||||
|
||||
// switch back for downstream
|
||||
gk_cumsum = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk_cumsum, 1, 0, 2, 3), S_k, chunk_size, n_chunks, HB);
|
||||
ggml_tensor * gkexp = ggml_exp(ctx0, gk_cumsum);
|
||||
cb(gk_cumsum, "gk_cumsum", il);
|
||||
|
||||
// u = (A*beta[..., None, :]) @ v aka U_[t]
|
||||
ggml_tensor * vb = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), Akk);
|
||||
|
||||
ggml_tensor * kbeta_gkexp = ggml_mul(ctx0, k_beta, gkexp);
|
||||
cb(kbeta_gkexp, "kbeta_gkexp", il);
|
||||
|
||||
ggml_tensor * k_cumdecay = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gkexp)), Akk);
|
||||
cb(k_cumdecay, "k_cumdecay", il);
|
||||
|
||||
ggml_tensor * core_attn_out = nullptr;
|
||||
ggml_tensor * new_state = ggml_dup(ctx0, state);
|
||||
|
||||
cb(new_state, "new_state", il);
|
||||
|
||||
for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
|
||||
// extract one chunk worth of data
|
||||
auto chunkify = [=](ggml_tensor * t) {
|
||||
return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
|
||||
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
|
||||
};
|
||||
auto chunkify_A = [=](ggml_tensor * t) {
|
||||
return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, chunk_size, 1, t->ne[3],
|
||||
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
|
||||
};
|
||||
|
||||
|
||||
// k [S,BT,NT,H*B] => k_chunk [S,BT,1,H*B]
|
||||
ggml_tensor * k_chunk = chunkify(k);
|
||||
ggml_tensor * q_chunk = chunkify(q);
|
||||
ggml_tensor * vb_chunk = chunkify(vb);
|
||||
|
||||
// gk_cumsum [S,BT,NT,H*B] => gk_cs_chunk [S,BT,1,H*B]
|
||||
ggml_tensor * gk_cs_chunk = chunkify(gk_cumsum);
|
||||
ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
|
||||
ggml_tensor * gkexp_chunk = ggml_exp(ctx0, gk_cs_chunk);
|
||||
ggml_tensor * Aqk_chunk = chunkify_A(Aqk);
|
||||
|
||||
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
|
||||
|
||||
// new_state [S,S,1,H*B] k_cumdecay_chunk [S,BT,1,H*B]
|
||||
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state or W_[t] @ S_[t]
|
||||
ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
|
||||
|
||||
// v_new = v_i - v_prime or U_[t] - W_[t]*S_[t]
|
||||
ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, vb_chunk, v_prime), v_prime);
|
||||
ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
|
||||
|
||||
// q_chunk [S,BT,1,H*B] gkexp_chunk [S,BT,1,H*B]
|
||||
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
// or Gamma_[t]*Q_]t] @ S
|
||||
ggml_tensor * q_gk_exp = ggml_mul(ctx0, q_chunk, gkexp_chunk);
|
||||
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_gk_exp);
|
||||
attn_inter = ggml_scale(ctx0, attn_inter, scale); // scale q
|
||||
|
||||
// v_new_t [S,BT,1,H*B] Aqk [BT,BT,1,H*B]
|
||||
// core_attn_out[:, :, i] = attn_inter + attn @ v_new or A' @ (U_[t] - W_[t]*S_[t])
|
||||
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, Aqk_chunk);
|
||||
|
||||
// o[:, :, i] = (q_i * g_i.exp()) @ S + A @ v_i
|
||||
ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
|
||||
|
||||
core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1);
|
||||
|
||||
ggml_tensor * gk_cum_last =
|
||||
ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cs_chunk, gk_cs_chunk->ne[0], 1, gk_cs_chunk->ne[2], gk_cs_chunk->ne[3],
|
||||
gk_cs_chunk->nb[1], gk_cs_chunk->nb[2], gk_cs_chunk->nb[3],
|
||||
gk_cs_chunk->nb[1] * (gk_cs_chunk->ne[1] - 1)));
|
||||
|
||||
ggml_tensor * gkexp_last = ggml_exp(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, gk_cum_last)));
|
||||
|
||||
ggml_tensor * gk_diff = ggml_neg(ctx0, ggml_sub(ctx0, gk_cs_chunk, gk_cum_last));
|
||||
|
||||
ggml_tensor * gk_diff_exp = ggml_exp(ctx0, gk_diff);
|
||||
|
||||
ggml_tensor * key_gkdiff = ggml_mul(ctx0, k_chunk, gk_diff_exp);
|
||||
|
||||
// rearrange((g_i[:,:,-1:] - g_i).exp()*k_i, 'b h c k -> b h k c') @ (U_[t] - W_[t] @ S)
|
||||
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gkdiff)));
|
||||
|
||||
new_state = ggml_add(ctx0,
|
||||
ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gkexp_last, gkexp_last->ne[0], gkexp_last->ne[1], H_v, n_seqs)),
|
||||
ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
|
||||
}
|
||||
|
||||
core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs);
|
||||
|
||||
// truncate padded tokens
|
||||
ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
|
||||
S_v, n_tokens, H_v, n_seqs,
|
||||
ggml_row_size(core_attn_out->type, S_v),
|
||||
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
|
||||
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
|
||||
output_tokens = ggml_cont(ctx0, output_tokens);
|
||||
// permute back to (S_v, H_v, n_tokens, n_seqs)
|
||||
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
|
||||
output_tokens = ggml_cont(ctx0, output_tokens);
|
||||
|
||||
cb(new_state, "output_state", il);
|
||||
|
||||
return {output_tokens, new_state};
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_kimi_linear::build_kda_autoregressive(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * gk,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
int il) {
|
||||
GGML_ASSERT(ggml_is_contiguous(v));
|
||||
GGML_ASSERT(ggml_is_contiguous(gk));
|
||||
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
const int64_t n_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(n_tokens == 1);
|
||||
GGML_ASSERT(v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(gk->ne[0] == S_k && gk->ne[1] == H_k && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_k && state->ne[2] == H_v && state->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
|
||||
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
|
||||
q = ggml_l2_norm(ctx0, q, eps_norm);
|
||||
k = ggml_l2_norm(ctx0, k, eps_norm);
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_v);
|
||||
|
||||
q = ggml_scale(ctx0, q, scale);
|
||||
beta = ggml_sigmoid(ctx0, beta);
|
||||
|
||||
cb(q, "q_in", il);
|
||||
cb(k, "k_in", il);
|
||||
cb(v, "v_in", il);
|
||||
cb(beta, "beta_in", il);
|
||||
cb(gk, "gk_in", il);
|
||||
|
||||
// g [H,1,B,1] g_t [1,H,B,1] => [1,1,H,B]
|
||||
// gk [S,H,1,B] => [S,1,H,B] gk_t [1,S,H,B]
|
||||
// beta [H,1,1,B] beta_t [1,H,1,B] => [1,1,H,B]
|
||||
gk = ggml_reshape_4d(ctx0, gk, S_k, 1, H_k, n_seqs);
|
||||
ggml_tensor * gk_t = ggml_cont(ctx0, ggml_transpose(ctx0, gk));
|
||||
ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
|
||||
|
||||
// Apply exponential to gk_t
|
||||
gk_t = ggml_exp(ctx0, gk_t);
|
||||
// Apply the gated delta rule for the single timestep
|
||||
// last_recurrent_state = last_recurrent_state * gk_t
|
||||
// S = S * g_i[..., None].exp()
|
||||
state = ggml_mul(ctx0, state, gk_t);
|
||||
|
||||
ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
|
||||
|
||||
// state [S,S,H,B] k [S,1,H,B] k_state [S_v,1,H,B]
|
||||
k = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs);
|
||||
ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k);
|
||||
|
||||
// v_i - (k_i[..., None] * S).sum(-2)
|
||||
v = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
|
||||
ggml_tensor * v_diff = ggml_sub(ctx0, v, k_state);
|
||||
|
||||
// b_i[..., None] * k_i
|
||||
ggml_tensor * k_beta = ggml_mul(ctx0, k, beta_t);
|
||||
|
||||
// S = S + torch.einsum('b h k, b h v -> b h k v', b_i[..., None] * k_i, v_i - (k_i[..., None] * S).sum(-2))
|
||||
// v_diff_t [1,S_v,H,B] k_beta_t [1,S_k,H,B] state [S_v,S_k,H,B]
|
||||
state = ggml_add(ctx0, state, ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_diff)), ggml_cont(ctx0, ggml_transpose(ctx0, k_beta))));
|
||||
|
||||
q = ggml_reshape_4d(ctx0, q, S_k, 1, H_k, n_seqs);
|
||||
state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
|
||||
ggml_tensor * core_attn_out = ggml_mul_mat(ctx0, state_t, q);
|
||||
// core_attn_out should be [S_v, 1, H_v, n_seqs] after this
|
||||
cb(core_attn_out, "output_tokens", il);
|
||||
cb(state, "new_state", il);
|
||||
|
||||
return {core_attn_out, state};
|
||||
}
|
||||
|
||||
|
||||
@@ -1,18 +1,149 @@
|
||||
#include "models.h"
|
||||
|
||||
#include "../llama-memory-hybrid-iswa.h"
|
||||
#include "../llama-memory-hybrid.h"
|
||||
|
||||
template <bool iswa>
|
||||
llm_build_lfm2<iswa>::llm_build_lfm2(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
using inp_hybrid_type = std::conditional_t<iswa, llm_graph_input_mem_hybrid_iswa, llm_graph_input_mem_hybrid>;
|
||||
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
|
||||
using mem_hybrid_ctx = std::conditional_t<iswa, llama_memory_hybrid_iswa_context, llama_memory_hybrid_context>;
|
||||
|
||||
llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params),
|
||||
model(model) {
|
||||
// lambda helpers for readability
|
||||
auto build_dense_feed_forward = [&model, this](ggml_tensor * cur, int il) -> ggml_tensor * {
|
||||
GGML_ASSERT(!model.layers[il].ffn_up_b);
|
||||
GGML_ASSERT(!model.layers[il].ffn_gate_b);
|
||||
GGML_ASSERT(!model.layers[il].ffn_down_b);
|
||||
return build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
};
|
||||
auto build_moe_feed_forward = [&model, this](ggml_tensor * cur, int il) -> ggml_tensor * {
|
||||
return build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0,
|
||||
static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func), il);
|
||||
};
|
||||
auto build_attn_block = [&model, this](ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
inp_attn_type * inp_attn,
|
||||
int il) -> ggml_tensor * {
|
||||
GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
|
||||
const auto n_embd_head = hparams.n_embd_head_v;
|
||||
const auto n_head_kv = hparams.n_head_kv(il);
|
||||
|
||||
auto * q = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(q, "model.layers.{}.self_attn.q_proj", il);
|
||||
auto * k = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(k, "model.layers.{}.self_attn.k_proj", il);
|
||||
auto * v = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(v, "model.layers.{}.self_attn.v_proj", il);
|
||||
|
||||
q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens);
|
||||
k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens);
|
||||
v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// qk norm
|
||||
q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(q, "model.layers.{}.self_attn.q_layernorm", il);
|
||||
k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(k, "model.layers.{}.self_attn.k_layernorm", il);
|
||||
|
||||
// RoPE
|
||||
q = ggml_rope_ext(ctx0, q, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
|
||||
attn_factor, beta_fast, beta_slow);
|
||||
k = ggml_rope_ext(ctx0, k, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
|
||||
attn_factor, beta_fast, beta_slow);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
|
||||
|
||||
cb(cur, "model.layers.{}.self_attn.out_proj", il);
|
||||
|
||||
return cur;
|
||||
};
|
||||
auto build_shortconv_block = [&model, this](ggml_tensor * cur,
|
||||
llm_graph_input_rs * inp_recr,
|
||||
int il) -> ggml_tensor * {
|
||||
const auto * mctx_cur = static_cast<const mem_hybrid_ctx *>(mctx)->get_recr();
|
||||
const uint32_t kv_head = mctx_cur->get_head();
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
GGML_ASSERT(n_seqs != 0);
|
||||
GGML_ASSERT(ubatch.equal_seqs());
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
|
||||
GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
|
||||
const uint32_t d_conv = hparams.n_shortconv_l_cache - 1;
|
||||
|
||||
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
||||
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
|
||||
|
||||
auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
|
||||
cb(bcx, "model.layers.{}.conv.in_proj", il);
|
||||
|
||||
constexpr auto n_chunks = 3;
|
||||
GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
|
||||
const auto chunk_size = bcx->ne[0] / n_chunks;
|
||||
auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
|
||||
0 * chunk_size * ggml_element_size(bcx));
|
||||
auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
|
||||
1 * chunk_size * ggml_element_size(bcx));
|
||||
auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
|
||||
2 * chunk_size * ggml_element_size(bcx));
|
||||
|
||||
auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
|
||||
|
||||
// read conv state
|
||||
auto * conv_state = mctx_cur->get_r_l(il);
|
||||
auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs);
|
||||
auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
|
||||
|
||||
bx = ggml_concat(ctx0, conv, bx, 0);
|
||||
GGML_ASSERT(bx->ne[0] > conv->ne[0]);
|
||||
|
||||
// last d_conv columns is a new conv state
|
||||
auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2],
|
||||
(bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
|
||||
GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
|
||||
|
||||
// write new conv conv state
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv,
|
||||
ggml_view_1d(ctx0, conv_state, ggml_nelements(new_conv),
|
||||
kv_head * d_conv * n_embd * ggml_element_size(new_conv))));
|
||||
|
||||
auto * conv_kernel = model.layers[il].shortconv.conv;
|
||||
auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
|
||||
cb(conv_out, "model.layers.{}.conv.conv", il);
|
||||
|
||||
auto * y = ggml_mul(ctx0, c, conv_out);
|
||||
y = build_lora_mm(model.layers[il].shortconv.out_proj, y);
|
||||
cb(y, "model.layers.{}.conv.out_proj", il);
|
||||
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
|
||||
y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs);
|
||||
|
||||
return y;
|
||||
};
|
||||
|
||||
// actual graph construction starts here
|
||||
ggml_tensor * cur = build_inp_embd(model.tok_embd);
|
||||
cb(cur, "model.embed_tokens", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
inp_hybrid_type * inp_hybrid = nullptr;
|
||||
if constexpr (iswa) {
|
||||
inp_hybrid = build_inp_mem_hybrid_iswa();
|
||||
} else {
|
||||
inp_hybrid = build_inp_mem_hybrid();
|
||||
}
|
||||
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
auto * inp_hybrid = build_inp_mem_hybrid();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
@@ -54,122 +185,6 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_lfm2::build_moe_feed_forward(ggml_tensor * cur, int il) const {
|
||||
return build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0,
|
||||
static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func), il);
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_lfm2::build_dense_feed_forward(ggml_tensor * cur, int il) const {
|
||||
GGML_ASSERT(!model.layers[il].ffn_up_b);
|
||||
GGML_ASSERT(!model.layers[il].ffn_gate_b);
|
||||
GGML_ASSERT(!model.layers[il].ffn_down_b);
|
||||
return build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_lfm2::build_attn_block(ggml_tensor * cur,
|
||||
ggml_tensor * inp_pos,
|
||||
llm_graph_input_attn_kv * inp_attn,
|
||||
int il) const {
|
||||
GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
|
||||
const auto n_embd_head = hparams.n_embd_head_v;
|
||||
const auto n_head_kv = hparams.n_head_kv(il);
|
||||
|
||||
auto * q = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(q, "model.layers.{}.self_attn.q_proj", il);
|
||||
auto * k = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(k, "model.layers.{}.self_attn.k_proj", il);
|
||||
auto * v = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(v, "model.layers.{}.self_attn.v_proj", il);
|
||||
|
||||
q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens);
|
||||
k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens);
|
||||
v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// qk norm
|
||||
q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(q, "model.layers.{}.self_attn.q_layernorm", il);
|
||||
k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(k, "model.layers.{}.self_attn.k_layernorm", il);
|
||||
|
||||
// RoPE
|
||||
q = ggml_rope_ext(ctx0, q, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
|
||||
attn_factor, beta_fast, beta_slow);
|
||||
k = ggml_rope_ext(ctx0, k, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor,
|
||||
attn_factor, beta_fast, beta_slow);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
|
||||
|
||||
cb(cur, "model.layers.{}.self_attn.out_proj", il);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_build_lfm2::build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il) {
|
||||
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
|
||||
const uint32_t kv_head = mctx_cur->get_head();
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
GGML_ASSERT(n_seqs != 0);
|
||||
GGML_ASSERT(ubatch.equal_seqs());
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
|
||||
GGML_ASSERT(hparams.n_shortconv_l_cache > 1);
|
||||
const uint32_t d_conv = hparams.n_shortconv_l_cache - 1;
|
||||
|
||||
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
||||
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
|
||||
|
||||
auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur);
|
||||
cb(bcx, "model.layers.{}.conv.in_proj", il);
|
||||
|
||||
constexpr auto n_chunks = 3;
|
||||
GGML_ASSERT(bcx->ne[0] % n_chunks == 0);
|
||||
const auto chunk_size = bcx->ne[0] / n_chunks;
|
||||
auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
|
||||
0 * chunk_size * ggml_element_size(bcx));
|
||||
auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
|
||||
1 * chunk_size * ggml_element_size(bcx));
|
||||
auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2],
|
||||
2 * chunk_size * ggml_element_size(bcx));
|
||||
|
||||
auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x));
|
||||
|
||||
// read conv state
|
||||
auto * conv_state = mctx_cur->get_r_l(il);
|
||||
auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs);
|
||||
auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
|
||||
|
||||
bx = ggml_concat(ctx0, conv, bx, 0);
|
||||
GGML_ASSERT(bx->ne[0] > conv->ne[0]);
|
||||
|
||||
// last d_conv columns is a new conv state
|
||||
auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2],
|
||||
(bx->ne[0] - conv->ne[0]) * ggml_element_size(bx));
|
||||
GGML_ASSERT(ggml_are_same_shape(conv, new_conv));
|
||||
|
||||
// write new conv conv state
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv,
|
||||
ggml_view_1d(ctx0, conv_state, ggml_nelements(new_conv),
|
||||
kv_head * d_conv * n_embd * ggml_element_size(new_conv))));
|
||||
|
||||
auto * conv_kernel = model.layers[il].shortconv.conv;
|
||||
auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel);
|
||||
cb(conv_out, "model.layers.{}.conv.conv", il);
|
||||
|
||||
auto * y = ggml_mul(ctx0, c, conv_out);
|
||||
y = build_lora_mm(model.layers[il].shortconv.out_proj, y);
|
||||
cb(y, "model.layers.{}.conv.out_proj", il);
|
||||
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
|
||||
y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs);
|
||||
|
||||
return y;
|
||||
}
|
||||
// Explicit template instantiations
|
||||
template struct llm_build_lfm2<true>;
|
||||
template struct llm_build_lfm2<false>;
|
||||
|
||||
@@ -316,12 +316,15 @@ struct llm_build_jais : public llm_graph_context {
|
||||
llm_build_jais(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_jais2 : public llm_graph_context {
|
||||
llm_build_jais2(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_jamba : public llm_build_mamba_base {
|
||||
llm_build_jamba(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
// TODO: derive llm_build_delta_net_base instead
|
||||
struct llm_build_kimi_linear : public llm_build_mamba_base {
|
||||
struct llm_build_kimi_linear : public llm_build_delta_net_base {
|
||||
llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params);
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_kda_autoregressive(
|
||||
@@ -348,15 +351,9 @@ struct llm_build_kimi_linear : public llm_build_mamba_base {
|
||||
const llama_model & model;
|
||||
};
|
||||
|
||||
template <bool iswa>
|
||||
struct llm_build_lfm2 : public llm_graph_context {
|
||||
const llama_model & model;
|
||||
|
||||
llm_build_lfm2(const llama_model & model, const llm_graph_params & params);
|
||||
ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, int il) const;
|
||||
ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, int il) const;
|
||||
ggml_tensor * build_attn_block(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, int il) const;
|
||||
ggml_tensor * build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il);
|
||||
|
||||
};
|
||||
|
||||
struct llm_build_llada : public llm_graph_context {
|
||||
@@ -542,8 +539,7 @@ private:
|
||||
const llama_model & model;
|
||||
};
|
||||
|
||||
// TODO: derive llm_build_delta_net_base instead
|
||||
struct llm_build_qwen35 : public llm_graph_context {
|
||||
struct llm_build_qwen35 : public llm_build_delta_net_base {
|
||||
llm_build_qwen35(const llama_model & model, const llm_graph_params & params);
|
||||
private:
|
||||
ggml_tensor * build_layer_attn(
|
||||
@@ -556,39 +552,12 @@ private:
|
||||
ggml_tensor * build_layer_attn_linear(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il);
|
||||
|
||||
|
||||
ggml_tensor * build_layer_ffn(
|
||||
ggml_tensor * cur,
|
||||
int il);
|
||||
|
||||
// returns pair of output and new state
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_chunking(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il);
|
||||
|
||||
// returns pair of output and new state
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_autoregressive(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_norm_gated(
|
||||
ggml_tensor * input,
|
||||
ggml_tensor * weights,
|
||||
@@ -604,7 +573,7 @@ private:
|
||||
};
|
||||
|
||||
// TODO: derive llm_build_delta_net_base instead
|
||||
struct llm_build_qwen35moe : public llm_graph_context {
|
||||
struct llm_build_qwen35moe : public llm_build_delta_net_base {
|
||||
llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params);
|
||||
private:
|
||||
ggml_tensor * build_layer_attn(
|
||||
@@ -617,38 +586,12 @@ private:
|
||||
ggml_tensor * build_layer_attn_linear(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_layer_ffn(
|
||||
ggml_tensor * cur,
|
||||
int il);
|
||||
|
||||
// returns pair of output and new state
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_chunking(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il);
|
||||
|
||||
// returns pair of output and new state
|
||||
std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_autoregressive(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
int il);
|
||||
|
||||
ggml_tensor * build_norm_gated(
|
||||
ggml_tensor * input,
|
||||
ggml_tensor * weights,
|
||||
|
||||
@@ -104,13 +104,6 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll
|
||||
LLM_NORM, -1);
|
||||
cb(cur, "final_norm_out", -1);
|
||||
|
||||
if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
|
||||
// extracting cls token
|
||||
cur = ggml_view_1d(ctx0, cur, hparams.n_embd, 0);
|
||||
cb(cur, "cls_pooled_embd", -1);
|
||||
}
|
||||
|
||||
cb(cur, "res_embd", -1);
|
||||
res->t_embd = cur;
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
@@ -2,10 +2,8 @@
|
||||
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#define CHUNK_SIZE 64
|
||||
|
||||
llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params), model(model) {
|
||||
llm_build_delta_net_base(params), model(model) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
@@ -25,17 +23,6 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
ggml_tensor * causal_mask =
|
||||
ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
|
||||
GGML_TRI_TYPE_LOWER);
|
||||
|
||||
ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
|
||||
ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
|
||||
|
||||
ggml_build_forward_expand(gf, causal_mask);
|
||||
ggml_build_forward_expand(gf, identity);
|
||||
ggml_build_forward_expand(gf, diag_mask);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
@@ -45,7 +32,7 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa
|
||||
// Determine layer type and build appropriate attention mechanism
|
||||
if (hparams.is_recurrent(il)) {
|
||||
// Linear attention layer (gated delta net)
|
||||
cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
|
||||
cur = build_layer_attn_linear(inp->get_recr(), cur, il);
|
||||
} else {
|
||||
// Full attention layer
|
||||
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
|
||||
@@ -95,361 +82,6 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
// utility to get one slice from the third dimension
|
||||
// input dim: [x, y, c, b]
|
||||
// output dim: [x, y, 1, b]
|
||||
static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
|
||||
return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
|
||||
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35::build_delta_net_chunking(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il) {
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
const int64_t n_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
|
||||
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
|
||||
q = ggml_l2_norm(ctx0, q, eps_norm);
|
||||
k = ggml_l2_norm(ctx0, k, eps_norm);
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_v);
|
||||
|
||||
q = ggml_scale(ctx0, q, scale);
|
||||
|
||||
beta = ggml_sigmoid(ctx0, beta);
|
||||
|
||||
cb(q, "q_in", il);
|
||||
cb(k, "k_in", il);
|
||||
cb(v, "v_in", il);
|
||||
cb(beta, "beta_in", il);
|
||||
cb(g, "g_in", il);
|
||||
|
||||
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
|
||||
|
||||
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
|
||||
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
|
||||
|
||||
cb(q, "q_perm", il);
|
||||
cb(k, "k_perm", il);
|
||||
cb(v, "v_perm", il);
|
||||
cb(beta, "beta_perm", il);
|
||||
cb(g, "g_perm", il);
|
||||
cb(state, "state_in", il);
|
||||
|
||||
GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
|
||||
|
||||
// Do padding
|
||||
const int64_t chunk_size = CHUNK_SIZE;
|
||||
|
||||
const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
|
||||
const int64_t n_chunks = (n_tokens + pad) / chunk_size;
|
||||
|
||||
q = ggml_pad(ctx0, q, 0, pad, 0, 0);
|
||||
k = ggml_pad(ctx0, k, 0, pad, 0, 0);
|
||||
v = ggml_pad(ctx0, v, 0, pad, 0, 0);
|
||||
g = ggml_pad(ctx0, g, pad, 0, 0, 0);
|
||||
beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
|
||||
|
||||
cb(q, "q_pad", il);
|
||||
cb(k, "k_pad", il);
|
||||
cb(v, "v_pad", il);
|
||||
cb(beta, "beta_pad", il);
|
||||
cb(g, "g_pad", il);
|
||||
|
||||
ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
|
||||
ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
|
||||
|
||||
cb(v_beta, "v_beta", il);
|
||||
cb(k_beta, "k_beta", il);
|
||||
|
||||
q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs);
|
||||
k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs);
|
||||
k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
|
||||
v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs);
|
||||
v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
|
||||
|
||||
g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
|
||||
beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
|
||||
|
||||
ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
|
||||
cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
|
||||
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
|
||||
|
||||
ggml_tensor * gcs_j_broadcast =
|
||||
ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
|
||||
|
||||
ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
|
||||
cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
|
||||
decay_mask = ggml_exp(ctx0, decay_mask);
|
||||
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
|
||||
|
||||
ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
|
||||
|
||||
ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
|
||||
ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
|
||||
cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
|
||||
ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
|
||||
|
||||
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
|
||||
attn = ggml_mul(ctx0, lin_solve, causal_mask);
|
||||
attn = ggml_add(ctx0, attn, identity);
|
||||
cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
|
||||
|
||||
ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
|
||||
ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t);
|
||||
|
||||
ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
|
||||
cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * k_cumdecay =
|
||||
ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
|
||||
cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
|
||||
attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
|
||||
attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
|
||||
cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
|
||||
// vectorized calculation of key_gdiff
|
||||
// improved from the chunked version:
|
||||
// g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
|
||||
// g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
|
||||
// key_gdiff = key * g_diff.unsqueeze(-1)
|
||||
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
|
||||
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
|
||||
|
||||
// get last element in g_cumsum along chunk_size dimension (ne0)
|
||||
// example: [[x, y, z, ..., last], ...] -> [[last], ...]
|
||||
ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
|
||||
g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
|
||||
(g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
|
||||
g_last = ggml_cont(ctx0, g_last);
|
||||
cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
|
||||
cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
|
||||
cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
|
||||
ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp,
|
||||
1, chunk_size, n_chunks, g_diff_exp->ne[3]);
|
||||
|
||||
ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
|
||||
cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
|
||||
cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
|
||||
|
||||
// state to be updated per chunk
|
||||
ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
|
||||
cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
|
||||
|
||||
// shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
|
||||
ggml_tensor * core_attn_out = nullptr;
|
||||
|
||||
for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
|
||||
// shape: (S_k, chunk_size, 1, H_k * n_seqs)
|
||||
ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul
|
||||
|
||||
// shape: (S_v, chunk_size, 1, H_v * n_seqs)
|
||||
ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat
|
||||
|
||||
// shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
|
||||
ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul
|
||||
|
||||
// shape: (chunk_size, 1, H_v * n_seqs)
|
||||
ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat
|
||||
|
||||
// attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
// replaced by precomputed attn_kq
|
||||
ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
|
||||
cb(attn_chunk, "attn_chunk", il);
|
||||
|
||||
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
|
||||
|
||||
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||
ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
|
||||
cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs)
|
||||
|
||||
// v_new = v_i - v_prime
|
||||
ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
|
||||
ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
|
||||
cb(v_new, "v_new_chunk", il);
|
||||
|
||||
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk);
|
||||
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
|
||||
cb(attn_inter, "attn_inter_chunk", il);
|
||||
|
||||
// core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
|
||||
cb(v_attn, "v_attn_chunk", il);
|
||||
|
||||
ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
|
||||
cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs)
|
||||
|
||||
core_attn_out = core_attn_out == nullptr
|
||||
? core_attn_out_chunk
|
||||
: ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
|
||||
|
||||
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
|
||||
ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk);
|
||||
//ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
|
||||
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
|
||||
|
||||
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
|
||||
ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
|
||||
new_state = ggml_add(ctx0,
|
||||
ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
|
||||
ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
|
||||
}
|
||||
|
||||
// truncate padded tokens
|
||||
ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
|
||||
S_v, n_tokens, H_v, n_seqs,
|
||||
ggml_row_size(core_attn_out->type, S_v),
|
||||
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
|
||||
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
|
||||
output_tokens = ggml_cont(ctx0, output_tokens);
|
||||
cb(output_tokens, "output_tokens", il);
|
||||
|
||||
// permute back to (S_v, H_v, n_tokens, n_seqs)
|
||||
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
|
||||
output_tokens = ggml_cont(ctx0, output_tokens);
|
||||
|
||||
return {output_tokens, new_state};
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35::build_delta_net_autoregressive(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
int il) {
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
const int64_t n_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing
|
||||
GGML_ASSERT(v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
|
||||
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
|
||||
q = ggml_l2_norm(ctx0, q, eps_norm);
|
||||
k = ggml_l2_norm(ctx0, k, eps_norm);
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_v);
|
||||
|
||||
q = ggml_scale(ctx0, q, scale);
|
||||
beta = ggml_sigmoid(ctx0, beta);
|
||||
|
||||
cb(q, "q_in", il);
|
||||
cb(k, "k_in", il);
|
||||
cb(v, "v_in", il);
|
||||
cb(beta, "beta_in", il);
|
||||
cb(g, "g_in", il);
|
||||
|
||||
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
|
||||
|
||||
ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
|
||||
ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
|
||||
|
||||
// Apply exponential to g_t
|
||||
g_t = ggml_exp(ctx0, g_t);
|
||||
|
||||
// Apply the gated delta rule for the single timestep
|
||||
// last_recurrent_state = last_recurrent_state * g_t
|
||||
state = ggml_mul(ctx0, state, g_t);
|
||||
|
||||
// kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
|
||||
ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
|
||||
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed);
|
||||
// we need to sum over dim=-2, so we transpose, sum, then transpose again
|
||||
kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
|
||||
|
||||
// v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
|
||||
ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
|
||||
// delta = (v_t - kv_mem) * beta_t
|
||||
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs]
|
||||
ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t);
|
||||
|
||||
// last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
|
||||
ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
|
||||
state = ggml_add(ctx0, state, k_t_delta);
|
||||
|
||||
// Compute the attention output
|
||||
// core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
|
||||
ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t
|
||||
ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed);
|
||||
// again, since it's over dim = -2, transpose, sum, transpose back
|
||||
ggml_tensor * core_attn_out =
|
||||
ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
|
||||
|
||||
// core_attn_out should be [S_v, 1, H_v, n_seqs] after this
|
||||
cb(core_attn_out, "output_tokens", il);
|
||||
cb(state, "new_state", il);
|
||||
|
||||
return {core_attn_out, state};
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35::build_qkvz(
|
||||
ggml_tensor * input,
|
||||
int il) {
|
||||
@@ -561,9 +193,6 @@ ggml_tensor * llm_build_qwen35::build_layer_attn(
|
||||
ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il) {
|
||||
const auto * mctx_cur = inp->mctx;
|
||||
|
||||
@@ -587,8 +216,11 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
|
||||
ggml_tensor * z = qkvz.second;
|
||||
|
||||
ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur);
|
||||
beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs);
|
||||
beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
|
||||
cb(beta, "beta", il);
|
||||
|
||||
beta = ggml_sigmoid(ctx0, beta);
|
||||
|
||||
ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur);
|
||||
alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
|
||||
cb(alpha, "alpha", il);
|
||||
@@ -596,15 +228,16 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
|
||||
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
|
||||
ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
|
||||
cb(alpha_softplus, "a_softplus", il);
|
||||
|
||||
ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus
|
||||
cb(gate, "gate", il);
|
||||
|
||||
gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// Get convolution states from cache
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||
|
||||
// bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state();
|
||||
|
||||
// Build the convolution states tensor
|
||||
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
|
||||
cb(conv_states, "conv_states", il);
|
||||
@@ -613,11 +246,12 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
|
||||
ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
|
||||
const int64_t conv_kernel_size = conv_kernel->ne[0];
|
||||
const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
|
||||
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
|
||||
|
||||
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
|
||||
cb(conv_states, "conv_states_reshaped", il);
|
||||
|
||||
qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
|
||||
cb(qkv_mixed, "qkv_mixed_permuted", il);
|
||||
qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
|
||||
cb(qkv_mixed, "qkv_mixed_transposed", il);
|
||||
|
||||
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
|
||||
cb(conv_input, "conv_input", il);
|
||||
@@ -637,7 +271,10 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
|
||||
cb(conv_states_all, "conv_states_updated", il);
|
||||
|
||||
// Apply SSM convolution
|
||||
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
|
||||
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
|
||||
cb(state, "state_predelta", il);
|
||||
|
||||
ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
|
||||
cb(conv_output_proper, "conv_output_raw", il);
|
||||
|
||||
@@ -651,31 +288,41 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
|
||||
int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
|
||||
|
||||
// Extract the convolved Q, K, V from conv_output
|
||||
ggml_tensor * q_conv =
|
||||
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
|
||||
ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
|
||||
ggml_row_size(conv_qkv_mix->type, head_k_dim),
|
||||
nb1_qkv,
|
||||
nb1_qkv * n_seq_tokens,
|
||||
0);
|
||||
|
||||
ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
|
||||
ggml_row_size(conv_qkv_mix->type, head_k_dim),
|
||||
nb1_qkv,
|
||||
nb1_qkv * n_seq_tokens,
|
||||
head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
||||
|
||||
ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
|
||||
ggml_row_size(conv_qkv_mix->type, head_v_dim),
|
||||
nb1_qkv,
|
||||
nb1_qkv * n_seq_tokens,
|
||||
ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads));
|
||||
|
||||
cb(q_conv, "q_conv", il);
|
||||
ggml_tensor * k_conv =
|
||||
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
|
||||
head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
||||
cb(k_conv, "k_conv", il);
|
||||
ggml_tensor * v_conv =
|
||||
ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
|
||||
2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
||||
cb(v_conv, "v_conv", il);
|
||||
|
||||
// Unsqueeze them
|
||||
q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
|
||||
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
|
||||
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
|
||||
cb(state, "state_predelta", il);
|
||||
q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
|
||||
k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
|
||||
|
||||
// if head keys and value keys are different, repeat Q/K to match V's head count
|
||||
// V heads are in tiled order (from conversion), so simple tiled repeat works
|
||||
//q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
//k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// if head keys and value keys are different, repeat to force tensors into matching shapes
|
||||
if (num_k_heads != num_v_heads) {
|
||||
GGML_ASSERT(num_v_heads % num_k_heads == 0);
|
||||
// TODO: try to avoid these explicit repeats by utilizing op broadcast
|
||||
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
}
|
||||
@@ -689,7 +336,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
|
||||
if (n_seq_tokens == 1) {
|
||||
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
} else {
|
||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
|
||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
}
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
@@ -698,19 +345,15 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
|
||||
|
||||
// Update the recurrent states
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0, new_state,
|
||||
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
|
||||
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
|
||||
|
||||
// Reshape both attn_out_final and z to 2D tensors for normalization
|
||||
// attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
||||
ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
|
||||
ggml_cpy(ctx0, new_state,
|
||||
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
|
||||
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
|
||||
|
||||
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
||||
ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
|
||||
ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// Apply gated normalization: self.norm(core_attn_out, z)
|
||||
ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
|
||||
ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il);
|
||||
|
||||
// Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
|
||||
ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
|
||||
@@ -721,7 +364,8 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
|
||||
cb(cur, "linear_attn_out", il);
|
||||
|
||||
// Reshape back to original dimensions
|
||||
cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,10 +2,8 @@
|
||||
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#define CHUNK_SIZE 64
|
||||
|
||||
llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params), model(model) {
|
||||
llm_build_delta_net_base(params), model(model) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
@@ -25,17 +23,6 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
ggml_tensor * causal_mask =
|
||||
ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
|
||||
GGML_TRI_TYPE_LOWER);
|
||||
|
||||
ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
|
||||
ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity);
|
||||
|
||||
ggml_build_forward_expand(gf, causal_mask);
|
||||
ggml_build_forward_expand(gf, identity);
|
||||
ggml_build_forward_expand(gf, diag_mask);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
@@ -45,7 +32,7 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr
|
||||
// Determine layer type and build appropriate attention mechanism
|
||||
if (hparams.is_recurrent(il)) {
|
||||
// Linear attention layer (gated delta net)
|
||||
cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
|
||||
cur = build_layer_attn_linear(inp->get_recr(), cur, il);
|
||||
} else {
|
||||
// Full attention layer
|
||||
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
|
||||
@@ -95,362 +82,6 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
// utility to get one slice from the third dimension
|
||||
// input dim: [x, y, c, b]
|
||||
// output dim: [x, y, 1, b]
|
||||
static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) {
|
||||
return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3],
|
||||
t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c);
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35moe::build_delta_net_chunking(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il) {
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
const int64_t n_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
|
||||
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
|
||||
q = ggml_l2_norm(ctx0, q, eps_norm);
|
||||
k = ggml_l2_norm(ctx0, k, eps_norm);
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_v);
|
||||
|
||||
q = ggml_scale(ctx0, q, scale);
|
||||
|
||||
beta = ggml_sigmoid(ctx0, beta);
|
||||
|
||||
cb(q, "q_in", il);
|
||||
cb(k, "k_in", il);
|
||||
cb(v, "v_in", il);
|
||||
cb(beta, "beta_in", il);
|
||||
cb(g, "g_in", il);
|
||||
|
||||
q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
|
||||
g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs);
|
||||
|
||||
beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
|
||||
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
|
||||
|
||||
cb(q, "q_perm", il);
|
||||
cb(k, "k_perm", il);
|
||||
cb(v, "v_perm", il);
|
||||
cb(beta, "beta_perm", il);
|
||||
cb(g, "g_perm", il);
|
||||
cb(state, "state_in", il);
|
||||
|
||||
GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
|
||||
GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
|
||||
|
||||
// Do padding
|
||||
const int64_t chunk_size = CHUNK_SIZE;
|
||||
|
||||
const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
|
||||
const int64_t n_chunks = (n_tokens + pad) / chunk_size;
|
||||
|
||||
q = ggml_pad(ctx0, q, 0, pad, 0, 0);
|
||||
k = ggml_pad(ctx0, k, 0, pad, 0, 0);
|
||||
v = ggml_pad(ctx0, v, 0, pad, 0, 0);
|
||||
g = ggml_pad(ctx0, g, pad, 0, 0, 0);
|
||||
beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
|
||||
|
||||
cb(q, "q_pad", il);
|
||||
cb(k, "k_pad", il);
|
||||
cb(v, "v_pad", il);
|
||||
cb(beta, "beta_pad", il);
|
||||
cb(g, "g_pad", il);
|
||||
|
||||
ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
|
||||
ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
|
||||
|
||||
cb(v_beta, "v_beta", il);
|
||||
cb(k_beta, "k_beta", il);
|
||||
|
||||
q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs);
|
||||
k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs);
|
||||
k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs);
|
||||
v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs);
|
||||
v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs);
|
||||
|
||||
g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs);
|
||||
beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs);
|
||||
|
||||
ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g);
|
||||
cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs);
|
||||
ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs);
|
||||
|
||||
ggml_tensor * gcs_j_broadcast =
|
||||
ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs);
|
||||
|
||||
ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i);
|
||||
cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
|
||||
decay_mask = ggml_exp(ctx0, decay_mask);
|
||||
decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
|
||||
|
||||
ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta);
|
||||
|
||||
ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask);
|
||||
ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask));
|
||||
cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask);
|
||||
ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
|
||||
|
||||
ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false);
|
||||
attn = ggml_mul(ctx0, lin_solve, causal_mask);
|
||||
attn = ggml_add(ctx0, attn, identity);
|
||||
cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn);
|
||||
|
||||
ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum));
|
||||
ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t);
|
||||
|
||||
ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp);
|
||||
cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * k_cumdecay =
|
||||
ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp)))));
|
||||
cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q);
|
||||
attn_kq = ggml_mul(ctx0, attn_kq, decay_mask);
|
||||
attn_kq = ggml_mul(ctx0, attn_kq, diag_mask);
|
||||
cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
|
||||
// vectorized calculation of key_gdiff
|
||||
// improved from the chunked version:
|
||||
// g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1)
|
||||
// g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp()
|
||||
// key_gdiff = key * g_diff.unsqueeze(-1)
|
||||
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
|
||||
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
|
||||
|
||||
// get last element in g_cumsum along chunk_size dimension (ne0)
|
||||
// example: [[x, y, z, ..., last], ...] -> [[last], ...]
|
||||
ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3],
|
||||
g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3],
|
||||
(g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum));
|
||||
g_last = ggml_cont(ctx0, g_last);
|
||||
cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last);
|
||||
cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last));
|
||||
cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
|
||||
ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp,
|
||||
1, chunk_size, n_chunks, g_diff_exp->ne[3]);
|
||||
|
||||
ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
|
||||
cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
|
||||
|
||||
ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
|
||||
cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
|
||||
|
||||
|
||||
// state to be updated per chunk
|
||||
ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
|
||||
cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs)
|
||||
|
||||
// shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs)
|
||||
ggml_tensor * core_attn_out = nullptr;
|
||||
|
||||
for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
|
||||
// shape: (S_k, chunk_size, 1, H_k * n_seqs)
|
||||
ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul
|
||||
|
||||
// shape: (S_v, chunk_size, 1, H_v * n_seqs)
|
||||
ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat
|
||||
|
||||
// shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
|
||||
ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul
|
||||
|
||||
// shape: (chunk_size, 1, H_v * n_seqs)
|
||||
ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat
|
||||
|
||||
// attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
// replaced by precomputed attn_kq
|
||||
ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk);
|
||||
cb(attn_chunk, "attn_chunk", il);
|
||||
|
||||
ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
|
||||
|
||||
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||
ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
|
||||
cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs)
|
||||
|
||||
// v_new = v_i - v_prime
|
||||
ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime);
|
||||
ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
|
||||
cb(v_new, "v_new_chunk", il);
|
||||
|
||||
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk);
|
||||
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp);
|
||||
cb(attn_inter, "attn_inter_chunk", il);
|
||||
|
||||
// core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||
ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk);
|
||||
cb(v_attn, "v_attn_chunk", il);
|
||||
|
||||
ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
|
||||
cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs)
|
||||
|
||||
core_attn_out = core_attn_out == nullptr
|
||||
? core_attn_out_chunk
|
||||
: ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
|
||||
|
||||
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
|
||||
ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk);
|
||||
//ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
|
||||
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
|
||||
|
||||
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
|
||||
ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
|
||||
new_state = ggml_add(ctx0,
|
||||
ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)),
|
||||
ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
|
||||
}
|
||||
|
||||
// truncate padded tokens
|
||||
ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
|
||||
S_v, n_tokens, H_v, n_seqs,
|
||||
ggml_row_size(core_attn_out->type, S_v),
|
||||
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
|
||||
ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
|
||||
output_tokens = ggml_cont(ctx0, output_tokens);
|
||||
cb(output_tokens, "output_tokens", il);
|
||||
|
||||
// permute back to (S_v, H_v, n_tokens, n_seqs)
|
||||
output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
|
||||
output_tokens = ggml_cont(ctx0, output_tokens);
|
||||
|
||||
return {output_tokens, new_state};
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35moe::build_delta_net_autoregressive(
|
||||
ggml_tensor * q,
|
||||
ggml_tensor * k,
|
||||
ggml_tensor * v,
|
||||
ggml_tensor * g,
|
||||
ggml_tensor * beta,
|
||||
ggml_tensor * state,
|
||||
int il) {
|
||||
const int64_t S_k = q->ne[0];
|
||||
const int64_t H_k = q->ne[1];
|
||||
const int64_t n_tokens = q->ne[2];
|
||||
const int64_t n_seqs = q->ne[3];
|
||||
|
||||
const int64_t S_v = v->ne[0];
|
||||
const int64_t H_v = v->ne[1];
|
||||
|
||||
GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing
|
||||
GGML_ASSERT(v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(k->ne[2] == n_tokens);
|
||||
GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs);
|
||||
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
|
||||
GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
|
||||
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
|
||||
|
||||
GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case
|
||||
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
|
||||
q = ggml_l2_norm(ctx0, q, eps_norm);
|
||||
k = ggml_l2_norm(ctx0, k, eps_norm);
|
||||
|
||||
const float scale = 1.0f / sqrtf(S_v);
|
||||
|
||||
q = ggml_scale(ctx0, q, scale);
|
||||
beta = ggml_sigmoid(ctx0, beta);
|
||||
|
||||
cb(q, "q_in", il);
|
||||
cb(k, "k_in", il);
|
||||
cb(v, "v_in", il);
|
||||
cb(beta, "beta_in", il);
|
||||
cb(g, "g_in", il);
|
||||
|
||||
state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
|
||||
|
||||
ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs);
|
||||
ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
|
||||
|
||||
// Apply exponential to g_t
|
||||
g_t = ggml_exp(ctx0, g_t);
|
||||
|
||||
// Apply the gated delta rule for the single timestep
|
||||
// last_recurrent_state = last_recurrent_state * g_t
|
||||
state = ggml_mul(ctx0, state, g_t);
|
||||
|
||||
// kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
|
||||
ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs);
|
||||
ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed);
|
||||
// we need to sum over dim=-2, so we transpose, sum, then transpose again
|
||||
kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem))));
|
||||
|
||||
// v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v)
|
||||
ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
|
||||
// delta = (v_t - kv_mem) * beta_t
|
||||
ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs]
|
||||
ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t);
|
||||
|
||||
// last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta
|
||||
ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta);
|
||||
state = ggml_add(ctx0, state, k_t_delta);
|
||||
|
||||
// Compute the attention output
|
||||
// core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
|
||||
ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t
|
||||
ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed);
|
||||
// again, since it's over dim = -2, transpose, sum, transpose back
|
||||
ggml_tensor * core_attn_out =
|
||||
ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q))));
|
||||
|
||||
// core_attn_out should be [S_v, 1, H_v, n_seqs] after this
|
||||
cb(core_attn_out, "output_tokens", il);
|
||||
cb(state, "new_state", il);
|
||||
|
||||
return {core_attn_out, state};
|
||||
}
|
||||
|
||||
std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35moe::build_qkvz(
|
||||
ggml_tensor * input,
|
||||
int il) {
|
||||
@@ -562,9 +193,6 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn(
|
||||
ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * causal_mask,
|
||||
ggml_tensor * identity,
|
||||
ggml_tensor * diag_mask,
|
||||
int il) {
|
||||
const auto * mctx_cur = inp->mctx;
|
||||
|
||||
@@ -588,8 +216,11 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
|
||||
ggml_tensor * z = qkvz.second;
|
||||
|
||||
ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur);
|
||||
beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs);
|
||||
beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
|
||||
cb(beta, "beta", il);
|
||||
|
||||
beta = ggml_sigmoid(ctx0, beta);
|
||||
|
||||
ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur);
|
||||
alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
|
||||
cb(alpha, "alpha", il);
|
||||
@@ -597,15 +228,16 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
|
||||
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt);
|
||||
ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased);
|
||||
cb(alpha_softplus, "a_softplus", il);
|
||||
|
||||
ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus
|
||||
cb(gate, "gate", il);
|
||||
|
||||
gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// Get convolution states from cache
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||
|
||||
// bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state();
|
||||
|
||||
// Build the convolution states tensor
|
||||
ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs);
|
||||
cb(conv_states, "conv_states", il);
|
||||
@@ -614,11 +246,12 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
|
||||
ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d;
|
||||
const int64_t conv_kernel_size = conv_kernel->ne[0];
|
||||
const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state;
|
||||
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
|
||||
|
||||
conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs);
|
||||
cb(conv_states, "conv_states_reshaped", il);
|
||||
|
||||
qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3);
|
||||
cb(qkv_mixed, "qkv_mixed_permuted", il);
|
||||
qkv_mixed = ggml_transpose(ctx0, qkv_mixed);
|
||||
cb(qkv_mixed, "qkv_mixed_transposed", il);
|
||||
|
||||
ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0);
|
||||
cb(conv_input, "conv_input", il);
|
||||
@@ -638,7 +271,10 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
|
||||
cb(conv_states_all, "conv_states_updated", il);
|
||||
|
||||
// Apply SSM convolution
|
||||
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
|
||||
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);
|
||||
cb(state, "state_predelta", il);
|
||||
|
||||
ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel);
|
||||
cb(conv_output_proper, "conv_output_raw", il);
|
||||
|
||||
@@ -652,31 +288,41 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
|
||||
int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim);
|
||||
|
||||
// Extract the convolved Q, K, V from conv_output
|
||||
ggml_tensor * q_conv =
|
||||
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0);
|
||||
ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
|
||||
ggml_row_size(conv_qkv_mix->type, head_k_dim),
|
||||
nb1_qkv,
|
||||
nb1_qkv * n_seq_tokens,
|
||||
0);
|
||||
|
||||
ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs,
|
||||
ggml_row_size(conv_qkv_mix->type, head_k_dim),
|
||||
nb1_qkv,
|
||||
nb1_qkv * n_seq_tokens,
|
||||
head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
||||
|
||||
ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs,
|
||||
ggml_row_size(conv_qkv_mix->type, head_v_dim),
|
||||
nb1_qkv,
|
||||
nb1_qkv * n_seq_tokens,
|
||||
ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads));
|
||||
|
||||
cb(q_conv, "q_conv", il);
|
||||
ggml_tensor * k_conv =
|
||||
ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv,
|
||||
head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
||||
cb(k_conv, "k_conv", il);
|
||||
ggml_tensor * v_conv =
|
||||
ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv,
|
||||
2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix));
|
||||
cb(v_conv, "v_conv", il);
|
||||
|
||||
// Unsqueeze them
|
||||
q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
const float eps_norm = hparams.f_norm_rms_eps;
|
||||
|
||||
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
|
||||
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs);
|
||||
cb(state, "state_predelta", il);
|
||||
q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm);
|
||||
k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm);
|
||||
|
||||
// if head keys and value keys are different, repeat Q/K to match V's head count
|
||||
// V heads are in tiled order (from conversion), so simple tiled repeat works
|
||||
//q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
//k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs);
|
||||
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// if head keys and value keys are different, repeat to force tensors into matching shapes
|
||||
if (num_k_heads != num_v_heads) {
|
||||
GGML_ASSERT(num_v_heads % num_k_heads == 0);
|
||||
// TODO: try to avoid these explicit repeats by utilizing op broadcast
|
||||
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
}
|
||||
@@ -690,7 +336,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
|
||||
if (n_seq_tokens == 1) {
|
||||
attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
} else {
|
||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il);
|
||||
attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il);
|
||||
}
|
||||
ggml_tensor * output = attn_out.first;
|
||||
ggml_tensor * new_state = attn_out.second;
|
||||
@@ -699,19 +345,15 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
|
||||
|
||||
// Update the recurrent states
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0, new_state,
|
||||
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
|
||||
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
|
||||
|
||||
// Reshape both attn_out_final and z to 2D tensors for normalization
|
||||
// attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
||||
ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
|
||||
ggml_cpy(ctx0, new_state,
|
||||
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
|
||||
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
|
||||
|
||||
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
|
||||
ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs);
|
||||
ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// Apply gated normalization: self.norm(core_attn_out, z)
|
||||
ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il);
|
||||
ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il);
|
||||
|
||||
// Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
|
||||
ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs);
|
||||
@@ -722,7 +364,8 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
|
||||
cb(cur, "linear_attn_out", il);
|
||||
|
||||
// Reshape back to original dimensions
|
||||
cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs);
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
|
||||
@@ -306,8 +306,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
|
||||
ggml_tensor * beta = ggml_sigmoid(ctx0, b);
|
||||
|
||||
beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs);
|
||||
|
||||
// Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
|
||||
ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
@@ -318,6 +316,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
||||
ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus
|
||||
cb(gate, "gate", il);
|
||||
|
||||
beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
|
||||
gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs);
|
||||
|
||||
// Get convolution states from cache
|
||||
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
||||
|
||||
@@ -691,6 +691,48 @@ static void test_filters(testing & t) {
|
||||
"{\n \"a\": 1,\n \"b\": [\n 1,\n 2\n ]\n}"
|
||||
);
|
||||
|
||||
test_template(t, "indent",
|
||||
"{{ data|indent(2) }}",
|
||||
{{ "data", "foo\nbar" }},
|
||||
"foo\n bar"
|
||||
);
|
||||
|
||||
test_template(t, "indent first only",
|
||||
"{{ data|indent(width=3,first=true) }}",
|
||||
{{ "data", "foo\nbar" }},
|
||||
" foo\n bar"
|
||||
);
|
||||
|
||||
test_template(t, "indent blank lines and first line",
|
||||
"{{ data|indent(width=5,blank=true,first=true) }}",
|
||||
{{ "data", "foo\n\nbar" }},
|
||||
" foo\n \n bar"
|
||||
);
|
||||
|
||||
test_template(t, "indent with default width",
|
||||
"{{ data|indent() }}",
|
||||
{{ "data", "foo\nbar" }},
|
||||
"foo\n bar"
|
||||
);
|
||||
|
||||
test_template(t, "indent with no newline",
|
||||
"{{ data|indent }}",
|
||||
{{ "data", "foo" }},
|
||||
"foo"
|
||||
);
|
||||
|
||||
test_template(t, "indent with trailing newline",
|
||||
"{{ data|indent(blank=true) }}",
|
||||
{{ "data", "foo\n" }},
|
||||
"foo\n "
|
||||
);
|
||||
|
||||
test_template(t, "indent with string",
|
||||
"{{ data|indent(width='>>>>') }}",
|
||||
{{ "data", "foo\nbar" }},
|
||||
"foo\n>>>>bar"
|
||||
);
|
||||
|
||||
test_template(t, "chained filters",
|
||||
"{{ ' HELLO '|trim|lower }}",
|
||||
json::object(),
|
||||
|
||||
@@ -628,9 +628,6 @@ ggml_tensor * clip_graph::build_attn(
|
||||
ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
|
||||
v = ggml_cont(ctx0, v);
|
||||
|
||||
const auto n_tokens = q->ne[1];
|
||||
const auto n_head = q->ne[2];
|
||||
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
// F32 may not needed for vision encoders?
|
||||
// ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||
@@ -639,7 +636,7 @@ ggml_tensor * clip_graph::build_attn(
|
||||
|
||||
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
||||
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2] * cur->ne[3]);
|
||||
}
|
||||
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
@@ -175,7 +175,7 @@ struct mtmd_context {
|
||||
|
||||
clip_context_params ctx_clip_params {
|
||||
/* use_gpu */ ctx_params.use_gpu,
|
||||
/* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO,
|
||||
/* flash_attn_type */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type),
|
||||
/* image_min_tokens */ ctx_params.image_min_tokens,
|
||||
/* image_max_tokens */ ctx_params.image_max_tokens,
|
||||
/* warmup */ ctx_params.warmup,
|
||||
|
||||
@@ -28,6 +28,14 @@ if [ "${1:-}" = "huge" ]; then
|
||||
echo "Include BIG and HUGE models..."
|
||||
fi
|
||||
|
||||
# Check if the second argument is "flash", then enable flash attention
|
||||
# This is useful to test if flash attention off works correctly
|
||||
FLASH_ATTN="on"
|
||||
if [ "${2:-}" = "flash_off" ] || [ "${1:-}" = "flash_off" ]; then
|
||||
FLASH_ATTN="off"
|
||||
echo "Flash attention disabled..."
|
||||
fi
|
||||
|
||||
###############
|
||||
|
||||
arr_prefix=()
|
||||
@@ -143,6 +151,7 @@ for i in "${!arr_hf[@]}"; do
|
||||
-hf $(printf %q "$hf") \
|
||||
--image $(printf %q "$SCRIPT_DIR/$inp_file") \
|
||||
--temp 0 -n 128 \
|
||||
--flash-attn $(printf %q "$FLASH_ATTN") \
|
||||
${extra_args}"
|
||||
|
||||
# if extra_args does not contain -p, we add a default prompt
|
||||
|
||||
Binary file not shown.
@@ -916,8 +916,7 @@ json oaicompat_chat_params_parse(
|
||||
json image_url = json_value(p, "image_url", json::object());
|
||||
handle_media(out_files, image_url, opt.media_path);
|
||||
|
||||
// replace this chunk with a marker
|
||||
p["type"] = "text";
|
||||
p["type"] = "media_marker";
|
||||
p["text"] = mtmd_default_marker();
|
||||
p.erase("image_url");
|
||||
|
||||
@@ -938,8 +937,7 @@ json oaicompat_chat_params_parse(
|
||||
|
||||
// TODO: add audio_url support by reusing handle_media()
|
||||
|
||||
// replace this chunk with a marker
|
||||
p["type"] = "text";
|
||||
p["type"] = "media_marker";
|
||||
p["text"] = mtmd_default_marker();
|
||||
p.erase("input_audio");
|
||||
|
||||
|
||||
@@ -498,7 +498,8 @@ class ChatStore {
|
||||
MessageRole.USER,
|
||||
content,
|
||||
MessageType.TEXT,
|
||||
parentIdForUserMessage ?? '-1'
|
||||
parentIdForUserMessage ?? '-1',
|
||||
extras
|
||||
);
|
||||
if (isNewConversation && content)
|
||||
await conversationsStore.updateConversationName(currentConv.id, content.trim());
|
||||
|
||||
Reference in New Issue
Block a user