mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-26 14:23:22 +02:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c830f99cfa | ||
|
|
aa6f918c1c | ||
|
|
8c2c0108dd | ||
|
|
3ea5360c00 | ||
|
|
39fb81f875 | ||
|
|
5eb0ea32f0 | ||
|
|
b68a83e641 | ||
|
|
d8aeb65cee | ||
|
|
9051663d5d | ||
|
|
72b44c0d21 | ||
|
|
bc160d3582 | ||
|
|
2b6dfe824d |
@@ -1760,3 +1760,65 @@ float lr_opt::get_lr(float epoch) const {
|
||||
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
|
||||
return r;
|
||||
}
|
||||
|
||||
bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos) {
|
||||
llama_batch batch = llama_batch_get_one(&last_token, 1);
|
||||
batch.pos = &pos;
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s: failed to replay last token\n", __func__);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool common_prompt_batch_decode(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & tokens,
|
||||
int & n_past,
|
||||
int n_batch,
|
||||
std::string_view state_path,
|
||||
bool save_state) {
|
||||
const int n_eval = tokens.size();
|
||||
if (n_eval == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (save_state && n_eval > 1) {
|
||||
const int n_tokens_before_last = n_eval - 1;
|
||||
|
||||
GGML_ASSERT(n_eval <= n_batch);
|
||||
|
||||
// Decode all but the last token so we can save the memory state before decoding the last token.
|
||||
// This is done so we can restore the session state later and replay the last token.
|
||||
// Memory implementations in recurrent/hybrid models don't support removing tokens from their
|
||||
// memory, so we can't just remove the last token from the memory and replay the last token which
|
||||
// is the reason for this logic.
|
||||
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_tokens_before_last))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
n_past += n_tokens_before_last;
|
||||
|
||||
llama_state_save_file(ctx, state_path.data(), tokens.data(), n_tokens_before_last);
|
||||
LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.data(), n_tokens_before_last);
|
||||
|
||||
llama_token last_token = tokens.back();
|
||||
llama_batch batch = llama_batch_get_one(&last_token, 1);
|
||||
int32_t pos = n_past;
|
||||
batch.pos = &pos;
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s : failed to eval last token\n", __func__);
|
||||
return false;
|
||||
}
|
||||
n_past++;
|
||||
} else {
|
||||
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_eval))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
n_past += n_eval;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -804,6 +804,23 @@ void common_batch_add(
|
||||
const std::vector<llama_seq_id> & seq_ids,
|
||||
bool logits);
|
||||
|
||||
// decodes a single batch of tokens for a prompt and manages session tokens
|
||||
//
|
||||
// Note: We save state before the last token so that we can replay it to ensure
|
||||
// compatibility with all memory types. Recurrent/hybrid models cannot remove
|
||||
// tokens from memory, so this approach works across all model architectures.
|
||||
bool common_prompt_batch_decode(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & embd,
|
||||
int & n_past,
|
||||
int n_batch,
|
||||
std::string_view state_path,
|
||||
bool save_state);
|
||||
|
||||
// replays the last token after loading state to regenerate logits
|
||||
// used after loading session state to ensure the sampling context has valid logits
|
||||
bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos);
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
||||
@@ -77,7 +77,10 @@ causal-verify-embeddings: causal-run-original-embeddings causal-run-converted-em
|
||||
@./scripts/causal/compare-embeddings-logits.sh
|
||||
|
||||
causal-inspect-original-model:
|
||||
@./scripts/utils/inspect-org-model.py
|
||||
@./scripts/utils/inspect-org-model.py --list-all -s
|
||||
|
||||
causal-list-original-model-tensors:
|
||||
@./scripts/utils/inspect-org-model.py --list-all-short -s
|
||||
|
||||
causal-inspect-converted-model:
|
||||
@./scripts/utils/inspect-converted-model.sh
|
||||
@@ -153,7 +156,7 @@ embedding-verify-logits-st: embedding-run-original-model-st embedding-run-conver
|
||||
|
||||
embedding-inspect-original-model:
|
||||
$(call validate_embedding_model_path,embedding-inspect-original-model)
|
||||
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH}
|
||||
@EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH} --list-all -s
|
||||
|
||||
embedding-inspect-converted-model:
|
||||
@CONVERTED_EMBEDDING_MODEL="$(CONVERTED_EMBEDDING_MODEL)" ./scripts/utils/inspect-converted-model.sh ${CONVERTED_EMBEDDING_MODEL}
|
||||
|
||||
@@ -1,67 +1,290 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from safetensors import safe_open
|
||||
from collections import defaultdict
|
||||
|
||||
parser = argparse.ArgumentParser(description='Process model with specified path')
|
||||
parser.add_argument('--model-path', '-m', help='Path to the model')
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = os.environ.get('MODEL_PATH', args.model_path)
|
||||
if model_path is None:
|
||||
parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
|
||||
MODEL_SAFETENSORS_FILE = "model.safetensors"
|
||||
MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json"
|
||||
|
||||
# Check if there's an index file (multi-file model)
|
||||
index_path = os.path.join(model_path, "model.safetensors.index.json")
|
||||
single_file_path = os.path.join(model_path, "model.safetensors")
|
||||
DTYPE_SIZES = {
|
||||
"F64": 8, "I64": 8, "U64": 8,
|
||||
"F32": 4, "I32": 4, "U32": 4,
|
||||
"F16": 2, "BF16": 2, "I16": 2, "U16": 2,
|
||||
"I8": 1, "U8": 1, "BOOL": 1,
|
||||
"F8_E4M3": 1, "F8_E5M2": 1,
|
||||
}
|
||||
|
||||
if os.path.exists(index_path):
|
||||
# Multi-file model
|
||||
print("Multi-file model detected")
|
||||
SIZE_UNITS = ['B', 'KB', 'MB', 'GB', 'TB']
|
||||
|
||||
with open(index_path, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
|
||||
# Get the weight map (tensor_name -> file_name)
|
||||
weight_map = index_data.get("weight_map", {})
|
||||
def get_weight_map(model_path: Path) -> Optional[dict[str, str]]:
|
||||
index_file = model_path / MODEL_SAFETENSORS_INDEX
|
||||
|
||||
# Group tensors by file for efficient processing
|
||||
file_tensors = defaultdict(list)
|
||||
for tensor_name, file_name in weight_map.items():
|
||||
file_tensors[file_name].append(tensor_name)
|
||||
if index_file.exists():
|
||||
with open(index_file, 'r') as f:
|
||||
index = json.load(f)
|
||||
return index.get("weight_map", {})
|
||||
|
||||
print("Tensors in model:")
|
||||
return None
|
||||
|
||||
# Process each shard file
|
||||
for file_name, tensor_names in file_tensors.items():
|
||||
file_path = os.path.join(model_path, file_name)
|
||||
print(f"\n--- From {file_name} ---")
|
||||
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
for tensor_name in sorted(tensor_names):
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
print(f"- {tensor_name} : shape = {tensor.shape}, dtype = {tensor.dtype}")
|
||||
def get_all_tensor_names(model_path: Path) -> list[str]:
|
||||
weight_map = get_weight_map(model_path)
|
||||
|
||||
elif os.path.exists(single_file_path):
|
||||
# Single file model (original behavior)
|
||||
print("Single-file model detected")
|
||||
if weight_map is not None:
|
||||
return list(weight_map.keys())
|
||||
|
||||
with safe_open(single_file_path, framework="pt") as f:
|
||||
keys = f.keys()
|
||||
print("Tensors in model:")
|
||||
for key in sorted(keys):
|
||||
tensor = f.get_tensor(key)
|
||||
print(f"- {key} : shape = {tensor.shape}, dtype = {tensor.dtype}")
|
||||
single_file = model_path / MODEL_SAFETENSORS_FILE
|
||||
if single_file.exists():
|
||||
try:
|
||||
with safe_open(single_file, framework="pt", device="cpu") as f:
|
||||
return list(f.keys())
|
||||
except Exception as e:
|
||||
print(f"Error reading {single_file}: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
else:
|
||||
print(f"Error: Neither 'model.safetensors.index.json' nor 'model.safetensors' found in {model_path}")
|
||||
print("Available files:")
|
||||
if os.path.exists(model_path):
|
||||
for item in sorted(os.listdir(model_path)):
|
||||
print(f" {item}")
|
||||
print(f"Error: No safetensors files found in {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]:
|
||||
weight_map = get_weight_map(model_path)
|
||||
|
||||
if weight_map is not None:
|
||||
return weight_map.get(tensor_name)
|
||||
|
||||
single_file = model_path / MODEL_SAFETENSORS_FILE
|
||||
if single_file.exists():
|
||||
return single_file.name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def read_safetensors_header(file_path: Path) -> dict:
|
||||
with open(file_path, 'rb') as f:
|
||||
header_size = struct.unpack('<Q', f.read(8))[0]
|
||||
return json.loads(f.read(header_size))
|
||||
|
||||
|
||||
def get_tensor_size_bytes(tensor_meta: dict) -> int:
|
||||
offsets = tensor_meta.get("data_offsets")
|
||||
if offsets and len(offsets) == 2:
|
||||
return offsets[1] - offsets[0]
|
||||
n_elements = 1
|
||||
for d in tensor_meta.get("shape", []):
|
||||
n_elements *= d
|
||||
return n_elements * DTYPE_SIZES.get(tensor_meta.get("dtype", "F32"), 4)
|
||||
|
||||
|
||||
def format_size(size_bytes: int) -> str:
|
||||
val = float(size_bytes)
|
||||
for unit in SIZE_UNITS[:-1]:
|
||||
if val < 1024.0:
|
||||
return f"{val:.2f} {unit}"
|
||||
val /= 1024.0
|
||||
return f"{val:.2f} {SIZE_UNITS[-1]}"
|
||||
|
||||
|
||||
def get_all_tensor_metadata(model_path: Path) -> dict[str, dict]:
|
||||
weight_map = get_weight_map(model_path)
|
||||
|
||||
if weight_map is not None:
|
||||
file_to_tensors: dict[str, list[str]] = {}
|
||||
for tensor_name, file_name in weight_map.items():
|
||||
file_to_tensors.setdefault(file_name, []).append(tensor_name)
|
||||
|
||||
all_metadata: dict[str, dict] = {}
|
||||
for file_name, tensor_names in file_to_tensors.items():
|
||||
try:
|
||||
header = read_safetensors_header(model_path / file_name)
|
||||
for tensor_name in tensor_names:
|
||||
if tensor_name in header:
|
||||
all_metadata[tensor_name] = header[tensor_name]
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not read header from {file_name}: {e}", file=sys.stderr)
|
||||
return all_metadata
|
||||
|
||||
single_file = model_path / MODEL_SAFETENSORS_FILE
|
||||
if single_file.exists():
|
||||
try:
|
||||
header = read_safetensors_header(single_file)
|
||||
return {k: v for k, v in header.items() if k != "__metadata__"}
|
||||
except Exception as e:
|
||||
print(f"Error reading {single_file}: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Error: No safetensors files found in {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def normalize_tensor_name(tensor_name: str) -> str:
|
||||
normalized = re.sub(r'\.\d+\.', '.#.', tensor_name)
|
||||
normalized = re.sub(r'\.\d+$', '.#', normalized)
|
||||
return normalized
|
||||
|
||||
|
||||
def list_all_tensors(
|
||||
model_path: Path,
|
||||
short: bool = False,
|
||||
show_sizes: bool = False,
|
||||
):
|
||||
tensor_names = get_all_tensor_names(model_path)
|
||||
|
||||
metadata: Optional[dict[str, dict]] = None
|
||||
if show_sizes:
|
||||
metadata = get_all_tensor_metadata(model_path)
|
||||
|
||||
total_bytes = 0
|
||||
|
||||
if short:
|
||||
seen: dict[str, str] = {}
|
||||
for tensor_name in sorted(tensor_names):
|
||||
normalized = normalize_tensor_name(tensor_name)
|
||||
if normalized not in seen:
|
||||
seen[normalized] = tensor_name
|
||||
display_pairs = list(sorted(seen.items()))
|
||||
name_width = max((len(n) for n, _ in display_pairs), default=0)
|
||||
for normalized, first_name in display_pairs:
|
||||
if metadata and first_name in metadata:
|
||||
m = metadata[first_name]
|
||||
size = get_tensor_size_bytes(m)
|
||||
total_bytes += size
|
||||
print(f"{normalized:{name_width}} {m.get('dtype', '?'):6s} {str(m.get('shape', '')):30s} {format_size(size)}")
|
||||
else:
|
||||
print(normalized)
|
||||
else:
|
||||
print(f" Directory {model_path} does not exist")
|
||||
exit(1)
|
||||
name_width = max((len(n) for n in tensor_names), default=0)
|
||||
for tensor_name in sorted(tensor_names):
|
||||
if metadata and tensor_name in metadata:
|
||||
m = metadata[tensor_name]
|
||||
size = get_tensor_size_bytes(m)
|
||||
total_bytes += size
|
||||
print(f"{tensor_name:{name_width}} {m.get('dtype', '?'):6s} {str(m.get('shape', '')):30s} {format_size(size)}")
|
||||
else:
|
||||
print(tensor_name)
|
||||
|
||||
if show_sizes:
|
||||
print(f"\nTotal: {format_size(total_bytes)}")
|
||||
|
||||
|
||||
def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None):
|
||||
tensor_file = find_tensor_file(model_path, tensor_name)
|
||||
|
||||
if tensor_file is None:
|
||||
print(f"Error: Could not find tensor '{tensor_name}' in model index")
|
||||
print(f"Model path: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
file_path = model_path / tensor_file
|
||||
|
||||
try:
|
||||
header = read_safetensors_header(file_path)
|
||||
tensor_meta = header.get(tensor_name, {})
|
||||
dtype_str = tensor_meta.get("dtype")
|
||||
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
if tensor_name in f.keys():
|
||||
tensor_slice = f.get_slice(tensor_name)
|
||||
shape = tensor_slice.get_shape()
|
||||
print(f"Tensor: {tensor_name}")
|
||||
print(f"File: {tensor_file}")
|
||||
print(f"Shape: {shape}")
|
||||
if dtype_str:
|
||||
print(f"Dtype: {dtype_str}")
|
||||
if tensor_meta:
|
||||
print(f"Size: {format_size(get_tensor_size_bytes(tensor_meta))}")
|
||||
if num_values is not None:
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
if not dtype_str:
|
||||
print(f"Dtype: {tensor.dtype}")
|
||||
flat = tensor.flatten()
|
||||
n = min(num_values, flat.numel())
|
||||
print(f"Values: {flat[:n].tolist()}")
|
||||
else:
|
||||
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
|
||||
sys.exit(1)
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file '{file_path}' was not found.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Print tensor information from a safetensors model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"tensor_name",
|
||||
nargs="?",
|
||||
help="Name of the tensor to inspect"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m", "--model-path",
|
||||
type=Path,
|
||||
help="Path to the model directory (default: MODEL_PATH environment variable)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list-all-short",
|
||||
action="store_true",
|
||||
help="List unique tensor patterns (layer numbers replaced with #)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-la", "--list-all",
|
||||
action="store_true",
|
||||
help="List all tensor names with actual layer numbers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n", "--num-values",
|
||||
nargs="?",
|
||||
const=10,
|
||||
default=None,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s", "--sizes",
|
||||
action="store_true",
|
||||
help="Show dtype, shape, and size for each tensor when listing"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = args.model_path
|
||||
if model_path is None:
|
||||
model_path_str = os.environ.get("MODEL_PATH")
|
||||
if model_path_str is None:
|
||||
print("Error: --model-path not provided and MODEL_PATH environment variable not set")
|
||||
sys.exit(1)
|
||||
model_path = Path(model_path_str)
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"Error: Model path does not exist: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if not model_path.is_dir():
|
||||
print(f"Error: Model path is not a directory: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.list_all_short or args.list_all:
|
||||
list_all_tensors(model_path, short=args.list_all_short, show_sizes=args.sizes)
|
||||
else:
|
||||
if args.tensor_name is None:
|
||||
print("Error: tensor_name is required when not using --list-all-short or --list-all")
|
||||
sys.exit(1)
|
||||
print_tensor_info(model_path, args.tensor_name, args.num_values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
MODEL_SAFETENSORS_FILE = "model.safetensors"
|
||||
MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json"
|
||||
|
||||
|
||||
def get_weight_map(model_path: Path) -> Optional[dict[str, str]]:
|
||||
index_file = model_path / MODEL_SAFETENSORS_INDEX
|
||||
|
||||
if index_file.exists():
|
||||
with open(index_file, 'r') as f:
|
||||
index = json.load(f)
|
||||
return index.get("weight_map", {})
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_all_tensor_names(model_path: Path) -> list[str]:
|
||||
weight_map = get_weight_map(model_path)
|
||||
|
||||
if weight_map is not None:
|
||||
return list(weight_map.keys())
|
||||
|
||||
single_file = model_path / MODEL_SAFETENSORS_FILE
|
||||
if single_file.exists():
|
||||
try:
|
||||
with safe_open(single_file, framework="pt", device="cpu") as f:
|
||||
return list(f.keys())
|
||||
except Exception as e:
|
||||
print(f"Error reading {single_file}: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Error: No safetensors files found in {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]:
|
||||
weight_map = get_weight_map(model_path)
|
||||
|
||||
if weight_map is not None:
|
||||
return weight_map.get(tensor_name)
|
||||
|
||||
single_file = model_path / MODEL_SAFETENSORS_FILE
|
||||
if single_file.exists():
|
||||
return single_file.name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def normalize_tensor_name(tensor_name: str) -> str:
|
||||
normalized = re.sub(r'\.\d+\.', '.#.', tensor_name)
|
||||
normalized = re.sub(r'\.\d+$', '.#', normalized)
|
||||
return normalized
|
||||
|
||||
|
||||
def list_all_tensors(model_path: Path, unique: bool = False):
|
||||
tensor_names = get_all_tensor_names(model_path)
|
||||
|
||||
if unique:
|
||||
seen = set()
|
||||
for tensor_name in sorted(tensor_names):
|
||||
normalized = normalize_tensor_name(tensor_name)
|
||||
if normalized not in seen:
|
||||
seen.add(normalized)
|
||||
print(normalized)
|
||||
else:
|
||||
for tensor_name in sorted(tensor_names):
|
||||
print(tensor_name)
|
||||
|
||||
|
||||
def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None):
|
||||
tensor_file = find_tensor_file(model_path, tensor_name)
|
||||
|
||||
if tensor_file is None:
|
||||
print(f"Error: Could not find tensor '{tensor_name}' in model index")
|
||||
print(f"Model path: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
file_path = model_path / tensor_file
|
||||
|
||||
try:
|
||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||
if tensor_name in f.keys():
|
||||
tensor_slice = f.get_slice(tensor_name)
|
||||
shape = tensor_slice.get_shape()
|
||||
print(f"Tensor: {tensor_name}")
|
||||
print(f"File: {tensor_file}")
|
||||
print(f"Shape: {shape}")
|
||||
if num_values is not None:
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
print(f"Dtype: {tensor.dtype}")
|
||||
flat = tensor.flatten()
|
||||
n = min(num_values, flat.numel())
|
||||
print(f"Values: {flat[:n].tolist()}")
|
||||
else:
|
||||
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
|
||||
sys.exit(1)
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"Error: The file '{file_path}' was not found.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Print tensor information from a safetensors model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"tensor_name",
|
||||
nargs="?", # optional (if --list is used for example)
|
||||
help="Name of the tensor to inspect"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m", "--model-path",
|
||||
type=Path,
|
||||
help="Path to the model directory (default: MODEL_PATH environment variable)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l", "--list",
|
||||
action="store_true",
|
||||
help="List unique tensor patterns in the model (layer numbers replaced with #)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n", "--num-values",
|
||||
nargs="?",
|
||||
const=10,
|
||||
default=None,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = args.model_path
|
||||
if model_path is None:
|
||||
model_path_str = os.environ.get("MODEL_PATH")
|
||||
if model_path_str is None:
|
||||
print("Error: --model-path not provided and MODEL_PATH environment variable not set")
|
||||
sys.exit(1)
|
||||
model_path = Path(model_path_str)
|
||||
|
||||
if not model_path.exists():
|
||||
print(f"Error: Model path does not exist: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if not model_path.is_dir():
|
||||
print(f"Error: Model path is not a directory: {model_path}")
|
||||
sys.exit(1)
|
||||
|
||||
if args.list:
|
||||
list_all_tensors(model_path, unique=True)
|
||||
else:
|
||||
if args.tensor_name is None:
|
||||
print("Error: tensor_name is required when not using --list")
|
||||
sys.exit(1)
|
||||
print_tensor_info(model_path, args.tensor_name, args.num_values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -5,12 +5,15 @@
|
||||
#include <vector>
|
||||
#include <cstdio>
|
||||
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
params.prompt = "The quick brown fox";
|
||||
params.sampling.seed = 1234;
|
||||
|
||||
const std::string_view state_file = "dump_state.bin";
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||
return 1;
|
||||
}
|
||||
@@ -53,35 +56,16 @@ int main(int argc, char ** argv) {
|
||||
// tokenize prompt
|
||||
auto tokens = common_tokenize(ctx, params.prompt, true);
|
||||
|
||||
// prepare the batch
|
||||
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
common_batch_add(batch, tokens[i], i, {0}, false);
|
||||
const bool save_state = true;
|
||||
if (!common_prompt_batch_decode(ctx, tokens, n_past, params.n_batch, state_file, save_state)) {
|
||||
return 1;
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true; // generate next token
|
||||
|
||||
// evaluate prompt
|
||||
llama_decode(ctx, batch);
|
||||
n_past += batch.n_tokens;
|
||||
|
||||
// save state (rng, logits, embedding and kv_cache) to file
|
||||
{
|
||||
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
|
||||
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
|
||||
|
||||
FILE *fp_write = fopen("dump_state.bin", "wb");
|
||||
fwrite(state_mem.data(), 1, written, fp_write);
|
||||
fclose(fp_write);
|
||||
|
||||
fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size());
|
||||
}
|
||||
|
||||
// save state (last tokens)
|
||||
const auto n_past_saved = n_past;
|
||||
|
||||
// first run
|
||||
printf("\nfirst run: %s", params.prompt.c_str());
|
||||
|
||||
llama_batch batch = llama_batch_init(1, 0, 1);
|
||||
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto next_token = llama_sampler_sample(smpl, ctx, -1);
|
||||
auto next_token_str = common_token_to_piece(ctx, next_token);
|
||||
@@ -111,27 +95,23 @@ int main(int argc, char ** argv) {
|
||||
|
||||
printf("\nsecond run: %s", params.prompt.c_str());
|
||||
|
||||
// load state (rng, logits, embedding and kv_cache) from file
|
||||
{
|
||||
std::vector<uint8_t> state_mem;
|
||||
// load state from file
|
||||
std::vector<llama_token> unused_sts(tokens.size()); // unused session tokens.
|
||||
size_t n_token_count_out = 0;
|
||||
|
||||
FILE * fp_read = fopen("dump_state.bin", "rb");
|
||||
fseek(fp_read, 0, SEEK_END);
|
||||
state_mem.resize(ftell(fp_read));
|
||||
fseek(fp_read, 0, SEEK_SET);
|
||||
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
||||
fclose(fp_read);
|
||||
|
||||
if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
|
||||
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
|
||||
if (!llama_state_load_file(ctx2, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
|
||||
fprintf(stderr, "\n%s : failed to load state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
|
||||
|
||||
// restore state (last tokens)
|
||||
n_past = n_past_saved;
|
||||
n_past = n_token_count_out;
|
||||
if (!common_replay_last_token(ctx2, tokens.back(), n_past)) {
|
||||
return 1;
|
||||
}
|
||||
++n_past;
|
||||
|
||||
// second run
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
@@ -160,7 +140,9 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// make new context
|
||||
llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params));
|
||||
auto params_ctx3 = common_context_params_to_llama(params);
|
||||
params_ctx3.n_seq_max = 2;
|
||||
llama_context * ctx3 = llama_init_from_model(model, params_ctx3);
|
||||
|
||||
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
|
||||
|
||||
@@ -169,26 +151,21 @@ int main(int argc, char ** argv) {
|
||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||
|
||||
// load state (rng, logits, embedding and kv_cache) from file
|
||||
{
|
||||
std::vector<uint8_t> state_mem;
|
||||
n_token_count_out = 0;
|
||||
|
||||
FILE * fp_read = fopen("dump_state.bin", "rb");
|
||||
fseek(fp_read, 0, SEEK_END);
|
||||
state_mem.resize(ftell(fp_read));
|
||||
fseek(fp_read, 0, SEEK_SET);
|
||||
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
||||
fclose(fp_read);
|
||||
|
||||
if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
|
||||
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
|
||||
if (!llama_state_load_file(ctx3, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
|
||||
fprintf(stderr, "\n%s : failed to load state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
|
||||
|
||||
// restore state (last tokens)
|
||||
n_past = n_past_saved;
|
||||
n_past = n_token_count_out;
|
||||
if (!common_replay_last_token(ctx3, tokens.back(), n_past)) {
|
||||
return 1;
|
||||
}
|
||||
++n_past;
|
||||
|
||||
// save seq 0 and load into seq 1
|
||||
{
|
||||
|
||||
@@ -42,6 +42,7 @@
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
@@ -55,9 +56,10 @@
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
@@ -77,6 +79,7 @@
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
@@ -86,6 +89,7 @@
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
@@ -110,6 +114,7 @@
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
@@ -123,6 +128,7 @@
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
@@ -148,6 +154,7 @@
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
@@ -161,6 +168,7 @@
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
@@ -187,6 +195,7 @@
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
@@ -199,6 +208,7 @@
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
@@ -230,6 +240,7 @@
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
@@ -243,6 +254,7 @@
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
@@ -276,6 +288,7 @@
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
||||
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
|
||||
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
|
||||
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
|
||||
@@ -289,6 +302,7 @@
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
|
||||
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
|
||||
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
|
||||
|
||||
@@ -785,6 +785,165 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x4_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 4;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mone = vdupq_n_u8(1);
|
||||
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||
|
||||
// 1x8 tile = 2 x 4
|
||||
float32x4_t acc_f32[col_groups];
|
||||
|
||||
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < col_groups; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
|
||||
float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
|
||||
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
||||
float32x4_t sb_scale_0123 = vmulq_f32(q5_d_0, q8_d);
|
||||
float32x4_t sb_scale_4567 = vmulq_f32(q5_d_1, q8_d);
|
||||
float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
|
||||
float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
|
||||
float32x4_t sb_min_0123 = vmulq_f32(q5_dmin_0, q8_d);
|
||||
float32x4_t sb_min_4567 = vmulq_f32(q5_dmin_1, q8_d);
|
||||
|
||||
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
|
||||
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
|
||||
int32x4_t acc_lo[col_groups];
|
||||
int32x4_t acc_hi[col_groups];
|
||||
|
||||
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
||||
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
||||
int16_t bsums_arr[8];
|
||||
vst1q_s16(bsums_arr, bsums);
|
||||
|
||||
uint8x16_t qh[col_groups][8];
|
||||
for (int c = 0; c < col_groups; c++) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
|
||||
}
|
||||
}
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
for (int i = 0; i < col_groups; i++) {
|
||||
acc_lo[i] = vdupq_n_s32(0);
|
||||
acc_hi[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int16x8_t q5sb_mins[2];
|
||||
int16x8_t q5sb_scales[2];
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q5sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
|
||||
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
|
||||
}
|
||||
|
||||
int8x16_t q8_qs[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
|
||||
}
|
||||
|
||||
for (int c = 0; c < col_groups; c++) {
|
||||
uint8x16_t q5_cols[8];
|
||||
uint8x16_t hbit_lo[8];
|
||||
uint8x16_t hbit_hi[8];
|
||||
int8x16_t q5_lo[8];
|
||||
int8x16_t q5_hi[8];
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
|
||||
hbit_lo[i] = vandq_u8(qh[c][i], mone);
|
||||
hbit_hi[i] = vshlq_n_u8(vandq_u8(qh[c][i], mtwo), 3);
|
||||
qh[c][i] = vshrq_n_u8(qh[c][i], 2);
|
||||
q5_lo[i] = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_cols[i], m4b), hbit_lo[i], 4));
|
||||
q5_hi[i] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_cols[i], 4), hbit_hi[i]));
|
||||
}
|
||||
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[0], q8_qs[0], 0);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[1], q8_qs[0], 1);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[2], q8_qs[0], 2);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[3], q8_qs[0], 3);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[4], q8_qs[1], 0);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[5], q8_qs[1], 1);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[6], q8_qs[1], 2);
|
||||
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[7], q8_qs[1], 3);
|
||||
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[0], q8_qs[2], 0);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[1], q8_qs[2], 1);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[2], q8_qs[2], 2);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[3], q8_qs[2], 3);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[4], q8_qs[3], 0);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[5], q8_qs[3], 1);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[6], q8_qs[3], 2);
|
||||
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[7], q8_qs[3], 3);
|
||||
}
|
||||
|
||||
// Scales
|
||||
// row c0123 blk0 and blk1
|
||||
const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
|
||||
const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
|
||||
const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
|
||||
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
|
||||
acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
|
||||
// row c4567 blk0 and blk1
|
||||
const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
|
||||
const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
|
||||
const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
|
||||
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
|
||||
acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
|
||||
|
||||
// Bias Correction
|
||||
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
||||
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
||||
|
||||
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
|
||||
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
|
||||
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
|
||||
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
|
||||
} // for sb
|
||||
|
||||
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
|
||||
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
|
||||
} // for b
|
||||
|
||||
int base = x * ncols_interleaved;
|
||||
vst1q_f32(s + base, acc_f32[0]);
|
||||
vst1q_f32(s + base + 4, acc_f32[1]);
|
||||
} // for x
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemv_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
@@ -3205,6 +3364,235 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q5_K_8x4_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
||||
constexpr int ncols_interleaved = 8;
|
||||
constexpr int blocklen = 4;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
constexpr int acc_size = 2 * 4; // 2 row pairs, 4 col pairs
|
||||
constexpr int col_groups = ncols_interleaved / 4;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mone = vdupq_n_u8(1);
|
||||
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||
|
||||
// 8 accumulators: 2 row pairs, 4 col pairs
|
||||
float32x4_t acc_f32[acc_size];
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
acc_f32[i] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
// d5 0 1 2 3, 4 5 6 7
|
||||
float32x4_t q5_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));
|
||||
float32x4_t q5_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));
|
||||
// d8 0 1 2 3
|
||||
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
|
||||
// mins
|
||||
float32x4_t q5_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin));
|
||||
float32x4_t q5_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4));
|
||||
|
||||
// Precomputation of scales and mins
|
||||
float32x4_t sbd_scale_0123[q8_k_blocklen];
|
||||
float32x4_t sbd_scale_4567[q8_k_blocklen];
|
||||
float32x4_t sbd_min_0123[q8_k_blocklen];
|
||||
float32x4_t sbd_min_4567[q8_k_blocklen];
|
||||
|
||||
sbd_scale_0123[0] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 0);
|
||||
sbd_scale_4567[0] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 0);
|
||||
sbd_min_0123[0] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 0);
|
||||
sbd_min_4567[0] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 0);
|
||||
|
||||
sbd_scale_0123[1] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 1);
|
||||
sbd_scale_4567[1] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 1);
|
||||
sbd_min_0123[1] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 1);
|
||||
sbd_min_4567[1] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 1);
|
||||
|
||||
sbd_scale_0123[2] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 2);
|
||||
sbd_scale_4567[2] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 2);
|
||||
sbd_min_0123[2] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 2);
|
||||
sbd_min_4567[2] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 2);
|
||||
|
||||
sbd_scale_0123[3] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 3);
|
||||
sbd_scale_4567[3] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 3);
|
||||
sbd_min_0123[3] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 3);
|
||||
sbd_min_4567[3] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 3);
|
||||
|
||||
// Precomputation of bsums, each vpaddq calcs all the bsums for each row
|
||||
const int16x8_t bsums[q8_k_blocklen] = {
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||
};
|
||||
int16_t bsums_arr[QK_K / 64][8];
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
||||
}
|
||||
|
||||
// interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
|
||||
int32x4_t bias_acc[acc_size];
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
bias_acc[i] = vdupq_n_s32(0);
|
||||
}
|
||||
|
||||
uint8x16_t qh[col_groups][8];
|
||||
for (int c = 0; c < col_groups; c++) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
|
||||
}
|
||||
}
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
// Int accumulators for qs vecdot (4 row * 2 col quartets)
|
||||
int32x4_t acc_lo[acc_size];
|
||||
int32x4_t acc_hi[acc_size];
|
||||
for (int i = 0; i < acc_size; i++) {
|
||||
acc_lo[i] = vdupq_n_s32(0);
|
||||
acc_hi[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
int16x8_t q5sb_scales[2];
|
||||
int16x8_t q5sb_mins[2];
|
||||
for (int i = 0; i < 2; i++) {
|
||||
int8_t aux_q5sb[8];
|
||||
const int offset = sb * 24 + i * 12;
|
||||
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
|
||||
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
|
||||
}
|
||||
|
||||
constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
|
||||
for (int k = 0; k < reads_per_sb; k++) {
|
||||
const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
|
||||
const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
|
||||
|
||||
// 0..3 & 32..35
|
||||
const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k);
|
||||
const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16);
|
||||
|
||||
// NOTE: This is the only difference with q4_K
|
||||
const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone);
|
||||
const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3);
|
||||
qh[0][k] = vshrq_n_u8(qh[0][k], 2);
|
||||
const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone);
|
||||
const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3);
|
||||
qh[1][k] = vshrq_n_u8(qh[1][k], 2);
|
||||
// From here, same as q4_K
|
||||
|
||||
const int8x16_t q5_0123_lo =
|
||||
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4));
|
||||
const int8x16_t q5_0123_hi =
|
||||
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123));
|
||||
|
||||
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
|
||||
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
|
||||
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q5_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
|
||||
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q5_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
|
||||
|
||||
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q5_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
|
||||
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q5_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
|
||||
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
|
||||
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
|
||||
|
||||
const int8x16_t q5_4567_lo =
|
||||
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4));
|
||||
const int8x16_t q5_4567_hi =
|
||||
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567));
|
||||
|
||||
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
|
||||
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
|
||||
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
|
||||
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q5_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
|
||||
|
||||
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q5_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
|
||||
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q5_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
|
||||
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q5_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
|
||||
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q5_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
|
||||
}
|
||||
|
||||
// Scale and bias application
|
||||
// acc is stored interleaved to match output layout
|
||||
const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
|
||||
const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
|
||||
const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
|
||||
const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
|
||||
for (int row = 0; row < q8_k_blocklen; row++) {
|
||||
// Bias correction
|
||||
// row c0123 blk0 and blk1
|
||||
const float32x4_t sumf_0123 =
|
||||
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
|
||||
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
|
||||
acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
|
||||
|
||||
// row c4567 blk0 and blk1
|
||||
const float32x4_t sumf_4567 =
|
||||
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
|
||||
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
|
||||
acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
|
||||
|
||||
// Bias
|
||||
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
|
||||
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
|
||||
|
||||
// row c0123 blk0 and blk1
|
||||
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
|
||||
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
|
||||
|
||||
// row c4567 blk0 and blk1
|
||||
bias_acc[2 * row + 1] =
|
||||
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
|
||||
bias_acc[2 * row + 1] =
|
||||
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
|
||||
}
|
||||
} // for sb
|
||||
|
||||
for (int row = 0; row < q8_k_blocklen; row++) {
|
||||
acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
|
||||
acc_f32[2 * row + 1] =
|
||||
vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
|
||||
}
|
||||
} // for b
|
||||
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
ggml_gemm_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
|
||||
@@ -450,6 +450,208 @@ static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n,
|
||||
}
|
||||
}
|
||||
|
||||
template <int M, int N>
|
||||
static void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int blocklen = M;
|
||||
constexpr int ncols_interleaved = N;
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[ncols_interleaved];
|
||||
float sum_minf[ncols_interleaved];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0;
|
||||
sum_minf[j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
constexpr int scale_stride = 32;
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;
|
||||
|
||||
const int qh_shift = (k / (32 / blocklen)) * 2;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * blocklen + i) % 32;
|
||||
const int qh_chunk = qh_idx / blocklen;
|
||||
const int qh_pos = qh_idx % blocklen;
|
||||
const int b_qh_offset = qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k / (32 / blocklen)) * 64 + (k % (32 / blocklen)) * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int M, int N>
|
||||
static void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
constexpr int blocklen = M;
|
||||
constexpr int ncols_interleaved = N;
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
float sumf[4][ncols_interleaved];
|
||||
float sum_minf[4][ncols_interleaved];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0;
|
||||
sum_minf[m][j] = 0.0;
|
||||
}
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
constexpr int scale_stride = 32;
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16;
|
||||
|
||||
const int qh_shift = (k / (32 / blocklen)) * 2;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * blocklen + i) % 32;
|
||||
const int qh_chunk = qh_idx / blocklen;
|
||||
const int qh_pos = qh_idx % blocklen;
|
||||
const int b_qh_offset =
|
||||
qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k / (32 / blocklen)) * 256 +
|
||||
(k % (32 / blocklen)) * 4 * blocklen + m * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
@@ -803,98 +1005,12 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(bs);
|
||||
UNUSED(nr);
|
||||
|
||||
float sumf[8];
|
||||
float sum_minf[8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0;
|
||||
sum_minf[j] = 0.0;
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
|
||||
|
||||
const int qh_shift = (k / 4) * 2;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * 8 + i) % 32;
|
||||
const int qh_chunk = qh_idx / 8;
|
||||
const int qh_pos = qh_idx % 8;
|
||||
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
||||
}
|
||||
}
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
|
||||
@@ -1494,107 +1610,12 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n,
|
||||
float * GGML_RESTRICT s,
|
||||
size_t bs,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
int nr,
|
||||
int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 8;
|
||||
const int blocklen = 8;
|
||||
void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
float sumf[4][8];
|
||||
float sum_minf[4][8];
|
||||
uint32_t utmp[32];
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0;
|
||||
sum_minf[m][j] = 0.0;
|
||||
}
|
||||
}
|
||||
for (int l = 0; l < nb; l++) {
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
|
||||
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
|
||||
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
|
||||
utmp[sb * 4 + 2] = uaux_0;
|
||||
utmp[sb * 4 + 0] &= kmask1;
|
||||
}
|
||||
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
|
||||
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
|
||||
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
|
||||
|
||||
const int qh_shift = (k / 4) * 2;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; ++i) {
|
||||
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
|
||||
|
||||
const int qh_idx = (k * 8 + i) % 32;
|
||||
const int qh_chunk = qh_idx / 8;
|
||||
const int qh_pos = qh_idx % 8;
|
||||
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
|
||||
|
||||
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
|
||||
const uint8_t h0 = (qh_val >> qh_shift) & 1;
|
||||
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
|
||||
|
||||
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
|
||||
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
|
||||
|
||||
const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i;
|
||||
|
||||
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
|
||||
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
|
||||
sumi1 = sumi1 * scales_0[j];
|
||||
sumi2 = sumi2 * scales_1[j];
|
||||
sumi += sumi1 + sumi2;
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int sb = 0; sb < 8; sb++) {
|
||||
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
|
||||
for (int m = 0; m < 4; m++) {
|
||||
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
|
||||
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
@@ -2029,18 +2050,16 @@ static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_in
|
||||
|
||||
const int end = QK_K * 4 / blck_size_interleave;
|
||||
|
||||
// Interleave Q5_K quants by taking 8 bytes at a time
|
||||
// Interleave Q5_K quants by taking blck_size_interleave bytes at a time
|
||||
for (int i = 0; i < end; ++i) {
|
||||
int src_id = i % 8;
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
|
||||
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave);
|
||||
}
|
||||
|
||||
// Repeat for low bits 8 bytes at a time as well, since
|
||||
// Repeat for high bits with the same chunk size, since
|
||||
// the high bits are interleaved in Q5_K and the index is
|
||||
// qh_idx = (qs_idx % 32);
|
||||
// qh_val = qh[qh_idx] >> (qs_idx / 32);
|
||||
@@ -2049,9 +2068,7 @@ static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_in
|
||||
int src_offset = (i / 8) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
uint64_t elems;
|
||||
memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t));
|
||||
memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t));
|
||||
memcpy(&out.qh[dst_offset], &in[src_id].qh[src_offset], blck_size_interleave);
|
||||
}
|
||||
|
||||
// The below logic is copied over from Q4_K
|
||||
@@ -2249,7 +2266,7 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t,
|
||||
const void * GGML_RESTRICT data,
|
||||
size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q5_K);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
|
||||
constexpr int nrows_interleaved = 8;
|
||||
|
||||
block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data;
|
||||
@@ -2523,6 +2540,10 @@ template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * da
|
||||
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q5_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q5_K_to_q5_K_8_bl(t, 4, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
@@ -2591,6 +2612,10 @@ template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
||||
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q5_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -2654,6 +2679,10 @@ template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
||||
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q5_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
@@ -3068,6 +3097,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
||||
|
||||
// instance for Q5_K
|
||||
static const ggml::cpu::repack::tensor_traits<block_q5_K, 4, 8, GGML_TYPE_Q8_K> q5_K_8x4_q8_K;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
|
||||
|
||||
// instance for Q6_K
|
||||
@@ -3130,6 +3160,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||
return &q5_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q5_K_8x4_q8_K;
|
||||
}
|
||||
}
|
||||
} else if (cur->type == GGML_TYPE_Q6_K) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
|
||||
@@ -111,6 +111,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -122,6 +123,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -143,6 +145,7 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
@@ -154,6 +157,7 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
||||
@@ -1749,23 +1749,6 @@ static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backe
|
||||
return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer;
|
||||
}
|
||||
|
||||
static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) {
|
||||
if (x->ne[0] != y->ne[0]) {
|
||||
return false;
|
||||
}
|
||||
if (x->ne[1] != y->ne[1]) {
|
||||
return false;
|
||||
}
|
||||
if (x->ne[2] != y->ne[2]) {
|
||||
return false;
|
||||
}
|
||||
if (x->ne[3] != y->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
@@ -1797,43 +1780,6 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
|
||||
return opt_experimental;
|
||||
}
|
||||
|
||||
static bool hex_supported_src0_type(ggml_type t) {
|
||||
return t == GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
static bool hex_supported_src1_type(ggml_type t) {
|
||||
return t == GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
static bool hex_supported_src2_type(ggml_type t) {
|
||||
return t == GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
static bool hex_supported_src1_type2(ggml_type t) {
|
||||
return t == GGML_TYPE_F16;
|
||||
}
|
||||
|
||||
static bool hex_supported_src1_type3(ggml_type t) {
|
||||
return t == GGML_TYPE_I32;
|
||||
}
|
||||
|
||||
static bool hex_supported_dst_type(ggml_type t) {
|
||||
return t == GGML_TYPE_F32;
|
||||
}
|
||||
|
||||
static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) {
|
||||
// TODO: support broadcast for ne[2 and 3]
|
||||
if (x->ne[0] != y->ne[0]) {
|
||||
return false;
|
||||
}
|
||||
if (x->ne[2] != y->ne[2]) {
|
||||
return false;
|
||||
}
|
||||
if (x->ne[3] != y->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) {
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
@@ -1919,19 +1865,19 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_src1_type(src1->type)) {
|
||||
if (src1->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dims2(src0, dst)) {
|
||||
if (!ggml_are_same_shape(src0, dst)) {
|
||||
return false;
|
||||
}
|
||||
if (!ggml_can_repeat(src1, src0)) {
|
||||
if (!ggml_can_repeat(src1, src0) || ggml_is_permuted(src1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1943,16 +1889,16 @@ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * se
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_src1_type(src1->type)) {
|
||||
if (src1->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dims2(src0, dst)) {
|
||||
if (!ggml_are_same_shape(src0, dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1968,13 +1914,13 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dims2(src0, dst)) {
|
||||
if (!ggml_are_same_shape(src0, dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1990,10 +1936,10 @@ static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session *
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2011,10 +1957,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
|
||||
const struct ggml_tensor * src1 = op->src[1];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2023,10 +1969,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session
|
||||
}
|
||||
|
||||
if (src1) {
|
||||
if (!hex_supported_src1_type(src1->type)) {
|
||||
if (src1->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dims2(src0, src1)) {
|
||||
if (!ggml_are_same_shape(src0, src1)) {
|
||||
return false;
|
||||
}
|
||||
if (!ggml_is_contiguous(src1)) {
|
||||
@@ -2047,15 +1993,15 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s
|
||||
return false; // FIXME: add support for sinks
|
||||
}
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (src1) {
|
||||
if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) {
|
||||
if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
if (src0->ne[0] != src1->ne[0]) {
|
||||
@@ -2162,17 +2108,17 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess
|
||||
const struct ggml_tensor * src2 = op->src[2];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
if (!hex_supported_src0_type(src0->type)) {
|
||||
if (src0->type != GGML_TYPE_F32) {
|
||||
return false; // FIXME: add support for GGML_TYPE_F16 for src0
|
||||
}
|
||||
if (!hex_supported_dst_type(dst->type)) {
|
||||
if (dst->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
if (!hex_supported_src1_type3(src1->type)) {
|
||||
if (src1->type != GGML_TYPE_I32) {
|
||||
return false;
|
||||
}
|
||||
if (src2) {
|
||||
if (!hex_supported_src2_type(src2->type)) {
|
||||
if (src2->type != GGML_TYPE_F32) {
|
||||
return false;
|
||||
}
|
||||
int n_dims = op_params[1];
|
||||
|
||||
@@ -69,27 +69,45 @@
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
struct htp_act_context {
|
||||
struct htp_ops_context * octx;
|
||||
|
||||
// Precomputed values
|
||||
const uint8_t * data_src0;
|
||||
const uint8_t * data_src1;
|
||||
uint8_t * data_dst;
|
||||
|
||||
size_t src0_row_size;
|
||||
size_t src1_row_size;
|
||||
size_t dst_row_size;
|
||||
|
||||
size_t src0_row_size_aligned;
|
||||
size_t src1_row_size_aligned;
|
||||
size_t dst_row_size_aligned;
|
||||
|
||||
size_t src0_spad_half_size;
|
||||
size_t src1_spad_half_size;
|
||||
size_t dst_spad_half_size;
|
||||
|
||||
uint32_t block;
|
||||
uint32_t src0_nrows;
|
||||
uint32_t src0_nrows_per_thread;
|
||||
int nc;
|
||||
};
|
||||
|
||||
static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * src1 = &actx->octx->src1;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble3;
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
|
||||
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
size_t src0_row_size = actx->src0_row_size;
|
||||
size_t src1_row_size = actx->src1_row_size;
|
||||
size_t dst_row_size = actx->dst_row_size;
|
||||
|
||||
const uint32_t src0_nrows = actx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
@@ -101,43 +119,34 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
const uint8_t * restrict data_src0 = actx->data_src0;
|
||||
const uint8_t * restrict data_src1 = actx->data_src1;
|
||||
uint8_t * restrict data_dst = actx->data_dst;
|
||||
|
||||
const bool src1_valid = src1->ne[0];
|
||||
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||
if (!src1_valid) {
|
||||
const int32_t swapped = op_params[1];
|
||||
data_src1 = data_src0;
|
||||
src1_row_size = src0_row_size;
|
||||
const int nc = actx->nc;
|
||||
|
||||
const size_t nc_in_bytes = nc * SIZEOF_FP32;
|
||||
data_src0 += swapped ? nc_in_bytes : 0;
|
||||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
|
||||
const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
size_t src0_spad_half_size = actx->src0_spad_half_size;
|
||||
size_t src1_spad_half_size = actx->src1_spad_half_size;
|
||||
size_t dst_spad_half_size = actx->dst_spad_half_size;
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
const int BLOCK = actx->block;
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR,
|
||||
"swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
@@ -196,27 +205,22 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * src1 = &actx->octx->src1;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble3;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
size_t src0_row_size = actx->src0_row_size;
|
||||
size_t src1_row_size = actx->src1_row_size;
|
||||
size_t dst_row_size = actx->dst_row_size;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src0_nrows = actx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
@@ -226,45 +230,36 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
const uint8_t * restrict data_src0 = actx->data_src0;
|
||||
const uint8_t * restrict data_src1 = actx->data_src1;
|
||||
uint8_t * restrict data_dst = actx->data_dst;
|
||||
|
||||
const bool src1_valid = src1->ne[0];
|
||||
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||
if (!src1_valid) {
|
||||
const int32_t swapped = op_params[1];
|
||||
data_src1 = data_src0;
|
||||
src1_row_size = src0_row_size;
|
||||
const int nc = actx->nc;
|
||||
|
||||
const size_t nc_in_bytes = nc * SIZEOF_FP32;
|
||||
data_src0 += swapped ? nc_in_bytes : 0;
|
||||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
|
||||
const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
size_t src0_spad_half_size = actx->src0_spad_half_size;
|
||||
size_t src1_spad_half_size = actx->src1_spad_half_size;
|
||||
size_t dst_spad_half_size = actx->dst_spad_half_size;
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
const int BLOCK = actx->block;
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR,
|
||||
"swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least "
|
||||
"%zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
const float alpha = ((const float *) (op_params))[2];
|
||||
const float limit = ((const float *) (op_params))[3];
|
||||
const float alpha = ((const float *) (actx->octx->op_params))[2];
|
||||
const float limit = ((const float *) (actx->octx->op_params))[3];
|
||||
|
||||
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
@@ -335,26 +330,22 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0,
|
||||
}
|
||||
|
||||
|
||||
static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble2;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
const size_t src0_row_size = actx->src0_row_size;
|
||||
const size_t dst_row_size = actx->dst_row_size;
|
||||
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
const uint32_t src0_nrows = actx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
@@ -364,25 +355,29 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * data_src0 = (const uint8_t *) src0->data;
|
||||
uint8_t * data_dst = (uint8_t *) dst->data;
|
||||
const uint8_t * data_src0 = actx->data_src0;
|
||||
uint8_t * data_dst = actx->data_dst;
|
||||
|
||||
uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
// nc/ne0 matches.
|
||||
const int ne0_val = actx->nc; // == dst->ne[0]
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
|
||||
uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
|
||||
|
||||
size_t src0_spad_half_size = actx->src0_spad_half_size;
|
||||
size_t dst_spad_half_size = actx->dst_spad_half_size;
|
||||
|
||||
// In gelu = x*sigmoid(x*1.702)
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
const int BLOCK = actx->block;
|
||||
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
@@ -408,9 +403,9 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
||||
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
|
||||
|
||||
// gelu = x * sigmoid(1.702 * x) // current implementation
|
||||
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0_val);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
@@ -435,34 +430,23 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
|
||||
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_gelu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
unary_gelu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble2;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
const size_t src0_row_size = actx->src0_row_size;
|
||||
const size_t dst_row_size = actx->dst_row_size;
|
||||
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
const uint32_t src0_nrows = actx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
@@ -472,24 +456,27 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * data_src0 = (const uint8_t *) src0->data;
|
||||
uint8_t * data_dst = (uint8_t *) dst->data;
|
||||
const uint8_t * data_src0 = actx->data_src0;
|
||||
uint8_t * data_dst = actx->data_dst;
|
||||
|
||||
uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
const int ne0_val = actx->nc; // == dst->ne[0]
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
|
||||
uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
size_t src0_spad_half_size = actx->src0_spad_half_size;
|
||||
size_t dst_spad_half_size = actx->dst_spad_half_size;
|
||||
|
||||
const int BLOCK = actx->block;
|
||||
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
@@ -515,8 +502,8 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
|
||||
|
||||
// silu = x * sigmoid(x)
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
|
||||
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0_val);
|
||||
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val);
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
@@ -544,27 +531,22 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
|
||||
static const float GELU_COEF_A = 0.044715f;
|
||||
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||
|
||||
static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
const struct htp_tensor * src1,
|
||||
struct htp_tensor * dst,
|
||||
const int32_t * op_params,
|
||||
struct htp_spad * src0_spad,
|
||||
struct htp_spad * src1_spad,
|
||||
struct htp_spad * dst_spad,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread,
|
||||
dma_queue * dma_queue) {
|
||||
static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_act_context * actx = (struct htp_act_context *) data;
|
||||
const struct htp_tensor * src0 = &actx->octx->src0;
|
||||
const struct htp_tensor * src1 = &actx->octx->src1;
|
||||
const struct htp_tensor * dst = &actx->octx->dst;
|
||||
htp_act_preamble3;
|
||||
|
||||
size_t src0_row_size = nb01;
|
||||
size_t src1_row_size = nb11;
|
||||
size_t dst_row_size = nb1;
|
||||
size_t src0_row_size = actx->src0_row_size;
|
||||
size_t src1_row_size = actx->src1_row_size;
|
||||
size_t dst_row_size = actx->dst_row_size;
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src0_nrows = actx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
@@ -574,43 +556,34 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
const uint8_t * restrict data_src0 = actx->data_src0;
|
||||
const uint8_t * restrict data_src1 = actx->data_src1;
|
||||
uint8_t * restrict data_dst = actx->data_dst;
|
||||
|
||||
const bool src1_valid = src1->ne[0];
|
||||
const int nc = (src1_valid) ? ne00 : ne00 / 2;
|
||||
if (!src1_valid) {
|
||||
const int32_t swapped = op_params[1];
|
||||
data_src1 = data_src0;
|
||||
src1_row_size = src0_row_size;
|
||||
const int nc = actx->nc;
|
||||
|
||||
const size_t nc_in_bytes = nc * SIZEOF_FP32;
|
||||
data_src0 += swapped ? nc_in_bytes : 0;
|
||||
data_src1 += swapped ? 0 : nc_in_bytes;
|
||||
}
|
||||
const size_t src0_row_size_aligned = actx->src0_row_size_aligned;
|
||||
const size_t src1_row_size_aligned = actx->src1_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = actx->dst_row_size_aligned;
|
||||
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread);
|
||||
|
||||
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
|
||||
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
|
||||
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
|
||||
size_t src0_spad_half_size = actx->src0_spad_half_size;
|
||||
size_t src1_spad_half_size = actx->src1_spad_half_size;
|
||||
size_t dst_spad_half_size = actx->dst_spad_half_size;
|
||||
|
||||
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
|
||||
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
|
||||
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
|
||||
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
|
||||
|
||||
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
|
||||
const int BLOCK = actx->block;
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR,
|
||||
"geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
src0_spad->size_per_thread, src0_row_size_aligned);
|
||||
actx->octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
dma_queue * dma_queue = actx->octx->ctx->dma[ith];
|
||||
|
||||
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
@@ -678,33 +651,7 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
|
||||
octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_swiglu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_swiglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_swiglu_oai_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
|
||||
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
|
||||
}
|
||||
|
||||
static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
@@ -719,26 +666,26 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_UNARY_SILU:
|
||||
act_op_func = unary_silu_f32;
|
||||
act_op_func = (worker_callback_t)unary_silu_f32_per_thread;
|
||||
op_type = "silu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_SWIGLU:
|
||||
act_op_func = glu_swiglu_f32;
|
||||
act_op_func = (worker_callback_t)glu_swiglu_f32_per_thread;
|
||||
op_type = "swiglu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_SWIGLU_OAI:
|
||||
act_op_func = glu_swiglu_oai_f32;
|
||||
act_op_func = (worker_callback_t)glu_swiglu_oai_f32_per_thread;
|
||||
op_type = "swiglu-oai-f32";
|
||||
break;
|
||||
case HTP_OP_UNARY_GELU:
|
||||
act_op_func = unary_gelu_f32;
|
||||
act_op_func = (worker_callback_t)unary_gelu_f32_per_thread;
|
||||
op_type = "gelu-f32";
|
||||
break;
|
||||
|
||||
case HTP_OP_GLU_GEGLU:
|
||||
act_op_func = glu_geglu_f32;
|
||||
act_op_func = (worker_callback_t)glu_geglu_f32_per_thread;
|
||||
op_type = "geglu-f32";
|
||||
break;
|
||||
default:
|
||||
@@ -797,13 +744,58 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
|
||||
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||
}
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs);
|
||||
if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
return err;
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
|
||||
// Prepare context
|
||||
struct htp_act_context actx;
|
||||
actx.octx = octx;
|
||||
|
||||
actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
|
||||
actx.src0_row_size = src0_row_size;
|
||||
actx.src1_row_size = src1_row_size;
|
||||
actx.dst_row_size = dst_row_size;
|
||||
|
||||
actx.src0_row_size_aligned = src0_row_size_aligned;
|
||||
actx.src1_row_size_aligned = src1_row_size_aligned;
|
||||
actx.dst_row_size_aligned = dst_row_size_aligned;
|
||||
|
||||
actx.src0_spad_half_size = octx->src0_spad.size_per_thread / 2;
|
||||
actx.src1_spad_half_size = octx->src1_spad.size_per_thread / 2;
|
||||
actx.dst_spad_half_size = octx->dst_spad.size_per_thread / 2;
|
||||
|
||||
actx.block = actx.src0_spad_half_size / actx.src0_row_size_aligned;
|
||||
actx.src0_nrows = src0_nrows;
|
||||
|
||||
actx.nc = dst->ne[0];
|
||||
|
||||
// Pointers and GLU logic
|
||||
const uint8_t * data_src0 = (const uint8_t *) src0->data;
|
||||
const uint8_t * data_src1 = (const uint8_t *) src1->data;
|
||||
|
||||
if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) {
|
||||
const int32_t swapped = octx->op_params[1];
|
||||
data_src1 = data_src0;
|
||||
actx.src1_row_size = actx.src0_row_size;
|
||||
|
||||
size_t nc_in_bytes = actx.nc * SIZEOF_FP32;
|
||||
if (swapped) {
|
||||
data_src0 += nc_in_bytes;
|
||||
} else {
|
||||
data_src1 += nc_in_bytes;
|
||||
}
|
||||
}
|
||||
|
||||
actx.data_src0 = data_src0;
|
||||
actx.data_src1 = data_src1;
|
||||
actx.data_dst = (uint8_t *) dst->data;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs);
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
int op_activations(struct htp_ops_context * octx) {
|
||||
|
||||
@@ -15,6 +15,13 @@
|
||||
#include "htp-ops.h"
|
||||
#include "hvx-utils.h"
|
||||
|
||||
struct get_rows_context {
|
||||
struct htp_ops_context * octx;
|
||||
uint32_t src1_nrows_per_thread;
|
||||
struct fastdiv_values get_rows_div_ne10;
|
||||
struct fastdiv_values get_rows_div_ne10_ne11;
|
||||
};
|
||||
|
||||
#define get_rows_preamble \
|
||||
const uint32_t ne00 = octx->src0.ne[0]; \
|
||||
const uint32_t ne01 = octx->src0.ne[1]; \
|
||||
@@ -39,20 +46,22 @@
|
||||
\
|
||||
const uint32_t nr = ne10 * ne11 * ne12;
|
||||
|
||||
static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
struct get_rows_context * grctx = (struct get_rows_context *)data;
|
||||
struct htp_ops_context * octx = grctx->octx;
|
||||
get_rows_preamble;
|
||||
|
||||
// parallelize by src1 elements (which correspond to dst rows)
|
||||
const uint32_t dr = octx->src1_nrows_per_thread;
|
||||
const uint32_t dr = grctx->src1_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
|
||||
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
|
||||
const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11);
|
||||
const uint32_t rem = i - i12 * ne11 * ne10;
|
||||
const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
|
||||
const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10);
|
||||
const uint32_t i10 = rem - i11 * ne10;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
@@ -68,12 +77,6 @@ static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
|
||||
const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
|
||||
hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
|
||||
get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_get_rows(struct htp_ops_context * octx) {
|
||||
@@ -95,12 +98,14 @@ int op_get_rows(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
|
||||
octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
|
||||
struct get_rows_context grctx;
|
||||
grctx.octx = octx;
|
||||
grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
|
||||
grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs);
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ static inline bool dma_queue_push(dma_queue * q,
|
||||
dmlink(q->tail, desc);
|
||||
q->tail = desc;
|
||||
|
||||
// FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src);
|
||||
// FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src);
|
||||
q->push_idx = (q->push_idx + 1) & q->idx_mask;
|
||||
return true;
|
||||
}
|
||||
@@ -144,11 +144,37 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
|
||||
|
||||
dptr = q->dptr[q->pop_idx];
|
||||
|
||||
// FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst);
|
||||
// FARF(ERROR, "dma-pop: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
|
||||
q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
|
||||
return dptr;
|
||||
}
|
||||
|
||||
static inline dma_ptr dma_queue_pop_nowait(dma_queue * q) {
|
||||
dma_ptr dptr = { NULL };
|
||||
|
||||
if (q->push_idx == q->pop_idx) {
|
||||
return dptr;
|
||||
}
|
||||
|
||||
dptr = q->dptr[q->pop_idx];
|
||||
|
||||
// FARF(ERROR, "dma-pop-nowait: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src);
|
||||
q->pop_idx = (q->pop_idx + 1) & q->idx_mask;
|
||||
return dptr;
|
||||
}
|
||||
|
||||
static inline bool dma_queue_empty(dma_queue * q) {
|
||||
return q->push_idx == q->pop_idx;
|
||||
}
|
||||
|
||||
static inline uint32_t dma_queue_depth(dma_queue * q) {
|
||||
return (q->push_idx - q->pop_idx) & q->idx_mask;
|
||||
}
|
||||
|
||||
static inline uint32_t dma_queue_capacity(dma_queue * q) {
|
||||
return q->capacity;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
||||
@@ -44,32 +44,6 @@ struct htp_ops_context {
|
||||
uint32_t src0_nrows_per_thread;
|
||||
uint32_t src1_nrows_per_thread;
|
||||
|
||||
struct fastdiv_values src0_div1; // fastdiv values for ne1
|
||||
struct fastdiv_values src0_div2; // fastdiv values for ne2
|
||||
struct fastdiv_values src0_div3; // fastdiv values for ne3
|
||||
struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1
|
||||
|
||||
struct fastdiv_values src1_div1; // fastdiv values for ne1
|
||||
struct fastdiv_values src1_div2; // fastdiv values for ne2
|
||||
struct fastdiv_values src1_div3; // fastdiv values for ne3
|
||||
struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
|
||||
|
||||
struct fastdiv_values src3_div1; // fastdiv values for ne1
|
||||
struct fastdiv_values src3_div2; // fastdiv values for ne2
|
||||
struct fastdiv_values src3_div3; // fastdiv values for ne3
|
||||
struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
|
||||
|
||||
struct fastdiv_values broadcast_rk2;
|
||||
struct fastdiv_values broadcast_rk3;
|
||||
struct fastdiv_values broadcast_rv2;
|
||||
struct fastdiv_values broadcast_rv3;
|
||||
|
||||
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
|
||||
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
|
||||
|
||||
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
|
||||
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
|
||||
|
||||
uint32_t flags;
|
||||
};
|
||||
|
||||
|
||||
@@ -49,62 +49,6 @@ struct htp_matmul_context {
|
||||
struct fastdiv_values mm_div_r3;
|
||||
};
|
||||
|
||||
// vdelta control to replicate first 4x fp32 values across lanes
|
||||
static const uint8_t __attribute__((aligned(128))) repl_4x_f32[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
|
||||
0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
|
||||
0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04,
|
||||
0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
|
||||
0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04,
|
||||
0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
|
||||
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10,
|
||||
};
|
||||
|
||||
// vdelta control to replicate and interleave first 8x fp32 values across lanes
|
||||
static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_f32[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00,
|
||||
0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20,
|
||||
0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04,
|
||||
0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40,
|
||||
0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44,
|
||||
0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
|
||||
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20,
|
||||
};
|
||||
|
||||
// vdelta control to replicate first fp32 value across all elements
|
||||
static const uint8_t __attribute__((aligned(128))) repl_1x_f32[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10,
|
||||
0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04,
|
||||
0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08,
|
||||
0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08,
|
||||
0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04,
|
||||
0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10,
|
||||
0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04,
|
||||
};
|
||||
|
||||
// vdelta control to replicate first fp16 value across all elements
|
||||
static const uint8_t __attribute__((aligned(128))) repl_1x_f16[128] = {
|
||||
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02,
|
||||
0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04,
|
||||
0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08,
|
||||
0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02,
|
||||
0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02,
|
||||
0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10,
|
||||
0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
};
|
||||
|
||||
// vdelta control to replicate first fp16 value across all elements
|
||||
static const uint8_t __attribute__((aligned(128))) repl_2x_f16[128] = {
|
||||
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
|
||||
};
|
||||
|
||||
// vdelta control to expand first 32 e8m0 values into 32 uint32 elements
|
||||
static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
|
||||
0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
|
||||
@@ -2067,10 +2011,10 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric
|
||||
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
|
||||
|
||||
// Convert to QF32
|
||||
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
|
||||
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
|
||||
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
|
||||
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
|
||||
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes
|
||||
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes
|
||||
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes
|
||||
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes
|
||||
|
||||
// Combine and convert to fp16
|
||||
HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
|
||||
@@ -2080,11 +2024,6 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric
|
||||
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
|
||||
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
|
||||
|
||||
// Replicate first fp16 scale across all lanes
|
||||
HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_f16;
|
||||
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
|
||||
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
|
||||
|
||||
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
|
||||
@@ -2130,13 +2069,8 @@ static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restric
|
||||
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
|
||||
|
||||
// Compute max and scale
|
||||
HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
|
||||
HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf));
|
||||
|
||||
// Replicate first fp16 scale across all lanes
|
||||
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
|
||||
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
|
||||
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
|
||||
HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes
|
||||
HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes
|
||||
|
||||
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
@@ -2179,11 +2113,7 @@ static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restric
|
||||
|
||||
// Compute max and scale
|
||||
HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf));
|
||||
vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf);
|
||||
|
||||
// Replicate first fp16 scale across all lanes
|
||||
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16;
|
||||
vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl);
|
||||
vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes
|
||||
|
||||
HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
|
||||
HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16);
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "hex-fastdiv.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
@@ -21,6 +22,9 @@
|
||||
#define HTP_ROPE_TYPE_NORMAL 0
|
||||
#define HTP_ROPE_TYPE_NEOX 2
|
||||
|
||||
#define HTP_ROPE_SPAD_NROWS 16
|
||||
#define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2)
|
||||
|
||||
#define htp_rope_preamble \
|
||||
const uint32_t ne00 = src0->ne[0]; \
|
||||
const uint32_t ne01 = src0->ne[1]; \
|
||||
@@ -42,7 +46,7 @@
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct rope_th_ctx {
|
||||
struct htp_rope_context {
|
||||
int32_t n_dims;
|
||||
int32_t mode;
|
||||
int32_t n_ctx_orig;
|
||||
@@ -57,7 +61,19 @@ struct rope_th_ctx {
|
||||
float theta_scale;
|
||||
float corr_dims[2];
|
||||
|
||||
uint32_t src0_nrows_per_thread;
|
||||
size_t spad_stride;
|
||||
|
||||
struct htp_ops_context * octx;
|
||||
|
||||
size_t src0_row_size;
|
||||
size_t dst_row_size;
|
||||
size_t src0_row_size_aligned;
|
||||
size_t dst_row_size_aligned;
|
||||
size_t theta_cache_offset;
|
||||
uint32_t src0_nrows;
|
||||
|
||||
uint64_t t_start;
|
||||
};
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
@@ -117,64 +133,23 @@ static void rope_corr_dims(int n_dims,
|
||||
dims[1] = MIN(n_dims - 1, end);
|
||||
}
|
||||
|
||||
static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) {
|
||||
memset(rope_ctx, 0, sizeof(struct rope_th_ctx));
|
||||
static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
|
||||
const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0;
|
||||
const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
|
||||
const int32_t * op_params = &octx->op_params[0];
|
||||
uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2
|
||||
|
||||
rope_ctx->n_dims = ((const int32_t *) op_params)[1];
|
||||
rope_ctx->mode = ((const int32_t *) op_params)[2];
|
||||
rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4];
|
||||
uint32_t he = ne / 2; // half_dims offset in elements
|
||||
uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors
|
||||
|
||||
memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float));
|
||||
memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float));
|
||||
memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float));
|
||||
memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float));
|
||||
memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float));
|
||||
memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float));
|
||||
memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4);
|
||||
#pragma unroll(2)
|
||||
for (uint32_t i = 0; i < nvec; i += 2) {
|
||||
HVX_Vector v0 = vsrc[i/2+0];
|
||||
HVX_Vector v1 = vsrc[i/2+hv];
|
||||
|
||||
rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims);
|
||||
|
||||
rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast,
|
||||
rope_ctx->beta_slow, rope_ctx->corr_dims);
|
||||
|
||||
rope_ctx->octx = octx;
|
||||
FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims,
|
||||
rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor);
|
||||
}
|
||||
|
||||
static void hvx_calc_rope_neox_f32(const float * restrict src0,
|
||||
float * restrict dst,
|
||||
const int num_elems,
|
||||
const float * restrict theta_cache) {
|
||||
// for (int i = 0; i < num_elems; i += 2) {
|
||||
//const float cos_theta = theta_cache[i + 0];
|
||||
//const float sin_theta = theta_cache[i + 1];
|
||||
|
||||
//const float x0 = src[0];
|
||||
//const float x1 = src[num_elems/2];
|
||||
|
||||
//dst[0] = x0*cos_theta - x1*sin_theta;
|
||||
//dst[num_elems/2] = x0*sin_theta + x1*cos_theta;
|
||||
|
||||
//src += 1;
|
||||
//dst += 1;
|
||||
// }
|
||||
|
||||
const uint8_t * restrict src0_curr = (const uint8_t *) src0;
|
||||
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
|
||||
uint8_t * restrict dst_curr = (uint8_t *) dst;
|
||||
|
||||
int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
|
||||
int half_size = (sizeof(float) * (num_elems / 2));
|
||||
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
|
||||
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size);
|
||||
|
||||
HVX_Vector v2 = *(HVX_Vector *) theta_curr;
|
||||
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
|
||||
HVX_Vector v2 = vtheta[i+0];
|
||||
HVX_Vector v3 = vtheta[i+1];
|
||||
|
||||
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
|
||||
|
||||
@@ -186,45 +161,34 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0,
|
||||
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
|
||||
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
|
||||
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);
|
||||
vdst[i/2+0] = Q6_Vsf_equals_Vqf32(v4);
|
||||
vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5);
|
||||
}
|
||||
|
||||
src0_curr += VLEN;
|
||||
theta_curr += 2 * VLEN;
|
||||
dst_curr += VLEN;
|
||||
for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
|
||||
const float cos_theta = theta_cache[i+0];
|
||||
const float sin_theta = theta_cache[i+1];
|
||||
float x0 = src0[i/2];
|
||||
float x1 = src0[i/2 + he];
|
||||
dst[i/2] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta;
|
||||
}
|
||||
}
|
||||
|
||||
static void hvx_calc_rope_f32(const float * restrict src0,
|
||||
float * restrict dst,
|
||||
const int num_elems,
|
||||
const float * restrict theta_cache) {
|
||||
// for (int i = 0; i < num_elems; i += 2) {
|
||||
//const float cos_theta = theta_cache[i + 0];
|
||||
//const float sin_theta = theta_cache[i + 1];
|
||||
static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) {
|
||||
const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0;
|
||||
const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache;
|
||||
HVX_Vector * restrict vdst = (HVX_Vector *) dst;
|
||||
|
||||
//const float x0 = src[0];
|
||||
//const float x1 = src[1];
|
||||
uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two
|
||||
|
||||
//dst[0] = x0*cos_theta - x1*sin_theta;
|
||||
//dst[1] = x0*sin_theta + x1*cos_theta;
|
||||
#pragma unroll(2)
|
||||
for (uint32_t i = 0; i < nvec; i+=2) {
|
||||
HVX_Vector v0 = vsrc[i+0];
|
||||
HVX_Vector v1 = vsrc[i+1];
|
||||
|
||||
//src += 2;
|
||||
//dst += 2;
|
||||
// }
|
||||
|
||||
const uint8_t * restrict src0_curr = (const uint8_t *) src0;
|
||||
const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache;
|
||||
uint8_t * restrict dst_curr = (uint8_t *) dst;
|
||||
|
||||
int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once
|
||||
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
HVX_Vector v0 = *(HVX_Vector *) src0_curr;
|
||||
HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN);
|
||||
|
||||
HVX_Vector v2 = *(HVX_Vector *) theta_curr;
|
||||
HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN);
|
||||
HVX_Vector v2 = vtheta[i+0];
|
||||
HVX_Vector v3 = vtheta[i+1];
|
||||
|
||||
HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1
|
||||
HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta
|
||||
@@ -239,116 +203,65 @@ static void hvx_calc_rope_f32(const float * restrict src0,
|
||||
|
||||
HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4);
|
||||
|
||||
*(HVX_Vector *) dst_curr = Q6_V_lo_W(vstore);
|
||||
*(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore);
|
||||
vdst[i+0] = Q6_V_lo_W(vstore);
|
||||
vdst[i+1] = Q6_V_hi_W(vstore);
|
||||
}
|
||||
|
||||
src0_curr += 2 * VLEN;
|
||||
theta_curr += 2 * VLEN;
|
||||
dst_curr += 2 * VLEN;
|
||||
for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) {
|
||||
const float cos_theta = theta_cache[i+0];
|
||||
const float sin_theta = theta_cache[i+1];
|
||||
float x0 = src0[i+0];
|
||||
float x1 = src0[i+1];
|
||||
dst[i+0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst[i+1] = x0 * sin_theta + x1 * cos_theta;
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
|
||||
const uint32_t ir0,
|
||||
const uint32_t ir1,
|
||||
int nth,
|
||||
int ith,
|
||||
const int opt_path) {
|
||||
struct htp_ops_context * octx = rope_ctx->octx;
|
||||
static void inline rope_basic_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
|
||||
uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
|
||||
#pragma unroll(4)
|
||||
for (uint32_t i = 0; i < nr; i++) {
|
||||
float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
|
||||
float * s = (float *) (src + i * rctx->src0_row_size_aligned);
|
||||
|
||||
hvx_rope_f32_aa(d, s, rctx->n_dims, theta_cache);
|
||||
|
||||
// fill the remain channels with data from src tensor
|
||||
if (rctx->n_dims < ne0) {
|
||||
hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void inline rope_neox_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src,
|
||||
uint32_t nr, uint32_t ne0, const float * restrict theta_cache) {
|
||||
#pragma unroll(4)
|
||||
for (uint32_t i = 0; i < nr; i++) {
|
||||
float * d = (float *) (dst + i * rctx->dst_row_size_aligned);
|
||||
float * s = (float *) (src + i * rctx->src0_row_size_aligned);
|
||||
|
||||
hvx_rope_neox_f32_aa(d, s, rctx->n_dims, theta_cache);
|
||||
|
||||
// fill the remain channels with data from src tensor
|
||||
if (rctx->n_dims < ne0) {
|
||||
hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_rope_context * rctx = (struct htp_rope_context *) data;
|
||||
struct htp_ops_context * octx = rctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * src2 = &octx->src2;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
const int32_t mode = rope_ctx->mode;
|
||||
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
|
||||
|
||||
htp_rope_preamble;
|
||||
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
|
||||
float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01));
|
||||
|
||||
const float * freq_factors = NULL;
|
||||
if (src2 != NULL) {
|
||||
freq_factors = (const float *) src2->data;
|
||||
}
|
||||
|
||||
const uint32_t i1_end = MIN(ir1, ne1);
|
||||
const int32_t half_dims = rope_ctx->n_dims / 2;
|
||||
const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float);
|
||||
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
|
||||
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
||||
const int32_t p = pos[i2];
|
||||
|
||||
rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
|
||||
rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);
|
||||
|
||||
for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads
|
||||
const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
|
||||
float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);
|
||||
|
||||
const float * src_loc = src;
|
||||
float * dst_data_loc = dst_data;
|
||||
|
||||
if (1 == opt_path) {
|
||||
if (is_neox) {
|
||||
hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
|
||||
} else {
|
||||
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
|
||||
}
|
||||
|
||||
src_loc += rope_ctx->n_dims;
|
||||
dst_data_loc += rope_ctx->n_dims;
|
||||
} else {
|
||||
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
|
||||
const float cos_theta = wp0[i0 + 0];
|
||||
const float sin_theta = wp0[i0 + 1];
|
||||
|
||||
if (is_neox) {
|
||||
const float x0 = src_loc[0];
|
||||
const float x1 = src_loc[half_dims];
|
||||
|
||||
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;
|
||||
|
||||
src_loc += 1;
|
||||
dst_data_loc += 1;
|
||||
} else {
|
||||
const float x0 = src_loc[0];
|
||||
const float x1 = src_loc[1];
|
||||
|
||||
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
|
||||
dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta;
|
||||
|
||||
src_loc += 2;
|
||||
dst_data_loc += 2;
|
||||
}
|
||||
}
|
||||
|
||||
src_loc += (is_neox ? half_dims : 0);
|
||||
dst_data_loc += (is_neox ? half_dims : 0);
|
||||
}
|
||||
|
||||
// TODO: use simd to speed up the remaining elements copy
|
||||
memcpy(dst_data_loc, src_loc, remain_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) {
|
||||
struct htp_ops_context * octx = rope_ctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
htp_rope_preamble;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||
const uint32_t src0_nrows = rctx->src0_nrows;
|
||||
const uint32_t src0_nrows_per_thread = rctx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
@@ -358,32 +271,114 @@ static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
uint64_t tt = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) ||
|
||||
(0 == hex_is_aligned((void *) dst->data, VLEN))) {
|
||||
FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n");
|
||||
is_aligned = 0;
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
const int32_t mode = rctx->mode;
|
||||
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
|
||||
|
||||
// VTCM setup
|
||||
uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
|
||||
float * theta_cache = (float *) (src0_spad_base);
|
||||
src0_spad_base = src0_spad_base + rctx->theta_cache_offset;
|
||||
uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
|
||||
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
const int32_t * pos = (const int32_t *) src1->data;
|
||||
const float * freq_factors = src2->data ? (const float *) src2->data : NULL;
|
||||
|
||||
uint32_t ir = 0;
|
||||
uint32_t prev_i2 = (uint32_t) -1;
|
||||
|
||||
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
|
||||
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
||||
for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads
|
||||
if (ir < src0_start_row) { ir++; i1++; continue; }
|
||||
if (ir >= src0_end_row) goto done;
|
||||
|
||||
// Rows in this block
|
||||
const uint32_t nrows = MIN(src0_end_row - ir, ne1 - i1);
|
||||
|
||||
// Depth before prefetch
|
||||
uint32_t dma_depth = dma_queue_depth(dma_queue);
|
||||
|
||||
// FARF(HIGH, "rope-block %u: ir %u n-rows %u dma-depth %u : usec %u", ith, ir, nrows, dma_depth,
|
||||
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
|
||||
|
||||
// Prefetch loop
|
||||
for (uint32_t pnr = 0, pr = 0; pr < nrows && pr < HTP_ROPE_SPAD_NROWS; pr += pnr) {
|
||||
pnr = MIN(nrows - pr, HTP_ROPE_SPAD_BLOCK);
|
||||
|
||||
uint32_t pi1 = i1 + pr;
|
||||
uint32_t pir = ir + pr;
|
||||
|
||||
// Dummy DMA transaction for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr((void *) dst->data, dst_spad_base + pr * rctx->dst_row_size_aligned), 0, 0, 0);
|
||||
|
||||
const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
|
||||
uint8_t * src_spad = src0_spad_base + pr * rctx->src0_row_size_aligned;
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
|
||||
rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
|
||||
|
||||
// FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
|
||||
}
|
||||
|
||||
// Update theta cache
|
||||
if (i2 != prev_i2) {
|
||||
prev_i2 = i2;
|
||||
|
||||
const int32_t p = pos[i2];
|
||||
rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale);
|
||||
|
||||
// FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache,
|
||||
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
|
||||
}
|
||||
|
||||
// Skip DMA transactions from prev block (if any)
|
||||
// No need to wait for these since the DMA is setup for in-order processing
|
||||
for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); }
|
||||
|
||||
// Compute loop
|
||||
for (uint32_t cnr = 0, cr = 0; cr < nrows; cr += cnr, ir += cnr, i1 += cnr) {
|
||||
// Number of rows to compute
|
||||
cnr = MIN(nrows - cr, HTP_ROPE_SPAD_BLOCK);
|
||||
|
||||
uint8_t * dst_spad = (uint8_t *) dma_queue_pop(dma_queue).src;
|
||||
uint8_t * src_spad = (uint8_t *) dma_queue_pop(dma_queue).dst;
|
||||
|
||||
// FARF(HIGH, "rope-compute %u: ir %u i1 %u i2 %u i3 %u src-spad %p cnr %u : usec %u", ith, ir, i1, i2, i3, src_spad, cnr,
|
||||
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
|
||||
|
||||
if (is_neox) {
|
||||
rope_neox_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
|
||||
} else {
|
||||
rope_basic_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache);
|
||||
}
|
||||
|
||||
uint8_t * dst_addr = (uint8_t *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1;
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(dst_addr, dst_spad), rctx->dst_row_size, rctx->dst_row_size_aligned, cnr);
|
||||
|
||||
// Prefetch more rows (if any)
|
||||
if ((cr + HTP_ROPE_SPAD_NROWS) < nrows) {
|
||||
uint32_t pnr = MIN(nrows - (cr + HTP_ROPE_SPAD_NROWS), HTP_ROPE_SPAD_BLOCK);
|
||||
uint32_t pi1 = i1 + HTP_ROPE_SPAD_NROWS;
|
||||
uint32_t pir = ir + HTP_ROPE_SPAD_NROWS;
|
||||
|
||||
const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01;
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr),
|
||||
rctx->src0_row_size_aligned, rctx->src0_row_size, pnr);
|
||||
|
||||
// FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path);
|
||||
done:
|
||||
dma_queue_flush(dma_queue);
|
||||
tt = HAP_perf_get_qtimer_count() - tt;
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row,
|
||||
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data;
|
||||
|
||||
rope_job_f32_per_thread(rope_ctx, n, i);
|
||||
FARF(HIGH, "rope-f32: %d/%d: (%u:%u) usec %u\n", ith, nth, src0_start_row, src0_end_row, (unsigned) HAP_perf_qtimer_count_to_us(tt));
|
||||
}
|
||||
|
||||
static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
||||
@@ -394,17 +389,10 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src2 = &octx->src2;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
worker_callback_t op_func;
|
||||
const char * op_type = NULL;
|
||||
|
||||
struct rope_th_ctx rope_ctx;
|
||||
const char * op_type = "rope-f32";
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_ROPE:
|
||||
op_func = rope_job_dispatcher_f32;
|
||||
op_type = "rope-f32";
|
||||
|
||||
init_rope_ctx(&rope_ctx, octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
@@ -415,49 +403,79 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) {
|
||||
const uint32_t n_threads = octx->n_threads;
|
||||
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t src1_row_size = src0_row_size;
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
|
||||
// Aligned row sizes for VTCM
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128);
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
// Calculate spad sizes per thread
|
||||
size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned;
|
||||
size_t dst_spad_per_thread = HTP_ROPE_SPAD_NROWS * dst_row_size_aligned;
|
||||
size_t spad_per_thread = src0_spad_per_thread + dst_spad_per_thread;
|
||||
|
||||
if (src2->ne[0]) {
|
||||
FARF(HIGH,
|
||||
"%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u "
|
||||
"dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2],
|
||||
dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||
} else {
|
||||
FARF(HIGH,
|
||||
"%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n",
|
||||
op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2],
|
||||
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size,
|
||||
octx->dst_spad.size);
|
||||
}
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (octx->ctx->vtcm_size < spad_size) {
|
||||
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size);
|
||||
// Check if we fit in VTCM
|
||||
size_t total_vtcm_needed = spad_per_thread * n_threads;
|
||||
if (octx->ctx->vtcm_size < total_vtcm_needed) {
|
||||
FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, total_vtcm_needed);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size;
|
||||
// Assign sizes
|
||||
octx->src0_spad.size_per_thread = src0_spad_per_thread;
|
||||
octx->dst_spad.size_per_thread = dst_spad_per_thread;
|
||||
octx->src0_spad.size = n_threads * src0_spad_per_thread;
|
||||
octx->dst_spad.size = n_threads * dst_spad_per_thread;
|
||||
octx->src1_spad.size = 0;
|
||||
|
||||
// Assign pointers
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->src1_spad.data = NULL;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
// Fill context
|
||||
struct htp_rope_context rctx;
|
||||
memset(&rctx, 0, sizeof(struct htp_rope_context));
|
||||
|
||||
rctx.t_start = HAP_perf_get_qtimer_count();
|
||||
|
||||
rctx.octx = octx;
|
||||
|
||||
const int32_t * op_params = &octx->op_params[0];
|
||||
rctx.n_dims = ((const int32_t *) op_params)[1];
|
||||
rctx.mode = ((const int32_t *) op_params)[2];
|
||||
rctx.n_ctx_orig = ((const int32_t *) op_params)[4];
|
||||
|
||||
memcpy(&rctx.freq_base, (int32_t *) op_params + 5, sizeof(float));
|
||||
memcpy(&rctx.freq_scale, (int32_t *) op_params + 6, sizeof(float));
|
||||
memcpy(&rctx.ext_factor, (int32_t *) op_params + 7, sizeof(float));
|
||||
memcpy(&rctx.attn_factor, (int32_t *) op_params + 8, sizeof(float));
|
||||
memcpy(&rctx.beta_fast, (int32_t *) op_params + 9, sizeof(float));
|
||||
memcpy(&rctx.beta_slow, (int32_t *) op_params + 10, sizeof(float));
|
||||
memcpy(&rctx.sections, (int32_t *) op_params + 11, sizeof(int) * 4);
|
||||
|
||||
rctx.theta_scale = powf(rctx.freq_base, -2.0f / rctx.n_dims);
|
||||
|
||||
rope_corr_dims(rctx.n_dims, rctx.n_ctx_orig, rctx.freq_base, rctx.beta_fast, rctx.beta_slow, rctx.corr_dims);
|
||||
|
||||
rctx.src0_row_size = src0_row_size;
|
||||
rctx.dst_row_size = dst_row_size;
|
||||
rctx.src0_row_size_aligned = src0_row_size_aligned;
|
||||
rctx.dst_row_size_aligned = dst_row_size_aligned;
|
||||
rctx.theta_cache_offset = theta_cache_size_aligned;
|
||||
|
||||
uint32_t ne0 = dst->ne[0];
|
||||
uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3];
|
||||
rctx.src0_nrows = src0_nrows;
|
||||
|
||||
FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0,
|
||||
rctx.ext_factor, rctx.theta_scale, rctx.attn_factor);
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs);
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
|
||||
@@ -43,11 +43,21 @@
|
||||
\
|
||||
const uint32_t nr = ne01;
|
||||
|
||||
static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
struct htp_set_rows_context {
|
||||
struct htp_ops_context * octx;
|
||||
struct fastdiv_values div_ne12;
|
||||
struct fastdiv_values div_ne11;
|
||||
uint32_t src0_nrows_per_thread;
|
||||
};
|
||||
|
||||
static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
|
||||
struct htp_ops_context * octx = srctx->octx;
|
||||
|
||||
set_rows_preamble;
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = octx->src0_nrows_per_thread;
|
||||
const uint32_t dr = srctx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
@@ -56,8 +66,8 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
|
||||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
@@ -76,15 +86,16 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
|
||||
struct htp_ops_context * octx = srctx->octx;
|
||||
|
||||
set_rows_preamble;
|
||||
|
||||
// parallelize by rows of src0
|
||||
const uint32_t dr = octx->src0_nrows_per_thread;
|
||||
const uint32_t dr = srctx->src0_nrows_per_thread;
|
||||
const uint32_t ir0 = dr * ith;
|
||||
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
|
||||
|
||||
@@ -93,8 +104,8 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth,
|
||||
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (uint32_t i = ir0; i < ir1; ++i) {
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
|
||||
const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
|
||||
const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
|
||||
const uint32_t i10 = i;
|
||||
|
||||
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
|
||||
@@ -112,16 +123,6 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
|
||||
set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
|
||||
set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_set_rows(struct htp_ops_context * octx) {
|
||||
@@ -143,18 +144,20 @@ int op_set_rows(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
|
||||
octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
|
||||
struct htp_set_rows_context srctx;
|
||||
srctx.octx = octx;
|
||||
srctx.div_ne12 = init_fastdiv_values(ne12);
|
||||
srctx.div_ne11 = init_fastdiv_values(ne11);
|
||||
|
||||
const uint32_t n_jobs = MIN(nr, octx->n_threads);
|
||||
octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
|
||||
|
||||
switch(octx->dst.type) {
|
||||
case HTP_TYPE_F32:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs);
|
||||
break;
|
||||
case HTP_TYPE_F16:
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
|
||||
worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs);
|
||||
break;
|
||||
default:
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "hex-dma.h"
|
||||
#include "hvx-utils.h"
|
||||
#include "hex-fastdiv.h"
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
@@ -48,7 +49,7 @@
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3];
|
||||
|
||||
struct softmax_th_ctx {
|
||||
struct htp_softmax_context {
|
||||
bool use_f16;
|
||||
bool use_src1;
|
||||
uint32_t n_head;
|
||||
@@ -59,28 +60,48 @@ struct softmax_th_ctx {
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t src0_nrows_per_thread;
|
||||
struct fastdiv_values fastdiv_ne01;
|
||||
struct fastdiv_values fastdiv_ne02;
|
||||
struct fastdiv_values fastdiv_ne12; // For mask broadcasting
|
||||
struct fastdiv_values fastdiv_ne13; // For mask broadcasting
|
||||
size_t spad_stride;
|
||||
|
||||
struct htp_ops_context * octx;
|
||||
};
|
||||
|
||||
static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) {
|
||||
static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
|
||||
memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx));
|
||||
memset(smctx, 0, sizeof(struct htp_softmax_context));
|
||||
|
||||
memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float));
|
||||
memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float));
|
||||
memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float));
|
||||
|
||||
softmax_ctx->n_head = src0->ne[2];
|
||||
softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head));
|
||||
smctx->n_head = src0->ne[2];
|
||||
smctx->n_head_log2 = 1u << (uint32_t) floor(log2(smctx->n_head));
|
||||
|
||||
softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2);
|
||||
softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2);
|
||||
smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2);
|
||||
smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2);
|
||||
|
||||
softmax_ctx->use_src1 = (src1->ne[0] != 0);
|
||||
softmax_ctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
|
||||
smctx->use_src1 = (src1->ne[0] != 0);
|
||||
smctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16);
|
||||
|
||||
softmax_ctx->octx = octx;
|
||||
smctx->octx = octx;
|
||||
|
||||
// Initialize fastdiv values
|
||||
const uint32_t ne01 = src0->ne[1];
|
||||
const uint32_t ne02 = src0->ne[2];
|
||||
|
||||
if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01);
|
||||
if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02);
|
||||
|
||||
const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1;
|
||||
const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1;
|
||||
|
||||
if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12);
|
||||
if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13);
|
||||
}
|
||||
|
||||
static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src,
|
||||
@@ -139,8 +160,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
||||
max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1);
|
||||
}
|
||||
|
||||
HVX_Vector v = hvx_vec_reduce_max_f32(max_vec);
|
||||
max_vec = hvx_vec_repl4(v);
|
||||
max_vec = hvx_vec_reduce_max_f32(max_vec); // replicated over all lanes
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
@@ -154,8 +174,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
|
||||
v_pad[i] = v3;
|
||||
}
|
||||
|
||||
v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec));
|
||||
sum_vec = hvx_vec_repl4(v);
|
||||
sum_vec = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); // replicated over all lanes
|
||||
|
||||
HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
|
||||
HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec);
|
||||
@@ -183,83 +202,9 @@ static float hvx_softmax_f32(const uint8_t * restrict src,
|
||||
return sum;
|
||||
}
|
||||
|
||||
static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) {
|
||||
struct htp_ops_context * octx = softmax_ctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
htp_softmax_preamble3;
|
||||
|
||||
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01);
|
||||
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01);
|
||||
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * nb1);
|
||||
|
||||
float * wp0 = (float *) src0_spad_data;
|
||||
float * wp1 = (float *) src1_spad_data;
|
||||
float * wp2 = (float *) dst_spad_data;
|
||||
|
||||
for (uint32_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (uint32_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (uint32_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const uint32_t i11 = i01;
|
||||
const uint32_t i12 = i02 % ne12;
|
||||
const uint32_t i13 = i03 % ne13;
|
||||
|
||||
// ALiBi
|
||||
const uint32_t h = i02; // head
|
||||
|
||||
const float slope = (softmax_ctx->max_bias > 0.0f) ?
|
||||
h < softmax_ctx->n_head_log2 ?
|
||||
powf(softmax_ctx->m0, h + 1) :
|
||||
powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) :
|
||||
1.0f;
|
||||
|
||||
float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03);
|
||||
float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3);
|
||||
|
||||
// broadcast the mask across rows
|
||||
__fp16 * mp_f16 = (softmax_ctx->use_src1) ?
|
||||
(__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
float * mp_f32 = (softmax_ctx->use_src1) ?
|
||||
(float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
|
||||
if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) {
|
||||
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
|
||||
(const uint8_t *) mp_f32, slope);
|
||||
} else {
|
||||
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale);
|
||||
if (mp_f32) {
|
||||
if (softmax_ctx->use_f16) {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * (float) mp_f16[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * mp_f32[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
|
||||
} else {
|
||||
float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
|
||||
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
|
||||
sum = sum > 0.0 ? (1.0 / sum) : 1;
|
||||
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) {
|
||||
struct htp_ops_context * octx = softmax_ctx->octx;
|
||||
static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) {
|
||||
struct htp_softmax_context * smctx = (struct htp_softmax_context *) data;
|
||||
struct htp_ops_context * octx = smctx->octx;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
@@ -268,7 +213,7 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
|
||||
htp_softmax_preamble3;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||
const uint32_t src0_nrows_per_thread = smctx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
@@ -291,20 +236,103 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int
|
||||
opt_path = 1;
|
||||
}
|
||||
|
||||
softmax_htp_f32(nth, ith, softmax_ctx, opt_path);
|
||||
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride);
|
||||
uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride);
|
||||
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * smctx->spad_stride);
|
||||
|
||||
float * wp0 = (float *) src0_spad_data;
|
||||
float * wp1 = (float *) src1_spad_data;
|
||||
float * wp2 = (float *) dst_spad_data;
|
||||
|
||||
uint32_t prev_i2 = (uint32_t)-1;
|
||||
float slope = 1.0f;
|
||||
|
||||
for (uint32_t r = src0_start_row; r < src0_end_row; ++r) {
|
||||
uint32_t i1 = fastmodulo(r, ne01, &smctx->fastdiv_ne01);
|
||||
uint32_t r_div_ne01 = fastdiv(r, &smctx->fastdiv_ne01);
|
||||
uint32_t i2 = fastmodulo(r_div_ne01, ne02, &smctx->fastdiv_ne02);
|
||||
uint32_t i3 = fastdiv(r_div_ne01, &smctx->fastdiv_ne02);
|
||||
|
||||
// Map to original logic indices
|
||||
// i01 = i1
|
||||
// i02 = i2
|
||||
// i03 = i3
|
||||
|
||||
const uint32_t i11 = i1;
|
||||
// const uint32_t i12 = i2 % ne12;
|
||||
// const uint32_t i13 = i3 % ne13;
|
||||
|
||||
uint32_t i12, i13;
|
||||
if (ne12 == ne02) {
|
||||
i12 = i2;
|
||||
} else {
|
||||
i12 = fastmodulo(i2, ne12, &smctx->fastdiv_ne12);
|
||||
}
|
||||
|
||||
if (ne13 == ne03) {
|
||||
i13 = i3;
|
||||
} else {
|
||||
i13 = fastmodulo(i3, ne13, &smctx->fastdiv_ne13);
|
||||
}
|
||||
|
||||
// ALiBi
|
||||
if (i2 != prev_i2) {
|
||||
const uint32_t h = i2; // head
|
||||
|
||||
slope = (smctx->max_bias > 0.0f) ?
|
||||
h < smctx->n_head_log2 ?
|
||||
powf(smctx->m0, h + 1) :
|
||||
powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) :
|
||||
1.0f;
|
||||
prev_i2 = i2;
|
||||
}
|
||||
|
||||
float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03);
|
||||
float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3);
|
||||
|
||||
// broadcast the mask across rows
|
||||
__fp16 * mp_f16 = (smctx->use_src1) ?
|
||||
(__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
float * mp_f32 = (smctx->use_src1) ?
|
||||
(float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) :
|
||||
NULL;
|
||||
|
||||
if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) {
|
||||
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale,
|
||||
(const uint8_t *) mp_f32, slope);
|
||||
} else {
|
||||
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale);
|
||||
if (mp_f32) {
|
||||
if (smctx->use_f16) {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * (float) mp_f16[i];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ne00; ++i) {
|
||||
wp0[i] += slope * mp_f32[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00);
|
||||
} else {
|
||||
float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00);
|
||||
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
|
||||
sum = sum > 0.0 ? (1.0 / sum) : 1;
|
||||
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum);
|
||||
}
|
||||
}
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
|
||||
softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
|
||||
smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13,
|
||||
ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) {
|
||||
struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data;
|
||||
softmax_job_f32_per_thread(p_softmax_ctx, n, i);
|
||||
}
|
||||
|
||||
static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
@@ -312,17 +340,12 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
const struct htp_tensor * src1 = &octx->src1;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
worker_callback_t op_func;
|
||||
const char * op_type = NULL;
|
||||
|
||||
struct softmax_th_ctx softmax_ctx;
|
||||
struct htp_softmax_context smctx;
|
||||
const char * op_type = "softmax-f32";
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_SOFTMAX:
|
||||
op_func = softmax_job_dispatcher_f32;
|
||||
op_type = "softmax-f32";
|
||||
|
||||
init_softmax_ctx(&softmax_ctx, octx);
|
||||
init_softmax_ctx(&smctx, octx);
|
||||
break;
|
||||
|
||||
default:
|
||||
@@ -342,6 +365,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
|
||||
octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads;
|
||||
|
||||
// Use stride for calculating offset
|
||||
smctx.spad_stride = hex_round_up(src0_row_size, 128);
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size;
|
||||
|
||||
if (src1->ne[0]) {
|
||||
@@ -371,8 +397,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) {
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs);
|
||||
smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
|
||||
#define sum_rows_preamble \
|
||||
struct htp_tensor *src0 = &octx->src0;\
|
||||
struct htp_tensor *dst = &octx->dst; \
|
||||
@@ -42,53 +41,54 @@
|
||||
const uint32_t nb2 = dst->nb[2]; \
|
||||
const uint32_t nb3 = dst->nb[3]; \
|
||||
|
||||
static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
|
||||
sum_rows_preamble;
|
||||
struct sum_rows_context {
|
||||
const uint8_t * src_data;
|
||||
uint8_t * dst_data;
|
||||
uint32_t ne00;
|
||||
size_t src_stride;
|
||||
size_t dst_stride;
|
||||
uint32_t rows_per_thread;
|
||||
uint32_t total_rows;
|
||||
bool opt_path;
|
||||
};
|
||||
|
||||
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) {
|
||||
const struct sum_rows_context * smctx = (const struct sum_rows_context *) data;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const uint32_t rows_per_thread = smctx->rows_per_thread;
|
||||
const uint32_t total_rows = smctx->total_rows;
|
||||
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
const uint32_t start_row = rows_per_thread * ith;
|
||||
const uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
|
||||
|
||||
// no work for this thread
|
||||
if (src0_start_row >= src0_end_row) {
|
||||
return HTP_STATUS_OK;
|
||||
if (start_row >= end_row) {
|
||||
return;
|
||||
}
|
||||
|
||||
int opt_path = 0;
|
||||
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
}
|
||||
const size_t src_stride = smctx->src_stride;
|
||||
const size_t dst_stride = smctx->dst_stride;
|
||||
const uint32_t ne00 = smctx->ne00;
|
||||
const bool opt_path = smctx->opt_path;
|
||||
|
||||
const uint8_t * restrict data_src = (const uint8_t *) src0->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
const float * restrict src_th = (const float *) (smctx->src_data + (start_row * src_stride));
|
||||
float * restrict dst_th = (float *) (smctx->dst_data + (start_row * dst_stride));
|
||||
|
||||
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
|
||||
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
|
||||
// Calculate actual number of rows for this thread
|
||||
const uint32_t n_rows = end_row - start_row;
|
||||
|
||||
for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
|
||||
const float * restrict src_local = src_th + (ir * ne00);
|
||||
for (uint32_t ir = 0; ir < n_rows; ir++) {
|
||||
const float * restrict src_local = src_th + (ir * (src_stride / sizeof(float)));
|
||||
|
||||
if (ir + 1 < src0_nrows_per_thread) {
|
||||
hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
|
||||
if (ir + 1 < n_rows) {
|
||||
hex_l2fetch(src_local + (src_stride / sizeof(float)), src_stride, src_stride, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
if (opt_path) {
|
||||
dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
|
||||
} else {
|
||||
dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
|
||||
}
|
||||
}
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
|
||||
sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
|
||||
}
|
||||
|
||||
int op_sum_rows(struct htp_ops_context * octx) {
|
||||
@@ -106,10 +106,25 @@ int op_sum_rows(struct htp_ops_context * octx) {
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03;
|
||||
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
|
||||
bool opt_path = false;
|
||||
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = true;
|
||||
}
|
||||
|
||||
struct sum_rows_context smctx = {
|
||||
.src_data = (const uint8_t *) src0->data,
|
||||
.dst_data = (uint8_t *) dst->data,
|
||||
.ne00 = ne00,
|
||||
.src_stride = nb01,
|
||||
.dst_stride = nb1,
|
||||
.rows_per_thread = rows_per_thread,
|
||||
.total_rows = src0_nrows,
|
||||
.opt_path = opt_path,
|
||||
};
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs);
|
||||
|
||||
return HTP_STATUS_OK;
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,28 @@
|
||||
#include "htp-msg.h"
|
||||
#include "htp-ops.h"
|
||||
|
||||
struct htp_unary_context {
|
||||
struct htp_ops_context * octx;
|
||||
|
||||
// Precomputed values
|
||||
const uint8_t * data_src0;
|
||||
uint8_t * data_dst;
|
||||
|
||||
size_t src0_row_size;
|
||||
size_t dst_row_size;
|
||||
|
||||
size_t src0_row_size_aligned;
|
||||
size_t dst_row_size_aligned;
|
||||
|
||||
size_t src0_spad_half_size;
|
||||
size_t dst_spad_half_size;
|
||||
|
||||
uint32_t block;
|
||||
uint32_t src0_nrows;
|
||||
uint32_t src0_nrows_per_thread;
|
||||
uint32_t nc;
|
||||
};
|
||||
|
||||
#define htp_unary_preamble \
|
||||
const uint32_t ne00 = src->ne[0]; \
|
||||
const uint32_t ne01 = src->ne[1]; \
|
||||
@@ -57,8 +79,7 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||
}
|
||||
|
||||
HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
|
||||
sum_v = hvx_vec_repl4(reduced_sum);
|
||||
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes
|
||||
|
||||
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
||||
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
||||
@@ -75,128 +96,95 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
||||
}
|
||||
}
|
||||
|
||||
static void scale_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
static void scale_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
float scale = 0.f;
|
||||
float bias = 0.f;
|
||||
memcpy(&scale, &op_params[0], sizeof(float));
|
||||
memcpy(&bias, &op_params[1], sizeof(float));
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
|
||||
hvx_scale_offset_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
|
||||
}
|
||||
}
|
||||
|
||||
static void rms_norm_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
static void rms_norm_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
float epsilon = 0.f;
|
||||
memcpy(&epsilon, op_params, sizeof(float));
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
|
||||
} else {
|
||||
float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems);
|
||||
|
||||
const float mean = sum / row_elems;
|
||||
const float scale = 1.0f / sqrtf(mean + epsilon);
|
||||
|
||||
hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale);
|
||||
}
|
||||
hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
static void sqr_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
static void sqr_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
} else {
|
||||
hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static void sqrt_htp_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params,
|
||||
int opt_path) {
|
||||
static void sqrt_f32(const float * restrict src,
|
||||
float * restrict dst,
|
||||
uint8_t * restrict spad,
|
||||
const uint32_t num_rows,
|
||||
const uint32_t row_elems,
|
||||
const size_t row_size,
|
||||
int32_t * op_params) {
|
||||
|
||||
for (uint32_t ir = 0; ir < num_rows; ir++) {
|
||||
const float * restrict src_local = src + (ir * row_elems);
|
||||
float * restrict dst_local = dst + (ir * row_elems);
|
||||
const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size);
|
||||
uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size);
|
||||
|
||||
if (ir + 1 < num_rows) {
|
||||
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
|
||||
}
|
||||
|
||||
if (1 == opt_path) {
|
||||
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
} else {
|
||||
hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
|
||||
}
|
||||
}
|
||||
|
||||
static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
||||
struct htp_tensor * dst,
|
||||
uint8_t * spad,
|
||||
int htp_op,
|
||||
int32_t * op_params,
|
||||
uint32_t nth,
|
||||
uint32_t ith,
|
||||
uint32_t src0_nrows_per_thread) {
|
||||
static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) {
|
||||
const struct htp_unary_context * uctx = (const struct htp_unary_context *) data;
|
||||
struct htp_ops_context * octx = uctx->octx;
|
||||
const struct htp_tensor * src = &octx->src0;
|
||||
const struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
htp_unary_preamble;
|
||||
|
||||
const size_t src0_row_size = nb01;
|
||||
const size_t dst_row_size = nb1;
|
||||
int htp_op = octx->op;
|
||||
int32_t * op_params = octx->op_params;
|
||||
uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread;
|
||||
|
||||
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
|
||||
const size_t src0_row_size = uctx->src0_row_size;
|
||||
const size_t dst_row_size = uctx->dst_row_size;
|
||||
|
||||
const size_t src0_row_size_aligned = uctx->src0_row_size_aligned;
|
||||
const size_t dst_row_size_aligned = uctx->dst_row_size_aligned;
|
||||
|
||||
const uint32_t src0_nrows = uctx->src0_nrows;
|
||||
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
|
||||
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
|
||||
|
||||
@@ -208,79 +196,104 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
|
||||
uint64_t t1, t2;
|
||||
t1 = HAP_perf_get_qtimer_count();
|
||||
|
||||
int is_aligned = 1;
|
||||
int opt_path = 0;
|
||||
if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) {
|
||||
is_aligned = 0;
|
||||
}
|
||||
if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
|
||||
opt_path = 1;
|
||||
const uint8_t * restrict data_src = uctx->data_src0;
|
||||
uint8_t * restrict data_dst = uctx->data_dst;
|
||||
|
||||
uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread);
|
||||
uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread);
|
||||
|
||||
size_t src0_spad_half_size = uctx->src0_spad_half_size;
|
||||
size_t dst_spad_half_size = uctx->dst_spad_half_size;
|
||||
|
||||
const int BLOCK = uctx->block;
|
||||
if (BLOCK == 0) {
|
||||
FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
|
||||
octx->src0_spad.size_per_thread, src0_row_size_aligned);
|
||||
return;
|
||||
}
|
||||
|
||||
const uint8_t * restrict data_src = (const uint8_t *) src->data;
|
||||
uint8_t * restrict data_dst = (uint8_t *) dst->data;
|
||||
dma_queue * dma_queue = octx->ctx->dma[ith];
|
||||
|
||||
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
|
||||
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
|
||||
uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01);
|
||||
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
switch (htp_op) {
|
||||
case HTP_OP_RMS_NORM:
|
||||
rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SCALE:
|
||||
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SQR:
|
||||
sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
case HTP_OP_SQRT:
|
||||
sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
|
||||
break;
|
||||
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
|
||||
dst_row_size, dst_row_size_aligned, 0);
|
||||
|
||||
default:
|
||||
break;
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, block_size);
|
||||
}
|
||||
|
||||
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
|
||||
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
|
||||
|
||||
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
|
||||
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
|
||||
|
||||
// Process block in VTCM
|
||||
switch (htp_op) {
|
||||
case HTP_OP_RMS_NORM:
|
||||
rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_SCALE:
|
||||
scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_SQR:
|
||||
sqr_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
case HTP_OP_SQRT:
|
||||
sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
dma_queue_push_vtcm_to_ddr(dma_queue,
|
||||
dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad),
|
||||
dst_row_size, dst_row_size_aligned, block_size);
|
||||
|
||||
// prefetch N+2 loop iteration if any
|
||||
const uint32_t pref_block = (ir + BLOCK * 2);
|
||||
if (pref_block < src0_end_row) {
|
||||
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
|
||||
dma_queue_push_ddr_to_vtcm(dma_queue,
|
||||
dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)),
|
||||
src0_row_size_aligned, src0_row_size, pref_block_size);
|
||||
}
|
||||
}
|
||||
|
||||
dma_queue_flush(dma_queue);
|
||||
|
||||
t2 = HAP_perf_get_qtimer_count();
|
||||
|
||||
FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0],
|
||||
FARF(HIGH, "unary-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, src->ne[0],
|
||||
src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2],
|
||||
dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
|
||||
}
|
||||
|
||||
static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) {
|
||||
struct htp_ops_context * octx = (struct htp_ops_context *) data;
|
||||
|
||||
unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i,
|
||||
octx->src0_nrows_per_thread);
|
||||
}
|
||||
|
||||
static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
int err = HTP_STATUS_OK;
|
||||
|
||||
const struct htp_tensor * src0 = &octx->src0;
|
||||
struct htp_tensor * dst = &octx->dst;
|
||||
|
||||
worker_callback_t unary_op_func;
|
||||
const char * op_type = NULL;
|
||||
const char * op_type = NULL;
|
||||
|
||||
switch (octx->op) {
|
||||
case HTP_OP_RMS_NORM:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "rmsnorm-f32";
|
||||
op_type = "rmsnorm-f32";
|
||||
break;
|
||||
case HTP_OP_SCALE:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "scale-f32";
|
||||
op_type = "scale-f32";
|
||||
break;
|
||||
case HTP_OP_SQR:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "sqr-f32";
|
||||
op_type = "sqr-f32";
|
||||
break;
|
||||
case HTP_OP_SQRT:
|
||||
unary_op_func = unary_job_dispatcher_f32;
|
||||
op_type = "sqrt-f32";
|
||||
op_type = "sqrt-f32";
|
||||
break;
|
||||
|
||||
default:
|
||||
@@ -294,32 +307,61 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
|
||||
const size_t src0_row_size = src0->nb[1];
|
||||
const size_t dst_row_size = dst->nb[1];
|
||||
|
||||
// VTCM scratchpads for all tensors
|
||||
octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads;
|
||||
octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads;
|
||||
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
|
||||
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
|
||||
|
||||
size_t spad_size = octx->src0_spad.size + octx->dst_spad.size;
|
||||
// VTCM scratchpads for all tensors
|
||||
// N rows per thread, padded to HVX vector size
|
||||
// Double buffering requires 2x size per buffer
|
||||
|
||||
size_t spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned);
|
||||
size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row);
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (vtcm_row_per_thread == 0) {
|
||||
FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size_per_row * n_threads);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread * 2;
|
||||
octx->dst_spad.size_per_thread = dst_row_size_aligned * vtcm_row_per_thread * 2;
|
||||
|
||||
octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread;
|
||||
octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread;
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size);
|
||||
|
||||
// Make sure the reserved vtcm size is sufficient
|
||||
if (octx->ctx->vtcm_size < spad_size) {
|
||||
FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size,
|
||||
spad_size);
|
||||
return HTP_STATUS_VTCM_TOO_SMALL;
|
||||
}
|
||||
|
||||
octx->src0_spad.data = octx->ctx->vtcm_base;
|
||||
octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size;
|
||||
|
||||
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
|
||||
uint32_t n_jobs = MIN(n_threads, src0_nrows);
|
||||
|
||||
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
|
||||
struct htp_unary_context uctx = {
|
||||
.octx = octx,
|
||||
.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs,
|
||||
.src0_nrows = src0_nrows,
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs);
|
||||
.data_src0 = (const uint8_t *)src0->data,
|
||||
.data_dst = (uint8_t *)dst->data,
|
||||
|
||||
.src0_row_size = src0_row_size,
|
||||
.dst_row_size = dst_row_size,
|
||||
|
||||
.src0_row_size_aligned = src0_row_size_aligned,
|
||||
.dst_row_size_aligned = dst_row_size_aligned,
|
||||
|
||||
.src0_spad_half_size = octx->src0_spad.size_per_thread / 2,
|
||||
.dst_spad_half_size = octx->dst_spad.size_per_thread / 2,
|
||||
|
||||
.block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned,
|
||||
.nc = src0->ne[0],
|
||||
};
|
||||
|
||||
worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs);
|
||||
}
|
||||
|
||||
return err;
|
||||
|
||||
@@ -403,19 +403,20 @@ enum FaCodePath {
|
||||
};
|
||||
|
||||
struct vk_fa_pipeline_state {
|
||||
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags)
|
||||
: HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {}
|
||||
|
||||
uint32_t HSK, HSV;
|
||||
bool small_rows, small_cache;
|
||||
uint32_t Br, Bc;
|
||||
uint32_t D_split, row_split;
|
||||
bool shmem_staging;
|
||||
FaCodePath path;
|
||||
uint32_t workgroup_size, subgroup_size;
|
||||
bool aligned;
|
||||
bool f32acc;
|
||||
uint32_t flags;
|
||||
uint32_t limit_occupancy_shmem;
|
||||
|
||||
bool operator<(const vk_fa_pipeline_state &b) const {
|
||||
return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) <
|
||||
std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags);
|
||||
return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) <
|
||||
std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -623,6 +624,8 @@ struct vk_device_struct {
|
||||
// floor(log2(maxComputeWorkGroupInvocations))
|
||||
uint32_t max_workgroup_size_log2 {};
|
||||
|
||||
bool flash_attention_fp16;
|
||||
|
||||
bool coopmat_support;
|
||||
bool coopmat_acc_f32_support {};
|
||||
bool coopmat_acc_f16_support {};
|
||||
@@ -1656,6 +1659,7 @@ static bool vk_perf_logger_concurrent = false;
|
||||
static bool vk_enable_sync_logger = false;
|
||||
// number of calls between perf logger prints
|
||||
static uint32_t vk_perf_logger_frequency = 1;
|
||||
static std::string vk_pipeline_stats_filter;
|
||||
|
||||
class vk_perf_logger {
|
||||
public:
|
||||
@@ -2172,7 +2176,32 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
|
||||
executableInfo.pipeline = pipeline->pipeline;
|
||||
|
||||
auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo);
|
||||
|
||||
bool print_stats = !vk_pipeline_stats_filter.empty() &&
|
||||
pipeline->name.find(vk_pipeline_stats_filter) != std::string::npos;
|
||||
if (print_stats) {
|
||||
std::cerr << "ggml_vulkan: pipeline stats for " << pipeline->name << ":" << std::endl;
|
||||
}
|
||||
|
||||
for (auto & s : statistics) {
|
||||
if (print_stats) {
|
||||
std::cerr << "ggml_vulkan: " << s.name.data() << ": ";
|
||||
switch (s.format) {
|
||||
case vk::PipelineExecutableStatisticFormatKHR::eBool32:
|
||||
std::cerr << (s.value.b32 ? "true" : "false");
|
||||
break;
|
||||
case vk::PipelineExecutableStatisticFormatKHR::eInt64:
|
||||
std::cerr << s.value.i64;
|
||||
break;
|
||||
case vk::PipelineExecutableStatisticFormatKHR::eUint64:
|
||||
std::cerr << s.value.u64;
|
||||
break;
|
||||
case vk::PipelineExecutableStatisticFormatKHR::eFloat64:
|
||||
std::cerr << s.value.f64;
|
||||
break;
|
||||
}
|
||||
std::cerr << std::endl;
|
||||
}
|
||||
// "Register Count" is reported by NVIDIA drivers.
|
||||
if (strcmp(s.name, "Register Count") == 0) {
|
||||
VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers");
|
||||
@@ -2755,78 +2784,214 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
||||
);
|
||||
}
|
||||
|
||||
// number of rows/cols for flash attention shader
|
||||
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
||||
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
||||
struct vk_fa_tuning_params {
|
||||
FaCodePath path;
|
||||
uint32_t workgroup_size;
|
||||
uint32_t subgroup_size;
|
||||
uint32_t block_rows;
|
||||
uint32_t block_cols;
|
||||
uint32_t d_split;
|
||||
uint32_t row_split;
|
||||
bool shmem_staging;
|
||||
bool disable_subgroups;
|
||||
uint32_t limit_occupancy_shmem;
|
||||
|
||||
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
|
||||
if (hsv >= 192) {
|
||||
return 2;
|
||||
} else if ((hsv | hsk) & 8 || small_cache) {
|
||||
return 4;
|
||||
} else {
|
||||
return 8;
|
||||
void print() const {
|
||||
std::cerr << "path=" << path << " workgroup_size=" << workgroup_size << " subgroup_size=" << subgroup_size <<
|
||||
" block_rows=" << block_rows << " block_cols=" << block_cols << " d_split=" << d_split <<
|
||||
" row_split=" << row_split << " shmem_staging=" << shmem_staging << " disable_subgroups=" << disable_subgroups <<
|
||||
" limit_occupancy_shmem=" << limit_occupancy_shmem << std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
|
||||
// 128 threads split into four subgroups, each subgroup does 1/4
|
||||
// of the Bc dimension.
|
||||
static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
|
||||
static constexpr uint32_t scalar_flash_attention_Bc = 64;
|
||||
static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
|
||||
|
||||
static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
||||
if (path == FA_COOPMAT2) {
|
||||
return flash_attention_num_small_rows;
|
||||
static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
GGML_UNUSED(kv_type);
|
||||
|
||||
vk_fa_tuning_params result{};
|
||||
result.path = FA_SCALAR;
|
||||
|
||||
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
|
||||
// Disable subgroup use due to performance issues when enforcing subgroup sizes
|
||||
result.subgroup_size = 32;
|
||||
result.disable_subgroups = true;
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) {
|
||||
result.subgroup_size = n_rows < 4 ? 32 : device->subgroup_size;
|
||||
} else {
|
||||
return scalar_flash_attention_num_small_rows;
|
||||
result.subgroup_size = device->subgroup_size;
|
||||
}
|
||||
}
|
||||
|
||||
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) {
|
||||
GGML_UNUSED(clamp);
|
||||
// Row split splits the workgroup so that synchronization only has to happen within subgroups, which avoids barriers
|
||||
uint32_t row_split_max_hsk = 64;
|
||||
if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !device->uma) {
|
||||
row_split_max_hsk = n_rows <= 8 ? 64 : 128;
|
||||
}
|
||||
result.row_split = (n_rows < 4 || hsk <= row_split_max_hsk) ? 1 : 4;
|
||||
|
||||
if (path == FA_SCALAR) {
|
||||
if (small_rows) {
|
||||
return {scalar_flash_attention_num_small_rows, 64};
|
||||
if (result.subgroup_size > 32 && (n_rows < 4 || hsk < (result.row_split == 1 ? 128 : 64))) {
|
||||
result.workgroup_size = result.subgroup_size * 2;
|
||||
} else {
|
||||
result.workgroup_size = result.subgroup_size * 4;
|
||||
}
|
||||
|
||||
const uint32_t D = hsk | hsv;
|
||||
|
||||
const bool reduce_block_rows = D & 8 || n_kv < 1024 || device->vendor_id == VK_VENDOR_ID_INTEL;
|
||||
|
||||
if (n_rows == 1) {
|
||||
result.block_rows = 1;
|
||||
result.block_cols = 64;
|
||||
} else {
|
||||
// row_split 1 means higher register use per row, so block size has to be adjusted
|
||||
if (result.row_split == 1) {
|
||||
result.block_rows = n_rows == 2 ? 2 : ((n_rows <= 4 || reduce_block_rows) ? 4 : 8);
|
||||
} else {
|
||||
if ((hsv | hsk) & 8) {
|
||||
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
|
||||
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
|
||||
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
|
||||
} else {
|
||||
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
|
||||
}
|
||||
result.block_rows = n_rows <= 4 ? 4 : ((n_rows <= 8 || reduce_block_rows) ? 8 : 16);
|
||||
}
|
||||
|
||||
result.block_cols = (D & 8) ? 64 : 32;
|
||||
}
|
||||
|
||||
const uint32_t D_lsb = D ^ (D & (D-1)); // extract lowest set bit
|
||||
|
||||
result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);
|
||||
|
||||
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
|
||||
|
||||
if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
|
||||
result.block_rows /= 2;
|
||||
}
|
||||
|
||||
// On AMD RDNA, for small head sizes and big batch size the shader uses few registers, so too many subgroups get scheduled
|
||||
// at once and end up thrashing the cache. Fix this by setting a large (unused) shmem buffer that reduces occupancy.
|
||||
// This targets an occupancy of 4 subgroups per SIMD.
|
||||
if (device->vendor_id == VK_VENDOR_ID_AMD && device->properties.limits.maxComputeSharedMemorySize == 65536) {
|
||||
if (device->architecture != AMD_GCN && n_rows >= 64 && hsk <= 128) {
|
||||
// 30kb target for hsk > 64, 26kb for <= 64 due to smaller workgroup size
|
||||
// Values are guessed, tested on RDNA2
|
||||
result.limit_occupancy_shmem = (hsk <= 64 ? 26 : 30) * 1024 / 4 / 4;
|
||||
} else if (device->architecture == AMD_GCN && n_rows <= 8 && hsk >= 256) {
|
||||
// Same thing for GCN, with an occupancy target of 2 subgroups per SIMD.
|
||||
// Here low-batch FA with large head size is affected.
|
||||
// n_rows < 4 switch because workgroup size switches from 128 to 256 there.
|
||||
result.limit_occupancy_shmem = (n_rows < 4 ? 14 : 26) * 1024 / 4 / 4;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
GGML_UNUSED(n_rows);
|
||||
GGML_UNUSED(n_kv);
|
||||
GGML_UNUSED(kv_type);
|
||||
GGML_UNUSED(f32acc);
|
||||
|
||||
vk_fa_tuning_params result{};
|
||||
result.path = FA_COOPMAT1;
|
||||
|
||||
const uint32_t D = hsk | hsv;
|
||||
|
||||
const uint32_t coopmat_block_rows = 16;
|
||||
const uint32_t coopmat_block_cols = 16;
|
||||
|
||||
const uint32_t num_subgroups = 4;
|
||||
|
||||
result.block_rows = coopmat_block_rows;
|
||||
result.block_cols = coopmat_block_cols * num_subgroups;
|
||||
result.row_split = num_subgroups;
|
||||
result.subgroup_size = device->subgroup_size;
|
||||
result.workgroup_size = num_subgroups * result.subgroup_size;
|
||||
|
||||
const uint32_t D_lsb = D ^ (D & (D-1)); // extract lowest set bit
|
||||
result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);
|
||||
|
||||
result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
GGML_UNUSED(n_kv);
|
||||
GGML_UNUSED(f32acc);
|
||||
|
||||
vk_fa_tuning_params result{};
|
||||
result.path = FA_COOPMAT2;
|
||||
|
||||
const uint32_t D = hsk | hsv;
|
||||
|
||||
const bool small_rows = n_rows < 32;
|
||||
|
||||
if (small_rows) {
|
||||
result.block_rows = 32;
|
||||
result.block_cols = 32;
|
||||
} else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) {
|
||||
result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64;
|
||||
result.block_cols = 32;
|
||||
} else {
|
||||
result.block_rows = 64;
|
||||
result.block_cols = 64;
|
||||
}
|
||||
|
||||
result.subgroup_size = device->subgroup_size;
|
||||
result.workgroup_size = (small_rows && (D % 32) == 0) ? 256 : 128;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
|
||||
FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
|
||||
device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
||||
|
||||
if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) {
|
||||
// Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
if (path == FA_COOPMAT1) {
|
||||
if (small_rows) {
|
||||
return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
|
||||
} else {
|
||||
return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
|
||||
bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
|
||||
(!f32acc && device->coopmat_support_16x16x16_f16acc);
|
||||
const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
||||
bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
|
||||
|
||||
if (!shape_ok || !shmem_ok) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
}
|
||||
|
||||
// small rows, large cols
|
||||
if (small_rows) {
|
||||
return {get_fa_num_small_rows(FA_COOPMAT2), 32};
|
||||
// scalar is faster than coopmat when N==1
|
||||
if (n_rows == 1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
// small cols to reduce register count
|
||||
if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) {
|
||||
if (hsk >= 512 || hsv >= 512) {
|
||||
return {32, 32};
|
||||
} else {
|
||||
return {64, 32};
|
||||
}
|
||||
switch (path) {
|
||||
case FA_SCALAR:
|
||||
return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
||||
case FA_COOPMAT1:
|
||||
return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
||||
case FA_COOPMAT2:
|
||||
return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
|
||||
default:
|
||||
throw std::runtime_error("unsupported FaCodePath");
|
||||
}
|
||||
return {64, 64};
|
||||
}
|
||||
|
||||
static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) {
|
||||
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1];
|
||||
static vk_fa_pipeline_state get_fa_pipeline_state(const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
|
||||
bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
|
||||
uint32_t flags = (use_mask_opt ? 1 : 0) |
|
||||
(use_mask ? 2 : 0) |
|
||||
(use_logit_softcap ? 4 : 0);
|
||||
|
||||
const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;
|
||||
|
||||
return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem};
|
||||
}
|
||||
|
||||
static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) {
|
||||
return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split,
|
||||
state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem};
|
||||
}
|
||||
|
||||
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
|
||||
@@ -3193,76 +3358,43 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
align, disable_robustness, require_full_subgroups, required_subgroup_size);
|
||||
};
|
||||
|
||||
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
|
||||
};
|
||||
|
||||
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
|
||||
// For large number of rows, 128 invocations seems to work best.
|
||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||
// can't use 256 for D==80.
|
||||
// For scalar, use 128 (arbitrary)
|
||||
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
|
||||
const uint32_t D = (hsk|hsv);
|
||||
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
|
||||
|
||||
uint32_t wg_size;
|
||||
switch (path) {
|
||||
case FA_COOPMAT2:
|
||||
wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||
break;
|
||||
case FA_COOPMAT1:
|
||||
wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
|
||||
break;
|
||||
default:
|
||||
wg_size = scalar_flash_attention_workgroup_size;
|
||||
break;
|
||||
}
|
||||
|
||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
||||
const uint32_t D_lsb = D ^ (D & (D-1));
|
||||
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
|
||||
|
||||
// Nvidia prefers shared memory use to load large tiles of K.
|
||||
// Switch to loading from global memory when it would use too much shared memory.
|
||||
// AMD prefers loading K directly from global memory
|
||||
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
|
||||
|
||||
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags};
|
||||
};
|
||||
|
||||
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
||||
for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
|
||||
uint32_t HSK = fa.first.HSK; \
|
||||
uint32_t HSV = fa.first.HSV; \
|
||||
bool small_rows = fa.first.small_rows; \
|
||||
bool small_cache = fa.first.small_cache; \
|
||||
FaCodePath path = fa.first.path; \
|
||||
uint32_t Br = fa.first.Br; \
|
||||
uint32_t Bc = fa.first.Bc; \
|
||||
bool aligned = fa.first.aligned; \
|
||||
bool f32acc = fa.first.f32acc; \
|
||||
uint32_t flags = fa.first.flags; \
|
||||
uint32_t fa_sgs = fa.first.subgroup_size; \
|
||||
bool fa_ds = fa.first.subgroup_size == 0; \
|
||||
if (path == FAPATH) { \
|
||||
if (aligned) { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} \
|
||||
} else { \
|
||||
if (f32acc) { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} else { \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
|
||||
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
if (device->flash_attention_fp16) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
||||
} else {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
|
||||
}
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (device->coopmat1_fa_support) {
|
||||
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
|
||||
@@ -3780,10 +3912,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
&& !device->coopmat_bf16_support
|
||||
#endif
|
||||
) {
|
||||
const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
|
||||
|
||||
// use scalar tile sizes
|
||||
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
||||
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
|
||||
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
|
||||
s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, 2, 2, 1, subgroup_size_8 };
|
||||
|
||||
l_wg_denoms = {128, 128, 1 };
|
||||
m_wg_denoms = { 64, 64, 1 };
|
||||
@@ -4533,6 +4667,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
}
|
||||
|
||||
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
|
||||
static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev);
|
||||
|
||||
static vk_device ggml_vk_get_device(size_t idx) {
|
||||
VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
|
||||
@@ -4749,6 +4884,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
device->shader_core_count = sm_props.shaderSMCount;
|
||||
} else if (amd_shader_core_properties2) {
|
||||
device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
|
||||
device->shader_core_count = ggml_vk_intel_shader_core_count(device->physical_device);
|
||||
} else {
|
||||
device->shader_core_count = 0;
|
||||
}
|
||||
@@ -4968,11 +5105,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
|
||||
#if defined(VK_KHR_cooperative_matrix)
|
||||
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
|
||||
|
||||
// coopmat1 fa shader currently assumes 32 invocations per subgroup
|
||||
device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
|
||||
device->subgroup_size_control && device->subgroup_min_size <= 32 &&
|
||||
device->subgroup_max_size >= 32;
|
||||
device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support;
|
||||
#endif
|
||||
|
||||
if (coopmat2_support) {
|
||||
@@ -5290,6 +5423,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
device->mmvq_mode = 1;
|
||||
}
|
||||
|
||||
// Driver issues with older AMD GPUs on Windows, see https://github.com/ggml-org/llama.cpp/pull/19625#issuecomment-3940840613
|
||||
const bool is_amd_proprietary_gcn = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture == AMD_GCN && device->driver_id == vk::DriverId::eAmdProprietary;
|
||||
device->flash_attention_fp16 = device->fp16 && !is_amd_proprietary_gcn;
|
||||
|
||||
return device;
|
||||
}
|
||||
|
||||
@@ -5540,6 +5677,10 @@ static void ggml_vk_instance_init() {
|
||||
vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr;
|
||||
vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr;
|
||||
vk_memory_logger_enabled = getenv("GGML_VK_MEMORY_LOGGER") != nullptr;
|
||||
const char* GGML_VK_PIPELINE_STATS = getenv("GGML_VK_PIPELINE_STATS");
|
||||
if (GGML_VK_PIPELINE_STATS != nullptr) {
|
||||
vk_pipeline_stats_filter = GGML_VK_PIPELINE_STATS;
|
||||
}
|
||||
const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY");
|
||||
|
||||
if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) {
|
||||
@@ -8419,21 +8560,27 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) {
|
||||
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
|
||||
GGML_UNUSED(f32acc);
|
||||
// Needs to be kept up to date on shader changes
|
||||
GGML_UNUSED(hsv);
|
||||
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
||||
const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache);
|
||||
const uint32_t Bc = scalar_flash_attention_Bc;
|
||||
const uint32_t wg_size = params.workgroup_size;
|
||||
const uint32_t Br = params.block_rows;
|
||||
const uint32_t Bc = params.block_cols;
|
||||
|
||||
const uint32_t float_type_size = device->flash_attention_fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
|
||||
// tmpsh is overestimated slightly
|
||||
const uint32_t tmpsh = wg_size * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
|
||||
const uint32_t tmpshv4 = wg_size * 4 * float_type_size;
|
||||
|
||||
const uint32_t masksh = Bc * Br * sizeof(float);
|
||||
const uint32_t masksh = Bc * (Br + 1) * float_type_size;
|
||||
|
||||
const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
|
||||
const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
|
||||
const uint32_t D = std::max(hsk, hsv);
|
||||
const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
|
||||
|
||||
const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
|
||||
@@ -8441,18 +8588,17 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
|
||||
return supported;
|
||||
}
|
||||
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
|
||||
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
|
||||
// Needs to be kept up to date on shader changes
|
||||
GGML_UNUSED(hsv);
|
||||
const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false);
|
||||
const uint32_t Br = rows_cols[0];
|
||||
const uint32_t Bc = rows_cols[1];
|
||||
const uint32_t Br = params.block_rows;
|
||||
const uint32_t Bc = params.block_cols;
|
||||
|
||||
const uint32_t MatBr = 16, MatBc = 16;
|
||||
|
||||
const uint32_t row_split = Bc / MatBc;
|
||||
|
||||
const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
|
||||
const uint32_t hsv_pad = ROUNDUP_POW2(hsv, 16);
|
||||
|
||||
const uint32_t acctype = f32acc ? 4 : 2;
|
||||
const uint32_t f16vec4 = 8;
|
||||
@@ -8468,17 +8614,19 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
||||
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
|
||||
const uint32_t sfsh = Bc * sfshstride * acctype;
|
||||
|
||||
const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256;
|
||||
const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2;
|
||||
const uint32_t kvshstride = (params.shmem_staging ? std::max(hsk_pad, hsv_pad) : MatBr) / 4 + 2;
|
||||
const uint32_t vsh_stride = MatBc / 4 * row_split;
|
||||
const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4;
|
||||
const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4;
|
||||
|
||||
const uint32_t osh_stride = params.row_split * MatBr / 4;
|
||||
const uint32_t pvsh = MatBc * osh_stride * f16vec4;
|
||||
|
||||
const uint32_t slope = Br * acctype;
|
||||
|
||||
const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
|
||||
const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + pvsh + slope;
|
||||
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
||||
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
|
||||
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
||||
|
||||
return supported;
|
||||
}
|
||||
@@ -8536,48 +8684,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
assert(q->type == GGML_TYPE_F32);
|
||||
assert(k->type == v->type);
|
||||
|
||||
FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
|
||||
ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
||||
|
||||
if (path == FA_COOPMAT1 && ctx->device->architecture == vk_device_architecture::NVIDIA_TURING) {
|
||||
// Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
if (path == FA_COOPMAT1) {
|
||||
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
||||
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
||||
|
||||
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32, k->type);
|
||||
|
||||
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t gqa_ratio = 1;
|
||||
uint32_t qk_ratio = neq2 / nek2;
|
||||
uint32_t workgroups_x = (uint32_t)neq1;
|
||||
uint32_t workgroups_y = (uint32_t)neq2;
|
||||
uint32_t workgroups_z = (uint32_t)neq3;
|
||||
|
||||
const bool small_cache = nek1 < 1024;
|
||||
const bool f32acc = !ctx->device->flash_attention_fp16 || dst->op_params[3] == GGML_PREC_F32;
|
||||
|
||||
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
||||
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
||||
uint32_t max_gqa;
|
||||
switch (path) {
|
||||
case FA_SCALAR:
|
||||
case FA_COOPMAT1:
|
||||
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
||||
max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache);
|
||||
break;
|
||||
case FA_COOPMAT2:
|
||||
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(0);
|
||||
}
|
||||
vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc);
|
||||
const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);
|
||||
|
||||
if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
||||
qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
|
||||
@@ -8589,24 +8707,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
workgroups_y /= gqa_ratio;
|
||||
}
|
||||
|
||||
bool small_rows = N <= get_fa_num_small_rows(path);
|
||||
|
||||
// coopmat1 does not actually support "small rows" (it needs 16 rows).
|
||||
// So use scalar instead.
|
||||
if (small_rows && path == FA_COOPMAT1) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
// scalar is faster than coopmat2 when N==1
|
||||
if (N == 1 && path == FA_COOPMAT2) {
|
||||
path = FA_SCALAR;
|
||||
}
|
||||
|
||||
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
|
||||
if (path == FA_SCALAR &&
|
||||
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) {
|
||||
small_rows = true;
|
||||
}
|
||||
tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc);
|
||||
|
||||
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
||||
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
||||
@@ -8620,18 +8721,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
v_stride /= 4;
|
||||
}
|
||||
|
||||
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache);
|
||||
const uint32_t alignment = tuning_params.block_cols;
|
||||
bool aligned = (KV % alignment) == 0 &&
|
||||
// the "aligned" shader variant will forcibly align strides, for performance
|
||||
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
|
||||
|
||||
// Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned.
|
||||
if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
|
||||
if (((HSK | HSV) % 16) != 0 && tuning_params.path == FA_COOPMAT2) {
|
||||
aligned = false;
|
||||
}
|
||||
|
||||
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
@@ -8646,12 +8745,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
|
||||
// Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
|
||||
bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
|
||||
|
||||
uint32_t flags = (use_mask_opt ? 1 : 0) |
|
||||
(mask != nullptr ? 2 : 0) |
|
||||
(logit_softcap != 0 ? 4 : 0);
|
||||
|
||||
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags);
|
||||
vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(tuning_params, HSK, HSV, aligned, f32acc,
|
||||
mask != nullptr, use_mask_opt, logit_softcap != 0);
|
||||
|
||||
vk_pipeline pipeline = nullptr;
|
||||
|
||||
@@ -8673,22 +8768,35 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
uint32_t split_kv = KV;
|
||||
uint32_t split_k = 1;
|
||||
|
||||
// Intel Alchemist prefers more workgroups
|
||||
const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1;
|
||||
|
||||
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
|
||||
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
||||
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16;
|
||||
|
||||
const uint32_t Br = fa_pipeline_state.Br;
|
||||
const uint32_t Bc = fa_pipeline_state.Bc;
|
||||
|
||||
GGML_ASSERT(Br == pipeline->wg_denoms[0]);
|
||||
const uint32_t Tr = CEIL_DIV(N, Br);
|
||||
|
||||
// Try to use split_k when KV is large enough to be worth the overhead.
|
||||
// Must either be a single batch or be using gqa, we can't mix the two.
|
||||
if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) {
|
||||
// Try to run two workgroups per SM.
|
||||
if (gqa_ratio > 1 && workgroups_x <= Br) {
|
||||
split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
|
||||
if (split_k > 1) {
|
||||
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
||||
// of "align", so recompute split_k based on that.
|
||||
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
|
||||
split_k = CEIL_DIV(KV, split_kv);
|
||||
} else if (gqa_ratio <= 1) {
|
||||
uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z;
|
||||
if (total_wgs_no_split < shader_core_count * 2) {
|
||||
split_k = shader_core_count * 2 / total_wgs_no_split;
|
||||
}
|
||||
}
|
||||
|
||||
if (split_k > 1) {
|
||||
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
||||
// of "align", so recompute split_k based on that.
|
||||
split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
|
||||
split_k = CEIL_DIV(KV, split_kv);
|
||||
}
|
||||
|
||||
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
|
||||
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
|
||||
// For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
|
||||
@@ -8702,10 +8810,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||
}
|
||||
|
||||
auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, small_rows, small_cache);
|
||||
const uint32_t Br = rows_cols[0];
|
||||
const uint32_t Bc = rows_cols[1];
|
||||
|
||||
const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);
|
||||
const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;
|
||||
|
||||
@@ -8785,15 +8889,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
if (ctx->prealloc_split_k_need_sync) {
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
workgroups_x *= pipeline->wg_denoms[0];
|
||||
|
||||
// We reuse workgroups_x to mean the number of splits, so we need to
|
||||
// cancel out the divide by wg_denoms[0].
|
||||
uint32_t dispatch_x;
|
||||
if (gqa_ratio > 1) {
|
||||
workgroups_x *= pipeline->wg_denoms[0];
|
||||
dispatch_x = split_k * workgroups_x;
|
||||
} else {
|
||||
dispatch_x = Tr * split_k * pipeline->wg_denoms[0];
|
||||
}
|
||||
|
||||
vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},
|
||||
// We only use split_k when group query attention is enabled, which means
|
||||
// there's no more than one tile of rows (i.e. workgroups_x would have been
|
||||
// one). We reuse workgroups_x to mean the number of splits, so we need to
|
||||
// cancel out the divide by wg_denoms[0].
|
||||
pc, { split_k * workgroups_x, workgroups_y, workgroups_z });
|
||||
pc, { dispatch_x, workgroups_y, workgroups_z });
|
||||
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
|
||||
@@ -15418,6 +15528,46 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
|
||||
}
|
||||
}
|
||||
|
||||
static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) {
|
||||
VkPhysicalDeviceProperties2 props = vkdev.getProperties2();
|
||||
|
||||
if (props.properties.vendorID != VK_VENDOR_ID_INTEL) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t device_id = props.properties.deviceID;
|
||||
|
||||
switch (device_id) {
|
||||
case 0x56A6: // A310
|
||||
return 6;
|
||||
case 0x5693: // A370M
|
||||
case 0x56A5: // A380
|
||||
case 0x56B1: // Pro A40/A50
|
||||
return 8;
|
||||
case 0x5697: // A530M
|
||||
return 12;
|
||||
case 0x5692: // A550M
|
||||
case 0x56B3: // Pro A60
|
||||
return 16;
|
||||
case 0x56A2: // A580
|
||||
return 24;
|
||||
case 0x5691: // A730M
|
||||
case 0x56A1: // A750
|
||||
return 28;
|
||||
case 0x56A0: // A770
|
||||
case 0x5690: // A770M
|
||||
return 32;
|
||||
case 0xE212: // Pro B50
|
||||
return 16;
|
||||
case 0xE20C: // B570
|
||||
return 18;
|
||||
case 0xE20B: // B580
|
||||
return 20;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// checks
|
||||
|
||||
#ifdef GGML_VULKAN_CHECK_RESULTS
|
||||
@@ -16094,7 +16244,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
ggml_vk_print_graph_origin(tensor, done);
|
||||
}
|
||||
|
||||
if (avg_err > 0.5 || std::isnan(avg_err)) {
|
||||
if (avg_err > 0.01 || std::isnan(avg_err)) {
|
||||
std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
|
||||
std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
|
||||
if (src0 != nullptr) {
|
||||
|
||||
@@ -3,9 +3,13 @@
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#ifdef FLOAT16
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_subgroup_extended_types_float16 : require
|
||||
#endif
|
||||
|
||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
||||
#extension GL_KHR_shader_subgroup_vote : enable
|
||||
|
||||
@@ -15,8 +19,10 @@
|
||||
const uint32_t HSK_per_thread = HSK / D_split;
|
||||
const uint32_t HSV_per_thread = HSV / D_split;
|
||||
|
||||
const uint32_t cols_per_iter = WorkGroupSize / D_split;
|
||||
const uint32_t rows_per_thread = Br / row_split;
|
||||
const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split;
|
||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||
const uint32_t num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGroupSize;
|
||||
|
||||
|
||||
layout (binding = 0) readonly buffer Q {float data_q[];};
|
||||
@@ -27,20 +33,22 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
||||
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
||||
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
uint32_t offset = (iq2 + r) * HSV + c;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
return elem;
|
||||
}
|
||||
// If SubGroupSize is set to 0 then only use shmem reductions
|
||||
const uint32_t tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize;
|
||||
shared float tmpsh[tmpsh_size];
|
||||
shared FLOAT_TYPEV4 tmpshv4[tmpsh_size];
|
||||
|
||||
shared FLOAT_TYPE tmpsh[WorkGroupSize];
|
||||
shared vec4 tmpshv4[WorkGroupSize];
|
||||
const uint32_t masksh_stride = Br + 1;
|
||||
shared FLOAT_TYPE masksh[Bc * masksh_stride];
|
||||
|
||||
shared float masksh[Bc][Br];
|
||||
shared vec4 Qf[Br][HSK / 4];
|
||||
const uint32_t qf_stride = HSK / 4 + 1;
|
||||
shared FLOAT_TYPEV4 Qf[Br * qf_stride];
|
||||
|
||||
const uint32_t D = HSK > HSV ? HSK : HSV;
|
||||
const uint32_t kvsh_stride = D / 4 + 1;
|
||||
shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
|
||||
|
||||
shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
@@ -50,8 +58,24 @@ void main() {
|
||||
init_indices();
|
||||
|
||||
const uint32_t tid = gl_LocalInvocationIndex;
|
||||
const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
|
||||
const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
|
||||
const uint32_t rowgroup_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
|
||||
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
|
||||
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
|
||||
const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
|
||||
|
||||
if (LIMIT_OCCUPANCY_SHMEM > 0) {
|
||||
// This just exists to avoid the occupancy_limiter array getting optimized out
|
||||
occupancy_limiter[tid] = vec4(tid);
|
||||
|
||||
barrier();
|
||||
|
||||
if (occupancy_limiter[tid] == vec4(99999.0)) {
|
||||
data_ov4[0] = D_TYPEV4(occupancy_limiter[tid]);
|
||||
}
|
||||
}
|
||||
|
||||
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
||||
|
||||
uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4;
|
||||
|
||||
@@ -60,37 +84,37 @@ void main() {
|
||||
uint32_t r = (idx + tid) / (HSK / 4);
|
||||
if (r < Br && d < HSK / 4 &&
|
||||
i * Br + r < N) {
|
||||
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
|
||||
Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
vec4 Of[Br][HSV_per_thread / 4];
|
||||
FLOAT_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] = vec4(0.0);
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] = FLOAT_TYPEV4(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
float Lf[Br], Mf[Br];
|
||||
float Lf[rows_per_thread], Mf[rows_per_thread];
|
||||
|
||||
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
|
||||
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Lf[r] = 0;
|
||||
Mf[r] = NEG_FLT_MAX_OVER_2;
|
||||
}
|
||||
|
||||
float slope[Br];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
slope[r] = 1.0;
|
||||
ACC_TYPE slope[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
slope[r] = ACC_TYPE(1.0);
|
||||
}
|
||||
|
||||
// ALiBi
|
||||
if (p.max_bias > 0.0f) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
slope[r] = perElemOpComputeSlope(tile_row(r), col_tid, ACC_TYPE(0), iq2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,75 +137,141 @@ void main() {
|
||||
|
||||
uint32_t mask_opt = 0;
|
||||
uint32_t mask_opt_idx = ~0;
|
||||
uint32_t mask_opt_bits = 0;
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
if (MASK_ENABLE) {
|
||||
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
|
||||
mask_opt_idx = j / 16;
|
||||
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
|
||||
}
|
||||
mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
|
||||
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
|
||||
// skip this block
|
||||
continue;
|
||||
}
|
||||
// Only load if the block is not all zeros
|
||||
if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
|
||||
mask_opt_idx = j / 16;
|
||||
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
barrier();
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
masksh[c * masksh_stride + r] = m;
|
||||
max_mask = max(max_mask, float(m));
|
||||
} else {
|
||||
masksh[c * masksh_stride + r] = FLOAT_TYPE(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
|
||||
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
|
||||
// skip this block
|
||||
continue;
|
||||
}
|
||||
// Only load if the block is not all zeros
|
||||
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
||||
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
||||
|
||||
float max_mask = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
|
||||
float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
masksh[c][r] = m;
|
||||
max_mask = max(max_mask, m);
|
||||
ACC_TYPE Sf[rows_per_thread][cols_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
Sf[r][c] = ACC_TYPE(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
barrier();
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) {
|
||||
FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
#endif
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
// More d iterations means Q register caching becomes relevant
|
||||
// Few iterations means the additional registers needed are worse than the speed-up from caching
|
||||
if (HSK_per_thread / 4 > 4) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||
FLOAT_TYPEV4 Q_cache[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
|
||||
FLOAT_TYPEV4 K_Tf;
|
||||
if (SHMEM_STAGING != 0) {
|
||||
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
|
||||
} else {
|
||||
masksh[c][r] = float(0);
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
|
||||
}
|
||||
}
|
||||
}
|
||||
// skip the block if the mask is entirely -inf
|
||||
bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
max_mask = max(max_mask, tmpsh[s]);
|
||||
}
|
||||
if (max_mask <= NEG_FLT_MAX_OVER_2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
float Sf[Br][cols_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
} else {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
Sf[r][c] = 0.0;
|
||||
}
|
||||
}
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||
FLOAT_TYPEV4 K_Tf;
|
||||
if (SHMEM_STAGING != 0) {
|
||||
K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
|
||||
} else {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -189,89 +279,109 @@ void main() {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
// Compute sum across the D_split
|
||||
[[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (LOGIT_SOFTCAP) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
|
||||
Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float mvf = masksh[c * cols_per_iter + col_tid][r];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
FLOAT_TYPE mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)];
|
||||
|
||||
Sf[r][c] += slope[r]*mvf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
rowmaxf[r] = NEG_FLT_MAX_OVER_2;
|
||||
float eMf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
float rowmaxf = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
|
||||
rowmaxf = max(rowmaxf, float(Sf[r][c]));
|
||||
}
|
||||
Moldf[r] = Mf[r];
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
// P = e^(S - M)
|
||||
// eM = e^(Mold - M)
|
||||
Mf[r] = max(rowmaxf[r], Moldf[r]);
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
Pf[r][c] = exp(Sf[r][c] - Mf[r]);
|
||||
}
|
||||
eMf[r] = exp(Moldf[r] - Mf[r]);
|
||||
|
||||
// Compute sum across row of P
|
||||
rowsumf[r] = 0.0;
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
rowsumf[r] += Pf[r][c];
|
||||
}
|
||||
|
||||
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
|
||||
Mf[r] = max(rowmaxf, Moldf);
|
||||
eMf[r] = exp(Moldf - Mf[r]);
|
||||
Lf[r] = eMf[r]*Lf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] = eMf[r] * Of[r][d];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] = FLOAT_TYPE(eMf[r]) * Of[r][d];
|
||||
}
|
||||
}
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
barrier();
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSV / 4);
|
||||
uint32_t c = (idx + tid) / (HSV / 4);
|
||||
if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) {
|
||||
FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
|
||||
#endif
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = V_Tf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
|
||||
FLOAT_TYPE Pf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Pf[r] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r]));
|
||||
Lf[r] += Pf[r];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
FLOAT_TYPEV4 Vf;
|
||||
if (SHMEM_STAGING != 0) {
|
||||
Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
|
||||
} else {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
||||
Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] += Pf[r][c] * Vf;
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
|
||||
// prevent race on tmpsh
|
||||
@@ -279,58 +389,108 @@ void main() {
|
||||
|
||||
// reduce across threads
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float rowmaxf, eMf;
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
float rowmaxf = Mf[r];
|
||||
|
||||
tmpsh[tid] = Mf[r];
|
||||
// Compute max across the row
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
|
||||
if (tid < s) {
|
||||
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
|
||||
if (SubGroupSize > 0) {
|
||||
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
||||
rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s));
|
||||
}
|
||||
if (row_split == 1) {
|
||||
// Reduce inside workgroup with shmem
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == d_tid) {
|
||||
tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;
|
||||
}
|
||||
barrier();
|
||||
rowmaxf = tmpsh[d_tid];
|
||||
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
||||
rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
barrier();
|
||||
tmpsh[tid] = rowmaxf;
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
|
||||
if (rowgroup_tid < s) {
|
||||
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]);
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid];
|
||||
}
|
||||
rowmaxf = tmpsh[d_tid];
|
||||
barrier();
|
||||
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
// eM = e^(Mold - M)
|
||||
Mf[r] = max(rowmaxf, Moldf);
|
||||
eMf = exp(Moldf - Mf[r]);
|
||||
float eMf = exp(Moldf - Mf[r]);
|
||||
|
||||
Lf[r] = eMf*Lf[r];
|
||||
|
||||
tmpsh[tid] = Lf[r];
|
||||
|
||||
// Compute sum across the row
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
|
||||
if (tid < s) {
|
||||
tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
|
||||
if (SubGroupSize > 0) {
|
||||
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
||||
Lf[r] += subgroupShuffleXor(Lf[r], s);
|
||||
}
|
||||
if (row_split == 1) {
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == d_tid) {
|
||||
tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];
|
||||
}
|
||||
barrier();
|
||||
Lf[r] = tmpsh[d_tid];
|
||||
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
||||
Lf[r] += tmpsh[s * D_split + d_tid];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
barrier();
|
||||
}
|
||||
Lf[r] = tmpsh[d_tid];
|
||||
barrier();
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
|
||||
Of[r][d] = eMf * Of[r][d];
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
|
||||
tmpsh[tid] = Lf[r];
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
|
||||
if (tid < s) {
|
||||
Of[r][d] += tmpshv4[tid + s];
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
|
||||
if (rowgroup_tid < s) {
|
||||
tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
Of[r][d] = tmpshv4[d_tid];
|
||||
barrier();
|
||||
Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid];
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
Of[r][d] = FLOAT_TYPE(eMf) * Of[r][d];
|
||||
|
||||
if (SubGroupSize > 0) {
|
||||
[[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
|
||||
Of[r][d] += subgroupShuffleXor(Of[r][d], s);
|
||||
}
|
||||
if (row_split == 1) {
|
||||
barrier();
|
||||
if (gl_SubgroupInvocationID == d_tid) {
|
||||
tmpshv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];
|
||||
}
|
||||
barrier();
|
||||
Of[r][d] = tmpshv4[d_tid];
|
||||
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
||||
Of[r][d] += tmpshv4[s * D_split + d_tid];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
barrier();
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
barrier();
|
||||
[[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
|
||||
if (rowgroup_tid < s) {
|
||||
Of[r][d] += tmpshv4[tid ^ s];
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
Of[r][d] = tmpshv4[row_tid * threads_per_rowgroup + d_tid];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -338,33 +498,53 @@ void main() {
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
// note: O and Q have swapped coord 1,2.
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
if (p.gqa_ratio > 1) {
|
||||
// note: O and Q have swapped coord 1,2.
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint row = tile_row(r);
|
||||
if (row < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
|
||||
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint row = tile_row(r);
|
||||
if (row < N) {
|
||||
perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint row = tile_row(r);
|
||||
const uint global_row = i * Br + row;
|
||||
|
||||
if (global_row < N) {
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
data_ov4[o_offset + iq2 * HSV/4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);
|
||||
}
|
||||
}
|
||||
|
||||
if (global_row < N && d_tid == 0 && col_tid == 0) {
|
||||
uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
|
||||
data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
|
||||
data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
|
||||
|
||||
float ms = 1.0f;
|
||||
float vs = 1.0f;
|
||||
@@ -373,7 +553,7 @@ void main() {
|
||||
ms = exp(Mf[r] - sink);
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
Of[r][d] *= ms;
|
||||
Of[r][d] *= FLOAT_TYPE(ms);
|
||||
}
|
||||
} else {
|
||||
vs = exp(sink - Mf[r]);
|
||||
@@ -383,39 +563,37 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
float Lfrcp[Br];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
float Lfrcp[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] *= Lfrcp[r];
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] *= FLOAT_TYPE(Lfrcp[r]);
|
||||
#if defined(FLOAT_TYPE_MAX)
|
||||
Of[r][d] = clamp(Of[r][d], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
|
||||
uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;
|
||||
|
||||
if (p.gqa_ratio > 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (r < N) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint row = tile_row(r);
|
||||
if (row < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
|
||||
}
|
||||
gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
if (i * Br + r < N) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint row = tile_row(r);
|
||||
if (i * Br + row < N) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
|
||||
}
|
||||
data_ov4[o_offset + (iq2 * HSV + (i * Br + row) * p.ne1 * HSV) / 4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
||||
layout (constant_id = 1) const uint32_t Br = 1;
|
||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||
layout (constant_id = 3) const uint32_t HSK = 32;
|
||||
layout (constant_id = 4) const uint32_t HSV = 32;
|
||||
layout (constant_id = 5) const uint32_t Clamp = 0;
|
||||
layout (constant_id = 6) const uint32_t D_split = 16;
|
||||
layout (constant_id = 7) const uint32_t SubGroupSize = 32;
|
||||
layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
|
||||
layout (constant_id = 9) const uint32_t Flags = 0;
|
||||
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
|
||||
layout (constant_id = 1) const uint32_t Br = 1;
|
||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||
layout (constant_id = 3) const uint32_t HSK = 32;
|
||||
layout (constant_id = 4) const uint32_t HSV = 32;
|
||||
layout (constant_id = 5) const uint32_t Clamp = 0;
|
||||
layout (constant_id = 6) const uint32_t D_split = 16;
|
||||
layout (constant_id = 7) const uint32_t row_split = 1;
|
||||
layout (constant_id = 8) const uint32_t SubGroupSize = 32;
|
||||
layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0;
|
||||
layout (constant_id = 10) const uint32_t Flags = 0;
|
||||
layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0;
|
||||
|
||||
const bool USE_MASK_OPT = (Flags & 1) != 0;
|
||||
const bool MASK_ENABLE = (Flags & 2) != 0;
|
||||
@@ -69,6 +71,7 @@ layout (push_constant) uniform parameter {
|
||||
layout (binding = 4) readonly buffer S {float data_s[];};
|
||||
|
||||
layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
||||
layout (binding = 5) writeonly buffer OV4 {D_TYPEV4 data_ov4[];};
|
||||
|
||||
layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
|
||||
|
||||
@@ -94,12 +97,12 @@ layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16
|
||||
#define BLOCK_SIZE 4
|
||||
#define BLOCK_BYTE_SIZE 16
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
// iqs is currently always zero in the flash attention shaders
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
return k_packed.k_data_packed[a_offset + ib];
|
||||
return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]);
|
||||
} else {
|
||||
return v_packed.v_data_packed[a_offset + ib];
|
||||
return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -107,7 +110,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
@@ -115,7 +118,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
|
||||
} else {
|
||||
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
@@ -123,24 +126,24 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
#define BLOCK_BYTE_SIZE 34
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
|
||||
} else {
|
||||
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -189,10 +192,16 @@ void init_indices()
|
||||
KV = p.KV;
|
||||
|
||||
if (p.k_num > 1) {
|
||||
i = 0;
|
||||
// batch and split_k share gl_WorkGroupID.x
|
||||
gqa_iq1 = gl_WorkGroupID.x / p.k_num;
|
||||
split_k_index = gl_WorkGroupID.x % p.k_num;
|
||||
if (p.gqa_ratio > 1) {
|
||||
i = 0;
|
||||
// batch and split_k share gl_WorkGroupID.x
|
||||
gqa_iq1 = gl_WorkGroupID.x / p.k_num;
|
||||
split_k_index = gl_WorkGroupID.x % p.k_num;
|
||||
} else {
|
||||
gqa_iq1 = 0;
|
||||
split_k_index = gl_WorkGroupID.x % p.k_num;
|
||||
i = gl_WorkGroupID.x / p.k_num;
|
||||
}
|
||||
} else if (p.gqa_ratio > 1) {
|
||||
i = 0;
|
||||
gqa_iq1 = gl_WorkGroupID.x;
|
||||
@@ -244,3 +253,11 @@ void init_indices()
|
||||
// Bias applied to softmax to stay in fp16 range.
|
||||
// Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606
|
||||
const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f;
|
||||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
uint32_t offset = (iq2 + r) * HSV / 4 + c;
|
||||
data_ov4[o_offset + offset] = D_TYPEV4(elems);
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
const uint32_t MatBr = 16;
|
||||
const uint32_t MatBc = 16;
|
||||
|
||||
const uint32_t row_split = Bc / MatBc;
|
||||
const uint32_t rows_per_thread = Br / row_split;
|
||||
const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split;
|
||||
const uint32_t cols_per_thread = Bc / cols_per_iter;
|
||||
@@ -33,15 +32,6 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];};
|
||||
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
|
||||
layout (binding = 3) readonly buffer M {float16_t data_m[];};
|
||||
|
||||
// Store the output when doing grouped query attention.
|
||||
// Rows index by Q's dimension 2, and the first N rows are valid.
|
||||
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
uint32_t offset = (iq2 + r) * HSV + c;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
return elem;
|
||||
}
|
||||
|
||||
shared float tmpsh[row_split];
|
||||
|
||||
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
@@ -54,10 +44,14 @@ shared f16vec4 Psh[Bc * psh_stride];
|
||||
const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
|
||||
shared ACC_TYPEV4 sfsh[Bc * sfshstride];
|
||||
|
||||
const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4
|
||||
const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad;
|
||||
const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4
|
||||
const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
|
||||
const uint vsh_stride = v_cols;
|
||||
shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)];
|
||||
shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)];
|
||||
|
||||
const uint32_t osh_stride = row_split * MatBr / 4;
|
||||
shared f16vec4 pvsh[MatBc * osh_stride];
|
||||
|
||||
shared ACC_TYPE slope[Br];
|
||||
|
||||
@@ -84,11 +78,6 @@ void main() {
|
||||
Qf[i + tid] = f16vec4(0);
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
|
||||
if (i + tid < Bc * kshstride) {
|
||||
ksh[i + tid] = f16vec4(0);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
@@ -104,10 +93,10 @@ void main() {
|
||||
}
|
||||
barrier();
|
||||
|
||||
ACC_TYPEV4 Of[rows_per_thread][d_per_thread];
|
||||
f16vec4 Of[rows_per_thread][d_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
[[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) {
|
||||
Of[r][d] = ACC_TYPEV4(0.0);
|
||||
Of[r][d] = f16vec4(0.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,22 +142,22 @@ void main() {
|
||||
|
||||
uint32_t mask_opt = 0;
|
||||
uint32_t mask_opt_idx = ~0;
|
||||
uint32_t mask_opt_bits = 0;
|
||||
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint32_t j = start_j; j < end_j; ++j) {
|
||||
|
||||
f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
|
||||
[[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {
|
||||
mask_cache[idx] = f16vec4(0);
|
||||
}
|
||||
|
||||
if (MASK_ENABLE) {
|
||||
|
||||
if (USE_MASK_OPT && mask_opt_idx != j / 16) {
|
||||
mask_opt_idx = j / 16;
|
||||
mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
|
||||
}
|
||||
uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
|
||||
mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
|
||||
if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
|
||||
// skip this block
|
||||
continue;
|
||||
@@ -231,24 +220,24 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
if (K_LOAD_SHMEM != 0) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
if (c < Bc && d < HSK / 4) {
|
||||
if (SHMEM_STAGING != 0) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSK_pad / 4);
|
||||
uint32_t c = (idx + tid) / (HSK_pad / 4);
|
||||
if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) {
|
||||
f16vec4 K_Tf = f16vec4(0);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
#endif
|
||||
}
|
||||
|
||||
ksh[c * kshstride + d] = K_Tf;
|
||||
kvsh[c * kvsh_stride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
@@ -262,7 +251,11 @@ void main() {
|
||||
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
|
||||
if (K_LOAD_SHMEM == 0) {
|
||||
// If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem
|
||||
// If not, f16 K is loaded directly from global memory if aligned, otherwise
|
||||
// staged through a Bc * MatBr size staging buffer.
|
||||
// If K is not type f16, then it is always staged for dequantization.
|
||||
if (SHMEM_STAGING == 0) {
|
||||
#if BLOCK_SIZE == 1
|
||||
if (KV_bounds_check || d * 16 + 16 > HSK) {
|
||||
#endif
|
||||
@@ -277,13 +270,13 @@ void main() {
|
||||
uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
||||
K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
|
||||
#else
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
|
||||
#endif
|
||||
}
|
||||
|
||||
ksh[row * kshstride + col_vec] = K_Tf;
|
||||
kvsh[row * kvsh_stride + col_vec] = K_Tf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
@@ -295,8 +288,8 @@ void main() {
|
||||
if (KV_bounds_check || d * 16 + 16 > HSK)
|
||||
#endif
|
||||
{
|
||||
uint coord = (gl_SubgroupID * MatBc) * kshstride;
|
||||
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride;
|
||||
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
#if BLOCK_SIZE == 1
|
||||
else {
|
||||
@@ -305,8 +298,8 @@ void main() {
|
||||
}
|
||||
#endif
|
||||
} else {
|
||||
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
||||
coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4;
|
||||
coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
|
||||
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
@@ -329,7 +322,7 @@ void main() {
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (MASK_ENABLE) {
|
||||
if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t c = (idx + tid) / (Br / 4);
|
||||
uint32_t r = (idx + tid) % (Br / 4);
|
||||
@@ -374,7 +367,7 @@ void main() {
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d_local] = ACC_TYPE(eMf[r]) * Of[r][d_local];
|
||||
Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -397,19 +390,47 @@ void main() {
|
||||
}
|
||||
}
|
||||
|
||||
if (SHMEM_STAGING != 0) {
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) {
|
||||
uint32_t d = (idx + tid) % (HSV_pad / 4);
|
||||
uint32_t c = (idx + tid) / (HSV_pad / 4);
|
||||
if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) {
|
||||
f16vec4 V_Tf = f16vec4(0);
|
||||
if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
|
||||
#endif
|
||||
}
|
||||
|
||||
kvsh[c * kvsh_stride + d] = V_Tf;
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up
|
||||
|
||||
// Each subgroup handles HSV/4 columns
|
||||
[[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) {
|
||||
const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16;
|
||||
|
||||
SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
// Preload V tiles for [Bc, 16 * num subgroups]
|
||||
const uint v_rows = Bc;
|
||||
const uint v_total = v_rows * v_cols;
|
||||
const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
|
||||
|
||||
// If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem.
|
||||
// If not, f16 V is loaded directly from global memory if aligned, otherwise
|
||||
// staged through a Bc * MatBr size staging buffer.
|
||||
// If V is not type f16, then it is always staged for dequantization.
|
||||
if (SHMEM_STAGING == 0) {
|
||||
#if BLOCK_SIZE == 1
|
||||
// For f16, only preload if not aligned
|
||||
if (KV_bounds_check) {
|
||||
@@ -428,44 +449,52 @@ void main() {
|
||||
|
||||
if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
|
||||
#if BLOCK_SIZE > 1
|
||||
ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
|
||||
kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
|
||||
#else
|
||||
ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
|
||||
kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
|
||||
#endif
|
||||
} else {
|
||||
ksh[row * vsh_stride + col] = f16vec4(0.0f);
|
||||
kvsh[row * vsh_stride + col] = f16vec4(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
#if BLOCK_SIZE == 1
|
||||
}
|
||||
#endif
|
||||
|
||||
}
|
||||
barrier();
|
||||
|
||||
[[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
|
||||
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
const uint o_offset = gl_SubgroupID * MatBr / 4;
|
||||
|
||||
if (hsv_offset < HSV_pad) {
|
||||
[[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
|
||||
coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
if (SHMEM_STAGING == 0) {
|
||||
#if BLOCK_SIZE == 1
|
||||
if (!KV_bounds_check) {
|
||||
// F16 values can be loaded directly from global memory
|
||||
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
|
||||
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
|
||||
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
|
||||
} else
|
||||
if (!KV_bounds_check) {
|
||||
// F16 values can be loaded directly from global memory
|
||||
const uint v_tile_row = j * Bc + bc_chunk * MatBc;
|
||||
const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
|
||||
coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
|
||||
coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
{
|
||||
const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
|
||||
coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
} else {
|
||||
const uint v_tile_offset = bc_chunk * MatBc * kvsh_stride + (hsv_tile * row_split + gl_SubgroupID) * (MatBc / 4);
|
||||
coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
|
||||
PVMat = coopMatMulAdd(KMat, QMat, PVMat);
|
||||
}
|
||||
|
||||
SfMat = coopMatMulAdd(KMat, QMat, SfMat);
|
||||
// Store PVMat to pvsh and load into Of
|
||||
coopMatStore(PVMat, pvsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
|
||||
// Store SfMat to sfsh and load into Of
|
||||
const uint osh_stride = row_split * MatBc / 4;
|
||||
const uint o_offset = gl_SubgroupID * MatBc / 4;
|
||||
coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
|
||||
barrier();
|
||||
|
||||
const uint hsv_per_tile = row_split * MatBc;
|
||||
@@ -484,7 +513,7 @@ void main() {
|
||||
|
||||
if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) {
|
||||
const uint local_hsv = (hsv_col - hsv_base) / 4;
|
||||
Of[r][d_local] += ACC_TYPEV4(sfsh[row * osh_stride + local_hsv]);
|
||||
Of[r][d_local] += pvsh[row * osh_stride + local_hsv];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -500,27 +529,48 @@ void main() {
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
// note: O and Q have swapped coord 1,2.
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
if (p.gqa_ratio > 1) {
|
||||
// note: O and Q have swapped coord 1,2.
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d = d0 + col_tid;
|
||||
if (d >= HSV/4) break;
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d = d0 + col_tid;
|
||||
if (d >= HSV/4) break;
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
|
||||
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
if (tile_row(r) < N) {
|
||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
|
||||
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
const uint row = tile_row(r);
|
||||
const uint global_row = i * Br + row;
|
||||
|
||||
if (global_row < N) {
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d = d0 + col_tid;
|
||||
if (d >= HSV/4) break;
|
||||
data_ov4[o_offset + iq2 * HSV/4 + d] = D_TYPEV4(Of[r][d/threads_per_rowgroup]);
|
||||
}
|
||||
}
|
||||
|
||||
if (global_row < N && col_tid == 0) {
|
||||
uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
|
||||
data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
|
||||
data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -539,7 +589,7 @@ void main() {
|
||||
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
Of[r][d_local] *= ACC_TYPE(ms);
|
||||
Of[r][d_local] *= float16_t(ms);
|
||||
}
|
||||
} else {
|
||||
vs = exp(sink - Mf[r]);
|
||||
@@ -557,14 +607,14 @@ void main() {
|
||||
[[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d_local] *= ACC_TYPE(Lfrcp[r]);
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
Of[r][d_local] = clamp(Of[r][d_local], -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
Of[r][d_local] *= float16_t(Lfrcp[r]);
|
||||
#if defined(FLOAT_TYPE_MAX)
|
||||
Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
|
||||
uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;
|
||||
|
||||
if (p.gqa_ratio > 1) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
@@ -573,9 +623,7 @@ void main() {
|
||||
const uint d = d0 + col_tid;
|
||||
if (d >= HSV / 4) break;
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
|
||||
}
|
||||
gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -586,9 +634,7 @@ void main() {
|
||||
const uint d = d0 + col_tid;
|
||||
if (d >= HSV / 4) break;
|
||||
const uint d_local = d0 / threads_per_rowgroup;
|
||||
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
|
||||
data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]);
|
||||
}
|
||||
data_ov4[o_offset + (iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV) / 4 + d] = D_TYPEV4(Of[r][d_local]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,6 +72,28 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Store O values for non-GQA split_k. Rows are tokens, not heads.
|
||||
D_TYPE perElemOpNonGqaSplitKStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t unused, const in uint32_t iq2, const in uint32_t N) {
|
||||
uint32_t global_row = i * Br + r;
|
||||
if (global_row < N && c < HSV) {
|
||||
uint32_t o_off = HSV * p.ne1
|
||||
* (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
|
||||
data_o[o_off + iq2 * HSV + c] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Store L/M values for non-GQA split_k.
|
||||
ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t lm_base, const in uint32_t iq2, const in uint32_t N) {
|
||||
uint32_t global_row = i * Br + r;
|
||||
if (global_row < N && c == 0) {
|
||||
uint32_t lm_off = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3
|
||||
+ p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
|
||||
data_o[lm_off + lm_base + iq2] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
@@ -290,13 +312,19 @@ void main() {
|
||||
if (p.k_num > 1) {
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
|
||||
|
||||
// note: O and Q have swapped coord 1,2.
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
if (p.gqa_ratio > 1) {
|
||||
// note: O and Q have swapped coord 1,2.
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
|
||||
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
||||
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
||||
o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
|
||||
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
||||
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
||||
} else {
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpNonGqaSplitKStore, 0u, iq2, N);
|
||||
coopMatPerElementNV(L, L, perElemOpNonGqaSplitKStoreCol0, 0u, iq2, N);
|
||||
coopMatPerElementNV(M, M, perElemOpNonGqaSplitKStoreCol0, p.ne1, iq2, N);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -167,7 +167,9 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
|
||||
@@ -43,7 +43,9 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
|
||||
@@ -595,8 +595,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
}
|
||||
|
||||
void process_shaders() {
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
|
||||
|
||||
// matmul
|
||||
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
|
||||
// No coopmats
|
||||
@@ -622,49 +620,63 @@ void process_shaders() {
|
||||
}
|
||||
}
|
||||
|
||||
// flash attention
|
||||
for (const auto& f16acc : {false, true}) {
|
||||
std::map<std::string, std::string> fa_base_dict = base_dict;
|
||||
fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
||||
fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
|
||||
if (f16acc) {
|
||||
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
|
||||
for (const bool& fp16 : {false, true}) {
|
||||
std::map<std::string, std::string> base_dict;
|
||||
if (fp16) {
|
||||
base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
|
||||
} else {
|
||||
base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}};
|
||||
}
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "bf16") continue;
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
|
||||
} else {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
||||
// flash attention
|
||||
for (const bool& f16acc : {false, true}) {
|
||||
std::map<std::string, std::string> fa_base_dict = base_dict;
|
||||
fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float";
|
||||
fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4";
|
||||
if (fp16 && f16acc) {
|
||||
fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
|
||||
}
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "bf16") continue;
|
||||
|
||||
if (fp16) {
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
|
||||
} else {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc);
|
||||
}
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
}
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
|
||||
}
|
||||
#endif
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
|
||||
}
|
||||
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
// mul mat vec
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
|
||||
@@ -54,6 +54,6 @@ adb $adbserial $adbhost shell " \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \
|
||||
./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --batch-size 128 -fa on \
|
||||
-ngl 99 --device $device $cli_opts $@ \
|
||||
--ctx-size 8192 --ubatch-size 256 -fa on \
|
||||
-ngl 99 --device $device $cli_opts $@ \
|
||||
"
|
||||
|
||||
@@ -54,6 +54,6 @@ adb $adbserial $adbhost shell " \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \
|
||||
./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --batch-size 128 -fa on \
|
||||
-ngl 99 -no-cnv --device $device $cli_opts $@ \
|
||||
--ctx-size 8192 --ubatch-size 256 -fa on \
|
||||
-ngl 99 -no-cnv --device $device $cli_opts $@ \
|
||||
"
|
||||
|
||||
@@ -58,11 +58,11 @@ adb $adbserial $adbhost shell " \
|
||||
cd $basedir; ulimit -c unlimited; \
|
||||
LD_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
ADSP_LIBRARY_PATH=$basedir/$branch/lib \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \
|
||||
./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \
|
||||
--mmproj $basedir/../gguf/$mmproj \
|
||||
--image $basedir/../gguf/$image \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \
|
||||
-ngl 99 --device $device -v $cli_opts $@ \
|
||||
$verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \
|
||||
./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \
|
||||
--mmproj $basedir/../gguf/$mmproj \
|
||||
--image $basedir/../gguf/$image \
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \
|
||||
--ctx-size 8192 --ubatch-size 256 -fa on \
|
||||
-ngl 99 --device $device -v $cli_opts $@ \
|
||||
"
|
||||
|
||||
@@ -49,5 +49,5 @@ $env:ADSP_LIBRARY_PATH="$basedir\lib"
|
||||
& "$basedir\bin\llama-completion.exe" `
|
||||
--no-mmap -no-cnv -m $basedir\..\..\gguf\$model `
|
||||
--poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 `
|
||||
--ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on `
|
||||
--ctx-size 8192 --ubatch-size 128 -fa on `
|
||||
-ngl 99 --device $device $cli_opts
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
HTTPLIB_VERSION = "refs/tags/v0.33.1"
|
||||
HTTPLIB_VERSION = "refs/tags/v0.34.0"
|
||||
|
||||
vendor = {
|
||||
"https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp",
|
||||
|
||||
@@ -2440,64 +2440,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
||||
// TODO: add more model-specific info which should prevent loading the session file if not identical
|
||||
}
|
||||
|
||||
// write output ids
|
||||
{
|
||||
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
|
||||
|
||||
const auto n_outputs = this->n_outputs;
|
||||
const auto & output_ids = this->output_ids;
|
||||
|
||||
std::vector<int32_t> w_output_pos;
|
||||
|
||||
w_output_pos.resize(n_outputs);
|
||||
|
||||
// build a more compact representation of the output ids
|
||||
for (size_t i = 0; i < n_batch(); ++i) {
|
||||
// map an output id to a position in the batch
|
||||
int64_t pos = output_ids[i];
|
||||
if (pos >= 0) {
|
||||
GGML_ASSERT(pos < n_outputs);
|
||||
w_output_pos[pos] = i;
|
||||
}
|
||||
}
|
||||
|
||||
io.write(&n_outputs, sizeof(n_outputs));
|
||||
|
||||
if (n_outputs) {
|
||||
io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
// [TAG_CONTEXT_STATE_LOGITS]
|
||||
// write logits
|
||||
{
|
||||
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
|
||||
|
||||
const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens());
|
||||
|
||||
io.write(&logits_size, sizeof(logits_size));
|
||||
|
||||
if (logits_size) {
|
||||
io.write(logits.data, logits_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// write embeddings
|
||||
{
|
||||
LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
|
||||
|
||||
const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd);
|
||||
|
||||
io.write(&embd_size, sizeof(embd_size));
|
||||
|
||||
if (embd_size) {
|
||||
io.write(embd.data, embd_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: handle sampling buffers and samplers state ?
|
||||
// https://github.com/ggml-org/llama.cpp/pull/17004
|
||||
|
||||
if (memory != nullptr) {
|
||||
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
|
||||
memory->state_write(io);
|
||||
@@ -2523,70 +2465,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
||||
// TODO: add more info which needs to be identical but which is not verified otherwise
|
||||
}
|
||||
|
||||
// read output ids
|
||||
{
|
||||
LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
|
||||
|
||||
auto n_outputs = this->n_outputs;
|
||||
io.read_to(&n_outputs, sizeof(n_outputs));
|
||||
|
||||
if (n_outputs > output_reserve(n_outputs)) {
|
||||
throw std::runtime_error("could not reserve outputs");
|
||||
}
|
||||
|
||||
std::vector<int32_t> output_pos;
|
||||
|
||||
if (n_outputs) {
|
||||
output_pos.resize(n_outputs);
|
||||
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
|
||||
int32_t id = output_pos[i];
|
||||
if ((uint32_t) id >= n_batch()) {
|
||||
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
|
||||
}
|
||||
this->output_ids[id] = i;
|
||||
}
|
||||
|
||||
this->n_outputs = n_outputs;
|
||||
}
|
||||
}
|
||||
|
||||
// read logits
|
||||
{
|
||||
LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
|
||||
|
||||
uint64_t logits_size;
|
||||
io.read_to(&logits_size, sizeof(logits_size));
|
||||
|
||||
if (this->logits.size < logits_size) {
|
||||
throw std::runtime_error("logits buffer too small");
|
||||
}
|
||||
|
||||
if (logits_size) {
|
||||
io.read_to(this->logits.data, logits_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// read embeddings
|
||||
{
|
||||
LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
|
||||
|
||||
uint64_t embd_size;
|
||||
io.read_to(&embd_size, sizeof(embd_size));
|
||||
|
||||
if (this->embd.size < embd_size) {
|
||||
throw std::runtime_error("embeddings buffer too small");
|
||||
}
|
||||
|
||||
if (embd_size) {
|
||||
io.read_to(this->embd.data, embd_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: handle sampling buffers and samplers state ?
|
||||
// https://github.com/ggml-org/llama.cpp/pull/17004
|
||||
|
||||
if (memory) {
|
||||
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);
|
||||
|
||||
|
||||
@@ -361,7 +361,7 @@ static void test_backend_temp_sampling(const test_params & params) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
// Verfify sequence 0
|
||||
// Verify sequence 0
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx);
|
||||
@@ -379,7 +379,7 @@ static void test_backend_temp_sampling(const test_params & params) {
|
||||
}
|
||||
|
||||
|
||||
// Verfify sequence 1
|
||||
// Verify sequence 1
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(1);
|
||||
|
||||
@@ -395,7 +395,7 @@ static void test_backend_temp_sampling(const test_params & params) {
|
||||
}
|
||||
}
|
||||
|
||||
// lambda to testing non-positive temperature values.
|
||||
// lambda for testing non-positive temperature values.
|
||||
auto test_argmax_temp = [&](float temp) {
|
||||
printf("\nTesting temperature = %.1f\n", temp);
|
||||
|
||||
@@ -454,7 +454,7 @@ static void test_backend_temp_ext_sampling(const test_params & params) {
|
||||
}
|
||||
}
|
||||
|
||||
// lambda to testing non-positive temp/delta/exponent values.
|
||||
// lambda for testing non-positive temp/delta/exponent values.
|
||||
auto test_argmax_temp = [&](float temp, float delta, float exponent) {
|
||||
printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent);
|
||||
|
||||
@@ -530,7 +530,7 @@ static void test_backend_min_p_sampling(const test_params & params) {
|
||||
printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
// Decode and sampler 10 more tokens
|
||||
// Decode and sample 10 more tokens
|
||||
for (int i = 0; i < 10; i++) {
|
||||
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
|
||||
@@ -582,7 +582,7 @@ static void test_backend_top_p_sampling(const test_params & params) {
|
||||
printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
// Decode and sampler 10 more tokens
|
||||
// Decode and sample 10 more tokens
|
||||
for (int i = 0; i < 10; i++) {
|
||||
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx);
|
||||
@@ -619,7 +619,7 @@ static void test_backend_multi_sequence_sampling(const test_params & params) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
// Verfiy sequence 0
|
||||
// Verify sequence 0
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
|
||||
@@ -763,7 +763,7 @@ static void test_backend_logit_bias_sampling(const test_params & params) {
|
||||
printf("backend logit bias sampling test PASSED\n");
|
||||
}
|
||||
|
||||
// This test verifies that it is possible to have two different backend sampler,
|
||||
// This test verifies that it is possible to have two different backend samplers,
|
||||
// one that uses the backend dist sampler, and another that uses CPU dist sampler.
|
||||
static void test_backend_mixed_sampling(const test_params & params) {
|
||||
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
|
||||
@@ -791,7 +791,7 @@ static void test_backend_mixed_sampling(const test_params & params) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
// Verfiy sequence 0 that used the dist backend sampler.
|
||||
// Verify sequence 0 that used the dist backend sampler.
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx);
|
||||
@@ -802,7 +802,7 @@ static void test_backend_mixed_sampling(const test_params & params) {
|
||||
//GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx) == 0);
|
||||
}
|
||||
|
||||
// Verfiy sequence 1 that used the top-k backend sampler.
|
||||
// Verify sequence 1 that used the top-k backend sampler.
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(1);
|
||||
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx);
|
||||
@@ -934,7 +934,7 @@ static void test_backend_cpu_mixed_batch(const test_params & params) {
|
||||
// samplers.
|
||||
llama_set_sampler(test_ctx.ctx.get(), 0, nullptr);
|
||||
|
||||
// Create a CPU sampler and verify we can sampler from it.
|
||||
// Create a CPU sampler and verify we can sample from it.
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
llama_sampler_ptr chain(llama_sampler_chain_init(chain_params));
|
||||
llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy());
|
||||
|
||||
@@ -387,6 +387,17 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
session_do_save = !path_session.empty() && n_match < embd_inp.size() && !params.prompt_cache_ro;
|
||||
|
||||
// Logits are not stored as part of the session state so we need to
|
||||
// "replay" the last token to get logits for sampling.
|
||||
if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) {
|
||||
if (!common_replay_last_token(ctx, session_tokens.back(), n_match)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
session_do_save = false;
|
||||
LOG_INF("%s: replayed last token from session\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// number of tokens to keep when resetting context
|
||||
@@ -675,40 +686,27 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
if (!embd.empty()) {
|
||||
int n_eval = (int) embd.size();
|
||||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
||||
|
||||
GGML_ASSERT(n_eval <= params.n_batch);
|
||||
if (llama_decode(ctx, llama_batch_get_one(embd.data(), n_eval))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
const bool is_last_batch = (n_consumed >= (int) embd_inp.size());
|
||||
const bool save_now = session_do_save && is_last_batch;
|
||||
if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, save_now)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
n_past += n_eval;
|
||||
session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin());
|
||||
n_session_consumed = session_tokens.size();
|
||||
session_do_save = false;
|
||||
|
||||
LOG_DBG("n_past = %d\n", n_past);
|
||||
|
||||
// Display total tokens alongside total time
|
||||
if (params.n_print > 0 && n_past % params.n_print == 0) {
|
||||
LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
if (!embd.empty() && !path_session.empty()) {
|
||||
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
|
||||
n_session_consumed = session_tokens.size();
|
||||
}
|
||||
}
|
||||
|
||||
embd.clear();
|
||||
|
||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||
// optionally save the session on first sample (for faster prompt loading next time)
|
||||
if (session_do_save) {
|
||||
session_do_save = false;
|
||||
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||
|
||||
LOG_DBG("saved session to %s\n", path_session.c_str());
|
||||
}
|
||||
|
||||
const llama_token id = common_sampler_sample(smpl, ctx, -1);
|
||||
|
||||
|
||||
Binary file not shown.
@@ -204,7 +204,8 @@ task_params server_task::params_from_json_cmpl(
|
||||
params.cache_prompt = json_value(data, "cache_prompt", defaults.cache_prompt);
|
||||
params.return_tokens = json_value(data, "return_tokens", false);
|
||||
params.return_progress = json_value(data, "return_progress", false);
|
||||
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
|
||||
auto max_tokens = json_value(data, "max_tokens", defaults.n_predict);
|
||||
params.n_predict = json_value(data, "n_predict", json_value(data, "max_completion_tokens", max_tokens));
|
||||
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
|
||||
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
|
||||
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
||||
|
||||
@@ -114,6 +114,11 @@
|
||||
label: 'Render user content as Markdown',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.FULL_HEIGHT_CODE_BLOCKS,
|
||||
label: 'Use full height code blocks',
|
||||
type: SettingsFieldType.CHECKBOX
|
||||
},
|
||||
{
|
||||
key: SETTINGS_KEYS.DISABLE_AUTO_SCROLL,
|
||||
label: 'Disable automatic scroll',
|
||||
|
||||
@@ -38,6 +38,8 @@
|
||||
import { ActionIconsCodeBlock, DialogCodePreview } from '$lib/components/app';
|
||||
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
|
||||
import type { DatabaseMessageExtra } from '$lib/types/database';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { SETTINGS_KEYS } from '$lib/constants/settings-keys';
|
||||
|
||||
interface Props {
|
||||
attachments?: DatabaseMessageExtra[];
|
||||
@@ -593,7 +595,12 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div bind:this={containerRef} class={className}>
|
||||
<div
|
||||
bind:this={containerRef}
|
||||
class="{className}{config()[SETTINGS_KEYS.FULL_HEIGHT_CODE_BLOCKS]
|
||||
? ' full-height-code-blocks'
|
||||
: ''}"
|
||||
>
|
||||
{#each renderedBlocks as block (block.id)}
|
||||
<div class="markdown-block" data-block-id={block.id}>
|
||||
<!-- eslint-disable-next-line no-at-html-tags -->
|
||||
@@ -914,6 +921,16 @@
|
||||
line-height: 1.3;
|
||||
}
|
||||
|
||||
.full-height-code-blocks :global(.code-block-wrapper) {
|
||||
max-height: none;
|
||||
}
|
||||
|
||||
.full-height-code-blocks :global(.code-block-scroll-container),
|
||||
.full-height-code-blocks .streaming-code-scroll-container {
|
||||
max-height: none;
|
||||
overflow-y: visible;
|
||||
}
|
||||
|
||||
div :global(.code-block-header) {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
|
||||
@@ -22,6 +22,7 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
|
||||
alwaysShowSidebarOnDesktop: false,
|
||||
autoShowSidebarOnNewChat: true,
|
||||
autoMicOnEmpty: false,
|
||||
fullHeightCodeBlocks: false,
|
||||
// make sure these default values are in sync with `common.h`
|
||||
samplers: 'top_k;typ_p;top_p;min_p;temperature',
|
||||
backend_sampling: false,
|
||||
@@ -113,6 +114,8 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
|
||||
'Automatically show sidebar when starting a new chat. Disable to keep the sidebar hidden until you click on it.',
|
||||
autoMicOnEmpty:
|
||||
'Automatically show microphone button instead of send button when textarea is empty for models with audio modality support.',
|
||||
fullHeightCodeBlocks:
|
||||
'Always display code blocks at their full natural height, overriding any height limits.',
|
||||
pyInterpreterEnabled:
|
||||
'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.',
|
||||
enableContinueGeneration:
|
||||
|
||||
@@ -23,6 +23,7 @@ export const SETTINGS_KEYS = {
|
||||
DISABLE_AUTO_SCROLL: 'disableAutoScroll',
|
||||
ALWAYS_SHOW_SIDEBAR_ON_DESKTOP: 'alwaysShowSidebarOnDesktop',
|
||||
AUTO_SHOW_SIDEBAR_ON_NEW_CHAT: 'autoShowSidebarOnNewChat',
|
||||
FULL_HEIGHT_CODE_BLOCKS: 'fullHeightCodeBlocks',
|
||||
// Sampling
|
||||
TEMPERATURE: 'temperature',
|
||||
DYNATEMP_RANGE: 'dynatemp_range',
|
||||
|
||||
@@ -153,6 +153,12 @@ export const SYNCABLE_PARAMETERS: SyncableParameter[] = [
|
||||
serverKey: 'enableContinueGeneration',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
},
|
||||
{
|
||||
key: 'fullHeightCodeBlocks',
|
||||
serverKey: 'fullHeightCodeBlocks',
|
||||
type: SyncableParameterType.BOOLEAN,
|
||||
canSync: true
|
||||
}
|
||||
];
|
||||
|
||||
|
||||
82
vendor/cpp-httplib/httplib.cpp
vendored
82
vendor/cpp-httplib/httplib.cpp
vendored
@@ -1660,6 +1660,7 @@ public:
|
||||
bool is_readable() const override;
|
||||
bool wait_readable() const override;
|
||||
bool wait_writable() const override;
|
||||
bool is_peer_alive() const override;
|
||||
ssize_t read(char *ptr, size_t size) override;
|
||||
ssize_t write(const char *ptr, size_t size) override;
|
||||
void get_remote_ip_and_port(std::string &ip, int &port) const override;
|
||||
@@ -3313,10 +3314,10 @@ bool write_content_with_progress(Stream &strm,
|
||||
return ok;
|
||||
};
|
||||
|
||||
data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };
|
||||
data_sink.is_writable = [&]() -> bool { return strm.is_peer_alive(); };
|
||||
|
||||
while (offset < end_offset && !is_shutting_down()) {
|
||||
if (!strm.wait_writable()) {
|
||||
if (!strm.wait_writable() || !strm.is_peer_alive()) {
|
||||
error = Error::Write;
|
||||
return false;
|
||||
} else if (!content_provider(offset, end_offset - offset, data_sink)) {
|
||||
@@ -3328,6 +3329,11 @@ bool write_content_with_progress(Stream &strm,
|
||||
}
|
||||
}
|
||||
|
||||
if (offset < end_offset) { // exited due to is_shutting_down(), not completion
|
||||
error = Error::Write;
|
||||
return false;
|
||||
}
|
||||
|
||||
error = Error::Success;
|
||||
return true;
|
||||
}
|
||||
@@ -3367,12 +3373,12 @@ write_content_without_length(Stream &strm,
|
||||
return ok;
|
||||
};
|
||||
|
||||
data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };
|
||||
data_sink.is_writable = [&]() -> bool { return strm.is_peer_alive(); };
|
||||
|
||||
data_sink.done = [&](void) { data_available = false; };
|
||||
|
||||
while (data_available && !is_shutting_down()) {
|
||||
if (!strm.wait_writable()) {
|
||||
if (!strm.wait_writable() || !strm.is_peer_alive()) {
|
||||
return false;
|
||||
} else if (!content_provider(offset, 0, data_sink)) {
|
||||
return false;
|
||||
@@ -3380,7 +3386,8 @@ write_content_without_length(Stream &strm,
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return !data_available; // true only if done() was called, false if shutting
|
||||
// down
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
@@ -3416,7 +3423,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
|
||||
return ok;
|
||||
};
|
||||
|
||||
data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); };
|
||||
data_sink.is_writable = [&]() -> bool { return strm.is_peer_alive(); };
|
||||
|
||||
auto done_with_trailer = [&](const Headers *trailer) {
|
||||
if (!ok) { return; }
|
||||
@@ -3466,7 +3473,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
|
||||
};
|
||||
|
||||
while (data_available && !is_shutting_down()) {
|
||||
if (!strm.wait_writable()) {
|
||||
if (!strm.wait_writable() || !strm.is_peer_alive()) {
|
||||
error = Error::Write;
|
||||
return false;
|
||||
} else if (!content_provider(offset, 0, data_sink)) {
|
||||
@@ -3478,6 +3485,11 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider,
|
||||
}
|
||||
}
|
||||
|
||||
if (data_available) { // exited due to is_shutting_down(), not done()
|
||||
error = Error::Write;
|
||||
return false;
|
||||
}
|
||||
|
||||
error = Error::Success;
|
||||
return true;
|
||||
}
|
||||
@@ -4646,6 +4658,7 @@ public:
|
||||
bool is_readable() const override;
|
||||
bool wait_readable() const override;
|
||||
bool wait_writable() const override;
|
||||
bool is_peer_alive() const override;
|
||||
ssize_t read(char *ptr, size_t size) override;
|
||||
ssize_t write(const char *ptr, size_t size) override;
|
||||
void get_remote_ip_and_port(std::string &ip, int &port) const override;
|
||||
@@ -6069,8 +6082,11 @@ bool SocketStream::wait_readable() const {
|
||||
}
|
||||
|
||||
bool SocketStream::wait_writable() const {
|
||||
return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 &&
|
||||
is_socket_alive(sock_);
|
||||
return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0;
|
||||
}
|
||||
|
||||
bool SocketStream::is_peer_alive() const {
|
||||
return detail::is_socket_alive(sock_);
|
||||
}
|
||||
|
||||
ssize_t SocketStream::read(char *ptr, size_t size) {
|
||||
@@ -6401,7 +6417,11 @@ bool SSLSocketStream::wait_readable() const {
|
||||
|
||||
bool SSLSocketStream::wait_writable() const {
|
||||
return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 &&
|
||||
is_socket_alive(sock_) && !tls::is_peer_closed(session_, sock_);
|
||||
!tls::is_peer_closed(session_, sock_);
|
||||
}
|
||||
|
||||
bool SSLSocketStream::is_peer_alive() const {
|
||||
return !tls::is_peer_closed(session_, sock_);
|
||||
}
|
||||
|
||||
ssize_t SSLSocketStream::read(char *ptr, size_t size) {
|
||||
@@ -6925,35 +6945,33 @@ bool Server::write_response_core(Stream &strm, bool close_connection,
|
||||
if (post_routing_handler_) { post_routing_handler_(req, res); }
|
||||
|
||||
// Response line and headers
|
||||
{
|
||||
detail::BufferStream bstrm;
|
||||
if (!detail::write_response_line(bstrm, res.status)) { return false; }
|
||||
if (header_writer_(bstrm, res.headers) <= 0) { return false; }
|
||||
detail::BufferStream bstrm;
|
||||
if (!detail::write_response_line(bstrm, res.status)) { return false; }
|
||||
if (header_writer_(bstrm, res.headers) <= 0) { return false; }
|
||||
|
||||
// Flush buffer
|
||||
auto &data = bstrm.get_buffer();
|
||||
detail::write_data(strm, data.data(), data.size());
|
||||
// Combine small body with headers to reduce write syscalls
|
||||
if (req.method != "HEAD" && !res.body.empty() && !res.content_provider_) {
|
||||
bstrm.write(res.body.data(), res.body.size());
|
||||
}
|
||||
|
||||
// Body
|
||||
// Log before writing to avoid race condition with client-side code that
|
||||
// accesses logger-captured data immediately after receiving the response.
|
||||
output_log(req, res);
|
||||
|
||||
// Flush buffer
|
||||
auto &data = bstrm.get_buffer();
|
||||
if (!detail::write_data(strm, data.data(), data.size())) { return false; }
|
||||
|
||||
// Streaming body
|
||||
auto ret = true;
|
||||
if (req.method != "HEAD") {
|
||||
if (!res.body.empty()) {
|
||||
if (!detail::write_data(strm, res.body.data(), res.body.size())) {
|
||||
ret = false;
|
||||
}
|
||||
} else if (res.content_provider_) {
|
||||
if (write_content_with_provider(strm, req, res, boundary, content_type)) {
|
||||
res.content_provider_success_ = true;
|
||||
} else {
|
||||
ret = false;
|
||||
}
|
||||
if (req.method != "HEAD" && res.content_provider_) {
|
||||
if (write_content_with_provider(strm, req, res, boundary, content_type)) {
|
||||
res.content_provider_success_ = true;
|
||||
} else {
|
||||
ret = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Log
|
||||
output_log(req, res);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
31
vendor/cpp-httplib/httplib.h
vendored
31
vendor/cpp-httplib/httplib.h
vendored
@@ -8,8 +8,8 @@
|
||||
#ifndef CPPHTTPLIB_HTTPLIB_H
|
||||
#define CPPHTTPLIB_HTTPLIB_H
|
||||
|
||||
#define CPPHTTPLIB_VERSION "0.33.1"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002101"
|
||||
#define CPPHTTPLIB_VERSION "0.34.0"
|
||||
#define CPPHTTPLIB_VERSION_NUM "0x002200"
|
||||
|
||||
/*
|
||||
* Platform compatibility check
|
||||
@@ -1038,6 +1038,32 @@ make_file_provider(const std::string &name, const std::string &filepath,
|
||||
return fdp;
|
||||
}
|
||||
|
||||
inline std::pair<size_t, ContentProvider>
|
||||
make_file_body(const std::string &filepath) {
|
||||
std::ifstream f(filepath, std::ios::binary | std::ios::ate);
|
||||
if (!f) { return {0, ContentProvider{}}; }
|
||||
auto size = static_cast<size_t>(f.tellg());
|
||||
|
||||
ContentProvider provider = [filepath](size_t offset, size_t length,
|
||||
DataSink &sink) -> bool {
|
||||
std::ifstream f(filepath, std::ios::binary);
|
||||
if (!f) { return false; }
|
||||
f.seekg(static_cast<std::streamoff>(offset));
|
||||
if (!f.good()) { return false; }
|
||||
char buf[8192];
|
||||
while (length > 0) {
|
||||
auto to_read = (std::min)(sizeof(buf), length);
|
||||
f.read(buf, static_cast<std::streamsize>(to_read));
|
||||
auto n = static_cast<size_t>(f.gcount());
|
||||
if (n == 0) { break; }
|
||||
if (!sink.write(buf, n)) { return false; }
|
||||
length -= n;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
return {size, std::move(provider)};
|
||||
}
|
||||
|
||||
using ContentReceiverWithProgress = std::function<bool(
|
||||
const char *data, size_t data_length, size_t offset, size_t total_length)>;
|
||||
|
||||
@@ -1352,6 +1378,7 @@ public:
|
||||
virtual bool is_readable() const = 0;
|
||||
virtual bool wait_readable() const = 0;
|
||||
virtual bool wait_writable() const = 0;
|
||||
virtual bool is_peer_alive() const { return wait_writable(); }
|
||||
|
||||
virtual ssize_t read(char *ptr, size_t size) = 0;
|
||||
virtual ssize_t write(const char *ptr, size_t size) = 0;
|
||||
|
||||
Reference in New Issue
Block a user