Compare commits

..

36 Commits

Author SHA1 Message Date
Georgi Gerganov
7a3c178d78 speculative : adapt to new llama API
ggml-ci
2025-03-18 22:05:44 +02:00
Xuan Son Nguyen
dc4bb64290 Merge branch 'master' into xsn/private_batch_api 2025-03-18 15:45:22 +01:00
Xuan-Son Nguyen
eab5606d7b Apply suggestions from code review 2025-03-17 12:17:14 +01:00
Xuan-Son Nguyen
de788e071b Update examples/tts/tts.cpp
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2025-03-17 12:05:23 +01:00
Xuan Son Nguyen
624a683c6f fix compile 2025-03-14 22:30:29 +01:00
Xuan Son Nguyen
116b9a1662 rename to init_from_text 2025-03-14 22:17:07 +01:00
Xuan Son Nguyen
eaffba0f2e llama_batch_ext_ptr::from_text/embd 2025-03-14 17:12:03 +01:00
Xuan Son Nguyen
8e7714fa77 fix compile 2025-03-14 11:28:15 +01:00
Xuan Son Nguyen
a363251fac qwen2vl: use llama_batch_ext_set_pos 2025-03-14 11:25:36 +01:00
Xuan Son Nguyen
ba79369615 fix llama_batch_ext_init_from_embd 2025-03-14 11:17:22 +01:00
Xuan Son Nguyen
07d84fa3c2 fix missing n_past in various places
this is actually a revert of cda0e4b648
2025-03-14 10:47:08 +01:00
Xuan Son Nguyen
32940369d3 fix gemma3-cli 2025-03-14 10:33:28 +01:00
Xuan Son Nguyen
5e6a6d4e1c fix llama-run n_past 2025-03-14 10:32:43 +01:00
Xuan Son Nguyen
bfdddbc150 bring back mistakenly deleted llama_batch_init/free 2025-03-14 00:22:28 +01:00
Xuan Son Nguyen
54566ad95d correct comment 2025-03-14 00:21:06 +01:00
Xuan Son Nguyen
04f8641815 rm redundant llama_batch_ext_set_output_last 2025-03-13 23:14:16 +01:00
Xuan Son Nguyen
c3dd79007b fix llama_batch_ext_init_from_text 2025-03-13 23:09:27 +01:00
Xuan Son Nguyen
65f0184517 compile ok 2025-03-13 22:56:35 +01:00
Xuan Son Nguyen
9fb2d81eab fix common_batch missing seq_id 2025-03-13 22:38:04 +01:00
Xuan Son Nguyen
47086fa82d apply to the rest 2025-03-13 22:36:27 +01:00
Xuan Son Nguyen
4aabf4e8f4 return output ID from llama_batch_ext_add/set 2025-03-13 17:47:07 +01:00
Xuan Son Nguyen
86973cb14a fix merge errors 2025-03-13 17:32:36 +01:00
Xuan Son Nguyen
17f954c8e2 Merge branch 'master' into xsn/private_batch_api 2025-03-13 15:55:18 +01:00
Xuan Son Nguyen
46596caf6d apply various in places 2025-03-01 20:42:18 +01:00
Xuan Son Nguyen
1d6ba97789 remove token_info API 2025-03-01 16:21:16 +01:00
Xuan Son Nguyen
1170135dfb llama_batch_ext_add_text 2025-03-01 14:00:14 +01:00
Xuan Son Nguyen
40989f4116 correct llama_decode_ext 2025-03-01 14:00:05 +01:00
Xuan Son Nguyen
9e75c49d35 Merge branch 'master' into xsn/private_batch_api 2025-03-01 12:13:03 +01:00
Xuan Son Nguyen
f0ffd81130 adapt common 2025-03-01 12:12:52 +01:00
Xuan Son Nguyen
a1b1dea33b Merge branch 'master' into xsn/private_batch_api 2025-02-24 17:01:30 +01:00
Xuan Son Nguyen
4bf7ca3943 llama_decode_ext 2025-02-24 17:01:20 +01:00
Xuan Son Nguyen
aed4a8e980 fix server 2025-02-16 11:36:50 +01:00
Xuan Son Nguyen
85ef80cbe9 server : use llama_batch_ext 2025-02-16 00:06:48 +01:00
Xuan Son Nguyen
17d3658b5f move to llama_batch_ext 2025-02-16 00:02:53 +01:00
Xuan Son Nguyen
f2e59a8eb9 rework, targeting llama-server 2025-02-14 18:16:49 +01:00
Xuan Son Nguyen
4ed4fe75ed first proposal for private llama_batch 2025-02-14 00:48:12 +01:00
88 changed files with 1854 additions and 3610 deletions

View File

@@ -676,35 +676,6 @@ jobs:
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
macOS-latest-cmake-visionos:
runs-on: macos-latest
steps:
- name: Clone
id: checkout
uses: actions/checkout@v4
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
- name: Build
id: cmake_build
run: |
sysctl -a
cmake -B build -G Xcode \
-DGGML_METAL_USE_BF16=ON \
-DGGML_METAL_EMBED_LIBRARY=ON \
-DLLAMA_BUILD_EXAMPLES=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_DEPLOYMENT_TARGET=1.0 \
-DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO
macOS-latest-swift:
runs-on: macos-latest

View File

@@ -432,8 +432,8 @@ cmake -B build-visionos -G Xcode \
-DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_SYSROOT=xros \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
-S .
cmake --build build-visionos --config Release -- -quiet
@@ -445,8 +445,8 @@ cmake -B build-visionos-sim -G Xcode \
-DCMAKE_SYSTEM_NAME=visionOS \
-DCMAKE_OSX_SYSROOT=xrsimulator \
-DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \
-DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 -Du_int=unsigned\ int -Du_char=unsigned\ char -Du_short=unsigned\ short ${COMMON_CXX_FLAGS}" \
-S .
cmake --build build-visionos-sim --config Release -- -quiet

View File

@@ -582,41 +582,6 @@ std::string string_from(const struct llama_context * ctx, const std::vector<llam
return buf.str();
}
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) {
std::stringstream buf;
buf << "[ ";
bool first = true;
for (int i = 0; i < batch.n_tokens; ++i) {
if (!first) {
buf << ", ";
} else {
first = false;
}
auto detokenized = common_token_to_piece(ctx, batch.token[i]);
detokenized.erase(
std::remove_if(
detokenized.begin(),
detokenized.end(),
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());
buf << "\n" << std::to_string(i)
<< ", token '" << detokenized << "'"
<< ", pos " << std::to_string(batch.pos[i])
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
<< ", logits " << std::to_string(batch.logits[i]);
}
buf << " ]";
return buf.str();
}
void string_process_escapes(std::string & input) {
std::size_t input_len = input.length();
std::size_t output_idx = 0;
@@ -1051,7 +1016,8 @@ struct common_init_result common_init_from_params(common_params & params) {
}
if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0, true));
llama_encode_ext(lctx, batch.get());
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
decoder_start_token_id = bos;
@@ -1060,7 +1026,8 @@ struct common_init_result common_init_from_params(common_params & params) {
tmp.push_back(decoder_start_token_id);
}
if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true));
llama_decode_ext(lctx, batch.get());
}
llama_kv_self_clear(lctx);
llama_synchronize(lctx);
@@ -1613,10 +1580,12 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string &, cons
// Batch utils
//
// DEPRECATED
void common_batch_clear(struct llama_batch & batch) {
batch.n_tokens = 0;
}
// DEPRECATED
void common_batch_add(
struct llama_batch & batch,
llama_token id,

View File

@@ -516,7 +516,6 @@ void string_process_escapes(std::string & input);
std::string string_from(bool value);
std::string string_from(const std::vector<int> & values);
std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);
//
// Filesystem utils
@@ -570,8 +569,10 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adap
// Batch utils
//
// DEPRECATED
void common_batch_clear(struct llama_batch & batch);
// DEPRECATED
void common_batch_add(
struct llama_batch & batch,
llama_token id,
@@ -579,6 +580,66 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);
// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions
// this is meant to be temporary
struct common_batch {
llama_batch_ext_ptr batch;
struct batch_token {
llama_token token;
llama_seq_id seq_id; // only support single seq for now
bool logits;
};
std::vector<batch_token> tokens;
int n_outputs = 0;
common_batch() = default;
common_batch(int32_t n_tokens, int32_t n_seq_max) {
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
tokens.reserve(n_tokens);
}
void clear() {
llama_batch_ext_clear(batch.get());
tokens.clear();
}
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
tokens.push_back({token, seq_id, logits});
if (logits) {
n_outputs++;
}
}
void add_text_multi_seq(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits);
tokens.push_back({token, seq_ids[0], logits});
if (logits) {
n_outputs++;
}
}
void set_logits_last() {
if (!tokens.empty()) {
llama_batch_ext_set_output_last(batch.get());
tokens.back().logits = true;
}
}
int32_t get_n_tokens() const {
return (int32_t)tokens.size();
}
llama_batch_ext * get() {
return batch.get();
}
common_batch get_view(int32_t offset, int32_t n_tokens) {
common_batch view;
view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens));
view.tokens.reserve(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) {
view.tokens.push_back(tokens[offset + i]);
if (tokens[offset + i].logits) {
view.n_outputs++;
}
}
return view;
}
};
//
// Token utils
//

View File

@@ -14,7 +14,7 @@ struct common_speculative {
struct llama_context * ctx;
struct common_sampler * smpl;
llama_batch batch;
llama_batch_ext_ptr batch;
llama_tokens prompt;
};
@@ -23,7 +23,7 @@ struct common_speculative * common_speculative_init(
auto * result = new common_speculative {
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .batch = */ llama_batch_ext_ptr(llama_batch_ext_init(llama_n_batch(ctx_dft), 1)),
/* .prompt = */ {},
};
@@ -69,8 +69,6 @@ void common_speculative_free(struct common_speculative * spec) {
common_sampler_free(spec->smpl);
llama_batch_free(spec->batch);
delete spec;
}
@@ -151,6 +149,8 @@ llama_tokens common_speculative_gen_draft(
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
const llama_seq_id seq_id = 0;
// reuse as much as possible from the old draft context
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
for (int i = 0; i < (int) prompt.size(); ++i) {
@@ -206,40 +206,40 @@ llama_tokens common_speculative_gen_draft(
}
// prepare a batch to evaluate any new tokens in the prompt
common_batch_clear(batch);
llama_batch_ext_clear(batch.get());
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
llama_batch_ext_add_text(batch.get(), prompt_tgt[i], i - i_start, &seq_id, 1, false);
prompt.push_back(prompt_tgt[i]);
}
// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
if (llama_batch_ext_get_n_tokens(batch.get()) > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
llama_decode(ctx, batch);
llama_decode_ext(ctx, batch.get());
}
const llama_pos n_past = prompt.size();
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
common_batch_clear(batch);
common_batch_add (batch, id_last, n_past, { 0 }, true);
llama_batch_ext_clear(batch.get());
llama_batch_ext_add_text(batch.get(), id_last, n_past, &seq_id, 1, true);
prompt.push_back(id_last);
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
llama_decode(ctx, batch);
llama_decode_ext(ctx, batch.get());
common_sampler_reset(smpl);
// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch);
llama_batch_ext_clear(batch.get());
common_sampler_sample(smpl, ctx, 0, true);
@@ -266,10 +266,10 @@ llama_tokens common_speculative_gen_draft(
break;
}
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
llama_batch_ext_add_text(batch.get(), id, n_past + i + 1, &seq_id, 1, true);
// evaluate the drafted tokens on the draft model
llama_decode(ctx, batch);
llama_decode_ext(ctx, batch.get());
prompt.push_back(id);
}

View File

@@ -180,8 +180,7 @@ class Model:
extra = sorted(tensor_names_from_parts.difference(self.tensor_names))
missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map))
if len(extra) == 0 and len(missing_files) > 0:
raise ValueError(f"Missing or incomplete model files: {missing_files}\n"
f"Missing tensors: {missing}")
raise ValueError(f"Missing or incomplete model files: {missing_files}")
else:
raise ValueError("Mismatch between weight map and model parts for tensor names:\n"
f"Missing tensors: {missing}\n"
@@ -529,8 +528,6 @@ class Model:
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
added_vocab = tokenizer.get_added_vocab()
added_tokens_decoder = tokenizer.added_tokens_decoder
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
@@ -540,13 +537,13 @@ class Model:
if token in added_vocab:
# The tokenizer in llama.cpp assumes the CONTROL and USER_DEFINED tokens are pre-normalized.
# To avoid unexpected issues - we make sure to normalize non-normalized tokens
if not added_tokens_decoder[i].normalized:
if not tokenizer.added_tokens_decoder[i].normalized:
previous_token = token
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
if previous_token != token:
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
if added_tokens_decoder[i].special or self.does_token_look_special(token):
if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token):
toktypes.append(gguf.TokenType.CONTROL)
else:
# NOTE: this was added for Gemma.
@@ -1102,6 +1099,13 @@ class BloomModel(Model):
tensors.append((self.map_tensor_name(name), data_torch))
if name == "word_embeddings.weight":
assert self.tensor_names is not None
# TODO: tie them at runtime, don't duplicate in the model file
if all(s not in self.tensor_names for s in ("lm_head.weight", "output.weight")):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
return tensors
@@ -1743,25 +1747,6 @@ class LlamaModel(Model):
raise ValueError(f"Unprocessed experts: {experts}")
@Model.register("Mistral3ForConditionalGeneration")
class Mistral3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA
# we need to merge the text_config into the root level of hparams
def __init__(self, *args, **kwargs):
hparams = Model.load_hparams(kwargs["dir_model"])
if "text_config" in hparams:
hparams = {**hparams, **hparams["text_config"]}
kwargs["hparams"] = hparams
super().__init__(*args, **kwargs)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
name = name.replace("language_model.", "")
if "multi_modal_projector" in name or "vision_tower" in name:
return []
return super().modify_tensors(data_torch, name, bid)
@Model.register("DeciLMForCausalLM")
class DeciModel(Model):
model_arch = gguf.MODEL_ARCH.DECI
@@ -2419,6 +2404,10 @@ class GPT2Model(Model):
tensors.append((new_name, data_torch))
# note: GPT2 output is tied to (same as) wte in original model
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
return tensors
@@ -2748,26 +2737,21 @@ class CodeShellModel(Model):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(1.0)
_has_tok_embd = False
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD)
new_name = self.map_tensor_name(name)
# assuming token_embd.weight is seen before output.weight
if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
# even though the tensor file(s) does not contain the word embeddings they are still in the weight map
if self.tensor_names and "transformer.wte.weight" in self.tensor_names:
logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied")
self.tensor_names.remove("transformer.wte.weight")
elif new_name == tok_embd_name:
self._has_tok_embd = True
tensors: list[tuple[str, Tensor]] = [(new_name, data_torch)]
return [(new_name, data_torch)]
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
assert self.tensor_names is not None
if all(s not in self.tensor_names for s in ("lm_head.weight", "output.weight")):
# copy tok_embd.weight to output.weight
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch))
return tensors
@Model.register("InternLM2ForCausalLM")

View File

@@ -237,15 +237,6 @@ cmake -B buildWithCublas -DCMAKE_CXX_COMPILER=icpx -DCMAKE_C_COMPILER=icx -DENAB
cmake --build buildWithCublas --config Release
```
**oneDNN**: The current oneDNN releases *(shipped with the oneAPI base-toolkit)* do not include the NVIDIA backend. Therefore, oneDNN must be compiled from source to enable the NVIDIA target:
```sh
git clone https://github.com/oneapi-src/oneDNN.git
cd oneDNN
cmake -GNinja -Bbuild-nvidia -DDNNL_CPU_RUNTIME=DPCPP -DDNNL_GPU_RUNTIME=DPCPP -DDNNL_GPU_VENDOR=NVIDIA -DONEDNN_BUILD_GRAPH=OFF -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
cmake --build build-nvidia --config Release
```
- **Adding support to AMD GPUs**
**oneAPI Plugin**: In order to enable SYCL support on AMD GPUs, please install the [Codeplay oneAPI Plugin for AMD GPUs](https://developer.codeplay.com/products/oneapi/amd/download). As with Nvidia GPUs, the user should also make sure the plugin version matches the installed base toolkit.
@@ -336,10 +327,10 @@ export CPLUS_INCLUDE_DIR=/path/to/oneMKL/include:$CPLUS_INCLUDE_DIR
GGML_SYCL_DEVICE_ARCH=sm_80 # Example architecture
# Option 1: Use FP32 (recommended for better performance in most cases)
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
# Option 2: Use FP16
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON
# build all binary
cmake --build build --config Release -j -v

View File

@@ -9,13 +9,6 @@ brew install llama.cpp
```
The formula is automatically updated with new `llama.cpp` releases. More info: https://github.com/ggml-org/llama.cpp/discussions/7668
## MacPorts
```sh
sudo port install llama.cpp
```
see also: https://ports.macports.org/port/llama.cpp/details/
## Nix
On Mac and Linux, the Nix package manager can be used via

View File

@@ -59,24 +59,17 @@ int main(int argc, char ** argv) {
const int32_t n_kv_max = llama_n_ctx(ctx);
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(n_kv_max, 1);
// decode in batches of ctx_params.n_batch tokens
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) {
const int32_t n_batch_tokens = llama_batch_ext_get_n_tokens(batch);
for (int32_t i = 0; i < (int32_t) n_batch_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (n_batch_tokens - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
llama_batch_ext_ptr batch_view = llama_batch_ext_ptr(llama_batch_ext_get_view(batch, i, n_tokens));
const int ret = llama_decode(ctx, batch_view);
const int ret = llama_decode_ext(ctx, batch_view.get());
if (ret != 0) {
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
@@ -91,7 +84,8 @@ int main(int argc, char ** argv) {
// warm up
{
for (int i = 0; i < 16; ++i) {
common_batch_add(batch, 0, i, { 0 }, false);
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false);
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
@@ -121,14 +115,14 @@ int main(int argc, char ** argv) {
continue;
}
common_batch_clear(batch);
llama_batch_ext_clear(batch);
for (int i = 0; i < pp; ++i) {
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
common_batch_add(batch, 0, i, { j }, false);
llama_batch_ext_add_text(batch, 0, i, &j, 1, false);
}
}
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch);
const auto t_pp_start = ggml_time_us();
@@ -150,10 +144,10 @@ int main(int argc, char ** argv) {
const auto t_tg_start = ggml_time_us();
for (int i = 0; i < tg; ++i) {
common_batch_clear(batch);
llama_batch_ext_clear(batch);
for (int j = 0; j < pl; ++j) {
common_batch_add(batch, 0, pp + i, { j }, true);
llama_batch_ext_add_text(batch, 0, pp + i, &j, 1, true);
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
@@ -191,7 +185,7 @@ int main(int argc, char ** argv) {
LOG("\n");
llama_perf_context_print(ctx);
llama_batch_free(batch);
llama_batch_ext_free(batch);
llama_free(ctx);
llama_model_free(model);

View File

@@ -102,7 +102,7 @@ int main(int argc, char ** argv) {
// create a llama_batch
// we use this object to submit token data for decoding
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel);
llama_batch_ext * batch = llama_batch_ext_init(std::max(tokens_list.size(), (size_t) n_parallel), n_parallel);
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
for (int32_t i = 0; i < n_parallel; ++i) {
@@ -111,12 +111,12 @@ int main(int argc, char ** argv) {
// evaluate the initial prompt
for (size_t i = 0; i < tokens_list.size(); ++i) {
common_batch_add(batch, tokens_list[i], i, seq_ids, false);
llama_batch_ext_add_text(batch, tokens_list[i], i, seq_ids.data(), seq_ids.size(), false);
}
GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) tokens_list.size());
if (llama_model_has_encoder(model)) {
if (llama_encode(ctx, batch)) {
if (llama_encode_ext(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}
@@ -126,14 +126,14 @@ int main(int argc, char ** argv) {
decoder_start_token_id = llama_vocab_bos(vocab);
}
common_batch_clear(batch);
common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false);
llama_batch_ext_clear(batch);
llama_batch_ext_add_text(batch, decoder_start_token_id, 0, seq_ids.data(), seq_ids.size(), false);
}
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch);
if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
@@ -155,16 +155,16 @@ int main(int argc, char ** argv) {
// remember the batch index of the last token for each parallel sequence
// we need this to determine which logits to sample from
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
std::vector<int32_t> i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1);
int n_cur = batch.n_tokens;
int n_cur = llama_batch_ext_get_n_tokens(batch);
int n_decode = 0;
const auto t_main_start = ggml_time_us();
while (n_cur <= n_predict) {
// prepare the next batch
common_batch_clear(batch);
llama_batch_ext_clear(batch);
// sample the next token for each parallel sequence / stream
for (int32_t i = 0; i < n_parallel; ++i) {
@@ -193,23 +193,23 @@ int main(int argc, char ** argv) {
streams[i] += common_token_to_piece(ctx, new_token_id);
i_batch[i] = batch.n_tokens;
i_batch[i] = llama_batch_ext_get_n_tokens(batch);
// push this new token for next evaluation
common_batch_add(batch, new_token_id, n_cur, { i }, true);
llama_batch_ext_add_text(batch, new_token_id, n_cur, &i, 1, true);
n_decode += 1;
}
// all streams are finished
if (batch.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch) == 0) {
break;
}
n_cur += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch)) {
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
@@ -234,7 +234,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n");
llama_batch_free(batch);
llama_batch_ext_free(batch);
llama_sampler_free(smpl);
llama_free(ctx);

View File

@@ -343,7 +343,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
llama_kv_self_clear(ctx);
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true);
if (llama_decode_ext(ctx, batch.get())) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}

View File

@@ -26,14 +26,14 @@ static std::vector<std::string> split_lines(const std::string & s, const std::st
return lines;
}
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
size_t n_tokens = tokens.size();
for (size_t i = 0; i < n_tokens; i++) {
common_batch_add(batch, tokens[i], i, { seq_id }, true);
batch.add_text(tokens[i], i, seq_id, true);
}
}
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
const struct llama_model * model = llama_get_model(ctx);
@@ -41,21 +41,21 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
llama_kv_self_clear(ctx);
// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq);
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
// encoder-only model
if (llama_encode(ctx, batch) < 0) {
if (llama_encode_ext(ctx, batch.get()) < 0) {
LOG_ERR("%s : failed to encode\n", __func__);
}
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
// decoder-only model
if (llama_decode(ctx, batch) < 0) {
if (llama_decode_ext(ctx, batch.get()) < 0) {
LOG_ERR("%s : failed to decode\n", __func__);
}
}
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) {
if (!batch.tokens[i].logits) {
continue;
}
@@ -69,8 +69,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
} else {
// try to get sequence embeddings - supported only when pooling_type is not NONE
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
embd_pos = batch.seq_id[i][0];
embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id);
embd_pos = batch.tokens[i].seq_id;
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
}
@@ -171,7 +171,7 @@ int main(int argc, char ** argv) {
// initialize batch
const int n_prompts = prompts.size();
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
struct common_batch batch = common_batch(n_batch, 1);
// count number of embeddings
int n_embd_count = 0;
@@ -198,12 +198,12 @@ int main(int argc, char ** argv) {
const uint64_t n_toks = inp.size();
// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
if (batch.get_n_tokens() + n_toks > n_batch) {
float * out = emb + e * n_embd;
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.get_n_tokens() : s;
s = 0;
common_batch_clear(batch);
batch.clear();
}
// add to batch
@@ -319,7 +319,6 @@ int main(int argc, char ** argv) {
llama_perf_context_print(ctx);
// clean up
llama_batch_free(batch);
llama_backend_free();
return 0;

View File

@@ -134,7 +134,8 @@ static bool run(llama_context * ctx, const common_params & params) {
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true);
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}

View File

@@ -13,10 +13,10 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(llama_n_batch(ctx), 1);
for (uint64_t i = 0; i < sentences.size(); i++) {
common_batch_clear(batch);
llama_batch_ext_clear(batch);
const std::string input_string = instruction + sentences[i];
@@ -41,7 +41,8 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
// add input to batch (this increments n_tokens)
for (int32_t j = 0; j < n_toks; j++) {
common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, inputs[j], j, &seq_id, 1 , j >= n_inst);
}
// clear previous kv_cache values (irrelevant for embeddings)
@@ -50,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
llama_set_causal_attn(ctx, false);
// run model
llama_decode(ctx, batch);
llama_decode_ext(ctx, batch);
// get embedding dimensions
uint64_t n_embd = llama_model_n_embd(model);
@@ -89,7 +90,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
#endif
}
llama_batch_free(batch);
llama_batch_ext_free(batch);
return result;
}
@@ -106,25 +107,26 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
llama_batch_ext * bat = llama_batch_ext_init(llama_n_batch(ctx), 1);
std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true);
int32_t i_current_token = 0;
while (true) {
common_batch_clear(bat);
llama_batch_ext_clear(bat);
{
const int32_t n_inputs = inputs.size();
for (int32_t i = 0; i < n_inputs; i++) {
common_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(bat, inputs[i], i_current_token++, &seq_id, 1, i == n_inputs - 1);
}
}
inputs.clear();
llama_decode(ctx, bat);
llama_decode_ext(ctx, bat);
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
llama_token token = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(bat) - 1);
if (token == eos_token) {
break;
@@ -145,7 +147,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
std::printf("\n");
}
llama_batch_free(bat);
llama_batch_ext_free(bat);
return result;
}

View File

@@ -497,7 +497,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
// clear the KV cache
llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@@ -511,14 +511,15 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
tokens[batch_start] = llama_vocab_bos(vocab);
}
common_batch_clear(batch);
llama_batch_ext_clear(batch);
for (int i = 0; i < batch_size; i++) {
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true);
}
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
llama_batch_free(batch);
llama_batch_ext_free(batch);
return false;
}
@@ -531,7 +532,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
}
}
llama_batch_free(batch);
llama_batch_ext_free(batch);
const auto t_end = std::chrono::high_resolution_clock::now();

View File

@@ -353,7 +353,8 @@ int main(int argc, char ** argv) {
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true);
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}

View File

@@ -1427,7 +1427,7 @@ struct sql_printer : public printer {
}
};
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx);
@@ -1444,14 +1444,15 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
for (int i = 1; i < n_tokens; i++) {
tokens[i] = std::rand() % n_vocab;
}
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true);
llama_decode_ext(ctx, batch.get());
n_processed += n_tokens;
}
llama_synchronize(ctx);
}
static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx);
@@ -1461,7 +1462,8 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
for (int i = 0; i < n_gen; i++) {
llama_decode(ctx, llama_batch_get_one(&token, 1));
auto batch = llama_batch_ext_ptr::init_from_text(&token, 1, n_past + i, 0, true);
llama_decode_ext(ctx, batch.get());
llama_synchronize(ctx);
token = std::rand() % n_vocab;
}
@@ -1608,13 +1610,13 @@ int main(int argc, char ** argv) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
}
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
}
if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count);
}
test_gen(ctx, 1, t.n_threads);
test_gen(ctx, 1, 0, t.n_threads);
}
for (int i = 0; i < params.reps; i++) {
@@ -1627,14 +1629,14 @@ int main(int argc, char ** argv) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count,
i + 1, params.reps);
}
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
}
if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
i + 1, params.reps);
}
test_gen(ctx, t.n_gen, t.n_threads);
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
}
uint64_t t_ns = get_time_ns() - t_start;

View File

@@ -5,6 +5,7 @@
#include "clip.h"
#include "stb_image.h"
#include "llama.h"
#include "llama-cpp.h"
#include "ggml.h"
#include "console.h"
@@ -63,7 +64,7 @@ struct gemma3_context {
llama_model * model;
llama_context * lctx;
const llama_vocab * vocab;
llama_batch batch;
llama_batch_ext_ptr batch;
int n_threads = 1;
llama_pos n_past = 0;
@@ -73,7 +74,7 @@ struct gemma3_context {
lctx = llama_init.context.get();
vocab = llama_model_get_vocab(model);
n_threads = params.cpuparams.n_threads;
batch = llama_batch_init(params.n_batch, 0, 1);
batch.reset(llama_batch_ext_init(params.n_batch, 1));
init_clip_model(params);
}
@@ -87,50 +88,18 @@ struct gemma3_context {
}
};
struct decode_embd_batch {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
};
static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
common_batch_clear(ctx.batch);
llama_batch_ext_clear(ctx.batch.get());
for (llama_token & t : tokens) {
common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(ctx.batch.get(), t, ctx.n_past++, &seq_id, 1, false);
}
if (logits_last) {
ctx.batch.logits[ctx.batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(ctx.batch.get());
}
// LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
if (llama_decode(ctx.lctx, ctx.batch)) {
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
LOG_ERR("Failed to decode text\n");
return 1;
}
@@ -179,8 +148,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) {
int64_t t1 = ggml_time_ms();
eval_text(ctx, "<start_of_image>");
llama_set_causal_attn(ctx.lctx, false);
decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0);
if (llama_decode(ctx.lctx, batch_img.batch)) {
llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, n_embd, ctx.n_past, 0));
if (llama_decode_ext(ctx.lctx, batch_img.get())) {
LOG_ERR("failed to decode image\n");
return 1;
}
@@ -210,9 +179,10 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
fflush(stdout);
// eval the token
common_batch_clear(ctx.batch);
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
if (llama_decode(ctx.lctx, ctx.batch)) {
llama_batch_ext_clear(ctx.batch.get());
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(ctx.batch.get(), token_id, ctx.n_past++, &seq_id, 1, true);
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
LOG_ERR("failed to decode token\n");
return 1;
}

View File

@@ -20,7 +20,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) {
n_eval = n_batch;
}
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
auto batch = llama_batch_ext_ptr::init_from_text(&tokens[i], n_eval, *n_past, 0, true);
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false;
}

View File

@@ -2,6 +2,7 @@
#include "llava.h"
#include "llama.h"
#include "llama-cpp.h"
#include <algorithm>
#include <cerrno>
@@ -438,39 +439,6 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
return true;
}
struct llava_embd_batch {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
};
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
@@ -480,8 +448,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
n_eval = n_batch;
}
float * embd = image_embed->embed+i*n_embd;
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
if (llama_decode(ctx_llama, llava_batch.batch)) {
auto batch = llama_batch_ext_ptr::init_from_embd(embd, n_eval, n_embd, 0, 0);
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}

View File

@@ -101,7 +101,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) {
n_eval = n_batch;
}
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
auto batch = llama_batch_ext_ptr::init_from_text(&tokens[i], n_eval, *n_past, 0, true);
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false;
}

View File

@@ -66,17 +66,11 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
llama_batch batch = {
int32_t(n_eval), // n_tokens
nullptr, // token
(image_embed->embed+i*n_embd), // embed
batch_mrope_pos.data(), // pos
nullptr, // n_seq_id
nullptr, // seq_id
nullptr, // logits
};
float * batch_embd = image_embed->embed+i*n_embd;
auto batch = llama_batch_ext_ptr::init_from_embd(batch_embd, n_eval, n_embd, 0, 0);
llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval);
if (llama_decode(ctx_llama, batch)) {
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
@@ -95,16 +89,24 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) {
n_eval = n_batch;
}
auto batch = llama_batch_get_one(&tokens[i], n_eval);
// TODO: add mrope pos ids somewhere else
pos.resize(batch.n_tokens * 4);
std::fill(pos.begin(), pos.end(), 0);
for (int j = 0; j < batch.n_tokens * 3; j ++) {
pos[j] = *st_pos_id + (j % batch.n_tokens);
}
batch.pos = pos.data();
if (llama_decode(ctx_llama, batch)) {
// TODO: add mrope pos ids somewhere else
int n_tokens = n_eval;
pos.resize(n_tokens * 4);
std::fill(pos.begin(), pos.end(), 0);
for (int j = 0; j < n_tokens * 3; j ++) {
pos[j] = *st_pos_id + (j % n_tokens);
}
llama_batch_ext_ptr batch(llama_batch_ext_init(n_eval, 1));
for (int j = 0; j < n_eval; j++) {
llama_token token = tokens[i + j];
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), token, pos[j], &seq_id, 1, false);
}
llama_batch_ext_set_output_last(batch.get());
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false;
}

View File

@@ -92,8 +92,10 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us();
// eval the prompt
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
llama_decode_ext(ctx, batch0.get());
llama_decode_ext(ctx, batch1.get());
for (int s = 1; s < W + G + 1; ++s) {
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
@@ -115,7 +117,7 @@ int main(int argc, char ** argv) {
// seq_id == 0 : the current input token
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
// seq_id [W + 1, W + G] : verification n-grams
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
llama_batch_ext * batch = llama_batch_ext_init(params.n_ctx, W + G + 1);
// target model sampling context
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
@@ -204,10 +206,10 @@ int main(int argc, char ** argv) {
// V V V V V V
// id
{
common_batch_clear(batch);
llama_batch_ext_clear(batch);
// current token - first token of the first level
common_batch_add(batch, id, n_past, seq_id_all, true);
llama_batch_ext_add_text(batch, id, n_past, seq_id_all.data(), seq_id_all.size(), true);
// verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
{
@@ -230,9 +232,10 @@ int main(int argc, char ** argv) {
const llama_token t = ngrams_observed.tokens[idx + j];
ngrams_cur[g].tokens [j + 1] = t;
ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
ngrams_cur[g].i_batch[j + 1] = llama_batch_ext_get_n_tokens(batch);
common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
llama_seq_id seq_id = W + 1 + g;
llama_batch_ext_add_text(batch, t, n_past + j + 1, &seq_id, 1, true);
}
}
}
@@ -244,18 +247,20 @@ int main(int argc, char ** argv) {
seq_id_look[j] = i + j + 1;
}
common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
llama_batch_ext_add_text(batch, tokens_j[0][i], n_past + i,
seq_id_look.data(), seq_id_look.size(), false);
}
// fill the rest of the levels
for (int j = 1; j < N - 1; j++) {
for (int i = 0; i < W; i++) {
common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
llama_seq_id seq_id = i + 1;
llama_batch_ext_add_text(batch, tokens_j[j][i], n_past + j + i, &seq_id, 1, j == N - 2);
}
}
}
if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch) != 0) {
LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__);
return 1;
}
@@ -475,7 +480,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_view_free(&kvc_view);
llama_batch_free(batch);
llama_batch_ext_free(batch);
llama_backend_free();

View File

@@ -91,8 +91,10 @@ int main(int argc, char ** argv){
const auto t_enc_start = ggml_time_us();
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
llama_decode_ext(ctx, batch0.get());
llama_decode_ext(ctx, batch1.get());
const auto t_enc_end = ggml_time_us();
@@ -108,7 +110,7 @@ int main(int argc, char ** argv){
std::vector<llama_token> draft;
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
llama_batch_ext * batch_tgt = llama_batch_ext_init(params.n_ctx, 1);
// debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
@@ -194,8 +196,9 @@ int main(int argc, char ** argv){
// clean the cache of draft tokens that weren't accepted
llama_kv_self_seq_rm(ctx, 0, n_past, -1);
common_batch_clear(batch_tgt);
common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
const llama_seq_id seq_id = 0;
llama_batch_ext_clear(batch_tgt);
llama_batch_ext_add_text(batch_tgt, draft[0], n_past, &seq_id, 1, true);
// Draft already contains a single token sampled from the model:
GGML_ASSERT(draft.size() == 1);
@@ -205,13 +208,13 @@ int main(int argc, char ** argv){
common_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
for (size_t i = 1; i < draft.size(); ++i) {
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true);
}
t_draft_us += ggml_time_us() - t_start_draft_us;
n_drafted += draft.size() - 1;
llama_decode(ctx, batch_tgt);
llama_decode_ext(ctx, batch_tgt);
++n_past;
draft.erase(draft.begin());
@@ -243,7 +246,7 @@ int main(int argc, char ** argv){
common_sampler_free(smpl);
llama_batch_free(batch_tgt);
llama_batch_ext_free(batch_tgt);
llama_backend_free();

View File

@@ -548,7 +548,8 @@ int main(int argc, char ** argv) {
int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data();
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) {
auto batch = llama_batch_ext_ptr::init_from_text(enc_input_buf, enc_input_size, 0, 0, true);
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}
@@ -668,7 +669,8 @@ int main(int argc, char ** argv) {
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true);
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}

View File

@@ -174,7 +174,7 @@ int main(int argc, char ** argv) {
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(n_ctx, 1);
int32_t n_total_prompt = 0;
int32_t n_total_gen = 0;
@@ -192,10 +192,11 @@ int main(int argc, char ** argv) {
LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
for (int32_t i = 0; i < n_tokens_system; ++i) {
common_batch_add(batch, tokens_system[i], i, { 0 }, false);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, tokens_system[i], i, &seq_id, 1, false);
}
if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
@@ -216,7 +217,7 @@ int main(int argc, char ** argv) {
common_kv_cache_dump_view_seqs(kvc_view, 40);
}
common_batch_clear(batch);
llama_batch_ext_clear(batch);
// decode any currently ongoing sequences
for (auto & client : clients) {
@@ -224,14 +225,15 @@ int main(int argc, char ** argv) {
continue;
}
client.i_batch = batch.n_tokens;
client.i_batch = llama_batch_ext_get_n_tokens(batch);
common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
llama_seq_id seq_id = client.id + 1;
llama_batch_ext_add_text(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, &seq_id, 1, true);
client.n_decoded += 1;
}
if (batch.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch) == 0) {
// all sequences have ended - clear the entire KV cache
for (int i = 1; i <= n_clients; ++i) {
llama_kv_self_seq_rm(ctx, i, -1, -1);
@@ -243,7 +245,7 @@ int main(int argc, char ** argv) {
}
// insert new sequences for decoding
if (cont_batching || batch.n_tokens == 0) {
if (cont_batching || llama_batch_ext_get_n_tokens(batch) == 0) {
for (auto & client : clients) {
if (client.seq_id == -1 && g_seq_id < n_seq) {
client.seq_id = g_seq_id;
@@ -262,17 +264,18 @@ int main(int argc, char ** argv) {
tokens_prompt = common_tokenize(ctx, client.prompt, false);
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
llama_seq_id seq_id = client.id + 1;
llama_batch_ext_add_text(batch, tokens_prompt[i], i + n_tokens_system, &seq_id, 1, false);
}
// extract the logits only for the last token
if (batch.n_tokens > 0) {
batch.logits[batch.n_tokens - 1] = true;
if (llama_batch_ext_get_n_tokens(batch) > 0) {
llama_batch_ext_set_output_last(batch);
}
client.n_prompt = tokens_prompt.size();
client.n_decoded = 0;
client.i_batch = batch.n_tokens - 1;
client.i_batch = llama_batch_ext_get_n_tokens(batch) - 1;
LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
@@ -286,14 +289,15 @@ int main(int argc, char ** argv) {
}
}
if (batch.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch) == 0) {
break;
}
// process in chunks of params.n_batch
int32_t n_batch = params.n_batch;
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch);
for (int32_t i = 0; i < (int32_t) n_tokens_in_batch; i += n_batch) {
// experiment: process in powers of 2
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
// n_batch /= 2;
@@ -301,19 +305,11 @@ int main(int argc, char ** argv) {
// continue;
//}
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
const int32_t n_tokens = std::min(n_batch, (int32_t) (n_tokens_in_batch - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
const int ret = llama_decode(ctx, batch_view);
llama_batch_ext * batch_view = llama_batch_ext_get_view(batch, i, n_tokens);
const int ret = llama_decode_ext(ctx, batch_view);
llama_batch_ext_free(batch_view);
if (ret != 0) {
if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size
@@ -417,7 +413,7 @@ int main(int argc, char ** argv) {
// TODO: print sampling/grammar timings for all clients
llama_perf_context_print(ctx);
llama_batch_free(batch);
llama_batch_ext_free(batch);
llama_backend_free();

View File

@@ -2,6 +2,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "llama-cpp.h"
#include <cmath>
#include <cstdio>
@@ -122,7 +123,7 @@ int main(int argc, char ** argv) {
LOG_INF("prompt tokens: %d\n", n_tokens_all);
//LOG_INF("prompt: %s\n", params.prompt.c_str());
llama_batch batch = llama_batch_init(params.n_batch, 0, 1);
llama_batch_ext_ptr batch(llama_batch_ext_init(params.n_batch, 1));
int n_past = 0;
@@ -140,17 +141,18 @@ int main(int argc, char ** argv) {
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
}
common_batch_clear(batch);
llama_batch_ext_clear(batch.get());
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
}
if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch.get());
}
if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch.get()) != 0) {
LOG_INF("%s: llama_decode() failed\n", __func__);
return 1;
}
@@ -174,17 +176,18 @@ int main(int argc, char ** argv) {
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
common_batch_clear(batch);
llama_batch_ext_clear(batch.get());
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
}
if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch.get());
}
if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch.get()) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
@@ -223,7 +226,7 @@ int main(int argc, char ** argv) {
while (n_cur <= n_len) {
// sample the next token
{
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1);
// is it an end of generation?
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
@@ -237,16 +240,17 @@ int main(int argc, char ** argv) {
n_decode += 1;
// prepare the next batch
common_batch_clear(batch);
llama_batch_ext_clear(batch.get());
// push this new token for next evaluation
common_batch_add(batch, new_token_id, n_past++, { 0 }, true);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), new_token_id, n_past++, &seq_id, 1, true);
}
n_cur += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
@@ -266,8 +270,6 @@ int main(int argc, char ** argv) {
llama_sampler_free(smpl);
llama_batch_free(batch);
llama_free(ctx);
llama_model_free(model);

View File

@@ -363,21 +363,20 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
// clear the KV cache
llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1);
common_batch batch(n_batch, 1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
common_batch_clear(batch);
batch.clear();
for (int i = 0; i < batch_size; i++) {
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true);
}
//LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch.get())) {
//LOG_ERR("%s : failed to eval\n", __func__);
llama_batch_free(batch);
return {tokens, -1, logit_history, prob_history};
}
@@ -397,8 +396,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
}
}
llama_batch_free(batch);
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {
@@ -504,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
common_batch batch(std::min(n_batch, n_ctx*n_seq), 1);
std::vector<float> logits;
if (num_batches > 1) {
@@ -555,7 +552,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
int n_outputs = 0;
batch.n_tokens = 0;
batch.clear();
for (int seq = 0; seq < n_seq_batch; seq++) {
int seq_start = batch_start + seq*n_ctx;
@@ -568,22 +565,18 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
}
for (int k = 0; k < batch_size; ++k) {
const int idx = seq*n_ctx + k;
batch.token [idx] = tokens[seq_start + k];
batch.pos [idx] = j*n_batch + k;
batch.n_seq_id[idx] = 1;
batch.seq_id [idx][0] = seq;
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
const llama_pos pos = j*n_batch + k;
bool output = pos >= first;
batch.add_text(tokens[seq_start + k], pos, seq, output);
n_outputs += batch.logits[idx] != 0;
n_outputs += output ? 1 : 0;
}
batch.n_tokens += batch_size;
// restore the original token in case it was set to BOS
tokens[seq_start] = token_org;
}
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch.get())) {
LOG_INF("%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
@@ -653,36 +646,23 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
}
llama_batch_free(batch);
return {tokens, ppl, logit_history, prob_history};
}
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
static bool decode_helper(llama_context * ctx, common_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
int prev_outputs = 0;
for (int i = 0; i < (int) batch.n_tokens; i += n_batch) {
const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i);
for (int i = 0; i < (int) batch.get_n_tokens(); i += n_batch) {
const int n_tokens = std::min<int>(n_batch, batch.get_n_tokens() - i);
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
common_batch batch_view = batch.get_view(i, n_tokens);
const int ret = llama_decode(ctx, batch_view);
const int ret = llama_decode_ext(ctx, batch_view.get());
if (ret != 0) {
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
}
int n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
n_outputs += batch_view.logits[i] != 0;
}
int n_outputs = batch_view.n_outputs;
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
@@ -863,7 +843,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, 4);
common_batch batch(n_ctx, 4);
std::vector<float> tok_logits(n_vocab);
// TODO: this could be made smaller; it's currently the worst-case size
@@ -879,7 +859,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
size_t i1 = i0;
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
common_batch_clear(batch);
batch.clear();
// batch as much tasks as possible into the available context
// each task has 4 unique sequence ids - one for each ending
@@ -895,9 +875,9 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
}
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
batch.add_text_multi_seq(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
llama_batch_ext_set_output_last(batch.get());
n_logits += 1;
for (int s = 0; s < 4; ++s) {
@@ -905,7 +885,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
// TODO: don't evaluate the last token of each sequence
for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
const bool needs_logits = i < seq_tokens_size - 1;
common_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
batch.add_text_multi_seq(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
n_logits += needs_logits;
}
}
@@ -992,8 +972,6 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
i0 = i1 - 1;
}
llama_batch_free(batch);
LOG("\n");
}
@@ -1147,7 +1125,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
const int max_tasks_per_batch = 128;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, 2);
common_batch batch(n_ctx, 2);
std::vector<float> tok_logits(n_vocab);
// TODO: this could be made smaller; it's currently the worst-case size
@@ -1166,7 +1144,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
size_t i1 = i0;
size_t i_logits = 0;
common_batch_clear(batch);
batch.clear();
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
int n_logits = 0;
@@ -1176,15 +1154,15 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
}
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
batch.add_text_multi_seq(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch.get());
n_logits += 1;
for (int s = 0; s < 2; ++s) {
// TODO: end before the last token, no need to predict past the end of the sequences
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
common_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
batch.add_text_multi_seq(data[i1].seq_tokens[s][i], i, { s0 + s }, true);
n_logits += 1;
}
}
@@ -1501,7 +1479,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
common_batch batch(n_ctx, max_seq);
std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
@@ -1521,7 +1499,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
size_t i1 = i0;
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
common_batch_clear(batch);
batch.clear();
// batch as much tasks as possible into the available context
// each task has 4 unique sequence ids - one for each ending
@@ -1544,9 +1522,9 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
for (size_t i = 0; i < cur_task.common_prefix; ++i) {
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
batch.add_text_multi_seq(cur_task.seq_tokens[0][i], i, batch_indeces, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix
n_logits += 1;
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
@@ -1554,7 +1532,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
// TODO: don't evaluate the last token of each sequence
for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
const bool needs_logits = i < seq_tokens_size - 1;
common_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
batch.add_text_multi_seq(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
n_logits += needs_logits;
}
}
@@ -1653,8 +1631,6 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
i0 = i1 - 1;
}
llama_batch_free(batch);
if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return;
float p = 1.f*n_correct/n_done;
@@ -1767,7 +1743,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
// clear the KV cache
llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1);
common_batch batch(n_batch, 1);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@@ -1781,14 +1757,13 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
tokens[batch_start] = llama_vocab_bos(vocab);
}
common_batch_clear(batch);
batch.clear();
for (int i = 0; i < batch_size; i++) {
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
batch.add_text_multi_seq(tokens[batch_start + i], j*n_batch + i, {0}, true);
}
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
llama_batch_free(batch);
return;
}
@@ -1801,8 +1776,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
}
}
llama_batch_free(batch);
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {

View File

@@ -74,40 +74,56 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
return chunks;
}
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
size_t n_tokens = tokens.size();
for (size_t i = 0; i < n_tokens; i++) {
common_batch_add(batch, tokens[i], i, { seq_id }, true);
batch.add_text(tokens[i], i, seq_id, true);
}
}
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm = 2) {
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
const struct llama_model * model = llama_get_model(ctx);
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_self_clear(ctx);
// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
if (llama_decode(ctx, batch) < 0) {
LOG_ERR("%s : failed to decode\n", __func__);
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq);
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
// encoder-only model
if (llama_encode_ext(ctx, batch.get()) < 0) {
LOG_ERR("%s : failed to encode\n", __func__);
}
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
// decoder-only model
if (llama_decode_ext(ctx, batch.get()) < 0) {
LOG_ERR("%s : failed to decode\n", __func__);
}
}
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) {
if (!batch.tokens[i].logits) {
continue;
}
// try to get sequence embeddings - supported only when pooling_type is not NONE
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
const float * embd = nullptr;
int embd_pos = 0;
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
// try to get token embeddings
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
LOG_ERR("%s: failed to get embeddings for token %d\n", __func__, i);
continue;
}
embd_pos = i;
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
} else {
// try to get sequence embeddings - supported only when pooling_type is not NONE
embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id);
embd_pos = batch.tokens[i].seq_id;
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
}
float * out = output + batch.seq_id[i][0] * n_embd;
common_embd_normalize(embd, out, n_embd, 2);
float * out = output + embd_pos * n_embd;
common_embd_normalize(embd, out, n_embd, embd_norm);
}
}
@@ -214,7 +230,7 @@ int main(int argc, char ** argv) {
// initialize batch
const int n_chunks = chunks.size();
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
struct common_batch batch = common_batch(n_batch, 1);
// allocate output
const int n_embd = llama_model_n_embd(model);
@@ -231,10 +247,10 @@ int main(int argc, char ** argv) {
const uint64_t n_toks = inp.size();
// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
if (llama_batch_ext_get_n_tokens(batch.get()) + n_toks > n_batch) {
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
common_batch_clear(batch);
batch.clear();
p += s;
s = 0;
}
@@ -255,7 +271,7 @@ int main(int argc, char ** argv) {
chunks[i].tokens.clear();
}
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
struct common_batch query_batch = common_batch(n_batch, 1);
// start loop, receive query and return top k similar chunks based on cosine similarity
std::string query;
@@ -269,7 +285,7 @@ int main(int argc, char ** argv) {
std::vector<float> query_emb(n_embd, 0);
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
common_batch_clear(query_batch);
query_batch.clear();
// compute cosine similarities
{
@@ -299,6 +315,5 @@ int main(int argc, char ** argv) {
llama_perf_context_print(ctx);
// clean up
llama_batch_free(query_batch);
llama_backend_free();
}

View File

@@ -640,6 +640,7 @@ class LlamaData {
std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
std::list<std::string> msg_strs;
std::vector<char> fmtted;
llama_pos n_past = 0;
int init(Opt & opt) {
model = initialize_model(opt);
@@ -950,10 +951,10 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
}
// Check if we have enough space in the context to evaluate this batch
static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) {
static int check_context_size(const llama_context_ptr & ctx, const llama_batch_ext_ptr & batch) {
const int n_ctx = llama_n_ctx(ctx.get());
const int n_ctx_used = llama_kv_self_used_cells(ctx.get());
if (n_ctx_used + batch.n_tokens > n_ctx) {
if (n_ctx_used + llama_batch_ext_get_n_tokens(batch.get()) > n_ctx) {
printf(LOG_COL_DEFAULT "\n");
printe("context size exceeded\n");
return 1;
@@ -991,15 +992,17 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
}
// prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size());
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), llama_data.n_past, 0, true);
llama_token new_token_id;
while (true) {
check_context_size(llama_data.context, batch);
if (llama_decode(llama_data.context.get(), batch)) {
if (llama_decode_ext(llama_data.context.get(), batch.get())) {
printe("failed to decode\n");
return 1;
}
llama_data.n_past += llama_batch_ext_get_n_tokens(batch.get());
// sample the next token, check is it an end of generation?
new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1);
if (llama_vocab_is_eog(vocab, new_token_id)) {
@@ -1014,7 +1017,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
print_word_and_concatenate_to_response(piece, response);
// prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1);
batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, llama_data.n_past, 0, true));
}
printf(LOG_COL_DEFAULT);

View File

@@ -48,15 +48,11 @@ int main(int argc, char ** argv) {
auto tokens = common_tokenize(ctx, params.prompt, true);
// prepare the batch
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
for (size_t i = 0; i < tokens.size(); i++) {
common_batch_add(batch, tokens[i], i, {0}, false);
}
batch.logits[batch.n_tokens - 1] = true; // generate next token
llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true);
// evaluate prompt
llama_decode(ctx, batch);
n_past += batch.n_tokens;
llama_decode_ext(ctx, batch);
n_past += llama_batch_ext_get_n_tokens(batch);
// save state (rng, logits, embedding and kv_cache) to file
{
@@ -83,12 +79,13 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result0 += next_token_str;
common_batch_clear(batch);
common_batch_add(batch, next_token, n_past, {0}, true);
llama_batch_ext_clear(batch);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch);
llama_batch_ext_free(batch);
return 1;
}
n_past += 1;
@@ -135,12 +132,13 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result1 += next_token_str;
common_batch_clear(batch);
common_batch_add(batch, next_token, n_past, {0}, true);
llama_batch_ext_clear(batch);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
if (llama_decode(ctx2, batch)) {
if (llama_decode_ext(ctx2, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch);
llama_batch_ext_free(batch);
return 1;
}
n_past += 1;
@@ -216,12 +214,13 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result2 += next_token_str;
common_batch_clear(batch);
common_batch_add(batch, next_token, n_past, {1}, true);
llama_batch_ext_clear(batch);
llama_seq_id seq_id = 1; // seq 1 instead of 0
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
if (llama_decode(ctx3, batch)) {
if (llama_decode_ext(ctx3, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch);
llama_batch_ext_free(batch);
return 1;
}
n_past += 1;
@@ -233,7 +232,7 @@ int main(int argc, char ** argv) {
llama_sampler_free(smpl2);
llama_sampler_free(smpl3);
llama_batch_free(batch);
llama_batch_ext_free(batch);
if (result0 != result2) {
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);

Binary file not shown.

View File

@@ -830,11 +830,6 @@ struct server_task_result_cmpl_final : server_task_result {
ret.push_back({"timings", timings.to_json()});
}
// extra fields for debugging purposes
if (verbose) {
ret["__verbose"] = to_json_non_oaicompat();
}
return ret;
}
};
@@ -1229,7 +1224,7 @@ struct server_slot {
// only used for completion/embedding/infill/rerank
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
llama_batch batch_spec = {};
common_batch batch_spec;
llama_context * ctx = nullptr;
llama_context * ctx_dft = nullptr;
@@ -1801,7 +1796,7 @@ struct server_context {
llama_context_params cparams_dft;
llama_batch batch = {};
common_batch batch;
bool clean_kv_cache = true;
bool add_bos_token = true;
@@ -1834,11 +1829,7 @@ struct server_context {
common_speculative_free(slot.spec);
slot.spec = nullptr;
llama_batch_free(slot.batch_spec);
}
llama_batch_free(batch);
}
bool load_model(const common_params & params) {
@@ -1931,7 +1922,7 @@ struct server_context {
slot.n_predict = params_base.n_predict;
if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
slot.batch_spec = common_batch(params_base.speculative.n_max + 1, 1);
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) {
@@ -1956,7 +1947,7 @@ struct server_context {
slot.reset();
slots.push_back(slot);
slots.push_back(std::move(slot));
}
default_generation_settings_for_props = slots[0].to_json();
@@ -1967,7 +1958,7 @@ struct server_context {
const int32_t n_batch = llama_n_batch(ctx);
// only a single seq_id per token is needed
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
batch = common_batch(std::max(n_batch, params_base.n_parallel), 1);
}
metrics.init();
@@ -2102,9 +2093,7 @@ struct server_context {
}
if (slot.ctx_dft) {
llama_batch_free(slot.batch_spec);
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
slot.batch_spec = common_batch(slot.params.speculative.n_max + 1, 1);
}
slot.state = SLOT_STATE_STARTED;
@@ -2412,7 +2401,7 @@ struct server_context {
queue_results.send(std::move(res));
}
void send_embedding(const server_slot & slot, const llama_batch & batch) {
void send_embedding(const server_slot & slot, common_batch & batch) {
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.id_task;
res->index = slot.index;
@@ -2423,18 +2412,19 @@ struct server_context {
std::vector<float> embd_res(n_embd, 0.0f);
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
for (int i = 0; i < batch.get_n_tokens(); ++i) {
auto tok = batch.tokens[i];
if (!tok.logits || tok.seq_id != slot.id) {
continue;
}
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id);
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
continue;
@@ -2455,24 +2445,25 @@ struct server_context {
queue_results.send(std::move(res));
}
void send_rerank(const server_slot & slot, const llama_batch & batch) {
void send_rerank(const server_slot & slot, common_batch & batch) {
auto res = std::make_unique<server_task_result_rerank>();
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
for (int i = 0; i < batch.get_n_tokens(); ++i) {
auto tok = batch.tokens[i];
if (!tok.logits || tok.seq_id != slot.id) {
continue;
}
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id);
res->score = -1e6;
continue;
@@ -2863,7 +2854,7 @@ struct server_context {
}
// start populating the batch for this iteration
common_batch_clear(batch);
batch.clear();
// track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr;
@@ -2885,9 +2876,9 @@ struct server_context {
continue;
}
slot.i_batch = batch.n_tokens;
slot.i_batch = batch.get_n_tokens();
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
batch.add_text(slot.sampled, slot.n_past, slot.id, true);
slot.n_past += 1;
@@ -2904,7 +2895,7 @@ struct server_context {
int32_t n_ubatch = llama_n_ubatch(ctx);
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
if (params_base.cont_batching || batch.get_n_tokens() == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
if (slot.is_processing()) {
@@ -3070,7 +3061,7 @@ struct server_context {
// non-causal tasks require to fit the entire prompt in the physical batch
if (slot.is_non_causal()) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
if (batch.get_n_tokens() + slot.n_prompt_tokens > n_batch) {
continue;
}
}
@@ -3090,11 +3081,11 @@ struct server_context {
slot.cache_tokens.resize(slot.n_past);
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
while (slot.n_past < slot.n_prompt_tokens && batch.get_n_tokens() < n_batch) {
// without pooling, we want to output the embeddings for all the tokens in the batch
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@@ -3104,13 +3095,13 @@ struct server_context {
slot.n_past++;
}
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.get_n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
// entire prompt has been processed
if (slot.n_past == slot.n_prompt_tokens) {
slot.state = SLOT_STATE_DONE_PROMPT;
GGML_ASSERT(batch.n_tokens > 0);
GGML_ASSERT(batch.get_n_tokens() > 0);
common_sampler_reset(slot.smpl);
@@ -3120,27 +3111,27 @@ struct server_context {
}
// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
batch.set_logits_last();
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
slot.i_batch = batch.get_n_tokens() - 1;
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.get_n_tokens());
}
}
if (batch.n_tokens >= n_batch) {
if (batch.get_n_tokens() >= n_batch) {
break;
}
}
}
if (batch.n_tokens == 0) {
if (batch.get_n_tokens() == 0) {
SRV_WRN("%s", "no tokens to decode\n");
return;
}
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
SRV_DBG("decoding batch, n_tokens = %d\n", batch.get_n_tokens());
if (slot_batched) {
// make sure we're in the right embedding mode
@@ -3150,20 +3141,12 @@ struct server_context {
}
// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i);
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
common_batch batch_view = batch.get_view(i, n_tokens);
const int ret = llama_decode(ctx, batch_view);
const int ret = llama_decode_ext(ctx, batch_view.get());
metrics.on_decoded(slots);
if (ret != 0) {
@@ -3298,16 +3281,16 @@ struct server_context {
}
// construct the speculation batch
common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
slot.batch_spec.clear();
slot.batch_spec.add_text(id, slot.n_past, slot.id, true);
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true);
}
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens());
llama_decode(ctx, slot.batch_spec);
llama_decode_ext(ctx, slot.batch_spec.get());
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);

View File

@@ -99,9 +99,13 @@ export default function ChatScreen() {
canvasData,
replaceMessageAndGenerate,
} = useAppContext();
const textarea = useOptimizedTextarea(prefilledMsg.content());
const [inputMsg, setInputMsg] = useState(prefilledMsg.content());
const inputRef = useRef<HTMLTextAreaElement>(null);
const { extraContext, clearExtraContext } = useVSCodeContext(textarea);
const { extraContext, clearExtraContext } = useVSCodeContext(
inputRef,
setInputMsg
);
// TODO: improve this when we have "upload file" feature
const currExtra: Message['extra'] = extraContext ? [extraContext] : undefined;
@@ -131,10 +135,9 @@ export default function ChatScreen() {
};
const sendNewMessage = async () => {
const lastInpMsg = textarea.value();
if (lastInpMsg.trim().length === 0 || isGenerating(currConvId ?? ''))
return;
textarea.setValue('');
if (inputMsg.trim().length === 0 || isGenerating(currConvId ?? '')) return;
const lastInpMsg = inputMsg;
setInputMsg('');
scrollToBottom(false);
setCurrNodeId(-1);
// get the last message node
@@ -143,13 +146,13 @@ export default function ChatScreen() {
!(await sendMessage(
currConvId,
lastMsgNodeId,
lastInpMsg,
inputMsg,
currExtra,
onChunk
))
) {
// restore the input message if failed
textarea.setValue(lastInpMsg);
setInputMsg(lastInpMsg);
}
// OK
clearExtraContext();
@@ -192,13 +195,16 @@ export default function ChatScreen() {
// send the prefilled message if needed
sendNewMessage();
} else {
// otherwise, focus on the input
textarea.focus();
// otherwise, focus on the input and move the cursor to the end
if (inputRef.current) {
inputRef.current.focus();
inputRef.current.selectionStart = inputRef.current.value.length;
}
}
prefilledMsg.clear();
// no need to keep track of sendNewMessage
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [textarea.ref]);
}, [inputRef]);
// due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
const pendingMsgDisplay: MessageDisplay[] =
@@ -252,7 +258,9 @@ export default function ChatScreen() {
<textarea
className="textarea textarea-bordered w-full"
placeholder="Type a message (Shift+Enter to add a new line)"
ref={textarea.ref}
ref={inputRef}
value={inputMsg}
onChange={(e) => setInputMsg(e.target.value)}
onKeyDown={(e) => {
if (e.nativeEvent.isComposing || e.keyCode === 229) return;
if (e.key === 'Enter' && e.shiftKey) return;
@@ -272,7 +280,11 @@ export default function ChatScreen() {
Stop
</button>
) : (
<button className="btn btn-primary ml-2" onClick={sendNewMessage}>
<button
className="btn btn-primary ml-2"
onClick={sendNewMessage}
disabled={inputMsg.trim().length === 0}
>
Send
</button>
)}
@@ -286,43 +298,3 @@ export default function ChatScreen() {
</div>
);
}
export interface OptimizedTextareaValue {
value: () => string;
setValue: (value: string) => void;
focus: () => void;
ref: React.RefObject<HTMLTextAreaElement>;
}
// This is a workaround to prevent the textarea from re-rendering when the inner content changes
// See https://github.com/ggml-org/llama.cpp/pull/12299
function useOptimizedTextarea(initValue: string): OptimizedTextareaValue {
const [savedInitValue, setSavedInitValue] = useState<string>(initValue);
const textareaRef = useRef<HTMLTextAreaElement>(null);
useEffect(() => {
if (textareaRef.current && savedInitValue) {
textareaRef.current.value = savedInitValue;
setSavedInitValue('');
}
}, [textareaRef, savedInitValue, setSavedInitValue]);
return {
value: () => {
return textareaRef.current?.value ?? savedInitValue;
},
setValue: (value: string) => {
if (textareaRef.current) {
textareaRef.current.value = value;
}
},
focus: () => {
if (textareaRef.current) {
// focus and move the cursor to the end
textareaRef.current.focus();
textareaRef.current.selectionStart = textareaRef.current.value.length;
}
},
ref: textareaRef,
};
}

View File

@@ -1,6 +1,5 @@
import { useEffect, useState } from 'react';
import { MessageExtraContext } from './types';
import { OptimizedTextareaValue } from '../components/ChatScreen';
// Extra context when using llama.cpp WebUI from llama-vscode, inside an iframe
// Ref: https://github.com/ggml-org/llama.cpp/pull/11940
@@ -15,7 +14,10 @@ interface SetTextEvData {
* window.postMessage({ command: 'setText', text: 'Spot the syntax error', context: 'def test()\n return 123' }, '*');
*/
export const useVSCodeContext = (textarea: OptimizedTextareaValue) => {
export const useVSCodeContext = (
inputRef: React.RefObject<HTMLTextAreaElement>,
setInputMsg: (text: string) => void
) => {
const [extraContext, setExtraContext] = useState<MessageExtraContext | null>(
null
);
@@ -25,20 +27,20 @@ export const useVSCodeContext = (textarea: OptimizedTextareaValue) => {
const handleMessage = (event: MessageEvent) => {
if (event.data?.command === 'setText') {
const data: SetTextEvData = event.data;
textarea.setValue(data?.text);
setInputMsg(data?.text);
if (data?.context && data.context.length > 0) {
setExtraContext({
type: 'context',
content: data.context,
});
}
textarea.focus();
inputRef.current?.focus();
}
};
window.addEventListener('message', handleMessage);
return () => window.removeEventListener('message', handleMessage);
}, [textarea]);
}, [inputRef, setInputMsg]);
// Add a keydown listener that sends the "escapePressed" message to the parent window
useEffect(() => {

View File

@@ -108,19 +108,22 @@ int main(int argc, char ** argv) {
}
// prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
llama_pos n_past = 0;
llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true);
n_past += llama_batch_ext_get_n_tokens(batch);
llama_token new_token_id;
while (true) {
// check if we have enough space in the context to evaluate this batch
int n_ctx = llama_n_ctx(ctx);
int n_ctx_used = llama_kv_self_used_cells(ctx);
if (n_ctx_used + batch.n_tokens > n_ctx) {
if (n_ctx_used + llama_batch_ext_get_n_tokens(batch) > n_ctx) {
printf("\033[0m\n");
fprintf(stderr, "context size exceeded\n");
exit(0);
}
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch)) {
GGML_ABORT("failed to decode\n");
}
@@ -144,9 +147,14 @@ int main(int argc, char ** argv) {
response += piece;
// prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1);
llama_batch_ext_clear(batch);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, new_token_id, n_past, &seq_id, 1, true);
n_past++;
}
llama_batch_ext_free(batch);
return response;
};

View File

@@ -143,7 +143,7 @@ int main(int argc, char ** argv) {
// prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0, true);
// main loop
@@ -151,14 +151,14 @@ int main(int argc, char ** argv) {
int n_decode = 0;
llama_token new_token_id;
for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) {
for (int n_pos = 0; n_pos + llama_batch_ext_get_n_tokens(batch) < n_prompt + n_predict; ) {
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
n_pos += batch.n_tokens;
n_pos += llama_batch_ext_get_n_tokens(batch);
// sample the next token
{
@@ -180,7 +180,9 @@ int main(int argc, char ** argv) {
fflush(stdout);
// prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1);
llama_batch_ext_clear(batch);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, new_token_id, n_pos, &seq_id, 1, true);
n_decode += 1;
}
@@ -198,6 +200,7 @@ int main(int argc, char ** argv) {
llama_perf_context_print(ctx);
fprintf(stderr, "\n");
llama_batch_ext_free(batch);
llama_sampler_free(smpl);
llama_free(ctx);
llama_model_free(model);

View File

@@ -113,7 +113,8 @@ int main(int argc, char ** argv) {
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
// eval the prompt
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
auto batch = llama_batch_ext_ptr::init_from_text(inp.data(), inp.size() - 1, 0, 0, true);
llama_decode_ext(ctx_tgt, batch.get());
// note: keep the last token separate!
llama_token id_last = inp.back();
@@ -132,7 +133,7 @@ int main(int argc, char ** argv) {
struct common_speculative * spec = common_speculative_init(ctx_dft);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), 1);
const auto t_enc_end = ggml_time_us();
@@ -151,8 +152,9 @@ int main(int argc, char ** argv) {
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
// always have a token to evaluate from before - id_last
common_batch_clear(batch_tgt);
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
llama_batch_ext_clear(batch_tgt);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch_tgt, id_last, n_past++, &seq_id, 1, true);
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
{
@@ -162,12 +164,12 @@ int main(int argc, char ** argv) {
}
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true);
}
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt);
llama_decode_ext(ctx_tgt, batch_tgt);
}
// sample from the full target batch and return the accepted tokens based on the target sampler
@@ -253,6 +255,7 @@ int main(int argc, char ** argv) {
common_sampler_free(smpl);
common_speculative_free(spec);
llama_batch_ext_free(batch_tgt);
llama_backend_free();
LOG("\n\n");

View File

@@ -45,7 +45,6 @@ int main(int argc, char ** argv) {
}
common_init();
if (params.speculative.model.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
@@ -166,9 +165,12 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us();
// eval the prompt with both models
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1));
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1));
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input));
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true));
llama_decode_ext(ctx_tgt, batch0.get());
llama_decode_ext(ctx_tgt, batch1.get());
llama_decode_ext(ctx_dft, batch2.get());
const auto t_enc_end = ggml_time_us();
@@ -199,8 +201,8 @@ int main(int argc, char ** argv) {
drafts[s].smpl = common_sampler_init(model_dft, params.sampling);
}
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
llama_batch_ext * batch_dft = llama_batch_ext_init(llama_n_batch(ctx_dft), 1);
llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), n_seq_dft);
const auto t_dec_start = ggml_time_us();
@@ -331,7 +333,7 @@ int main(int argc, char ** argv) {
}
active_seqs.erase(s);
for (int i = 0; i < n_seq_dft; i++) {
for(int i = 0; i < n_seq_dft; i++) {
if (i == s) {
continue;
}
@@ -441,12 +443,13 @@ int main(int argc, char ** argv) {
drafts[0].dists.push_back(std::vector<llama_token_data>());
drafts[0].i_batch_tgt.push_back(0);
common_batch_clear(batch_dft);
common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
llama_batch_ext_clear(batch_dft);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch_dft, token_id, n_past_dft, &seq_id, 1, true);
llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode(ctx_dft, batch_dft);
llama_decode_ext(ctx_dft, batch_dft);
++n_past_dft;
}
@@ -471,12 +474,19 @@ int main(int argc, char ** argv) {
drafts[0].drafting = true;
drafts[0].i_batch_dft = 0;
common_batch_clear(batch_tgt);
common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
struct batch_info {
llama_token id;
llama_pos pos;
std::vector<llama_seq_id> seq_id;
};
std::vector<batch_info> batch_tgt_data;
batch_tgt_data.push_back({ drafts[0].tokens[0], n_past_tgt, {0} });
// sample n_draft tokens from the draft model using tree-based sampling
for (int i = 0; i < n_draft; ++i) {
batch_dft.n_tokens = 0;
llama_batch_ext_clear(batch_dft);
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].skip = false;
@@ -507,11 +517,10 @@ int main(int argc, char ** argv) {
llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
// all previous tokens from this branch are now also part of the new branch
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
if (batch_tgt.seq_id[t][p] == s) {
batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
batch_tgt.n_seq_id[t]++;
for (int t = 0; t < (int) batch_tgt_data.size(); ++t) {
for (int p = 0; p < (int) batch_tgt_data[t].seq_id.size(); ++p) {
if (batch_tgt_data[t].seq_id[p] == s) {
batch_tgt_data[t].seq_id.push_back(n_seq_cur);
break;
}
}
@@ -553,32 +562,30 @@ int main(int argc, char ** argv) {
drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
// add unique drafted tokens to the target batch
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
drafts[s].i_batch_tgt.push_back(batch_tgt_data.size());
common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
batch_tgt_data.push_back({ id, n_past_tgt + i + 1, { s }});
// add the token to the batch for batched decoding with the draft model
drafts[s].i_batch_dft = batch_dft.n_tokens;
drafts[s].i_batch_dft = llama_batch_ext_add_text(batch_dft, id, n_past_cur, &s, 1, true);
common_batch_add(batch_dft, id, n_past_cur, { s }, true);
if (batch_tgt.n_tokens > n_draft) {
if (batch_tgt_data.size() > (size_t) n_draft) {
drafts[s].drafting = false;
}
}
}
// no sequence is drafting anymore
if (batch_dft.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch_dft) == 0) {
break;
}
// evaluate the drafted tokens on the draft model
llama_decode(ctx_dft, batch_dft);
llama_decode_ext(ctx_dft, batch_dft);
++n_past_cur;
++n_drafted;
if (batch_tgt.n_tokens > n_draft) {
if (batch_tgt_data.size() > (size_t) n_draft) {
break;
}
}
@@ -590,8 +597,15 @@ int main(int argc, char ** argv) {
llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1);
}
llama_batch_ext_clear(batch_tgt);
for (int i = 0; i < (int) batch_tgt_data.size(); ++i) {
const auto & data = batch_tgt_data[i];
llama_batch_ext_add_text(batch_tgt, data.id, data.pos, data.seq_id.data(), data.seq_id.size(), true);
}
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt);
llama_decode_ext(ctx_tgt, batch_tgt);
++n_past_tgt;
}
@@ -634,7 +648,8 @@ int main(int argc, char ** argv) {
common_sampler_free(drafts[s].smpl);
}
llama_batch_free(batch_dft);
llama_batch_ext_free(batch_dft);
llama_batch_ext_free(batch_tgt);
llama_backend_free();

View File

@@ -571,10 +571,6 @@ int main(int argc, char ** argv) {
model_ttc = llama_init_ttc.model.get();
ctx_ttc = llama_init_ttc.context.get();
if (model_ttc == nullptr || ctx_ttc == nullptr) {
return ENOENT;
}
const llama_vocab * vocab = llama_model_get_vocab(model_ttc);
// TODO: refactor in a common struct
@@ -590,10 +586,6 @@ int main(int argc, char ** argv) {
model_cts = llama_init_cts.model.get();
ctx_cts = llama_init_cts.context.get();
if (model_cts == nullptr || ctx_cts == nullptr) {
return ENOENT;
}
std::vector<common_sampler *> smpl(n_parallel);
for (int i = 0; i < n_parallel; ++i) {
params.sampling.no_perf = (i != 0);
@@ -826,7 +818,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
// create a llama_batch
// we use this object to submit token data for decoding
llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel);
llama_batch_ext * batch = llama_batch_ext_init(std::max(prompt_inp.size(), (size_t) n_parallel), n_parallel);
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
for (int32_t i = 0; i < n_parallel; ++i) {
@@ -835,14 +827,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
// evaluate the initial prompt
for (size_t i = 0; i < prompt_inp.size(); ++i) {
common_batch_add(batch, prompt_inp[i], i, seq_ids, false);
llama_batch_ext_add_text(batch, prompt_inp[i], i, seq_ids.data(), seq_ids.size(), false);
}
GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size());
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) prompt_inp.size());
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch);
if (llama_decode(ctx_ttc, batch) != 0) {
if (llama_decode_ext(ctx_ttc, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
@@ -861,16 +853,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
// remember the batch index of the last token for each parallel sequence
// we need this to determine which logits to sample from
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
std::vector<int32_t> i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1);
int n_past = batch.n_tokens;
int n_past = llama_batch_ext_get_n_tokens(batch);
int n_decode = 0;
bool next_token_uses_guide_token = true;
while (n_decode <= n_predict) {
// prepare the next batch
common_batch_clear(batch);
llama_batch_ext_clear(batch);
// sample the next token for each parallel sequence / stream
for (int32_t i = 0; i < n_parallel; ++i) {
@@ -926,14 +918,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
//LOG_CNT("%d", i);
}
i_batch[i] = batch.n_tokens;
i_batch[i] = llama_batch_ext_get_n_tokens(batch);
// push this new token for next evaluation
common_batch_add(batch, new_token_id, n_past, { i }, true);
llama_batch_ext_add_text(batch, new_token_id, n_past, &i, 1, true);
}
// all streams are finished
if (batch.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch) == 0) {
break;
}
@@ -941,13 +933,13 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
n_past += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx_ttc, batch)) {
if (llama_decode_ext(ctx_ttc, batch)) {
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
}
llama_batch_free(batch);
llama_batch_ext_free(batch);
LOG("\n");
LOG_INF("%s: time for decoder: %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f);
@@ -1016,14 +1008,15 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
const int n_codes = codes.size();
llama_batch batch = llama_batch_init(n_codes, 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(n_codes, 1);
for (size_t i = 0; i < codes.size(); ++i) {
common_batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits?
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, codes[i], i, &seq_id, 1, true); // TODO: all logits?
}
GGML_ASSERT(batch.n_tokens == n_codes);
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == n_codes);
if (llama_decode(ctx_cts, batch) != 0) {
if (llama_decode_ext(ctx_cts, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
@@ -1087,6 +1080,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
retval = ENOENT;
}
llama_batch_ext_free(batch);
llama_backend_free();
return retval;

View File

@@ -76,11 +76,7 @@ if (GGML_CCACHE)
set(GGML_CCACHE_VARIANT sccache)
endif()
# TODO: should not be set globally
if (GGML_SYCL AND GGML_CCACHE_FOUND AND WIN32)
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "ccache compiler_type=icl")
else ()
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}")
endif ()
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}")
set(ENV{CCACHE_SLOPPINESS} time_macros)
message(STATUS "${GGML_CCACHE_VARIANT} found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
else()
@@ -329,10 +325,6 @@ if (CMAKE_SYSTEM_NAME MATCHES "Android")
target_link_libraries(ggml-base PRIVATE dl)
endif()
if(CMAKE_SYSTEM_NAME MATCHES "visionOS")
target_compile_definitions(ggml-base PUBLIC _DARWIN_C_SOURCE)
endif()
if (BUILD_SHARED_LIBS)
foreach (target ggml-base ggml)
set_target_properties(${target} PROPERTIES POSITION_INDEPENDENT_CODE ON)

File diff suppressed because it is too large Load Diff

View File

@@ -3110,17 +3110,17 @@ static void ggml_compute_forward_dup_same_cont(
const int ith = params->ith; // thread index
const int nth = params->nth; // number of threads
// parallelize by blocks
const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);
const int dr = (nk + nth - 1) / nth;
const int k0 = dr * ith;
const int k1 = MIN(k0 + dr, nk);
// parallelize by elements
const int ne = ggml_nelements(dst);
const int dr = (ne + nth - 1) / nth;
const int ie0 = dr * ith;
const int ie1 = MIN(ie0 + dr, ne);
if (k0 < k1) {
if (ie0 < ie1) {
memcpy(
((char *) dst->data + k0*nb0),
((char *) src0->data + k0*nb0),
(k1 - k0) * nb0);
((char *) dst->data + ie0*nb0),
((char *) src0->data + ie0*nb0),
(ie1 - ie0) * nb0);
}
}
@@ -4055,6 +4055,7 @@ static void ggml_compute_forward_dup_f32(
static void ggml_compute_forward_dup_bytes(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
@@ -4068,10 +4069,10 @@ static void ggml_compute_forward_dup_bytes(
}
const size_t type_size = ggml_type_size(src0->type);
const int ith = params->ith; // thread index
const int nth = params->nth; // number of threads
// parallelize by rows
const int nr = ne01;
// number of rows per thread
@@ -4081,10 +4082,10 @@ static void ggml_compute_forward_dup_bytes(
const int ir1 = MIN(ir0 + dr, nr);
if (src0->type == dst->type &&
ggml_are_same_shape(src0, dst) &&
ne00 == ne0 &&
nb00 == type_size && nb0 == type_size) {
// copy by rows
const size_t rs = ggml_row_size(src0->type, ne00);
const size_t rs = ne00 * type_size;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ir0; i01 < ir1; i01++) {
@@ -4139,20 +4140,17 @@ static void ggml_compute_forward_dup_bytes(
}
// dst counters
int64_t k10 = 0;
int64_t i10 = 0;
int64_t i11 = 0;
int64_t i12 = 0;
int64_t i13 = 0;
// number of blocks in a row
const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
const int64_t nk0 = ne0 / ggml_blck_size(dst->type);
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
k10 += nk00 * ir0;
while (k10 >= nk0) {
k10 -= nk0;
i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
@@ -4164,14 +4162,14 @@ static void ggml_compute_forward_dup_bytes(
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t k00 = 0; k00 < nk00; k00++) {
const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
memcpy(dst_ptr, src0_ptr, type_size);
if (++k10 == nk0) {
k10 = 0;
if (++i10 == ne0) {
i10 = 0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
@@ -4184,9 +4182,9 @@ static void ggml_compute_forward_dup_bytes(
}
}
}
k10 += nk00 * (ne01 - ir1);
while (k10 >= nk0) {
k10 -= nk0;
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
@@ -14310,9 +14308,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
}
// extra_buffer op?
if (ggml_cpu_extra_compute_forward(params, tensor)) {
return;
}
if (ggml_cpu_extra_compute_forward(params, tensor)) return;
switch (tensor->op) {
case GGML_OP_DUP:

View File

@@ -41,17 +41,14 @@
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
#define GGML_CUDA_CC_PASCAL 600
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define GGML_CUDA_CC_VOLTA 700
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_ADA_LOVELACE 890
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
#define GGML_CUDA_CC_PASCAL 600
#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define GGML_CUDA_CC_VOLTA 700
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_ADA_LOVELACE 890
#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
// AMD
// GCN/CNDA, wave size is 64
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
@@ -73,17 +70,8 @@
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
// Moore Threads
#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
#define GGML_CUDA_CC_QY1 (GGML_MUSA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
#define GGML_CUDA_CC_QY2 (GGML_MUSA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
#define GGML_CUDA_CC_NG (GGML_MUSA_CC_OFFSET_MTHREADS + 0x310) // TBD
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NEXT)
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
#define GGML_CUDA_CC_QY1 210
#define GGML_CUDA_CC_QY2 220
#ifdef __CUDA_ARCH_LIST__
constexpr bool ggml_cuda_has_arch_impl(int) {
@@ -221,21 +209,21 @@ typedef float2 dfloat2;
#define CP_ASYNC_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
#define FLASH_ATTN_AVAILABLE
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
}
static bool fast_fp16_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
return fp16_available(cc) && cc != 610;
}
// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fast_fp16_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
}
// Any FP16 tensor core instructions are available for ggml code.
@@ -243,20 +231,20 @@ static bool fp16_mma_available(const int cc) {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
return false;
#else
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
}
// To be used for feature selection of external libraries, e.g. cuBLAS.
static bool fp16_mma_hardware_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA ||
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA ||
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
}
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static bool new_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
}
static bool cp_async_available(const int cc) {

View File

@@ -606,47 +606,48 @@ static __global__ void flash_attn_stream_k_fixup(
*dst = dst_val / rowsum;
}
template<int D> // D == head size
template<int D, int parallel_blocks> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_combine_results(
const float * __restrict__ VKQ_parts,
const float2 * __restrict__ VKQ_meta,
float * __restrict__ dst,
const int parallel_blocks) {
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
dst += D * gridDim.z*blockIdx.x;
float * __restrict__ dst) {
VKQ_parts += parallel_blocks*D * gridDim.y*blockIdx.x;
VKQ_meta += parallel_blocks * gridDim.y*blockIdx.x;
dst += D * gridDim.y*blockIdx.x;
const int tid = threadIdx.x;
__builtin_assume(tid < D);
extern __shared__ float2 meta[];
__shared__ float2 meta[parallel_blocks];
if (tid < 2*parallel_blocks) {
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.y*(2*parallel_blocks) + tid];
}
__syncthreads();
float kqmax = meta[0].x;
#pragma unroll
for (int l = 1; l < parallel_blocks; ++l) {
kqmax = max(kqmax, meta[l].x);
}
float VKQ_numerator = 0.0f;
float VKQ_denominator = 0.0f;
#pragma unroll
for (int l = 0; l < parallel_blocks; ++l) {
const float diff = meta[l].x - kqmax;
const float KQ_max_scale = expf(diff);
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.y*D + blockIdx.y*D + tid];
VKQ_denominator += KQ_max_scale * meta[l].y;
}
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
}
static void on_no_fattn_vec_case(const int D) {
@@ -670,10 +671,12 @@ static void on_no_fattn_vec_case(const int D) {
}
}
template <int D, int ncols1, int ncols2, int KQ_stride>
// parallel_blocks == 0 is stream-k decomposition
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
void launch_fattn(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
const int warp_size = WARP_SIZE
) {
constexpr int ncols = ncols1 * ncols2;
@@ -745,14 +748,12 @@ void launch_fattn(
nb23 = nb23*bs*sizeof(half)/ts;
}
int parallel_blocks = 1;
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
const dim3 block_dim(warp_size, nwarps, 1);
dim3 blocks_num;
if (stream_k) {
if (parallel_blocks == 0) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const int max_blocks = 2*nsm;
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
@@ -768,43 +769,9 @@ void launch_fattn(
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
} else {
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
// parallel_blocks must not be larger than what the tensor size allows:
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
// If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
// Test whether parallel_blocks can be set to a higher value for better efficiency.
const int blocks_per_wave = nsm * max_blocks_per_sm;
int nwaves_best = 0;
int efficiency_percent_best = 0;
for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
const int nblocks_total = ntiles_total * parallel_blocks_test;
const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
// Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
if (efficiency_percent_best >= 90 && nwaves > nwaves_best) {
break;
}
if (efficiency_percent > efficiency_percent_best) {
nwaves_best = nwaves;
efficiency_percent_best = efficiency_percent;
parallel_blocks = parallel_blocks_test;
}
}
blocks_num.x = ntiles_x;
blocks_num.y = parallel_blocks;
blocks_num.z = Q->ne[2]*Q->ne[3];
blocks_num.x = parallel_blocks*ntiles_x;
blocks_num.y = Q->ne[2];
blocks_num.z = Q->ne[3];
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
@@ -836,7 +803,7 @@ void launch_fattn(
K_data,
V_data,
mask ? ((const char *) mask->data) : nullptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
(parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@@ -848,7 +815,7 @@ void launch_fattn(
);
CUDA_CHECK(cudaGetLastError());
if (stream_k) {
if constexpr (parallel_blocks == 0) {
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
@@ -857,14 +824,13 @@ void launch_fattn(
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
}
} else if (parallel_blocks > 1) {
} else if constexpr (parallel_blocks > 1) {
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
flash_attn_combine_results<D>
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
}
CUDA_CHECK(cudaGetLastError());
}

View File

@@ -970,8 +970,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
}
launch_fattn<D, ncols1, ncols2, KQ_per_iter>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
launch_fattn<D, ncols1, ncols2, 0, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true);
}

View File

@@ -4,7 +4,7 @@
#define FATTN_KQ_STRIDE_TILE_F16 64
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -58,17 +58,18 @@ static __global__ void flash_attn_tile_ext_f16(
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
const int stride_KV2 = nb11 / sizeof(half2);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -104,7 +105,8 @@ static __global__ void flash_attn_tile_ext_f16(
__syncthreads();
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
// Calculate KQ tile and keep track of new maximum KQ values:
half kqmax_new[ncols/nwarps];
@@ -269,16 +271,16 @@ static __global__ void flash_attn_tile_ext_f16(
const int i0 = i00 + 2*threadIdx.x;
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
if (gridDim.y == 1) {
if (parallel_blocks == 1) {
dst_val /= __half2half2(kqsum_j);
}
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
}
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
if (parallel_blocks != 1 && threadIdx.x == 0) {
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
@@ -286,7 +288,7 @@ static __global__ void flash_attn_tile_ext_f16(
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
template <int cols_per_block, bool use_logit_softcap>
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
@@ -294,17 +296,15 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, -1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, -1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
@@ -324,22 +324,37 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
}

View File

@@ -4,7 +4,7 @@
#define FATTN_KQ_STRIDE_TILE_F32 32
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -58,17 +58,18 @@ static __global__ void flash_attn_tile_ext_f32(
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
const int stride_KV2 = nb11 / sizeof(half2);
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -102,7 +103,8 @@ static __global__ void flash_attn_tile_ext_f32(
__syncthreads();
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
// Calculate KQ tile and keep track of new maximum KQ values:
float kqmax_new[ncols/nwarps];
@@ -267,17 +269,17 @@ static __global__ void flash_attn_tile_ext_f32(
const int i0 = i00 + 2*threadIdx.x;
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
if (gridDim.y == 1) {
if (parallel_blocks == 1) {
dst_val.x /= kqsum_j;
dst_val.y /= kqsum_j;
}
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
}
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
if (parallel_blocks != 1 && threadIdx.x == 0) {
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
@@ -285,7 +287,7 @@ static __global__ void flash_attn_tile_ext_f32(
#endif // FLASH_ATTN_AVAILABLE
}
template <int cols_per_block, bool use_logit_softcap>
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
@@ -293,17 +295,15 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, -1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, -1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
@@ -320,22 +320,37 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
}

View File

@@ -1,7 +1,7 @@
#include "common.cuh"
#include "fattn-common.cuh"
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -55,16 +55,17 @@ static __global__ void flash_attn_vec_ext_f16(
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb02* blockIdx.z + nb01*ic0;
K += nb12*(blockIdx.z / gqa_ratio);
V += nb22*(blockIdx.z / gqa_ratio);
Q += nb02* blockIdx.y + nb01*ic0;
K += nb12*(blockIdx.y / gqa_ratio);
V += nb22*(blockIdx.y / gqa_ratio);
const half * maskh = (const half *) mask + ne11*ic0;
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -171,7 +172,8 @@ static __global__ void flash_attn_vec_ext_f16(
half2 VKQ[ncols] = {{0.0f, 0.0f}};
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
@@ -281,29 +283,29 @@ static __global__ void flash_attn_vec_ext_f16(
kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
if (gridDim.y == 1) {
if (parallel_blocks == 1) {
dst_val /= kqsum[j_VKQ];
}
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
}
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, type_K, type_V, use_logit_softcap>;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
}
template <int D, ggml_type type_K, ggml_type type_V>
@@ -323,48 +325,65 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2;
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4;
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 8;
if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
}

View File

@@ -1,7 +1,7 @@
#include "common.cuh"
#include "fattn-common.cuh"
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
@@ -55,15 +55,16 @@ static __global__ void flash_attn_vec_ext_f32(
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb02* blockIdx.z + nb01*ic0;
K += nb12*(blockIdx.z / gqa_ratio);
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
Q += nb02* blockIdx.y + nb01*ic0;
K += nb12*(blockIdx.y / gqa_ratio);
V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE;
@@ -166,7 +167,8 @@ static __global__ void flash_attn_vec_ext_f32(
float VKQ[ncols] = {0.0f};
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
float kqmax_new_arr[ncols];
@@ -266,29 +268,29 @@ static __global__ void flash_attn_vec_ext_f32(
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
float dst_val = VKQ[j_VKQ];
if (gridDim.y == 1) {
if (parallel_blocks == 1) {
dst_val /= kqsum[j_VKQ];
}
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
}
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, type_K, type_V, use_logit_softcap>;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
}
template <int D, ggml_type type_K, ggml_type type_V>
@@ -305,48 +307,65 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2;
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4;
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 8;
if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
}

View File

@@ -18,7 +18,7 @@ namespace wmma = rocwmma;
#endif // FP16_MMA_AVAILABLE
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
@@ -67,7 +67,8 @@ static __global__ void flash_attn_ext_f16(
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
@@ -90,16 +91,16 @@ static __global__ void flash_attn_ext_f16(
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef);
@@ -175,7 +176,7 @@ static __global__ void flash_attn_ext_f16(
__syncthreads();
// Iterate over ne11 == previous tokens:
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
@@ -394,7 +395,7 @@ static __global__ void flash_attn_ext_f16(
if (ic0 + j_VKQ >= ne01) {
return;
}
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
float KQ_rowsum_j;
if (std::is_same<KQ_acc_t, float>::value) {
@@ -410,13 +411,13 @@ static __global__ void flash_attn_ext_f16(
break;
}
float dst_val = VKQ[j_VKQ*D_padded + i];
if (gridDim.y == 1) {
if (parallel_blocks == 1) {
dst_val /= KQ_rowsum_j;
}
dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
}
if (gridDim.y == 1 || threadIdx.x != 0) {
if (parallel_blocks == 1 || threadIdx.x != 0) {
continue;
}
@@ -427,7 +428,7 @@ static __global__ void flash_attn_ext_f16(
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
}
dst_meta_val.y = KQ_rowsum_j;
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
}
#else
NO_DEVICE_CODE;
@@ -461,26 +462,60 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
template <int D, int cols_per_block, typename KQ_acc_t>
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
constexpr int nwarps = 4;
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (4*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 4;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
return;
}
if (2*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 2;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
return;
}
constexpr int parallel_blocks = 1;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
}
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@@ -253,7 +253,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
if (GGML_CUDA_CC_IS_AMD(cc)) {
if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
#if defined(GGML_HIP_ROCWMMA_FATTN)
if (fp16_mma_available(cc)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
@@ -281,13 +281,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
if (!fp16_mma_available(cc)) {
if (prec == GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
}
} else {
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
@@ -296,17 +296,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
return;
}
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
const bool can_use_vector_kernel = (Q->ne[0] % (2*warp_size) == 0) && (prec == GGML_PREC_DEFAULT || Q->ne[0] <= 128);
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
const int gqa_ratio = Q->ne[2] / K->ne[2];
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) {
if (prec == GGML_PREC_DEFAULT) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
return;
} else if(Q->ne[0] <= 128) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
return;
}
return;
}
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:

View File

@@ -262,11 +262,9 @@ static ggml_cuda_device_info ggml_cuda_init() {
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
device_vmm ? "yes" : "no", prop.warpSize);
#elif defined(GGML_USE_MUSA)
// FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
info.devices[id].warp_size = 32;
// TODO: refine the .cc to reflect MUSA's actual CC capabilities
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
info.devices[id].cc += prop.minor * 0x10;
info.devices[id].cc = 100*prop.major + 10*prop.minor;
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
#else
@@ -1188,11 +1186,11 @@ static void ggml_cuda_op_mul_mat_cublas(
// ldc == nrows of the matrix that cuBLAS writes into
int64_t ldc = id == ctx.device ? ne0 : row_diff;
const int cc = ggml_cuda_info().devices[id].cc;
const int compute_capability = ggml_cuda_info().devices[id].cc;
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
if (((cc >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc)) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
if (src0->type != GGML_TYPE_F16) {
@@ -1216,7 +1214,7 @@ static void ggml_cuda_op_mul_mat_cublas(
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
if (GGML_CUDA_CC_IS_CDNA(cc)) {
if (GGML_CUDA_CC_IS_CDNA(compute_capability)) {
const float alpha = 1.0f;
const float beta = 0.0f;
CUBLAS_CHECK(
@@ -3230,9 +3228,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#ifndef FLASH_ATTN_AVAILABLE
return false;
#endif // FLASH_ATTN_AVAILABLE
if (op->src[0]->ne[3] != 1) {
return false;
}
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
return false;
}

View File

@@ -28,7 +28,7 @@ void ggml_cuda_op_mul_mat_q(
// Also its fixup needs to allocate a temporary buffer in the memory pool.
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
GGML_CUDA_CC_IS_NVIDIA(cc) && src1_ncols == ne11;
cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
switch (src0->type) {
@@ -145,7 +145,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return true;
#endif //GGML_CUDA_FORCE_MMQ
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
if (cc < GGML_CUDA_CC_OFFSET_AMD) {
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

View File

@@ -90,7 +90,7 @@ struct tile_x_sizes {
static int get_mmq_x_max_host(const int cc) {
return new_mma_available(cc) ? 128 :
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc) ?
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
#ifdef GGML_CUDA_FORCE_MMQ
128 : 64;
#else
@@ -123,8 +123,8 @@ static constexpr __device__ int get_mmq_x_max_device() {
}
static int get_mmq_y_host(const int cc) {
return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
((ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc)) ? 128 : 64);
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
(ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
}
static constexpr __device__ int get_mmq_y_device() {
@@ -2772,14 +2772,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shmem_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
shmem_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
@@ -2832,7 +2832,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
const int mmq_x_max = get_mmq_x_max_host(cc);
const int mmq_y = get_mmq_y_host(cc);
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && GGML_CUDA_CC_IS_NVIDIA(cc);
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
int mmq_x_best = 0;
int nparts_best = INT_MAX;

View File

@@ -129,7 +129,6 @@
#define cudaGraph_t hipGraph_t
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED

View File

@@ -134,6 +134,5 @@
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamBeginCapture musaStreamBeginCapture
#define cudaStreamEndCapture musaStreamEndCapture
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
typedef mt_bfloat16 nv_bfloat16;

View File

@@ -297,27 +297,8 @@ static int ggml_backend_opencl_n_devices = 0;
struct ProfilingInfo {
std::string op_name;
std::string kernel_name;
cl_kernel kernel;
cl_event evt;
cl_ulong cmd_queued;
cl_ulong cmd_submit;
cl_ulong cmd_start;
cl_ulong cmd_end;
cl_ulong overhead_start;
cl_ulong overhead_end;
// For the times below, see spec for clGetEventProfilingInfo
// The time kernel spent in cmd queue - SUBMIT - QUEUED
cl_ulong cmd_queued_duration_ns;
// The time kernel spent for submission - START - SUBMIT
cl_ulong cmd_submit_duration_ns;
// Kernel execution time in nanoseconds - END - START
cl_ulong cmd_duration_ns;
// The time for the kernel to complete - COMPLETE - END
cl_ulong cmd_complete_duration_ns;
// Total time to finish the kernel - COMPELTE - QUEUED
cl_ulong cmd_total_duration_ns;
// Kernel execution time in nanoseconds.
cl_ulong duration_ns;
// Global and local work sizes.
size_t global_size[3];
size_t local_size[3];
@@ -922,56 +903,12 @@ static void ggml_cl2_free(void) {
return;
}
// Populate profiling info
for (ProfilingInfo & info : g_profiling_info) {
cl_ulong cmd_queued;
cl_ulong cmd_submit;
cl_ulong cmd_start;
cl_ulong cmd_end;
cl_ulong cmd_complete;
CL_CHECK(clWaitForEvents(1, &info.evt));
CL_CHECK(clGetEventProfilingInfo(
info.evt, CL_PROFILING_COMMAND_QUEUED, sizeof(cl_ulong), &cmd_queued, NULL));
CL_CHECK(clGetEventProfilingInfo(
info.evt, CL_PROFILING_COMMAND_SUBMIT, sizeof(cl_ulong), &cmd_submit, NULL));
CL_CHECK(clGetEventProfilingInfo(
info.evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &cmd_start, NULL));
CL_CHECK(clGetEventProfilingInfo(
info.evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &cmd_end, NULL));
CL_CHECK(clGetEventProfilingInfo(
info.evt, CL_PROFILING_COMMAND_COMPLETE, sizeof(cl_ulong), &cmd_complete, NULL));
CL_CHECK(clReleaseEvent(info.evt));
char kernel_name[512];
CL_CHECK(clGetKernelInfo(info.kernel, CL_KERNEL_FUNCTION_NAME,
sizeof(kernel_name), kernel_name, NULL));
info.kernel_name = kernel_name;
info.cmd_queued = cmd_queued;
info.cmd_submit = cmd_submit;
info.cmd_start = cmd_start;
info.cmd_end = cmd_end;
info.cmd_queued_duration_ns = cmd_submit - cmd_queued;
info.cmd_submit_duration_ns = cmd_start - cmd_submit;
info.cmd_duration_ns = cmd_end - cmd_start;
info.cmd_complete_duration_ns = cmd_complete - cmd_end;
info.cmd_total_duration_ns = cmd_complete - cmd_queued;
}
// Dump a csv
float total_kernel_time = 0;
fprintf(fperf, "op name, kernel name, queued duration (ms), submit duration(ms), exec duration (ms), complete duration (ms), total duration (ms), global size, local size, output size\n");
fprintf(fperf, "op name, kernel name, duration (ms), global size, local size, output size\n");
for (const ProfilingInfo & info : g_profiling_info) {
total_kernel_time += info.cmd_duration_ns/1.e6f;
fprintf(fperf, "%s,%s,%f,%f,%f,%f,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
info.op_name.c_str(), info.kernel_name.c_str(),
info.cmd_queued_duration_ns/1.e6f,
info.cmd_submit_duration_ns/1.e6f,
info.cmd_duration_ns/1.e6f,
info.cmd_complete_duration_ns/1.e6f,
info.cmd_total_duration_ns/1.e6f,
total_kernel_time += info.duration_ns/1.e6f;
fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
info.op_name.c_str(), info.kernel_name.c_str(), info.duration_ns/1.e6f,
info.global_size[0], info.global_size[1], info.global_size[2],
info.local_size[0], info.local_size[2], info.local_size[2],
info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]);
@@ -979,27 +916,6 @@ static void ggml_cl2_free(void) {
fclose(fperf);
GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time);
// Dump a simple chrome trace
FILE* ftrace = fopen("cl_trace.json", "w");
if (!ftrace) {
GGML_LOG_ERROR("Failed to open cl_trace.json\n");
return;
}
fprintf(ftrace, "[\n");
for (const ProfilingInfo & info : g_profiling_info) {
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
info.kernel_name.c_str(), info.cmd_queued/1000);
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Host\"},\n",
info.kernel_name.c_str(), info.cmd_submit/1000);
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
info.kernel_name.c_str(), info.cmd_start/1000);
fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %lu, \"pid\": \"\", \"tid\": \"Device\"},\n",
info.kernel_name.c_str(), info.cmd_end/1000);
}
fclose(ftrace);
#endif
}
@@ -2146,14 +2062,25 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso
// Profiling utility
//------------------------------------------------------------------------------
#ifdef GGML_OPENCL_PROFILING
static void populateProfilingInfo(
void populateProfilingInfo(
ProfilingInfo& info, cl_event evt, cl_kernel kernel,
size_t global_size[3], size_t local_size[3],
const ggml_tensor * tensor) {
info.op_name = tensor->name;
info.kernel = kernel;
info.evt = evt;
cl_ulong start;
cl_ulong end;
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clGetEventProfilingInfo(
evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &start, NULL));
CL_CHECK(clGetEventProfilingInfo(
evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &end, NULL));
char kernel_name[512];
CL_CHECK(clGetKernelInfo(kernel, CL_KERNEL_FUNCTION_NAME,
sizeof(kernel_name), kernel_name, NULL));
info.duration_ns = end - start;
info.op_name = tensor->name;
info.kernel_name = kernel_name;
info.local_size[0] = local_size[0];
info.local_size[1] = local_size[1];
info.local_size[2] = local_size[2];

View File

@@ -23,38 +23,6 @@ ggml_add_backend_library(ggml-sycl
../../include/ggml-sycl.h
)
find_package(DNNL)
set(GGML_SYCL_DNNL 0)
if(DNNL_FOUND)
if (DEFINED ENV{ONEAPI_ROOT} AND NOT DEFINED DNNL_GPU_VENDOR)
# Assuming oneDNN packaged with oneapi release is used which
# supports only intel target
set(DNNL_GPU_VENDOR "INTEL")
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
endif()
endif()
# Verify oneDNN was compiled for the same target as llama
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
set(GGML_SYCL_DNNL 1)
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
foreach(CONFIG ${CONFIGS})
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
message(STATUS "Found oneDNN: ${DNNL_LIB}")
endforeach()
else()
message(WARNING
"oneDNN must be compiled for the same target as llama.cpp.
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
Disabling oneDNN support.")
endif()
else()
message(STATUS "oneDNN not found, disabling oneDNN support")
endif()
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
if (GGML_SYCL_F16)
if (GGML_SYCL_TARGET STREQUAL "AMD")
message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.")
@@ -80,6 +48,18 @@ file(GLOB GGML_HEADERS_SYCL "*.hpp")
file(GLOB GGML_SOURCES_SYCL "*.cpp")
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
find_package(DNNL)
message("-- DNNL found:" ${DNNL_FOUND})
if (GGML_SYCL_TARGET STREQUAL "INTEL")
add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
else()
add_compile_definitions(GGML_SYCL_DNNL=0)
endif()
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
endif()
if (WIN32)
find_package(IntelSYCL REQUIRED)

View File

@@ -170,6 +170,7 @@ static size_t g_scratch_offset = 0;
int get_current_device_id();
inline dpct::err0 ggml_sycl_set_device(const int device) try {
int current_device_id;
SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
@@ -241,14 +242,6 @@ struct ggml_sycl_pool_alloc {
}
}
T * realloc(size_t size) {
GGML_ASSERT(pool != nullptr);
if (ptr)
pool->free(ptr, actual_size);
ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
return ptr;
}
// size is in number of elements
T * alloc(size_t size) {
GGML_ASSERT(pool != nullptr);
@@ -378,29 +371,10 @@ struct ggml_backend_sycl_context {
dnnl::stream stream_dnnl() {
return stream_dnnl(device, 0);
}
dnnl::memory get_scratchpad_mem(const dnnl::memory::desc & scratchpad_md,
const dnnl::engine & eng, const queue_ptr q) {
ggml_sycl_pool_alloc<uint8_t> * pool;
auto it = scratchpad_map.find(q);
if (it == scratchpad_map.end()) {
scratchpad_map[q] = std::make_unique<ggml_sycl_pool_alloc<uint8_t>>(this->pool());
pool = scratchpad_map[q].get();
} else {
pool = it->second.get();
}
size_t scratchpad_size = scratchpad_md.get_size();
if (scratchpad_size > pool->actual_size) {
pool->realloc(scratchpad_size);
}
void * mem_ptr = pool->get();
return dnnl::memory(scratchpad_md, eng, mem_ptr);
}
#endif
// pool
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
std::unordered_map<sycl::queue *, std::unique_ptr<ggml_sycl_pool_alloc<uint8_t>>> scratchpad_map;
std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];

View File

@@ -13,6 +13,9 @@
#ifndef GGML_SYCL_GEMM_HPP
#define GGML_SYCL_GEMM_HPP
#include <fstream>
#include <iostream>
#include "ggml-sycl.h"
#if GGML_SYCL_DNNL
@@ -32,34 +35,62 @@ public:
else static_assert(0);
}
static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
auto stream = ctx.stream_dnnl(q);
auto eng = ctx.engine_dnnl(q);
static inline void row_gemm(sycl::queue& q, bool a_trans,
bool b_trans, int m, int n, int k,
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
{
// Get the device associated with the queue
sycl::device dev = q.get_device();
// Get the context associated with the queue
sycl::context ctx = q.get_context();
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
dnnl::memory::dims a_dims = { m, k };
dnnl::memory::dims b_dims = { k, n };
dnnl::memory::dims c_dims = { m, n };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
dnnl::primitive_attr primitive_attr;
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr);
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
auto scratchpad_md = matmul_pd.scratchpad_desc();
auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
// Create the primitive.
auto matmul_prim = dnnl::matmul(matmul_pd);
// Primitive arguments.
std::unordered_map<int, dnnl::memory> matmul_args;
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
matmul_args.insert({ DNNL_ARG_DST, c_mem });
matmul_prim.execute(stream, matmul_args);
}
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
bool b_trans, int m, int n, int k,
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
{
auto const eng = stream.get_engine();
dnnl::memory::dims a_dims = { m, k };
dnnl::memory::dims b_dims = { k, n };
dnnl::memory::dims c_dims = { m, n };
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
// Create the primitive.
auto matmul_prim = dnnl::matmul(matmul_pd);
// Primitive arguments.
std::unordered_map<int, dnnl::memory> matmul_args;
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
matmul_args.insert({ DNNL_ARG_DST, c_mem });
matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
matmul_prim.execute(stream, matmul_args);
}

View File

@@ -2058,9 +2058,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
#else
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
auto dnnl_stream = ctx.stream_dnnl(stream);
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
#endif
@@ -2099,9 +2099,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
dst_dd_i, ldc)));
# endif
#else
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
auto dnnl_stream = ctx.stream_dnnl(stream);
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
#endif
}
GGML_UNUSED(dst);

View File

@@ -149,7 +149,6 @@ class vk_perf_logger;
static void ggml_vk_destroy_buffer(vk_buffer& buf);
static constexpr uint32_t mul_mat_vec_max_cols = 8;
static constexpr uint32_t p021_max_gqa_ratio = 8;
enum vk_device_architecture {
OTHER,
@@ -232,7 +231,6 @@ struct vk_device_struct {
bool uma;
bool prefer_host_memory;
bool float_controls_rte_fp16;
bool subgroup_add;
bool subgroup_size_control;
uint32_t subgroup_min_size;
@@ -279,7 +277,7 @@ struct vk_device_struct {
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
@@ -2267,13 +2265,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
if (device->subgroup_add && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
} else {
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
}
}
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -2289,21 +2281,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
} else {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
}
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
@@ -2487,15 +2471,13 @@ static vk_device ggml_vk_get_device(size_t idx) {
vk::PhysicalDeviceDriverProperties driver_props;
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
vk::PhysicalDeviceVulkan11Properties vk11_props;
vk::PhysicalDeviceVulkan12Properties vk12_props;
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
props2.pNext = &props3;
props3.pNext = &subgroup_props;
subgroup_props.pNext = &driver_props;
driver_props.pNext = &vk11_props;
vk11_props.pNext = &vk12_props;
driver_props.pNext = &vk12_props;
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
@@ -2559,9 +2541,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
}
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -4648,15 +4627,9 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t d_sz = sizeof(float) * d_ne;
// With grouped query attention there are > 1 Q matrices per K, V matrix.
uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02;
if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
gqa_ratio = 1;
}
if (dryrun) {
// Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1);
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
return;
}
@@ -4680,15 +4653,8 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
// compute
const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
uint32_t workgroups_z = (uint32_t)ne12;
// When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
if (gqa_ratio > 1) {
workgroups_z /= gqa_ratio;
}
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z });
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
}
static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -8470,12 +8436,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
}
}
if (ctx->device->need_compiles) {
ggml_vk_load_shaders(ctx->device);
@@ -8496,27 +8458,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
bool first_node_in_batch = true; // true if next node will be first node in a batch
int submit_node_idx = 0; // index to first node in a batch
// Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
// Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
// (and scaled down based on model size, so smaller models submit earlier).
// Also submit at least every 100 nodes, in case there are workloads without as much matmul.
int nodes_per_submit = 100;
// Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution.
// Start with a smaller count to get work submitted right away, and increase it after each submit.
int nodes_per_submit = 20;
int submitted_nodes = 0;
int submit_count = 0;
uint64_t mul_mat_bytes = 0;
uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), total_mat_mul_bytes / 40u);
for (int i = 0; i < cgraph->n_nodes; i++) {
if (first_node_in_batch) {
submit_node_idx = i;
}
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
}
bool submit = (submitted_nodes >= nodes_per_submit) ||
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
(i == last_node);
bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node);
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
@@ -8533,9 +8485,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
if (submit) {
first_node_in_batch = true;
submitted_nodes = 0;
mul_mat_bytes = 0;
if (submit_count < 3) {
mul_mat_bytes_per_submit *= 2;
switch (submit_count) {
case 0:
nodes_per_submit = 50;
break;
default:
nodes_per_submit = 100;
break;
}
submit_count++;
}

View File

@@ -1,10 +1,5 @@
#version 450
#if RTE16
#extension GL_EXT_spirv_intrinsics : enable
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
#endif // RTE16
#include "types.comp"
#include "generic_unary_head.comp"

View File

@@ -82,8 +82,8 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy;
const i8vec2 v0 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2]);
const i8vec2 v1 = unpack8(data_a_packed16[a_offset + ib].qs[iqs/2 + 1]);
return vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif

View File

@@ -311,8 +311,8 @@ float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint ib32 = (idx & 0xE0) >> 5;
const uint ib8 = (idx & 0xF8) >> 3;
const uint ib32 = idx / 32;
const uint ib8 = idx / 8;
const uint qh = bl.block.qh[ib32];
const uint qs = bl.block.qs[ib8];
@@ -330,20 +330,14 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1
block_iq1_m block;
};
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 {
block_iq1_m_packed64 block;
};
float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);
const u16vec4 scales = u16vec4(bl.block.scales[0], bl.block.scales[1], bl.block.scales[2], bl.block.scales[3]) >> 12;
const float16_t d = uint16BitsToHalf(scales.x | (scales.y << 4) | (scales.z << 8) | (scales.w << 12));
const uint idx = coordInBlock[1];
uvec2 scales = unpack32(bl64.block.scales);
const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));
const uint ib8 = (idx & 0xF8) >> 3;
const uint ib16 = (idx & 0xF0) >> 4;
const uint ib8 = idx / 8;
const uint ib16 = idx / 16;
const int i8 = int(idx % 8);
const uint sc = bl.block.scales[ib8 / 8];
const uint qs = bl.block.qs[ib8];

View File

@@ -105,16 +105,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
int unroll_count = 4;
uint unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
// If the K dimension is odd, we need lastiter==true on the last iteration
// so OOB is computed correctly. Skip some unrolling to make that happen.
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif
uint i = 0;
while (i < unrolled_iters) {
// Manually partially unroll the loop
@@ -123,18 +113,8 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
i++;
}
}
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {

View File

@@ -19,8 +19,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const float db = d * (0.5 + scale) * 0.25;
const uint qh = data_a[ibi].qh[ib32];
const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147
const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy;
const u8vec2 qs16 = unpack8(data_a_packed16[ibi].qs[itid]);
const u8vec2 sign16 = unpack8(data_a_packed16[ibi].qs[QUANT_K / 16 + itid]);
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint8_t sign = sign16[l];
const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300);

View File

@@ -21,7 +21,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32,
sum[j] = 0.0;
}
[[unroll]] for (uint l = 0; l < 4; ++l) {
const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147
const u8vec2 qs = unpack8(data_a_packed16[ibi].qs[4 * ib32 + l]);
const uint sign = data_a[ibi].signs[4 * ib32 + l];
const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)]));
const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)]));

