Compare commits

...

6 Commits

Author SHA1 Message Date
Vaibhavs10
996195299e up.
Some checks failed
Python check requirements.txt / check-requirements (push) Has been cancelled
Python Type-Check / pyright type-check (push) Has been cancelled
2025-07-07 23:42:40 +02:00
Vaibhavs10
97c64a0974 up. 2025-07-04 14:15:34 +02:00
Vaibhavs10
6201b43814 Update the graph. 2025-06-19 17:13:28 +02:00
Vaibhavs10
02ff085071 fix errors in conversion. 2025-06-17 16:01:53 +02:00
Vaibhavs10
32ea9c5fc1 Model -> ModelBase. 2025-06-17 15:09:15 +02:00
Vaibhavs10
024bd29445 Init - first pass. 2025-06-17 15:03:34 +02:00
7 changed files with 214 additions and 9 deletions

View File

@@ -6298,6 +6298,17 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel):
super().set_gguf_parameters()
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
@ModelBase.register("SmolLM3ForCausalLM")
class SmolLM3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.SMOLLM3
def set_gguf_parameters(self):
super().set_gguf_parameters()
no_rope_layer_interval = self.hparams.get("no_rope_layer_interval")
if no_rope_layer_interval is not None:
self.gguf_writer.add_uint32("no_rope_layer_interval", no_rope_layer_interval)
###### CONVERSION LOGIC ######

View File

@@ -83,20 +83,22 @@ NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the conv
### 2. Define the model architecture in `llama.cpp`
The model params and tensors layout must be defined in `llama.cpp`:
1. Define a new `llm_arch`
2. Define the tensors layout in `LLM_TENSOR_NAMES`
3. Add any non-standard metadata in `llm_load_hparams`
4. Create the tensors for inference in `llm_load_tensors`
5. If the model has a RoPE operation, add the rope type in `llama_rope_type`
The model params and tensors layout must be defined in `llama.cpp` source files:
1. Define a new `llm_arch` enum value in `src/llama-arch.h`.
2. In `src/llama-arch.cpp`:
- Add the architecture name to the `LLM_ARCH_NAMES` map.
- Add the tensor mappings to the `LLM_TENSOR_NAMES` map.
3. Add any non-standard metadata loading in the `llama_model_loader` constructor in `src/llama-model-loader.cpp`.
4. If the model has a RoPE operation, add a case for the architecture in `llama_model_rope_type` function in `src/llama-model.cpp`.
NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorch` dimensions.
### 3. Build the GGML graph implementation
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.
Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`.
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `src/llama-model.cpp`.
Create a new struct that inherits from `llm_graph_context` and implement the graph-building logic in its constructor.
Have a look at existing implementations like `llm_build_llama`, `llm_build_dbrx` or `llm_build_bert`.
Then, in the `llama_model::build_graph` method, add a case for your architecture to instantiate your new graph-building struct.
Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR.

View File

@@ -346,6 +346,7 @@ class MODEL_ARCH(IntEnum):
BAILINGMOE = auto()
DOTS1 = auto()
ARCEE = auto()
SMOLLM3 = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -629,6 +630,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.BAILINGMOE: "bailingmoe",
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.SMOLLM3: "smollm3",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2101,6 +2103,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.SMOLLM3: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO
}

View File

@@ -75,6 +75,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@@ -1625,6 +1626,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
},
},
{
LLM_ARCH_SMOLLM3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd.weight" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm.weight" },
{ LLM_TENSOR_OUTPUT, "output.weight" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm.weight" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q.weight" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k.weight" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v.weight" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output.weight" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate.weight" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down.weight" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up.weight" },
},
},
};
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {

View File

@@ -79,6 +79,7 @@ enum llm_arch {
LLM_ARCH_BAILINGMOE,
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_SMOLLM3,
LLM_ARCH_UNKNOWN,
};

View File

@@ -186,6 +186,9 @@ struct llama_hparams {
// dimension of the recurrent state embeddings
uint32_t n_embd_v_s() const;
// for NoPE interval
uint32_t no_rope_layer_interval = 0;
bool is_swa(uint32_t il) const;
};

View File

@@ -443,6 +443,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
return;
}
if (arch == LLM_ARCH_SMOLLM3) {
ml.get_key("no_rope_layer_interval", hparams.no_rope_layer_interval);
}
ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer);
@@ -13734,6 +13738,147 @@ struct llm_build_arcee : public llm_graph_context {
}
};
struct llm_build_smollm3 : public llm_graph_context {
llm_build_smollm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
const uint32_t interval = hparams.no_rope_layer_interval;
// token embeddings
ggml_tensor * inpL = build_inp_embd(model.tok_embd);
// positional ids
ggml_tensor * inp_pos = build_inp_pos();
// attention helper (unified KV cache)
auto * inp_attn = build_attn_inp_kv_unified();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
ggml_tensor * cur = nullptr;
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
// attention norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// ---- self-attention ----
{
// fused QKV projection
ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur);
cb(qkv, "wqkv", il);
if (model.layers[il].bqkv) {
qkv = ggml_add(ctx0, qkv, model.layers[il].bqkv);
cb(qkv, "bqkv", il);
}
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0));
ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], sizeof(float)*(n_embd)));
ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], sizeof(float)*(n_embd + n_embd_gqa)));
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
if (interval == 0 || il % interval != 0) {
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
}
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
cur = build_attn(inp_attn, gf,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
}
// skip padded tokens for final layer
if (il == n_layer - 1) {
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// ---- feed-forward ----
if (hparams.use_par_res) {
// parallel residual
ggml_tensor * ffn_cur = build_norm(inpL,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(ffn_cur, "ffn_norm", il);
ffn_cur = build_ffn(
ffn_cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, nullptr,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(ffn_cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_cur);
cb(cur, "par_res", il);
} else {
// sequential residual
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(
cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, nullptr,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr,
nullptr,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
}
// post-processing
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
inpL = cur;
}
// final RMSNorm
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
};
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
llama_memory_i * res;
@@ -14085,6 +14230,10 @@ llm_graph_result_ptr llama_model::build_graph(
{
llm = std::make_unique<llm_build_arcee>(*this, params, gf);
} break;
case LLM_ARCH_SMOLLM3:
{
llm = std::make_unique<llm_build_smollm3>(*this, params, gf);
} break;
default:
GGML_ABORT("fatal error");
}
@@ -14235,9 +14384,11 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_CHAMELEON:
case LLM_ARCH_BAILINGMOE:
case LLM_ARCH_NEO_BERT:
case LLM_ARCH_SMOLLM3:
case LLM_ARCH_ARCEE:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2
case LLM_ARCH_FALCON:
case LLM_ARCH_GROK: