mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-19 14:13:22 +02:00
Compare commits
4 Commits
gg/scripts
...
b8068
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
267ba5a1d9 | ||
|
|
ff4affb4c1 | ||
|
|
55d58599c8 | ||
|
|
1a8c700bfd |
@@ -1,190 +0,0 @@
|
||||
# llama-eval Codebase Guidelines
|
||||
|
||||
## Overview
|
||||
|
||||
This directory contains Python evaluation tools for llama.cpp:
|
||||
- `llama-eval.py` - Main evaluation tool with multiple datasets (AIME, AIME2025, GSM8K, GPQA)
|
||||
- `llama-server-simulator.py` - Flask-based server simulator for testing
|
||||
- `test-simulator.sh` - Test script for the simulator
|
||||
|
||||
## Build/Run Commands
|
||||
|
||||
### Virtual Environment
|
||||
The project uses a virtual environment located at `venv/`:
|
||||
```bash
|
||||
source venv/bin/activate
|
||||
```
|
||||
|
||||
### Running the Main Evaluator
|
||||
```bash
|
||||
python llama-eval.py \
|
||||
--server http://127.0.0.1:8013 \
|
||||
--model gpt-oss-20b-hf-low \
|
||||
--dataset aime \
|
||||
--n_cases 10 \
|
||||
--grader-type llm \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
### Running the Simulator (for testing)
|
||||
```bash
|
||||
python llama-server-simulator.py --port 8033 --success-rate 0.8
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
```bash
|
||||
./test-simulator.sh
|
||||
```
|
||||
|
||||
## Code Style Guidelines
|
||||
|
||||
### Imports
|
||||
- Standard library imports first (argparse, json, os, re, subprocess, sys, time)
|
||||
- Third-party imports (requests, tqdm, datasets, flask) after standard library
|
||||
- Relative imports not used
|
||||
- Group imports by category with blank line between groups
|
||||
|
||||
### Formatting
|
||||
- 4-space indentation
|
||||
- Max line length: 125 characters (per parent project's .flake8)
|
||||
- Use double quotes for strings
|
||||
- Use triple double quotes for docstrings
|
||||
- Binary operators at the beginning of continued lines
|
||||
|
||||
### Naming Conventions
|
||||
- Classes: PascalCase (e.g., `AimeDataset`, `Grader`, `Processor`)
|
||||
- Functions: snake_case (e.g., `normalize_number`, `get_prompt`)
|
||||
- Variables: snake_case (e.g., `question_text`, `correct_count`)
|
||||
- Constants: UPPER_SNAKE_CASE (e.g., `GRADER_PATTERNS`, `TEMPLATE_REGISTRY`)
|
||||
- Private methods: prefix with underscore (e.g., `_load_dataset`, `_grade_regex`)
|
||||
|
||||
### Types
|
||||
- Use type hints for all function signatures
|
||||
- Import from `typing` module: `Dict`, `List`, `Optional`, `Any`, `Tuple`
|
||||
- Use `@dataclass` for data structures
|
||||
- Prefer `Optional[T]` over `Union[T, None]`
|
||||
|
||||
### Error Handling
|
||||
- Use try/except for network requests and file operations
|
||||
- Return `None` or `False` on errors when appropriate
|
||||
- Use `ValueError` for invalid arguments
|
||||
- Use `FileNotFoundError` for missing files
|
||||
- CLI scripts should handle exceptions gracefully
|
||||
|
||||
### Dataclasses
|
||||
- Use `@dataclass` for structured data
|
||||
- Define fields with explicit types
|
||||
- Use `Optional[T]` for nullable fields
|
||||
- Provide default values where appropriate
|
||||
|
||||
### String Formatting
|
||||
- Use f-strings for formatting (Python 3.6+)
|
||||
- Use triple double quotes for multi-line strings
|
||||
- Escape backslashes in regex patterns: `r'\\boxed{(\d+)}'`
|
||||
|
||||
### File Paths
|
||||
- Use `pathlib.Path` instead of string paths
|
||||
- Create directories with `mkdir(parents=True, exist_ok=True)`
|
||||
- Use `Path.home()` for user home directory
|
||||
|
||||
### Logging
|
||||
- Use `print()` for user-facing output
|
||||
- Use `sys.stderr` for debug logging
|
||||
- Simulator writes debug logs to `/tmp/simulator-debug.log`
|
||||
|
||||
### Testing
|
||||
|
||||
- Test script uses bash with `set -e` for strict error handling
|
||||
- Simulator runs in background with PID tracking
|
||||
- Tests verify correct answers, error cases, and edge cases
|
||||
- Use `curl` for HTTP testing in shell scripts
|
||||
|
||||
### Whitespace Cleanup
|
||||
- Remove trailing whitespace from all lines
|
||||
- When making edits, do not leave trailing whitespace
|
||||
|
||||
## Dataset Support
|
||||
|
||||
### AIME Dataset
|
||||
- 90 questions from 2025 AIME competition
|
||||
- Answers in `\boxed{answer}` format
|
||||
- Supports regex, CLI, and LLM grading
|
||||
|
||||
### AIME2025 Dataset
|
||||
- 30 questions from 2025 AIME I & II
|
||||
- Answers in `\boxed{answer}` format
|
||||
- Requires loading two config parts
|
||||
|
||||
### GSM8K Dataset
|
||||
- 7473 math word problems
|
||||
- Answers numeric values with `####` separator
|
||||
- Supports regex, CLI, and LLM grading
|
||||
|
||||
### GPQA Dataset
|
||||
- 198 questions from GPQA Diamond
|
||||
- Multiple choice with shuffled options (A, B, C, D)
|
||||
- **Requires LLM grader** (returns letter A/B/C/D)
|
||||
|
||||
## Grading Types
|
||||
|
||||
### Regex Grader
|
||||
- Built-in patterns per dataset
|
||||
- Prioritizes `\boxed{}` for AIME datasets
|
||||
- Extracts last number for GSM8K
|
||||
|
||||
### CLI Grader
|
||||
- External script interface
|
||||
- Call: `grader.sh --answer <pred> --expected <gold>`
|
||||
- Exit code 0 = correct, non-zero = incorrect
|
||||
|
||||
### LLM Grader
|
||||
- Uses judge model for answer extraction
|
||||
- Includes few-shot examples
|
||||
- Case-insensitive comparison
|
||||
- Required for GPQA
|
||||
|
||||
## Configuration
|
||||
|
||||
### Sampling Parameters (Optional)
|
||||
- `--temperature`: Sampling temperature
|
||||
- `--top-k`: Top K sampling
|
||||
- `--top-p`: Top P sampling
|
||||
- `--min-p`: Min P sampling
|
||||
- Only passed to API if explicitly specified
|
||||
|
||||
### Default Values
|
||||
- `--n_predict`: -1 (infinite)
|
||||
- `--grader-type`: llm
|
||||
- `--seed`: 1234
|
||||
- `--threads`: 32
|
||||
- `--output`: llama-eval-state.json
|
||||
|
||||
## Output Format
|
||||
|
||||
### Progress Table
|
||||
- Shows task ID, dataset, prompt (truncated to 43 chars), expected answer, status
|
||||
- Uses `tqdm` for progress bars
|
||||
|
||||
### Results Summary
|
||||
- Format: `Results: X/Y correct (Z%)`
|
||||
- Displayed after all tasks complete
|
||||
|
||||
### JSON Output
|
||||
- Complete eval state saved to output file
|
||||
- Contains: task IDs, correctness, prompts, extracted answers, sampling config
|
||||
- Uses `dataclasses.asdict()` for serialization
|
||||
|
||||
## HuggingFace Datasets
|
||||
|
||||
- Cache directory: `~/.cache/huggingface/datasets`
|
||||
- Set via `HF_DATASETS_CACHE` environment variable
|
||||
- Telemetry disabled via `HF_HUB_DISABLE_TELEMETRY=1`
|
||||
- Datasets loaded with `datasets.load_dataset()`
|
||||
|
||||
## Flask Simulator
|
||||
|
||||
- Runs on configurable port (default: 5000)
|
||||
- Endpoint: `/v1/chat/completions` (OpenAI-compatible)
|
||||
- Uses Dice coefficient for question matching
|
||||
- Configurable success rate for testing
|
||||
- Debug logs to `/tmp/simulator-debug.log`
|
||||
@@ -1,94 +0,0 @@
|
||||
# llama-eval Implementation Summary
|
||||
|
||||
## Overview
|
||||
|
||||
Simple evaluation tool for llama.cpp with support for multiple datasets (AIME, GSM8K, GPQA) and flexible grading (regex, CLI, LLM).
|
||||
|
||||
## Key Features
|
||||
|
||||
- **Multiple Datasets**: AIME, GSM8K, GPQA with proper answer extraction
|
||||
- **Flexible Grading**: Regex, CLI, or LLM-based grading
|
||||
- **Parallel Processing**: Configurable thread count for concurrent requests
|
||||
- **Sampling Parameters**: Temperature, Top K, Top P, Min P (optional)
|
||||
- **Real-time Feedback**: Progress tracking with detailed output
|
||||
- **JSON Output**: Complete eval state saved for debugging
|
||||
- **GPQA Support**: Answer shuffling with reproducible results
|
||||
|
||||
## Architecture
|
||||
|
||||
### Eval State
|
||||
```python
|
||||
@dataclass
|
||||
class EvalState:
|
||||
id: str
|
||||
tasks: List[str]
|
||||
task_states: Dict[str, Dict[str, Any]]
|
||||
sampling_config: Dict[str, Any]
|
||||
```
|
||||
|
||||
### Processor
|
||||
- Handles processing, grading, and state management
|
||||
- Thread-safe concurrent execution
|
||||
- Configurable sampling parameters
|
||||
|
||||
### Grader
|
||||
- Abstract grading interface supporting multiple types
|
||||
- Regex grader with dataset-specific patterns
|
||||
- CLI grader with external script interface
|
||||
- LLM grader with configurable server and model
|
||||
|
||||
### Datasets
|
||||
- `AimeDataset`: 90 AIME 2025 questions
|
||||
- `Aime2025Dataset`: 30 AIME 2025 I & II questions
|
||||
- `Gsm8kDataset`: 7473 math word problems
|
||||
- `GpqaDataset`: 198 GPQA Diamond questions with shuffling
|
||||
|
||||
## Configuration
|
||||
|
||||
### Sampling Parameters (Optional)
|
||||
- `--temperature`: Sampling temperature
|
||||
- `--top-k`: Top K sampling
|
||||
- `--top-p`: Top P sampling
|
||||
- `--min-p`: Min P sampling
|
||||
- Only passed if explicitly specified
|
||||
|
||||
### Grading Types
|
||||
- **regex**: Built-in patterns for each dataset
|
||||
- **cli**: External script with `--answer` and `--expected` args
|
||||
- **llm**: LLM-based extraction with few-shot examples and configurable server/model
|
||||
|
||||
### Dataset Requirements
|
||||
- **AIME**: Supports regex, CLI, or LLM grader
|
||||
- **AIME2025**: Supports regex, CLI, or LLM grader
|
||||
- **GSM8K**: Supports regex, CLI, or LLM grader
|
||||
- **GPQA**: Requires LLM grader
|
||||
|
||||
## Output Format
|
||||
|
||||
### Progress Table
|
||||
```
|
||||
Task ID Dataset Prompt (first 43 chars) Expected Status
|
||||
aime_000_001 AIME Complete the following reactions and sel... A pending
|
||||
```
|
||||
|
||||
### Results Summary
|
||||
```
|
||||
============================================================
|
||||
Results: 8/10 correct (80.0%)
|
||||
============================================================
|
||||
```
|
||||
|
||||
### JSON Output
|
||||
Complete eval state with task IDs, correctness, prompts, extracted answers, and sampling configuration.
|
||||
|
||||
## Technical Details
|
||||
|
||||
- Default max tokens: -1 (infinite)
|
||||
- Default grader type: llm
|
||||
- Default seed: 1234
|
||||
- Default threads: 32
|
||||
- Prompt truncation: First 43 chars + padding + "..."
|
||||
- Response truncation: Last 10 lines for grading
|
||||
- GPQA requires LLM grader (returns letter A/B/C/D)
|
||||
- Judge model defaults to evaluated model if not specified
|
||||
- Sample answers defined in SAMPLE_ANSWERS dict for few-shot learning
|
||||
@@ -1,112 +0,0 @@
|
||||
# llama-eval Evaluation Tool
|
||||
|
||||
Simple evaluation tool for llama.cpp with support for multiple datasets.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multiple Datasets**: AIME, GSM8K, GPQA
|
||||
- **Flexible Grading**: Regex, CLI, or LLM-based grading
|
||||
- **Parallel Processing**: Configurable thread count
|
||||
- **Real-time Feedback**: Progress tracking with detailed output
|
||||
- **Sampling Parameters**: Temperature, Top K, Top P, Min P
|
||||
- **JSON Output**: Complete eval state saved for debugging
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
python llama-eval.py \
|
||||
--server http://127.0.0.1:8013 \
|
||||
--model gpt-oss-20b-hf-low \
|
||||
--judge-model gpt-oss-20b-hf-medium \
|
||||
--dataset aime \
|
||||
--n_cases 10 \
|
||||
--grader-type llm \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
## CLI Arguments
|
||||
|
||||
- `--server`: llama-server URL (default: http://127.0.0.1:8013)
|
||||
- `--model`: Model name for evaluation (default: llama)
|
||||
- `--judge-model`: Model name for LLM judge (default: same as main model)
|
||||
- `--judge-server`: Server URL for LLM judge (default: same as main server)
|
||||
- `--dataset`: Dataset type (aime, aime2025, gsm8k, gpqa)
|
||||
- `--n_cases`: Number of cases to evaluate (default: all)
|
||||
- `--n_predict`: Max tokens to predict per prompt (default: -1, infinite)
|
||||
- `--temperature`: Sampling temperature (default: not passed)
|
||||
- `--top-k`: Top K sampling (default: not passed)
|
||||
- `--top-p`: Top P sampling (default: not passed)
|
||||
- `--min-p`: Min P sampling (default: not passed)
|
||||
- `--threads`: Number of threads for parallel requests (default: 32)
|
||||
- `--verbose`: Show detailed output for each case
|
||||
- `--output`: Output file for eval state (default: llama-eval-state.json)
|
||||
- `--grader-type`: Grader type (regex, cli, llm, default: llm)
|
||||
- `--grader-script`: Path to CLI grader script (required for --grader-type cli)
|
||||
- `--seed`: Random seed for shuffling (default: 1234)
|
||||
|
||||
## Datasets
|
||||
|
||||
### AIME
|
||||
- 90 questions from 2025 AIME competition
|
||||
- Answers in boxed format: `\boxed{answer}`
|
||||
- Requires regex grader or LLM grader
|
||||
|
||||
### AIME2025
|
||||
- 30 questions from 2025 AIME I & II competitions
|
||||
- Answers in boxed format: `\boxed{answer}`
|
||||
- Supports regex, CLI, or LLM grader
|
||||
|
||||
### GSM8K
|
||||
- 7473 math word problems
|
||||
- Answers are numeric values
|
||||
- Requires regex grader or LLM grader
|
||||
|
||||
### GPQA
|
||||
- 198 questions from GPQA Diamond dataset
|
||||
- Multiple choice with shuffled options
|
||||
- Requires LLM grader (returns letter A, B, C, or D)
|
||||
|
||||
## Grading Types
|
||||
|
||||
### Regex Grader
|
||||
Built-in patterns for different datasets:
|
||||
- AIME: `\boxed{(\d+)}|\b(\d+)\b`
|
||||
- AIME2025: `\boxed{(\d+)}|\b(\d+)\b`
|
||||
- GSM8K: `\b(\d+)\b`
|
||||
- GPQA: Letter extraction (A, B, C, D)
|
||||
|
||||
### CLI Grader
|
||||
External script interface:
|
||||
```bash
|
||||
./grader.sh --answer <pred> --expected <gold>
|
||||
```
|
||||
Returns exit code 0 if correct, non-zero if incorrect.
|
||||
|
||||
### LLM Grader
|
||||
Uses LLM to extract and compare answers:
|
||||
- Configurable server and model
|
||||
- Includes few-shot examples from sample answers
|
||||
- Case-insensitive comparison
|
||||
- Required for GPQA dataset
|
||||
|
||||
## Output
|
||||
|
||||
### Progress Table
|
||||
```
|
||||
Task ID Dataset Prompt (first 43 chars) Expected Status
|
||||
aime_000_001 AIME Complete the following reactions and sel... A pending
|
||||
```
|
||||
|
||||
### Results
|
||||
```
|
||||
============================================================
|
||||
Results: 8/10 correct (80.0%)
|
||||
============================================================
|
||||
```
|
||||
|
||||
### JSON Output
|
||||
Complete eval state saved to output file with:
|
||||
- Task IDs and correctness status
|
||||
- Prompts and extracted answers
|
||||
- Sampling configuration
|
||||
- Processing metadata
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,36 +0,0 @@
|
||||
# llama-server-simulator
|
||||
|
||||
Standalone Python script simulating llama-server HTTP endpoint for testing.
|
||||
|
||||
## Features
|
||||
|
||||
- HTTP Server with OpenAI-compatible `/v1/chat/completions` endpoint
|
||||
- AIME Dataset Integration - Loads 90 questions from HuggingFace
|
||||
- Intelligent Question Matching - Uses exact matching, LaTeX removal, and Levenshtein distance
|
||||
- Configurable Success Rate - Control correct/wrong answer generation (0-1)
|
||||
- Debug Logging - Troubleshoot matching issues
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
python llama-server-simulator.py --success-rate 0.8
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
- `--success-rate`: Probability of returning correct answer (0.0-1.0, default: 0.8)
|
||||
- `--port`: Server port (default: 8033)
|
||||
- `--debug`: Enable debug logging (default: False)
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
./test-simulator.sh
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
- Uses Levenshtein distance for partial matching (threshold: 0.3)
|
||||
- Automatic caching via HuggingFace datasets library
|
||||
- Wrong answers generated by incrementing expected answer
|
||||
- Debug output written to stderr
|
||||
@@ -1,283 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
from flask import Flask, request, jsonify
|
||||
|
||||
# Set cache directory for HuggingFace datasets
|
||||
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
os.environ["HF_DATASETS_CACHE"] = str(cache_dir)
|
||||
|
||||
def dice(s1: str, s2: str) -> float:
|
||||
"""Calculate Dice coefficient between two strings based on bigram overlap."""
|
||||
if not s1 and not s2:
|
||||
return 1.0
|
||||
|
||||
def _bigrams(s: str):
|
||||
return [s[i : i + 2] for i in range(len(s) - 1)]
|
||||
|
||||
bigrams1 = _bigrams(s1)
|
||||
bigrams2 = _bigrams(s2)
|
||||
|
||||
if not bigrams1 and not bigrams2:
|
||||
return 1.0
|
||||
|
||||
from collections import Counter
|
||||
|
||||
freq1 = Counter(bigrams1)
|
||||
freq2 = Counter(bigrams2)
|
||||
|
||||
intersection = sum(min(freq1[bg], freq2[bg]) for bg in freq1)
|
||||
dice_coeff = 2 * intersection / (len(bigrams1) + len(bigrams2))
|
||||
return dice_coeff
|
||||
|
||||
def debug_log(message: str):
|
||||
"""Log debug messages to both stdout and a file"""
|
||||
print(message, file=sys.stderr)
|
||||
with open("/tmp/simulator-debug.log", "a") as f:
|
||||
f.write(message + "\n")
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@dataclass
|
||||
class EvalState:
|
||||
id: str
|
||||
tasks: List[str]
|
||||
task_states: Dict[str, Dict]
|
||||
sampling_config: Dict
|
||||
|
||||
def normalize_number(s: str) -> Optional[int]:
|
||||
match = re.match(r"\d+", s) # match digits from the start
|
||||
if not match:
|
||||
return None
|
||||
return int(match.group(0))
|
||||
|
||||
class AimeDataset:
|
||||
def __init__(self, split: str = "train"):
|
||||
self.split = split
|
||||
self.questions: List[Dict] = []
|
||||
self._load_dataset()
|
||||
|
||||
def _load_dataset(self):
|
||||
print(f"Loading AIME dataset (split: {self.split})...")
|
||||
|
||||
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
|
||||
if cache_path.exists():
|
||||
print(f"Using cached dataset from {cache_path}")
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
|
||||
else:
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
|
||||
|
||||
self.questions = list(ds)
|
||||
print(f"AIME dataset loaded: {len(self.questions)} questions")
|
||||
|
||||
def find_question(self, request_text: str) -> Optional[Dict]:
|
||||
best_match = None
|
||||
best_distance = -1
|
||||
best_index = -1
|
||||
|
||||
for i, question in enumerate(self.questions):
|
||||
question_text = question["problem"]
|
||||
request_lower = request_text.lower()
|
||||
question_lower = question_text.lower()
|
||||
|
||||
# Exact match
|
||||
if question_lower == request_lower:
|
||||
debug_log(f"DEBUG: Found exact match at index {i}")
|
||||
return question
|
||||
|
||||
# Remove LaTeX formatting for more flexible matching
|
||||
question_no_latex = re.sub(r'\$[^$]+\$', '', question_text)
|
||||
if question_no_latex.lower() == request_lower:
|
||||
debug_log(f"DEBUG: Found match (no LaTeX) at index {i}")
|
||||
return question
|
||||
|
||||
# Calculate Levenshtein distance for partial matches
|
||||
# Only consider if request is at least 50% of question length
|
||||
if len(request_lower) >= len(question_lower) * 0.5:
|
||||
distance = dice(question_lower, request_lower)
|
||||
|
||||
if distance > best_distance:
|
||||
best_distance = distance
|
||||
best_match = question
|
||||
best_index = i
|
||||
|
||||
if best_match and best_distance > 0.3: # Threshold for partial match
|
||||
debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}")
|
||||
return best_match
|
||||
|
||||
debug_log(f"DEBUG: No matching question found for: {request_text[:100]}...")
|
||||
return None
|
||||
|
||||
def get_answer(self, question: Dict) -> str:
|
||||
answer = question["answer"]
|
||||
if isinstance(answer, str):
|
||||
normalized = normalize_number(answer)
|
||||
return str(normalized) if normalized is not None else answer
|
||||
return str(answer)
|
||||
|
||||
class Simulator:
|
||||
def __init__(
|
||||
self,
|
||||
port: int = 8033,
|
||||
host: str = "localhost",
|
||||
success_rate: float = 0.8,
|
||||
dataset_split: str = "train"
|
||||
):
|
||||
self.port = port
|
||||
self.host = host
|
||||
self.success_rate = success_rate
|
||||
self.dataset = AimeDataset(dataset_split)
|
||||
self.eval_state = EvalState(
|
||||
id="aime-2025",
|
||||
tasks=["aime"],
|
||||
task_states={},
|
||||
sampling_config={"temperature": 0, "max_tokens": 2048}
|
||||
)
|
||||
|
||||
def _generate_response(
|
||||
self,
|
||||
question: Dict,
|
||||
should_be_correct: bool
|
||||
) -> Dict:
|
||||
expected_answer = self.dataset.get_answer(question)
|
||||
|
||||
if should_be_correct:
|
||||
response_text = expected_answer
|
||||
else:
|
||||
response_text = self._generate_wrong_answer(question)
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{int(time.time())}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": "llama",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response_text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150
|
||||
}
|
||||
}
|
||||
|
||||
def _generate_wrong_answer(self, question: Dict) -> str:
|
||||
expected_answer = self.dataset.get_answer(question)
|
||||
|
||||
if expected_answer.isdigit():
|
||||
wrong_answer = str(int(expected_answer) + 1)
|
||||
else:
|
||||
wrong_answer = expected_answer + " (wrong)"
|
||||
|
||||
return wrong_answer
|
||||
|
||||
def _process_request(self, request_data: Dict) -> Dict:
|
||||
messages = request_data.get("messages", [])
|
||||
if not messages:
|
||||
return {"error": "No messages in request"}
|
||||
|
||||
request_text = messages[0].get("content", "")
|
||||
debug_log(f"DEBUG: Received request with content: {request_text[:150]}...")
|
||||
|
||||
question = self.dataset.find_question(request_text)
|
||||
if not question:
|
||||
debug_log(f"DEBUG: find_question returned None")
|
||||
return {"error": "No matching question found"}
|
||||
|
||||
should_be_correct = random.random() < self.success_rate
|
||||
|
||||
response = self._generate_response(question, should_be_correct)
|
||||
|
||||
task_id = "aime"
|
||||
self.eval_state.task_states[task_id] = {
|
||||
"correct": should_be_correct,
|
||||
"expected": self.dataset.get_answer(question),
|
||||
"predicted": response["choices"][0]["message"]["content"]
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
@app.route('/v1/chat/completions', methods=['POST'])
|
||||
def chat_completions():
|
||||
try:
|
||||
request_data = request.get_json()
|
||||
|
||||
if not request_data:
|
||||
return jsonify({"error": "Invalid JSON"}), 400
|
||||
|
||||
response = simulator._process_request(request_data)
|
||||
|
||||
return jsonify(response)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing request: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="llama-server simulator for testing eval scripts"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8033,
|
||||
help="Server port (default: 8033)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Server host (default: localhost)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--success-rate",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="Success rate 0-1 (default: 0.8)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-split",
|
||||
type=str,
|
||||
default="train",
|
||||
help="AIME dataset split to use (default: train)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
global simulator
|
||||
simulator = Simulator(
|
||||
port=args.port,
|
||||
host=args.host,
|
||||
success_rate=args.success_rate,
|
||||
dataset_split=args.dataset_split
|
||||
)
|
||||
|
||||
print("\n=== llama-server-simulator ===")
|
||||
print(f"Server running on http://{args.host}:{args.port}")
|
||||
print(f"Success rate: {args.success_rate}")
|
||||
print(f"AIME dataset loaded: {len(simulator.dataset.questions)} questions")
|
||||
print("\nPress Ctrl+C to stop\n")
|
||||
|
||||
app.run(host=args.host, port=args.port, debug=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,86 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Get the directory where this script is located
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
echo "=== llama-server-simulator Test Script ==="
|
||||
echo ""
|
||||
|
||||
PORT=8033
|
||||
SUCCESS_RATE=0.8
|
||||
TEST_PORT=8034
|
||||
|
||||
echo "Starting simulator on port $PORT with success rate $SUCCESS_RATE..."
|
||||
source "$SCRIPT_DIR/venv/bin/activate"
|
||||
python3 "$SCRIPT_DIR/llama-server-simulator.py" --port $PORT --success-rate $SUCCESS_RATE > /tmp/simulator-test.log 2>&1 &
|
||||
SIMULATOR_PID=$!
|
||||
|
||||
echo "Waiting for simulator to start..."
|
||||
sleep 5
|
||||
|
||||
# Helper function to make a request and extract the answer
|
||||
make_request() {
|
||||
local question="$1"
|
||||
curl -s -X POST http://localhost:$PORT/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"model\": \"llama\",
|
||||
\"messages\": [
|
||||
{\"role\": \"user\", \"content\": \"$question\"}
|
||||
],
|
||||
\"temperature\": 0,
|
||||
\"max_tokens\": 2048
|
||||
}" | python3 -c "import sys, json; data = json.load(sys.stdin); print(data.get('choices', [{}])[0].get('message', {}).get('content', data.get('error', 'No response')))"
|
||||
}
|
||||
|
||||
# Test question (repeated in multiple tests)
|
||||
TEST_QUESTION="Quadratic polynomials P(x) and Q(x) have leading coefficients 2 and -2, respectively. The graphs of both polynomials pass through the two points (16,54) and (20,53). Find P(0) + Q(0)."
|
||||
|
||||
echo ""
|
||||
echo "=== Test 1: Correct Answer ==="
|
||||
echo "Sending request with known question..."
|
||||
answer=$(make_request "$TEST_QUESTION")
|
||||
echo "Answer: $answer"
|
||||
echo "Expected: 116"
|
||||
echo "Correct: $([ "$answer" == "116" ] && echo "Yes" || echo "No")"
|
||||
|
||||
echo ""
|
||||
echo "=== Test 2: Wrong Answer ==="
|
||||
echo "Sending request with known question (success rate 0.0)..."
|
||||
answer=$(make_request "$TEST_QUESTION")
|
||||
echo "Answer: $answer"
|
||||
echo "Expected: 116"
|
||||
echo "Correct: $([ "$answer" == "116" ] && echo "Yes" || echo "No")"
|
||||
|
||||
echo ""
|
||||
echo "=== Test 3: No Matching Question ==="
|
||||
echo "Sending request with non-matching text..."
|
||||
response=$(make_request "What is the capital of France?")
|
||||
echo "Response: $response"
|
||||
echo "Expected: No matching question found"
|
||||
echo "Correct: $([ "$response" == "No matching question found" ] && echo "Yes" || echo "No")"
|
||||
|
||||
echo ""
|
||||
echo "=== Test 4: Success Rate Verification ==="
|
||||
echo "Sending 10 requests to test success rate..."
|
||||
correct_count=0
|
||||
for i in {1..10}; do
|
||||
answer=$(make_request "$TEST_QUESTION")
|
||||
if [ "$answer" == "116" ]; then
|
||||
correct_count=$((correct_count + 1))
|
||||
fi
|
||||
echo " Request $i: Answer = $answer"
|
||||
done
|
||||
echo "Correct answers: $correct_count/10"
|
||||
echo "Expected: ~8/10 (80% success rate)"
|
||||
echo "Success rate: $(echo "scale=1; $correct_count * 10" | bc)%"
|
||||
|
||||
echo ""
|
||||
echo "=== Test Complete ==="
|
||||
echo "Stopping simulator..."
|
||||
kill $SIMULATOR_PID 2>/dev/null
|
||||
wait $SIMULATOR_PID 2>/dev/null || true
|
||||
|
||||
echo "Simulator stopped."
|
||||
@@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 9)
|
||||
set(GGML_VERSION_PATCH 5)
|
||||
set(GGML_VERSION_PATCH 7)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
@@ -3226,6 +3226,316 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (svcntb() * 8 == 256) {
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
const svuint8_t m4b_1 = svdup_n_u8(0x0f);
|
||||
// 8 accumulators: 2 row pairs × 4 col pairs
|
||||
svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67;
|
||||
uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 };
|
||||
svbool_t pg = svptrue_pat_b32(SV_VL8);
|
||||
svuint32_t idx = svld1(pg, idx_arr);
|
||||
|
||||
static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7};
|
||||
svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data);
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
|
||||
acc_f32_01 = svdup_n_f32(0);
|
||||
acc_f32_23 = svdup_n_f32(0);
|
||||
acc_f32_45 = svdup_n_f32(0);
|
||||
acc_f32_67 = svdup_n_f32(0);
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
// bsums pairs belongs to the same q8_k subblock
|
||||
// 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum
|
||||
const int16x8_t bsums[4]{
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||
};
|
||||
|
||||
int32_t bsums_arr32[4][8];
|
||||
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
int16x8_t v16 = bsums[q8_row];
|
||||
|
||||
// low 4
|
||||
int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16));
|
||||
vst1q_s32(&bsums_arr32[q8_row][0], v32_lo);
|
||||
|
||||
// high 4
|
||||
int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16));
|
||||
vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
|
||||
}
|
||||
|
||||
svint32_t sb_acc_0 = svdup_n_s32(0);
|
||||
svint32_t sb_acc_2 = svdup_n_s32(0);
|
||||
|
||||
svint32_t acc_00 = svdup_n_s32(0);
|
||||
svint32_t acc_11 = svdup_n_s32(0);
|
||||
svint32_t acc_22 = svdup_n_s32(0);
|
||||
svint32_t acc_33 = svdup_n_s32(0);
|
||||
svint32_t acc_44 = svdup_n_s32(0);
|
||||
svint32_t acc_55 = svdup_n_s32(0);
|
||||
svint32_t acc_66 = svdup_n_s32(0);
|
||||
svint32_t acc_77 = svdup_n_s32(0);
|
||||
|
||||
svint32_t bias_acc_00 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_22 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_44 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_66 = svdup_n_s32(0);
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3;
|
||||
svint32_t q4sb_mins_0, q4sb_mins_1;
|
||||
{
|
||||
// 2-superblock I am working on
|
||||
const int offset = sb * 24 + 0 * 12;
|
||||
const uint8_t * scales_in = &q4_ptr[b].scales[offset];
|
||||
|
||||
const int offset1 = sb * 24 + 12;
|
||||
const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1];
|
||||
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
constexpr uint8_t scales_size = 12;
|
||||
|
||||
uint32_t sm[3];
|
||||
memcpy(sm, scales_in, scales_size);
|
||||
|
||||
uint32_t sm1[3];
|
||||
memcpy(sm1, scales_in1, scales_size);
|
||||
|
||||
const uint32_t mins_0_3 = sm[1] & kmask1;
|
||||
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
|
||||
|
||||
const uint32_t mins_0_3_1 = sm1[1] & kmask1;
|
||||
const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4);
|
||||
|
||||
svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7));
|
||||
svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1));
|
||||
|
||||
/* reinterpret u32 → u8 */
|
||||
svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp);
|
||||
svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1);
|
||||
|
||||
/* widen u8 → u16->u32 (lower half only) */
|
||||
svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8));
|
||||
svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1));
|
||||
|
||||
q4sb_mins_0 = svreinterpret_s32_u32(mins_u16);
|
||||
q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1);
|
||||
|
||||
uint32_t scales_u32_0 = sm[0] & kmask1;
|
||||
uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
|
||||
uint32_t scales_u32_2 = sm1[0] & kmask1;
|
||||
uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4);
|
||||
|
||||
svuint32_t S01 = svdup_n_u32(scales_u32_0);
|
||||
svuint32_t S23 = svdup_n_u32(scales_u32_1);
|
||||
svuint32_t R01 = svdup_n_u32(scales_u32_2);
|
||||
svuint32_t R23 = svdup_n_u32(scales_u32_3);
|
||||
|
||||
svint8_t S01_b = svreinterpret_s8_u32(S01);
|
||||
svint8_t S23_b = svreinterpret_s8_u32(S23);
|
||||
svint8_t R01_b = svreinterpret_s8_u32(R01);
|
||||
svint8_t R23_b = svreinterpret_s8_u32(R23);
|
||||
|
||||
svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b)));
|
||||
svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b)));
|
||||
svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b)));
|
||||
svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b)));
|
||||
|
||||
block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx);
|
||||
block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx);
|
||||
block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx);
|
||||
block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx);
|
||||
}
|
||||
|
||||
const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256;
|
||||
|
||||
// Load 32-byte per row pair, 1 subblock each time
|
||||
// predicate for activating higher lanes for 16 int8 elements
|
||||
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
|
||||
// predicate for activating lower lanes for 16 int8 elements
|
||||
const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
|
||||
|
||||
svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
|
||||
svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
|
||||
svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
|
||||
svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
|
||||
|
||||
svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
|
||||
svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
|
||||
svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
|
||||
svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
|
||||
|
||||
// Q4s columns iterated in pairs (01, 23, 45, 67)
|
||||
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
|
||||
|
||||
sb_acc_0 = svdup_n_s32(0);
|
||||
sb_acc_2 = svdup_n_s32(0);
|
||||
|
||||
svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0);
|
||||
svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64);
|
||||
svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128);
|
||||
svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192);
|
||||
|
||||
svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
|
||||
svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
|
||||
svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
|
||||
svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
|
||||
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
|
||||
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
|
||||
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
|
||||
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
|
||||
|
||||
if(cp == 0) {
|
||||
acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
|
||||
acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
|
||||
}
|
||||
if(cp == 1) {
|
||||
acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
|
||||
acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
|
||||
}
|
||||
if(cp == 2) {
|
||||
acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
|
||||
acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
|
||||
}
|
||||
if(cp == 3) {
|
||||
acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
|
||||
acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
|
||||
}
|
||||
}
|
||||
|
||||
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0);
|
||||
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1);
|
||||
|
||||
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0);
|
||||
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1);
|
||||
|
||||
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0);
|
||||
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1);
|
||||
|
||||
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0);
|
||||
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1);
|
||||
} // for sb
|
||||
|
||||
|
||||
acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4));
|
||||
acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4));
|
||||
acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));
|
||||
acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4));
|
||||
acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4));
|
||||
acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4));
|
||||
acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4));
|
||||
acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4));
|
||||
|
||||
svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1);
|
||||
svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1);
|
||||
|
||||
svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1);
|
||||
svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1);
|
||||
|
||||
// Broadcast q8 scalar
|
||||
svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]);
|
||||
|
||||
svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0)));
|
||||
|
||||
svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0)));
|
||||
|
||||
svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1);
|
||||
acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[1]);
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1);
|
||||
acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[2]);
|
||||
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1);
|
||||
acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[3]);
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1);
|
||||
acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1);
|
||||
|
||||
} // for b
|
||||
|
||||
// With the previous reorder, the tile is already in the correct memory layout.
|
||||
// Predicate for exactly 4 lanes
|
||||
svbool_t pg4 = svptrue_pat_b32(SV_VL4);
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
|
||||
if (i == 0 && j == 0) {
|
||||
// acc_f32_0 → lower half of acc_f32_01
|
||||
svst1_f32(pg4, s + offset, acc_f32_01);
|
||||
} else if (i == 0 && j == 1) {
|
||||
// acc_f32_1 → upper half of acc_f32_01
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4));
|
||||
} else if (i == 1 && j == 0) {
|
||||
// acc_f32_2
|
||||
svst1_f32(pg4, s + offset, acc_f32_23);
|
||||
} else if (i == 1 && j == 1) {
|
||||
// acc_f32_3
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4));
|
||||
} else if (i == 2 && j == 0) {
|
||||
// acc_f32_4
|
||||
svst1_f32(pg4, s + offset, acc_f32_45);
|
||||
} else if (i == 2 && j == 1) {
|
||||
// acc_f32_5
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4));
|
||||
} else if (i == 3 && j == 0) {
|
||||
// acc_f32_6
|
||||
svst1_f32(pg4, s + offset, acc_f32_67);
|
||||
} else if (i == 3 && j == 1) {
|
||||
// acc_f32_7
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
}
|
||||
#endif // SVE compile-time end
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
|
||||
@@ -1 +1 @@
|
||||
a8db410a252c8c8f2d120c6f2e7133ebe032f35d
|
||||
d6754f3d0e6d0acd21c12442353c9fd2f94188e7
|
||||
|
||||
Reference in New Issue
Block a user