View File

@@ -12,9 +12,6 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
layout (push_constant) uniform parameter
{
uint ncols_x;
@@ -40,65 +37,24 @@ void main() {
const uint idst = channel*nrows_dst + row_dst;
FLOAT_TYPE temp = 0.0f;
tmp[tid] = 0.0f;
// Detect alignment for vector loads
bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0;
for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
const uint col_x = col_x0 + tid;
for (uint col_x0 = 0; col_x0 < p.ncols_x;) {
// Unroll 2x and do vec4 loads if aligned
const uint unroll_count = 2;
if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) {
[[unroll]] for (uint i = 0; i < unroll_count; ++i) {
const uint col_x = col_x0 + 4*tid;
const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel*nrows_y + row_y;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
temp += dot(av4, bv4);
col_x0 += 4*BLOCK_SIZE;
}
// do vec4 loads if aligned
} else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
const uint col_x = col_x0 + 4*tid;
const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel*nrows_y + row_y;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
temp += dot(av4, bv4);
col_x0 += 4*BLOCK_SIZE;
} else {
const uint col_x = col_x0 + tid;
if (col_x >= p.ncols_x) {
break;
}
const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel*nrows_y + row_y;
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp);
col_x0 += BLOCK_SIZE;
if (col_x >= p.ncols_x) {
break;
}
}
tmp[tid] = temp;
const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel*nrows_y + row_y;
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
}
// sum up partial sums and write back result
barrier();

View File

@@ -2,25 +2,16 @@
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif
#define BLOCK_SIZE 32
#define FLOAT_TYPE float
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
layout(constant_id = 0) const int BLOCK_SIZE = 32;
// gqa_ratio is in the range [1,8]
layout(constant_id = 1) const uint gqa_ratio = 1;
layout (push_constant) uniform parameter
{
uint ncols_x;
@@ -31,124 +22,52 @@ layout (push_constant) uniform parameter
uint d_offset;
} p;
#if !USE_SUBGROUP_ADD
shared FLOAT_TYPE tmp[8][BLOCK_SIZE];
#endif
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint row_x = gl_GlobalInvocationID.y;
uint channel, channel_x;
// When gqa_ratio > 1, each invocation does multiple rows.
// The row in the A matrix is starting from channel / gqa_ratio and the
// rows in the B matrix are [channel, channel+gqa_ratio).
// When gpa_ratio is 1, each invocation does one row.
if (gqa_ratio > 1) {
channel_x = gl_GlobalInvocationID.z;
channel = channel_x * gqa_ratio;
} else {
channel = gl_GlobalInvocationID.z;
channel_x = channel / (p.nchannels_y / p.nchannels_x);;
}
const uint channel = gl_GlobalInvocationID.z;
const uint channel_x = channel / (p.nchannels_y / p.nchannels_x);
const uint nrows_y = p.ncols_x;
const uint nrows_dst = p.nrows_x;
const uint row_dst = row_x;
FLOAT_TYPE temp[8];
[[unroll]] for (uint i = 0; i < 8; ++i) {
temp[i] = FLOAT_TYPE(0.0f);
}
// Detect alignment for vector loads
bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0;
tmp[tid] = FLOAT_TYPE(0.0f);
for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
const uint col_x = col_x0 + tid;
// Use vec4 loads if aligned
if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
uint col_x = col_x0 + 4*tid;
const uint row_y = col_x;
// x is transposed and permuted
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
// y is not transposed but permuted
const uint iy = (channel + c)*nrows_y + row_y;
vec4 bv4 = data_b_v4[iy / 4];
temp[c] += dot(av4, bv4);
}
col_x0 += 3*BLOCK_SIZE;
} else {
const uint col_x = col_x0 + tid;
if (col_x >= p.ncols_x) {
break;
}
// x is transposed and permuted
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
const uint row_y = col_x;
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
// y is not transposed but permuted
const uint iy = (channel + c)*nrows_y + row_y;
temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]);
}
if (col_x >= p.ncols_x) {
break;
}
// x is transposed and permuted
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
const uint row_y = col_x;
// y is not transposed but permuted
const uint iy = channel*nrows_y + row_y;
tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
}
#if USE_SUBGROUP_ADD
// reduce vec4 at a time
vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]);
t = subgroupAdd(t);
temp[0] = t[0];
temp[1] = t[1];
temp[2] = t[2];
temp[3] = t[3];
if (gqa_ratio > 4) {
t = vec4(temp[4], temp[5], temp[6], temp[7]);
t = subgroupAdd(t);
temp[4] = t[0];
temp[5] = t[1];
temp[6] = t[2];
temp[7] = t[3];
}
#else
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
tmp[c][tid] = temp[c];
}
// dst is not transposed and not permuted
const uint idst = channel*nrows_dst + row_dst;
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
temp[c] += tmp[c][tid + s];
tmp[c][tid] = temp[c];
}
tmp[tid] += tmp[tid + s];
}
barrier();
}
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
temp[c] = tmp[c][tid];
}
#endif
if (tid == 0) {
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
// dst is not transposed and not permuted
const uint idst = (channel + c)*nrows_dst + row_dst;
dst[idst] = temp[c];
}
dst[idst] = tmp[0];
}
}

View File

@@ -336,8 +336,8 @@ void main() {
const uint iqs = idx & 0x07;
const float d = float(data_a_packed16[ib].d);
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
const i8vec2 v0 = unpack8(data_a_packed16[ib].qs[2*iqs]);
const i8vec2 v1 = unpack8(data_a_packed16[ib].qs[2*iqs + 1]);
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
@@ -544,7 +544,7 @@ void main() {
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
@@ -564,7 +564,7 @@ void main() {
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
@@ -586,7 +586,7 @@ void main() {
const float db = d * 0.25 * (0.5 + scale);
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
@@ -611,7 +611,7 @@ void main() {
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
@@ -631,7 +631,7 @@ void main() {
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);

View File

@@ -2,7 +2,6 @@
#if !defined(GGML_TYPES_COMP)
#define GGML_TYPES_COMP
#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
@@ -313,12 +312,6 @@ struct block_iq1_m {
uint16_t scales[QUANT_K_IQ1_M/64];
};
struct block_iq1_m_packed64 {
uint64_t qs[QUANT_K_IQ1_M/8/8];
uint64_t qh[QUANT_K_IQ1_M/16/8];
uint64_t scales;
};
#if defined(DATA_A_IQ1_S)
#define QUANT_K QUANT_K_IQ1_S
#define QUANT_R QUANT_R_IQ1_S

View File

@@ -426,9 +426,8 @@ void process_shaders() {
}
}
string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
// Norms
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
@@ -446,7 +445,6 @@ void process_shaders() {
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}

View File

@@ -1113,7 +1113,6 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
],
MODEL_ARCH.GEMMA3: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,

View File

@@ -154,12 +154,7 @@ class SpecialVocab:
return True
with open(tokenizer_config_file, encoding = 'utf-8') as f:
tokenizer_config = json.load(f)
chat_template_alt = None
chat_template_file = path / 'chat_template.json'
if chat_template_file.is_file():
with open(chat_template_file, encoding = 'utf-8') as f:
chat_template_alt = json.load(f).get('chat_template')
chat_template = tokenizer_config.get('chat_template', chat_template_alt)
chat_template = tokenizer_config.get('chat_template')
if chat_template is None or isinstance(chat_template, (str, list)):
self.chat_template = chat_template
else:

View File

@@ -24,7 +24,34 @@ struct llama_adapter_lora_deleter {
void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); }
};
struct llama_batch_ext_deleter {
void operator()(llama_batch_ext * batch) { llama_batch_ext_free(batch); }
};
typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;
typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr;
typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr;
typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr;
struct llama_batch_ext_ptr : std::unique_ptr<llama_batch_ext, llama_batch_ext_deleter> {
llama_batch_ext_ptr() : std::unique_ptr<llama_batch_ext, llama_batch_ext_deleter>() {}
llama_batch_ext_ptr(llama_batch_ext * batch) : std::unique_ptr<llama_batch_ext, llama_batch_ext_deleter>(batch) {}
// convenience function to create a batch from text tokens, without worrying about manually freeing it
static llama_batch_ext_ptr init_from_text(llama_token * tokens,
int32_t n_tokens,
int32_t pos0,
int32_t seq_id,
bool output_last) {
return llama_batch_ext_ptr(llama_batch_ext_init_from_text(tokens, n_tokens, pos0, seq_id, output_last));
}
// convenience function to create a batch from text embeddings, without worrying about manually freeing it
static llama_batch_ext_ptr init_from_embd(float * embd,
size_t n_tokens,
size_t n_embd,
int32_t pos0,
int32_t seq_id) {
return llama_batch_ext_ptr(llama_batch_ext_init_from_embd(embd, n_tokens, n_embd, pos0, seq_id));
}
};

View File

@@ -234,6 +234,9 @@ extern "C" {
typedef bool (*llama_progress_callback)(float progress, void * user_data);
// Input data for llama_decode
//
// WARN: This struct is DEPRECATED and will be removed in the future, use llama_batch_ext instead
//
// A llama_batch object can contain input about one or many sequences
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
//
@@ -257,6 +260,10 @@ extern "C" {
int8_t * logits; // TODO: rename this to "output"
} llama_batch;
// Input data for llama_decode / llama_encode
// It can contain text tokens and embeddings for one or many sequences
struct llama_batch_ext;
enum llama_model_kv_override_type {
LLAMA_KV_OVERRIDE_TYPE_INT,
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
@@ -891,9 +898,9 @@ extern "C" {
//
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
//
LLAMA_API struct llama_batch llama_batch_get_one(
DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens);
int32_t n_tokens), "use llama_batch_ext_init_from_text instead");
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
// Each token can be assigned up to n_seq_max sequence ids
@@ -902,13 +909,98 @@ extern "C" {
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
// The rest of the llama_batch members are allocated with size n_tokens
// All members are left uninitialized
LLAMA_API struct llama_batch llama_batch_init(
int32_t n_tokens,
int32_t embd,
int32_t n_seq_max);
DEPRECATED(LLAMA_API struct llama_batch llama_batch_init(
int32_t n_tokens,
int32_t embd,
int32_t n_seq_max), "use llama_batch_ext_init instead");
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);
DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch),
"use llama_batch_ext API instead");
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
// Each token can be assigned up to n_seq_max sequence ids
// The batch has to be freed with llama_batch_ext_free()
LLAMA_API struct llama_batch_ext * llama_batch_ext_init(
int32_t n_tokens,
int32_t n_seq_max);
// Same with llama_batch_init, but initializes the batch with the provided text tokens
// First token will be at position pos0
// The sequence ID will be fixed to seq_id
// If output_last is true, the last token will have output set
// The batch has to be freed with llama_batch_ext_free()
LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_text(
llama_token * tokens,
int32_t n_tokens,
int32_t pos0,
int32_t seq_id,
bool output_last);
// Same with llama_batch_init, but initializes the batch with the provided raw embeddings
// Size of embd should be n_tokens * n_embd
// n_embd is the number of embeddings per token, can be obtained from llama_model_n_embd()
// First token will be at position pos0
// The sequence ID will be fixed to seq_id
// The batch has to be freed with llama_batch_ext_free()
LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd(
float * embd,
size_t n_tokens,
size_t n_embd,
int32_t pos0,
int32_t seq_id);
// Set arbitrary token to the embeddings batch
// Note: this is only to be used in conjunction with llama_batch_ext_init_from_embd()
// n_pos must match the n_tokens of the batch
// Returns -1 if n_pos does not match the n_tokens of the batch
LLAMA_API int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos);
// Get the number of tokens in the batch
LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch);
// Add text tokens to the batch
// Return values:
// -1 : not enough space in the batch
// -2 : embd is already set, cannot add text tokens
// otherwise, returns the output ID
LLAMA_API int32_t llama_batch_ext_add_text(
struct llama_batch_ext * batch,
llama_token token,
llama_pos pos,
const llama_seq_id * seq_ids,
size_t n_seq_ids,
bool output);
// Set output (logits/embeddings) for the token in the ith sequence
// If pos == -1, output will be set for the all tokens
// Return values:
// -1 : the token is not in the batch
// otherwise, returns the output ID
LLAMA_API int32_t llama_batch_ext_set_output(
struct llama_batch_ext * batch,
llama_pos pos,
llama_seq_id seq_id);
// Set output (logits/embeddings) for the last added token
// Return values:
// -1 : the batch is empty
// otherwise, returns the output ID
LLAMA_API int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch);
// Get a "view" from a number of tokens offset
// Return returned batch must be freed with llama_batch_ext_free()
LLAMA_API struct llama_batch_ext * llama_batch_ext_get_view(
struct llama_batch_ext * batch,
int32_t offset,
int32_t n_tokens);
// Remove everything from the batch
LLAMA_API void llama_batch_ext_clear(struct llama_batch_ext * batch);
// Frees a batch of tokens allocated with llama_batch_ext_init()
// If this is a view, the original batch is not freed
LLAMA_API void llama_batch_ext_free(struct llama_batch_ext * batch);
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
// Stores the encoder output internally for later use by the decoder cross-attention layers.
@@ -918,13 +1010,21 @@ extern "C" {
struct llama_context * ctx,
struct llama_batch batch);
LLAMA_API int32_t llama_encode_ext(
struct llama_context * ctx,
struct llama_batch_ext * batch);
// Positive return values does not mean a fatal error, but rather a warning.
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// < 0 - error. the KV cache state is restored to the state before this call
LLAMA_API int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch);
struct llama_batch batch);
LLAMA_API int32_t llama_decode_ext(
struct llama_context * ctx,
struct llama_batch_ext * batch);
// Set the number of threads used for decoding
// n_threads is the number of threads used for generation (single token)

View File

@@ -778,7 +778,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },

View File

@@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
return ubatch;
}
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
void llama_sbatch::from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split, bool logits_all) {
GGML_ASSERT(batch.n_tokens >= 0);
this->batch = &batch;
this->n_embd = n_embd;
@@ -273,46 +273,60 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
);
}
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
batch = in_batch;
GGML_ASSERT(batch.n_tokens > 0);
if (!batch.pos) {
pos.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
llama_batch_allocr::llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0) {
batch = new llama_batch_ext{
/*n_tokens =*/ in_batch.n_tokens,
/*max_tokens =*/ in_batch.n_tokens,
/*is_view =*/ false,
/*tokens =*/ in_batch.token,
/*embd =*/ in_batch.embd,
/*pos =*/ in_batch.pos,
/*n_seq_id =*/ in_batch.n_seq_id,
/*seq_id =*/ in_batch.seq_id,
/*logits =*/ in_batch.logits,
};
GGML_ASSERT(batch->n_tokens > 0);
if (!in_batch.pos) {
pos.resize(batch->n_tokens);
for (int32_t i = 0; i < batch->n_tokens; i++) {
pos[i] = i + p0;
}
batch.pos = pos.data();
batch->pos = pos.data();
}
if (!batch.n_seq_id) {
n_seq_id.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
if (!batch->n_seq_id) {
n_seq_id.resize(batch->n_tokens);
for (int32_t i = 0; i < batch->n_tokens; i++) {
n_seq_id[i] = seq_id_0.size();
}
batch.n_seq_id = n_seq_id.data();
batch->n_seq_id = n_seq_id.data();
}
if (!batch.seq_id) {
seq_id.resize(batch.n_tokens + 1);
seq_id[batch.n_tokens] = NULL;
for (int32_t i = 0; i < batch.n_tokens; i++) {
if (!batch->seq_id) {
seq_id.resize(batch->n_tokens + 1);
seq_id[batch->n_tokens] = NULL;
for (int32_t i = 0; i < batch->n_tokens; i++) {
seq_id[i] = seq_id_0.data();
}
batch.seq_id = seq_id.data();
batch->seq_id = seq_id.data();
}
if (!batch.logits) {
logits.resize(batch.n_tokens);
if (!batch->logits) {
logits.resize(batch->n_tokens);
logits[logits.size() - 1] = true;
batch.logits = logits.data();
batch->logits = logits.data();
}
}
llama_batch_allocr::~llama_batch_allocr() {
delete batch;
}
//
// interface implementation
//
struct llama_batch llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens) {
return {
llama_token * tokens,
int32_t n_tokens) {
return llama_batch{
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens,
/*embd =*/ nullptr,
@@ -323,6 +337,183 @@ struct llama_batch llama_batch_get_one(
};
}
struct llama_batch_ext * llama_batch_ext_init_from_text(
llama_token * tokens,
int32_t n_tokens,
int32_t pos0,
int32_t seq_id,
bool output_last) {
llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1);
for (int32_t i = 0; i < n_tokens; i++) {
llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false);
}
if (output_last) {
llama_batch_ext_set_output_last(batch);
}
return batch;
}
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max) {
llama_batch_ext * batch = new llama_batch_ext{
/*n_tokens =*/ 0,
/*max_tokens =*/ n_tokens_alloc,
/*is_view =*/ false,
/*tokens =*/ nullptr,
/*embd =*/ nullptr,
/*pos =*/ nullptr,
/*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr,
/*logits =*/ nullptr,
};
if (n_embd) {
batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * n_embd);
} else {
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
}
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
for (int i = 0; i < n_tokens_alloc; ++i) {
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
}
batch->seq_id[n_tokens_alloc] = nullptr;
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
return batch;
}
struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_seq_max) {
return llama_batch_ext_init_impl(n_tokens_alloc, 0, n_seq_max);
}
struct llama_batch_ext * llama_batch_ext_init_from_embd(
float * embd,
size_t n_tokens,
size_t n_embd,
int32_t pos0,
int32_t seq_id) {
struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1);
memcpy(batch->embd, embd, n_tokens * n_embd * sizeof(float));
for (size_t i = 0; i < n_tokens; i++) {
batch->pos [i] = pos0 + i;
batch->n_seq_id[i] = 1;
batch->seq_id [i][0] = seq_id;
}
return batch;
}
int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos) {
if ((size_t) batch->n_tokens != n_pos) {
return -1;
}
memcpy(batch->pos, pos, n_pos * sizeof(llama_pos));
return 0;
}
int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) {
return batch->n_tokens;
}
int32_t llama_batch_ext_add_text(
struct llama_batch_ext * batch,
llama_token token,
llama_pos pos,
const llama_seq_id * seq_ids,
size_t n_seq_ids,
bool output) {
if (batch->n_tokens + 1 > batch->max_tokens) {
return -1; // llama_batch size exceeded
}
if (batch->embd) {
return -2; // embd is already set, cannot add text tokens
}
const int32_t output_id = batch->n_tokens;
batch->token [output_id] = token;
batch->pos [output_id] = pos;
batch->n_seq_id[output_id] = n_seq_ids;
for (size_t j = 0; j < n_seq_ids; j++) {
batch->seq_id[batch->n_tokens][j] = seq_ids[j];
}
batch->logits [output_id] = output;
batch->n_tokens++;
return output_id;
}
int32_t llama_batch_ext_set_output(
struct llama_batch_ext * batch,
llama_pos pos,
llama_seq_id seq_id) {
for (int32_t i = 0; i < batch->n_tokens; i++) {
// find the token having seq_id
for (int32_t j = 0; j < batch->n_seq_id[i]; j++) {
if (batch->seq_id[i][j] == seq_id) {
// found the sequence
if (pos == -1 || pos == batch->pos[i]) {
batch->logits[i] = true;
return i;
}
}
}
}
return -1; // not found
}
int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch) {
if (batch->n_tokens == 0) {
return -1;
}
const int32_t output_id = batch->n_tokens - 1;
batch->logits[output_id] = true;
return output_id;
}
void llama_batch_ext_clear(struct llama_batch_ext * batch) {
batch->n_tokens = 0;
}
struct llama_batch_ext * llama_batch_ext_get_view(
struct llama_batch_ext * batch,
int32_t offset,
int32_t n_tokens) {
if (batch->embd) {
return nullptr; // not yet supported
}
llama_batch_ext * batch_view = new llama_batch_ext{
/*n_tokens =*/ n_tokens,
/*max_tokens =*/ n_tokens,
/*is_view =*/ true,
/*tokens =*/ batch->token + offset,
/*embd =*/ nullptr,
/*pos =*/ batch->pos + offset,
/*n_seq_id =*/ batch->n_seq_id + offset,
/*seq_id =*/ batch->seq_id + offset,
/*logits =*/ batch->logits + offset,
};
return batch_view;
}
void llama_batch_ext_free(struct llama_batch_ext * batch) {
// do not free the members if it's a view
if (!batch->is_view) {
if (batch->token) free(batch->token);
if (batch->embd) free(batch->embd);
if (batch->pos) free(batch->pos);
if (batch->n_seq_id) free(batch->n_seq_id);
if (batch->seq_id) {
for (int i = 0; batch->seq_id[i] != nullptr; ++i) {
free(batch->seq_id[i]);
}
free(batch->seq_id);
}
if (batch->logits) free(batch->logits);
}
delete batch;
}
// deprecated
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
llama_batch batch = {
/*n_tokens =*/ 0,
@@ -353,6 +544,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
return batch;
}
// deprecated
void llama_batch_free(struct llama_batch batch) {
if (batch.token) free(batch.token);
if (batch.embd) free(batch.embd);

View File

@@ -5,6 +5,32 @@
#include <array>
#include <vector>
// Input data for llama_decode / llama_encode
// A llama_batch_ext object can contain input about one or many sequences
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
//
// - token : the token ids of the input (used when embd is NULL)
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
// - pos : the positions of the respective token in the sequence
// (if set to NULL, the token position will be tracked automatically by llama_decode)
// - seq_id : the sequence to which the respective token belongs
// (if set to NULL, the sequence ID will be assumed to be 0)
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
// (if set to NULL, only the logits for last token will be returned)
//
struct llama_batch_ext {
int32_t n_tokens;
int32_t max_tokens;
bool is_view;
llama_token * token;
float * embd;
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
};
// very similar to llama_batch,
// but has more metadata about sequences
struct llama_ubatch {
@@ -47,7 +73,7 @@ struct llama_sbatch {
std::vector<int64_t> out_ids;
std::vector<llama_sbatch_seq> seq;
const llama_batch * batch = nullptr;
const llama_batch_ext * batch = nullptr;
// buffers for the ubatch
std::vector<llama_token> ubatch_token;
@@ -70,12 +96,12 @@ struct llama_sbatch {
// sequence-wise split
llama_ubatch split_seq(size_t n_ubatch);
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
void from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
};
// temporary allocate memory for the input batch if needed
struct llama_batch_allocr {
struct llama_batch batch;
struct llama_batch_ext * batch;
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
std::vector<llama_pos> pos;
@@ -84,5 +110,7 @@ struct llama_batch_allocr {
std::vector<int8_t> logits;
// optionally fulfill the batch returned by llama_batch_get_one
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0);
~llama_batch_allocr();
};

View File

@@ -4,6 +4,7 @@
#include "llama-io.h"
#include "llama-mmap.h"
#include "llama-model.h"
#include "llama-batch.h"
#include "llama-kv-cache.h"
#include <cassert>
@@ -1000,16 +1001,26 @@ bool llama_context::apply_adapter_cvec(
}
int llama_context::encode(llama_batch & inp_batch) {
// temporary allocate memory and convert llama_batch to llama_batch_ext
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
return encode(*batch_allocr.batch);
}
int llama_context::decode(llama_batch & inp_batch) {
// temporary allocate memory and convert llama_batch to llama_batch_ext
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
return decode(*batch_allocr.batch);
}
int llama_context::encode(llama_batch_ext & inp_batch) {
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
llama_batch_ext & batch = inp_batch;
const int32_t n_tokens = batch.n_tokens;
const auto & hparams = model.hparams;
@@ -1143,8 +1154,6 @@ int llama_context::encode(llama_batch & inp_batch) {
if (model.arch == LLM_ARCH_T5 && t_embd) {
//cross.t_embd = t_embd;
synchronize();
cross.n_embd = t_embd->ne[0];
cross.n_enc = t_embd->ne[1];
cross.v_embd.resize(cross.n_embd*cross.n_enc);
@@ -1153,7 +1162,6 @@ int llama_context::encode(llama_batch & inp_batch) {
// remember the sequence ids used during the encoding - needed for cross attention later
cross.seq_ids_enc.resize(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) {
cross.seq_ids_enc[i].clear();
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
llama_seq_id seq_id = ubatch.seq_id[i][s];
cross.seq_ids_enc[i].insert(seq_id);
@@ -1164,17 +1172,13 @@ int llama_context::encode(llama_batch & inp_batch) {
return 0;
}
int llama_context::decode(llama_batch & inp_batch) {
int llama_context::decode(llama_batch_ext & inp_batch) {
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
llama_batch_ext & batch = inp_batch;
const auto & vocab = model.vocab;
const auto & hparams = model.hparams;
@@ -2750,26 +2754,30 @@ size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, lla
///
// deprecated
int32_t llama_encode(
llama_context * ctx,
llama_batch batch) {
const int ret = ctx->encode(batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
}
return ret;
struct llama_context * ctx,
struct llama_batch inp_batch) {
return ctx->encode(inp_batch);
}
// deprecated
int32_t llama_decode(
llama_context * ctx,
llama_batch batch) {
const int ret = ctx->decode(batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}
struct llama_context * ctx,
struct llama_batch inp_batch) {
return ctx->decode(inp_batch);
}
return ret;
int32_t llama_encode_ext(
struct llama_context * ctx,
struct llama_batch_ext * inp_batch) {
return ctx->encode(*inp_batch);
}
int32_t llama_decode_ext(
struct llama_context * ctx,
struct llama_batch_ext * inp_batch) {
return ctx->decode(*inp_batch);
}
//

View File

@@ -82,9 +82,13 @@ struct llama_context {
int32_t il_start,
int32_t il_end);
// deprecated
int encode(llama_batch & inp_batch);
int decode(llama_batch & inp_batch);
int encode(llama_batch_ext & inp_batch);
int decode(llama_batch_ext & inp_batch);
//
// state save/load
//

View File

@@ -1378,7 +1378,7 @@ ggml_tensor * llm_graph_context::build_attn(
// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);
ggml_tensor * v_cache_view = nullptr;

View File

@@ -487,9 +487,9 @@ struct llm_graph_context {
ggml_tensor * build_attn_mha(
ggml_cgraph * gf,
ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
bool v_trans,
@@ -502,9 +502,9 @@ struct llm_graph_context {
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) const;
@@ -516,9 +516,9 @@ struct llm_graph_context {
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) const;
@@ -530,9 +530,9 @@ struct llm_graph_context {
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
float kq_scale,
int il) const;

View File

@@ -476,7 +476,7 @@ struct llama_mlock::impl {
char* errmsg = std::strerror(errno);
bool suggest = (errno == ENOMEM);
#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX)
#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV)
// visionOS/tvOS dont't support RLIMIT_MEMLOCK
// Skip resource limit checks on visionOS/tvOS
suggest = false;

File diff suppressed because it is too large Load Diff

View File

@@ -259,10 +259,6 @@ static std::string var_to_str(ggml_type type) {
return ggml_type_name(type);
}
static std::string var_to_str(ggml_prec prec) {
return prec == GGML_PREC_F32 ? "f32" : "def";
}
static std::string var_to_str(ggml_op_pool pool) {
switch (pool) {
case GGML_OP_POOL_AVG: return "avg";
@@ -1463,13 +1459,11 @@ struct test_cpy : public test_case {
const ggml_type type_src;
const ggml_type type_dst;
const std::array<int64_t, 4> ne;
const std::array<int64_t, 4> permute_src;
const std::array<int64_t, 4> permute_dst;
const std::array<int64_t, 4> permute;
bool _src_use_permute;
bool _dst_use_permute;
std::string vars() override {
return VARS_TO_STR5(type_src, type_dst, ne, permute_src, permute_dst);
return VARS_TO_STR4(type_src, type_dst, ne, permute);
}
double max_nmse_err() override {
@@ -1482,11 +1476,9 @@ struct test_cpy : public test_case {
test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 1},
std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0})
: type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
_src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0) {}
std::array<int64_t, 4> permute = {0, 0, 0, 0})
: type_src(type_src), type_dst(type_dst), ne(ne), permute(permute),
_src_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
@@ -1494,18 +1486,13 @@ struct test_cpy : public test_case {
ggml_set_name(src, "src");
if (_src_use_permute) {
src = ggml_permute(ctx, src, permute_src[0], permute_src[1], permute_src[2], permute_src[3]);
src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]);
ggml_set_name(src, "src_permuted");
}
ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
ggml_set_name(dst, "dst");
if (_dst_use_permute) {
dst = ggml_permute(ctx, dst, permute_dst[0], permute_dst[1], permute_dst[2], permute_dst[3]);
ggml_set_name(dst, "dst_permuted");
}
ggml_tensor * out = ggml_cpy(ctx, src, dst);
ggml_set_name(out, "out");
@@ -1973,10 +1960,9 @@ struct test_mul_mat : public test_case {
const std::array<int64_t, 2> bs; // dims 3 and 4
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
const std::array<int64_t, 4> per; // permutation of dimensions
const bool v; // whether a is a non-contiguous view
std::string vars() override {
return VARS_TO_STR9(type_a, type_b, m, n, k, bs, nr, per, v);
return VARS_TO_STR8(type_a, type_b, m, n, k, bs, nr, per);
}
double max_nmse_err() override {
@@ -1996,9 +1982,8 @@ struct test_mul_mat : public test_case {
int64_t m = 32, int64_t n = 32, int64_t k = 32,
std::array<int64_t, 2> bs = {10, 10},
std::array<int64_t, 2> nr = {2, 2},
std::array<int64_t, 4> per = {0, 1, 2, 3},
bool v = false)
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), v(v) {}
std::array<int64_t, 4> per = {0, 1, 2, 3})
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
@@ -2008,7 +1993,6 @@ struct test_mul_mat : public test_case {
const int npermuted = (per[0] != 0) + (per[1] != 1) + (per[2] != 2) + (per[3] != 3);
if (npermuted > 0) {
GGML_ASSERT(npermuted == 2);
GGML_ASSERT(!v); // not handled
GGML_ASSERT(!ggml_is_quantized(type_a) || per[0] == 0);
GGML_ASSERT(!ggml_is_quantized(type_b) || per[0] == 0);
@@ -2032,13 +2016,7 @@ struct test_mul_mat : public test_case {
ggml_set_name(a, "a_permuted");
ggml_set_name(b, "b_permuted");
} else {
if (v) {
a = ggml_new_tensor_4d(ctx, type_a, k*2, m, bs[0], bs[1]);
a = ggml_view_4d(ctx, a, k, m, bs[0], bs[1], a->nb[1], a->nb[2], a->nb[3], 0);
} else {
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
}
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
if (!ggml_is_quantized(type_a)) {
if (bs[1] == 1 && nr[1] == 1) {
@@ -3228,12 +3206,11 @@ struct test_flash_attn_ext : public test_case {
const float max_bias; // ALiBi
const float logit_softcap; // Gemma 2
const ggml_prec prec;
const ggml_type type_KV;
std::array<int32_t, 4> permute;
std::string vars() override {
return VARS_TO_STR11(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
return VARS_TO_STR10(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
}
double max_nmse_err() override {
@@ -3248,9 +3225,9 @@ struct test_flash_attn_ext : public test_case {
}
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
: hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
std::array<int32_t, 4> permute = {0, 1, 2, 3})
: hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
@@ -3284,7 +3261,6 @@ struct test_flash_attn_ext : public test_case {
}
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
ggml_flash_attn_ext_set_prec(out, prec);
ggml_set_name(out, "out");
return out;
@@ -4013,25 +3989,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
}
// same-type copy
for (ggml_type type : all_types) {
const auto nk = ggml_blck_size(type);
for (int k = 1; k < 4; ++k) {
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}));
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 3, 1, 2}, {0, 2, 1, 3}));
}
}
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (ggml_type type_dst : all_types) {
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
}
}
for (ggml_type type_src : all_types) {
for (ggml_type type_dst : {GGML_TYPE_F32}) {
for (ggml_type type_dst : {GGML_TYPE_F32}) {
for (ggml_type type_src : all_types) {
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
}
@@ -4204,19 +4169,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 64, { 8, 1}, {4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67, {1, 1}, {4, 1}, {0, 2, 1, 3}));
for (auto bs : {1,2,4,8}) {
for (auto nr : {1,4}) {
for (uint32_t m = 0; m < 2; ++m) {
for (uint32_t k = 0; k < 2; ++k) {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true));
}
}
}
}
// sycl backend will limit task global_range < MAX_INT
// test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion)
@@ -4424,16 +4376,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (int kv : { 512, 1024, }) {
if (nr != 1 && kv != 512) continue;
for (int nb : { 1, 3, 32, 35, }) {
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
if (hs != 128 && prec == GGML_PREC_DEFAULT) continue;
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(
hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
// run fewer test cases permuted
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
test_cases.emplace_back(new test_flash_attn_ext(
hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
}
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV));
// run fewer test cases permuted
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
}
}
}
@@ -4486,9 +4433,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32000, 512, 1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, true));
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {