mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-05 13:53:23 +02:00
Compare commits
23 Commits
b7760
...
gg/scripts
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3754239e43 | ||
|
|
c965abbe6e | ||
|
|
98e9eabbf4 | ||
|
|
f61e6af1cf | ||
|
|
bb58f1e67d | ||
|
|
b7786174b6 | ||
|
|
fc541d0532 | ||
|
|
ce6d66b0c4 | ||
|
|
1e79722596 | ||
|
|
fbccf28275 | ||
|
|
43d9ba7c93 | ||
|
|
c00cd35d92 | ||
|
|
eb55a20d58 | ||
|
|
12fe3d2f34 | ||
|
|
316f043a04 | ||
|
|
b441963b11 | ||
|
|
1dcc180095 | ||
|
|
f3582a6630 | ||
|
|
4a6e59c363 | ||
|
|
979299a32f | ||
|
|
b0d50a5681 | ||
|
|
f3a5b4ea72 | ||
|
|
2357f6f193 |
247
examples/llama-eval/llama-eval-discussion.md
Normal file
247
examples/llama-eval/llama-eval-discussion.md
Normal file
@@ -0,0 +1,247 @@
|
||||
# llama-eval Implementation Discussion
|
||||
|
||||
## Overview
|
||||
Discussion about implementing a lean evaluation tool for llama.cpp based on ggerganov's feedback in PR #18892.
|
||||
|
||||
## Key Requirements from ggerganov
|
||||
|
||||
### 1. Simplify and Focus on One Eval
|
||||
- Start with AIME2025 (most familiar with it)
|
||||
- Don't support multiple evals initially
|
||||
|
||||
### 2. Implement an "eval state" object
|
||||
- ID
|
||||
- List of tasks
|
||||
- Task states
|
||||
- Sampling config
|
||||
|
||||
### 3. Implement a "processor" object
|
||||
- List of endpoints
|
||||
- Threads per endpoint
|
||||
- Grade/judge type (regex, endpoint, or CLI tool)
|
||||
|
||||
### 4. Processor responsibilities
|
||||
- Accepts eval state
|
||||
- Starts processing
|
||||
- Dumps eval state periodically as it progresses
|
||||
|
||||
### 5. Real-time feedback
|
||||
- Default: show "correct / not correct" for each task
|
||||
- Verbose mode: show produced answer vs expected answer as soon as it completes
|
||||
|
||||
### 6. Grading approach
|
||||
- Abstract grading to support external "grader" or "judge"
|
||||
- Use LLM post-processing instead of regex (to avoid issues from GPT-OSS evals)
|
||||
|
||||
### 7. Output format
|
||||
- Use structured output (JSON) instead of boxed text
|
||||
|
||||
## Current Implementation Analysis
|
||||
|
||||
### What exists in llama-eval.py:
|
||||
- Multiple task implementations (AIME, GSM8K, MMLU, HellaSwag, ARC, WinoGrande)
|
||||
- Regex-based answer extraction
|
||||
- HTTP requests to OpenAI-compatible endpoint
|
||||
- Checkpointing/resume capability
|
||||
- Thread-based parallel execution
|
||||
- Summary reporting
|
||||
|
||||
### What needs to be removed:
|
||||
- All task implementations except AIME
|
||||
- Regex-based grading
|
||||
- Multiple endpoint support
|
||||
- Complex task loading logic
|
||||
- Summary reporting (replace with real-time feedback)
|
||||
|
||||
## Discussion Points
|
||||
|
||||
### 1. Eval State Object Structure
|
||||
**Status: Under Discussion**
|
||||
|
||||
Questions:
|
||||
- What fields should be in the eval state object?
|
||||
- Should it include the actual prompts, or just metadata?
|
||||
- How should task states be tracked?
|
||||
|
||||
### 2. Processor Architecture
|
||||
**Status: Not Started**
|
||||
|
||||
Questions:
|
||||
- Should the processor handle multiple endpoints (for distributed evaluation)?
|
||||
- What's the threading model?
|
||||
- How are endpoints configured?
|
||||
|
||||
### 3. Grader Interface
|
||||
**Status: Not Started**
|
||||
|
||||
Questions:
|
||||
- How should the grader be configured?
|
||||
- Should it be a separate service, or a local LLM call?
|
||||
- What's the interface for grading?
|
||||
|
||||
### 4. Checkpointing
|
||||
**Status: Not Started**
|
||||
|
||||
Questions:
|
||||
- Should the eval state be serialized to disk?
|
||||
- How often should it be dumped?
|
||||
- What format should it use?
|
||||
|
||||
### 5. Real-time Output
|
||||
**Status: Not Started**
|
||||
|
||||
Questions:
|
||||
- How should progress be displayed?
|
||||
- Console output, file logging, or both?
|
||||
- What verbosity levels are needed?
|
||||
|
||||
### 6. Output Format
|
||||
**Status: Not Started**
|
||||
|
||||
Questions:
|
||||
- Should responses be in JSON format?
|
||||
- How should the grader interface work with JSON output?
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. **Eval State Object** - Currently discussing
|
||||
2. Processor Architecture
|
||||
3. Grader Interface
|
||||
4. Checkpointing
|
||||
5. Real-time Output
|
||||
6. Output Format
|
||||
|
||||
## References
|
||||
- PR #18892: https://github.com/ggml-org/llama.cpp/pull/18892
|
||||
- Discussion #18195: https://github.com/ggml-org/llama.cpp/discussions/18195
|
||||
|
||||
## Session Work Summary
|
||||
|
||||
### llama-server-simulator Implementation
|
||||
|
||||
**Created:**
|
||||
- `llama-server-simulator.py` - Standalone Python script simulating llama-server HTTP endpoint
|
||||
- `test-simulator.sh` - Test script for verifying simulator functionality
|
||||
- `llama-server-simulator-plan.md` - Implementation plan
|
||||
- `simulator-summary.md` - Summary of implementation
|
||||
|
||||
**Features Implemented:**
|
||||
1. HTTP Server - Flask-based `/v1/chat/completions` endpoint with OpenAI-compatible format
|
||||
2. AIME Dataset Integration - Loads 90 questions from HuggingFace with automatic local caching
|
||||
3. Intelligent Question Matching - Uses exact matching, LaTeX removal, and Levenshtein distance
|
||||
4. Response Generation - Configurable success rate (0-1) for correct/wrong answer generation
|
||||
5. Debug Logging - Helps troubleshoot matching issues
|
||||
|
||||
**Testing Results:**
|
||||
- ✅ Correct answers returned when success rate allows
|
||||
- ✅ Wrong answers returned when success rate doesn't allow
|
||||
- ✅ No matching questions return errors
|
||||
- ✅ Success rate verified (80% in 10 requests)
|
||||
- ✅ HuggingFace dataset caching working correctly
|
||||
|
||||
**Key Technical Decisions:**
|
||||
- Used Levenshtein distance for partial matching (threshold: 0.3)
|
||||
- Automatic caching via HuggingFace datasets library
|
||||
- Wrong answers generated by incrementing expected answer
|
||||
- Debug output written to stderr for better visibility
|
||||
|
||||
**Refactoring:**
|
||||
- Extracted repeating question string into TEST_QUESTION variable
|
||||
- Created make_request() helper function to reduce code duplication
|
||||
- Added proper error handling for error responses
|
||||
- Fixed simulator stopping issue at script completion
|
||||
|
||||
### llama-eval-new.py Implementation
|
||||
|
||||
**Created:**
|
||||
- `llama-eval-new.py` - Simplified evaluation tool focused on AIME
|
||||
|
||||
**Features Implemented:**
|
||||
1. **Eval State Object** - Structured dataclass with ID, tasks, task states, and sampling config
|
||||
2. **Processor Object** - Handles processing, grading, and state management
|
||||
3. **Real-time Feedback** - Shows correct/incorrect status for each case
|
||||
4. **Flexible Grading System** - Supports regex and CLI-based grading
|
||||
5. **Structured JSON Output** - Saves complete eval state to JSON file
|
||||
6. **HuggingFace Dataset Caching** - Uses cached dataset path to avoid HF Hub requests
|
||||
|
||||
**Grading System:**
|
||||
- **Regex Grading**: Built-in patterns for different task types
|
||||
- `aime`: `\boxed{(\d+)}|\b(\d+)\b` (handles boxed and plain text)
|
||||
- `gsm8k`: `\b(\d+)\b` (extract first number)
|
||||
- `mmlu`, `hellaswag`, `arc`, `winogrande`: `[A-D]` (extract single letter)
|
||||
- **CLI Grading**: External script interface
|
||||
- Script accepts `--answer <pred>` and `--expected <gold>`
|
||||
- Returns exit code 0 if correct, non-zero if incorrect
|
||||
- 30-second timeout to prevent hanging
|
||||
|
||||
**Configuration Options:**
|
||||
- `--server`: llama-server URL (default: http://localhost:8033)
|
||||
- `--n_cases`: Number of cases to evaluate (default: all)
|
||||
- `--n_predict`: Max tokens to predict per prompt (default: 2048)
|
||||
- `--threads`: Number of threads for parallel requests (default: 32)
|
||||
- `--verbose`: Show detailed output for each case
|
||||
- `--output`: Output file for eval state (default: llama-eval-state.json)
|
||||
- `--grader-type`: `regex` or `cli`
|
||||
- `--grader-regex-type`: aime, gsm8k, mmlu, hellaswag, arc, winogrande
|
||||
- `--grader-script`: Path to CLI grader script
|
||||
|
||||
**Testing Results:**
|
||||
- ✅ Works with simulator at 100% success rate (all correct)
|
||||
- ✅ Works with simulator at 0% success rate (all incorrect)
|
||||
- ✅ Works with simulator at 80% success rate (8/10 correct)
|
||||
- ✅ Real-time verbose output shows gold/pred/status for each case
|
||||
- ✅ JSON output contains complete eval state with all cases
|
||||
- ✅ HF Hub telemetry disabled (no warnings)
|
||||
- ✅ Uses cached dataset path to avoid HF Hub requests when available
|
||||
|
||||
**Key Technical Decisions:**
|
||||
- Removed Levenshtein matching - eval script only sends requests and validates answers
|
||||
- Abstract grading interface for external grader support
|
||||
- Exact match requirement for regex patterns
|
||||
- Handles both boxed and plain text formats for AIME answers
|
||||
- 30-second timeout for CLI grader
|
||||
- Validates script exists before running
|
||||
|
||||
**Refactoring:**
|
||||
- Removed all task implementations except AIME
|
||||
- Removed regex-based grading (moved to flexible grader system)
|
||||
- Removed multiple endpoint support
|
||||
- Removed complex task loading logic
|
||||
- Removed summary reporting (replaced with real-time feedback)
|
||||
- Added HuggingFace dataset caching optimization
|
||||
|
||||
### llama-eval-new.py Threading and Model Parameter Updates
|
||||
|
||||
**Changes Made:**
|
||||
1. **Threading Support** - Added ThreadPoolExecutor for parallel request processing
|
||||
- Added `from concurrent.futures import ThreadPoolExecutor, as_completed`
|
||||
- Created `_process_single_case()` method for thread-safe case processing
|
||||
- Refactored `process()` to use ThreadPoolExecutor with configurable thread count
|
||||
- Updated progress tracking to work with concurrent execution
|
||||
- Thread-safe eval state updates (task_states and counters)
|
||||
|
||||
2. **Model Parameter** - Added `--model` argument to specify model name in request data
|
||||
- Added `model_name` parameter to Processor.__init__()
|
||||
- Updated `_make_request()` to use provided model name or default to "llama"
|
||||
- Added `--model` argument to argument parser
|
||||
- Model name is included in request JSON as `"model": "gpt-oss-20b-hf"`
|
||||
|
||||
**Testing Results:**
|
||||
- ✅ Works with 2 threads (5 cases processed in ~0.2s)
|
||||
- ✅ Works with 4 threads (slightly faster throughput)
|
||||
- ✅ Model parameter correctly added to request data
|
||||
- ✅ Thread-safe progress tracking with tqdm
|
||||
- ✅ No race conditions in eval state updates
|
||||
|
||||
**Key Technical Decisions:**
|
||||
- Used ThreadPoolExecutor for simple, effective parallelism
|
||||
- No rate limiting needed (server can handle concurrent requests)
|
||||
- Thread-safe counter updates for correct/total tracking
|
||||
- Progress bar shows completion status across all threads
|
||||
- Model parameter is optional - defaults to "llama" if not specified
|
||||
|
||||
**Refactoring:**
|
||||
- Extracted single case processing into `_process_single_case()` method
|
||||
- Changed from sequential loop to ThreadPoolExecutor with futures
|
||||
- Updated verbose output to show total count instead of index
|
||||
- Made eval state updates thread-safe
|
||||
401
examples/llama-eval/llama-eval-new.py
Executable file
401
examples/llama-eval/llama-eval-new.py
Executable file
@@ -0,0 +1,401 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
os.environ["HF_DATASETS_CACHE"] = str(cache_dir)
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
GRADER_PATTERNS = {
|
||||
"aime": r'\boxed{(\d+)}|\b(\d+)\b',
|
||||
"gsm8k": r'\b(\d+)\b',
|
||||
"mmlu": r'[A-D]',
|
||||
"hellaswag": r'[A-D]',
|
||||
"arc": r'[A-D]',
|
||||
"winogrande": r'[A-D]',
|
||||
}
|
||||
|
||||
TEMPLATE_REGISTRY = {
|
||||
"aime": """{question}
|
||||
Please reason step by step, and put your final answer within \\boxed{{}}.
|
||||
""",
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class EvalState:
|
||||
id: str
|
||||
tasks: List[str]
|
||||
task_states: Dict[str, Dict[str, Any]]
|
||||
sampling_config: Dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class TaskState:
|
||||
case_id: str
|
||||
prompt: str
|
||||
gold: str
|
||||
pred: Optional[str] = None
|
||||
correct: bool = False
|
||||
status: str = "pending"
|
||||
|
||||
def normalize_number(s: str) -> Optional[int]:
|
||||
match = re.match(r"\d+", s) # match digits from the start
|
||||
if not match:
|
||||
return None
|
||||
return int(match.group(0))
|
||||
|
||||
class AimeDataset:
|
||||
def __init__(self, split: str = "train"):
|
||||
self.split = split
|
||||
self.questions: List[Dict] = []
|
||||
self._load_dataset()
|
||||
|
||||
def _load_dataset(self):
|
||||
print(f"Loading AIME dataset (split: {self.split})...")
|
||||
from datasets import load_dataset
|
||||
|
||||
cache_path = cache_dir / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
|
||||
if cache_path.exists():
|
||||
print(f"Using cached dataset from {cache_path}")
|
||||
ds = load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
|
||||
else:
|
||||
ds = load_dataset("AI-MO/aimo-validation-aime", split=self.split)
|
||||
|
||||
self.questions = []
|
||||
for row in ds:
|
||||
question = dict(row)
|
||||
question["dataset_type"] = "aime"
|
||||
self.questions.append(question)
|
||||
|
||||
print(f"AIME dataset loaded: {len(self.questions)} questions")
|
||||
|
||||
def get_question(self, index: int) -> Dict:
|
||||
"""Get question by index"""
|
||||
return self.questions[index]
|
||||
|
||||
def get_answer(self, question: Dict) -> str:
|
||||
answer = question["answer"]
|
||||
if isinstance(answer, str):
|
||||
normalized = normalize_number(answer)
|
||||
return str(normalized) if normalized is not None else answer
|
||||
return str(answer)
|
||||
|
||||
class Grader:
|
||||
def __init__(
|
||||
self,
|
||||
grader_type: str = "regex",
|
||||
grader_regex_type: str = "aime",
|
||||
grader_script: Optional[str] = None
|
||||
):
|
||||
self.grader_type = grader_type
|
||||
self.grader_regex_type = grader_regex_type
|
||||
self.grader_script = grader_script
|
||||
self.pattern = self._get_pattern()
|
||||
|
||||
def _get_pattern(self) -> str:
|
||||
if self.grader_type == "regex":
|
||||
if self.grader_regex_type not in GRADER_PATTERNS:
|
||||
raise ValueError(f"Unknown grader regex type: {self.grader_regex_type}")
|
||||
return GRADER_PATTERNS[self.grader_regex_type]
|
||||
return None
|
||||
|
||||
def _grade_regex(self, gold: str, pred: str) -> bool:
|
||||
"""Grade using regex pattern matching"""
|
||||
matches = re.findall(self.pattern, pred, re.IGNORECASE)
|
||||
if not matches:
|
||||
return False
|
||||
|
||||
for match in matches:
|
||||
if isinstance(match, tuple):
|
||||
match = match[0] if match[0] else match[1]
|
||||
if match.strip() == gold.strip():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _grade_cli(self, gold: str, pred: str) -> bool:
|
||||
"""Grade using external CLI script"""
|
||||
if not self.grader_script:
|
||||
raise ValueError("CLI grader requires --grader-script")
|
||||
|
||||
script_path = Path(self.grader_script)
|
||||
if not script_path.exists():
|
||||
raise FileNotFoundError(f"Grader script not found: {self.grader_script}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[str(script_path), "--answer", pred, "--expected", gold],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
return result.returncode == 0
|
||||
except subprocess.TimeoutExpired:
|
||||
return False
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
def grade(self, gold: str, pred: str) -> bool:
|
||||
"""Grade the response"""
|
||||
if self.grader_type == "regex":
|
||||
return self._grade_regex(gold, pred)
|
||||
elif self.grader_type == "cli":
|
||||
return self._grade_cli(gold, pred)
|
||||
else:
|
||||
raise ValueError(f"Unknown grader type: {self.grader_type}")
|
||||
|
||||
class Processor:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
n_predict: int = 2048,
|
||||
threads: int = 32,
|
||||
verbose: bool = False,
|
||||
grader: Optional[Grader] = None,
|
||||
model_name: Optional[str] = None
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.n_predict = n_predict
|
||||
self.threads = threads
|
||||
self.verbose = verbose
|
||||
self.model_name = model_name
|
||||
self.dataset = AimeDataset()
|
||||
self.grader = grader or Grader()
|
||||
self.eval_state = EvalState(
|
||||
id="aime-2025",
|
||||
tasks=["aime"],
|
||||
task_states={},
|
||||
sampling_config={"temperature": 0, "max_tokens": n_predict}
|
||||
)
|
||||
|
||||
def _make_request(self, prompt: str) -> Dict[str, Any]:
|
||||
"""Make HTTP request to the server"""
|
||||
url = f"{self.server_url}/v1/chat/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {
|
||||
"model": self.model_name if self.model_name else "llama",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0,
|
||||
"max_tokens": self.n_predict
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _process_single_case(self, i: int, task_id: str) -> TaskState:
|
||||
"""Process a single case (thread-safe)"""
|
||||
question = self.dataset.get_question(i)
|
||||
dataset_id = f"aime_{self.dataset.split}_{question['id']}"
|
||||
gold = self.dataset.get_answer(question)
|
||||
|
||||
# Apply template if available
|
||||
if question["dataset_type"] in TEMPLATE_REGISTRY:
|
||||
prompt = TEMPLATE_REGISTRY[question["dataset_type"]].format(question=question["problem"])
|
||||
else:
|
||||
prompt = question["problem"]
|
||||
|
||||
task_state = TaskState(
|
||||
case_id=task_id,
|
||||
prompt=prompt,
|
||||
gold=gold
|
||||
)
|
||||
|
||||
try:
|
||||
response = self._make_request(prompt)
|
||||
pred = response["choices"][0]["message"]["content"]
|
||||
task_state.pred = pred
|
||||
task_state.correct = self.grader.grade(gold, pred)
|
||||
task_state.status = "ok"
|
||||
except Exception as e:
|
||||
task_state.status = f"error: {str(e)}"
|
||||
|
||||
return task_state
|
||||
|
||||
def process(self, n_cases: int = None, seed: int = 1234):
|
||||
"""Process cases and update eval state"""
|
||||
if n_cases is None:
|
||||
n_cases = len(self.dataset.questions)
|
||||
|
||||
print(f"\nProcessing {n_cases} AIME questions...")
|
||||
print(f"Server: {self.server_url}")
|
||||
print(f"Threads: {self.threads}")
|
||||
print(f"Max tokens: {self.n_predict}")
|
||||
print()
|
||||
|
||||
dataset_size = len(self.dataset.questions)
|
||||
random.seed(seed)
|
||||
|
||||
task_list = []
|
||||
for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size):
|
||||
chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size)
|
||||
indices = list(range(dataset_size))
|
||||
random.shuffle(indices)
|
||||
chunk_indices = indices[:chunk_size]
|
||||
|
||||
for i in chunk_indices:
|
||||
task_id = f"aime_{self.eval_state.id}_{chunk_idx:03d}_{i:03d}"
|
||||
task_list.append((i, task_id))
|
||||
|
||||
# Print task summary table
|
||||
print("Tasks:")
|
||||
print(" Task ID Dataset Prompt (first 40 chars) Expected Status")
|
||||
for i, task_id in task_list:
|
||||
question = self.dataset.get_question(i)
|
||||
prompt = question["problem"]
|
||||
gold = self.dataset.get_answer(question)
|
||||
truncated_prompt = prompt[:40] + "..." if len(prompt) > 40 else prompt
|
||||
print(f" {task_id:<15} AIME2025 {truncated_prompt:<40} {gold:<10} pending")
|
||||
print()
|
||||
|
||||
task_states: Dict[str, List[TaskState]] = {task: [] for task in self.eval_state.tasks}
|
||||
total = 0
|
||||
correct = 0
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.threads) as executor:
|
||||
futures = {executor.submit(self._process_single_case, i, task_id): (i, task_id) for i, task_id in task_list}
|
||||
|
||||
for future in as_completed(futures):
|
||||
task_state = future.result()
|
||||
task_states["aime"].append(task_state)
|
||||
total += 1
|
||||
|
||||
if task_state.correct:
|
||||
correct += 1
|
||||
|
||||
# Print task completion status
|
||||
pred_display = task_state.pred if task_state.pred else "N/A"
|
||||
success_ratio = correct / total if total > 0 else 0.0
|
||||
print(f"{total:3}/{n_cases:3} {task_state.case_id:<15} AIME2025 {task_state.prompt[:50]:<50} {task_state.gold:<10} {pred_display:<10} {'✓' if task_state.correct else '✗'} [{correct:3}/{total:3}, {success_ratio:.3f}]")
|
||||
|
||||
if self.verbose:
|
||||
print(f"\nCase {total}: {task_state.correct}")
|
||||
print(f" Gold: {task_state.gold}")
|
||||
if task_state.pred:
|
||||
print(f" Pred: {task_state.pred}")
|
||||
print(f" Status: {task_state.status}")
|
||||
|
||||
self.eval_state.task_states["aime"] = {
|
||||
"total": total,
|
||||
"correct": correct,
|
||||
"cases": task_states
|
||||
}
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
return self.eval_state
|
||||
|
||||
def dump_state(self, output_file: Path):
|
||||
"""Dump eval state to JSON file"""
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(asdict(self.eval_state), f, indent=2)
|
||||
print(f"\nEval state dumped to {output_file}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Simplified AIME evaluation tool for llama.cpp"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server",
|
||||
type=str,
|
||||
default="http://localhost:8033",
|
||||
help="llama-server URL (default: http://localhost:8033)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_cases",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of cases to evaluate (default: all)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=1234,
|
||||
help="Random seed for shuffling (default: 1234)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_predict",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Max tokens to predict per prompt (default: 2048)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threads",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Number of threads for parallel requests (default: 32)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model name to append as query parameter (e.g., gpt-oss-20b-hf)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Show detailed output for each case"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
default=Path("llama-eval-state.json"),
|
||||
help="Output file for eval state (default: llama-eval-state.json)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grader-type",
|
||||
type=str,
|
||||
default="regex",
|
||||
choices=["regex", "cli"],
|
||||
help="Grader type: regex or cli (default: regex)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grader-regex-type",
|
||||
type=str,
|
||||
default="aime",
|
||||
choices=list(GRADER_PATTERNS.keys()),
|
||||
help="Regex grader type (default: aime)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grader-script",
|
||||
type=str,
|
||||
default=None,
|
||||
help="CLI grader script path (required for --grader-type cli)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
grader = Grader(
|
||||
grader_type=args.grader_type,
|
||||
grader_regex_type=args.grader_regex_type,
|
||||
grader_script=args.grader_script
|
||||
)
|
||||
|
||||
processor = Processor(
|
||||
server_url=args.server,
|
||||
n_predict=args.n_predict,
|
||||
threads=args.threads,
|
||||
verbose=args.verbose,
|
||||
grader=grader,
|
||||
model_name=args.model
|
||||
)
|
||||
|
||||
eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)
|
||||
processor.dump_state(args.output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
703
examples/llama-eval/llama-eval.py
Normal file
703
examples/llama-eval/llama-eval.py
Normal file
@@ -0,0 +1,703 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import re
|
||||
import argparse
|
||||
import os
|
||||
from time import time
|
||||
from typing import Union, Any, Mapping, cast
|
||||
|
||||
import datasets
|
||||
import logging
|
||||
import requests
|
||||
from tqdm.contrib.concurrent import thread_map
|
||||
from typing import Iterator, Set
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import json
|
||||
import threading
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logger = logging.getLogger("llama-eval")
|
||||
|
||||
MATH_TEMPLATE = """
|
||||
{question}
|
||||
Do not include any explanation. Put your final answer within \\boxed{{}}.
|
||||
"""
|
||||
|
||||
|
||||
def format_multiple_choice(prompt: str, choices: list[str]):
|
||||
lines = [prompt]
|
||||
|
||||
labels = [chr(ord("A") + i) for i in range(len(choices))]
|
||||
for l, c in zip(labels, choices):
|
||||
lines.append(f"({l}): {c.strip()}")
|
||||
lines.append(
|
||||
"Do not include any explanation. Answer with the corresponding option letter only"
|
||||
)
|
||||
lines.append(", ".join(labels))
|
||||
lines.append("Put your final answer within \\boxed{{}}.")
|
||||
|
||||
return "\n".join(lines), labels
|
||||
|
||||
|
||||
def extract_boxed_text(text: str) -> str:
|
||||
pattern = r"boxed{(.*?)}|framebox{(.*?)}"
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
logger.debug(matches)
|
||||
if matches:
|
||||
for match in matches[::-1]:
|
||||
for group in match:
|
||||
if group != "":
|
||||
return group.split(",")[-1].strip()
|
||||
logger.debug("Could not extract boxed text. Maybe expand context window")
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Case:
|
||||
task: str
|
||||
kind: str
|
||||
case_id: str
|
||||
prompt: str
|
||||
gold: str
|
||||
meta_data: dict[str, Any]
|
||||
|
||||
|
||||
class TaskSpec(ABC):
|
||||
name: str
|
||||
kind: str
|
||||
|
||||
@abstractmethod
|
||||
def load(self, limit, seed) -> datasets.Dataset:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def grade(case: Case, response: dict) -> dict[str, Any]:
|
||||
pass
|
||||
|
||||
|
||||
class MCTaskSpec(TaskSpec):
|
||||
@staticmethod
|
||||
def grade(case: Case, response: dict) -> dict[str, Any]:
|
||||
logger.debug(f"response {response}")
|
||||
result = {
|
||||
"task": case.task,
|
||||
"case_id": case.case_id,
|
||||
"correct": 0,
|
||||
"pred": None,
|
||||
"gold": case.gold,
|
||||
"status": "ok",
|
||||
}
|
||||
|
||||
try:
|
||||
extracted_answer = extract_boxed_text(response["choices"][0]["text"])
|
||||
except Exception as e:
|
||||
result["status"] = "error"
|
||||
logger.warning("ERROR: extract_boxed_text")
|
||||
|
||||
return result
|
||||
|
||||
if not extracted_answer:
|
||||
result["status"] = "invalid"
|
||||
logger.warning("INVALID: extract_boxed_text")
|
||||
return result
|
||||
|
||||
logger.debug(f"extracted_answer {extracted_answer}")
|
||||
logger.debug(f"data['answer'] {case.gold}")
|
||||
result["pred"] = extracted_answer
|
||||
result["correct"] = 1 if extracted_answer == case.gold else 0
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class MathTaskSpec(TaskSpec):
|
||||
|
||||
@staticmethod
|
||||
def grade(case: Case, response: dict) -> dict[str, Any]:
|
||||
logger.debug(f"response {response}")
|
||||
result = {
|
||||
"task": case.task,
|
||||
"case_id": case.case_id,
|
||||
"correct": 0,
|
||||
"gold": case.gold,
|
||||
"status": "ok",
|
||||
"pred": None,
|
||||
}
|
||||
|
||||
try:
|
||||
extracted_answer = extract_boxed_text(response["choices"][0]["text"])
|
||||
except:
|
||||
result["status"] = "error"
|
||||
logger.warning("ERROR: extract_boxed_text")
|
||||
return result
|
||||
|
||||
source_answer = case.gold
|
||||
try: # All AIME answers are integers, so we convert the extracted answer to an integer
|
||||
extracted_answer = int(extracted_answer)
|
||||
source_answer = int(case.gold)
|
||||
except (ValueError, TypeError):
|
||||
result["status"] = "invalid"
|
||||
return result
|
||||
|
||||
logger.debug(f"extracted_answer {extracted_answer}")
|
||||
logger.debug(f"data['answer'] {case.gold}")
|
||||
result["pred"] = extracted_answer
|
||||
result["correct"] = 1 if extracted_answer == source_answer else 0
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ARC_Task(MCTaskSpec):
|
||||
|
||||
def __init__(self):
|
||||
self.name = "arc"
|
||||
self.kind = "mc"
|
||||
self.config = "ARC-Challenge"
|
||||
self.split = "test"
|
||||
|
||||
def load(self, limit, seed) -> datasets.Dataset:
|
||||
ds = datasets.load_dataset("allenai/ai2_arc", self.config, split=self.split)
|
||||
ds = ds.add_column("_row_id", list(range(len(ds))))
|
||||
if limit:
|
||||
ds = ds.shuffle(seed=seed)
|
||||
ds = ds.select(range(min(limit, len(ds))))
|
||||
return ds
|
||||
|
||||
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||
ds = self.load(limit, seed)
|
||||
|
||||
for doc in ds:
|
||||
doc = cast(Mapping[str, Any], doc)
|
||||
|
||||
prompt, labels = format_multiple_choice(
|
||||
doc["question"], doc["choices"]["text"]
|
||||
)
|
||||
yield Case(
|
||||
task=self.name,
|
||||
kind=self.kind,
|
||||
case_id=f"ARC-Challenge_{self.config}_{self.split}_{doc['_row_id']}",
|
||||
prompt=prompt,
|
||||
gold=doc["answerKey"],
|
||||
meta_data={"labels": labels},
|
||||
)
|
||||
|
||||
|
||||
class WinoGrande_Task(MCTaskSpec):
|
||||
|
||||
def __init__(self):
|
||||
self.name = "winogrande"
|
||||
self.kind = "mc"
|
||||
self.config = "winogrande_debiased"
|
||||
self.split = "validation"
|
||||
|
||||
def load(self, limit, seed) -> datasets.Dataset:
|
||||
ds = datasets.load_dataset("winogrande", self.config, split=self.split)
|
||||
|
||||
ds = ds.add_column("_row_id", list(range(len(ds))))
|
||||
if limit:
|
||||
ds = ds.shuffle(seed=seed)
|
||||
ds = ds.select(range(min(limit, len(ds))))
|
||||
return ds
|
||||
|
||||
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||
ds = self.load(limit, seed)
|
||||
|
||||
for doc in ds:
|
||||
doc = cast(Mapping[str, Any], doc)
|
||||
|
||||
prompt, labels = format_multiple_choice(
|
||||
doc["sentence"], [doc["option1"], doc["option2"]]
|
||||
)
|
||||
yield Case(
|
||||
task=self.name,
|
||||
kind=self.kind,
|
||||
case_id=f"winogrande_{self.config}_{self.split}_{doc['_row_id']}",
|
||||
prompt=prompt,
|
||||
gold=labels[int(doc["answer"]) - 1], # winogrande answers are 1 based
|
||||
meta_data={"labels": labels},
|
||||
)
|
||||
|
||||
|
||||
class MMLU_Task(MCTaskSpec):
|
||||
|
||||
def __init__(self):
|
||||
self.name = "mmlu"
|
||||
self.kind = "mc"
|
||||
self.config = "all"
|
||||
self.split = "test"
|
||||
|
||||
def load(self, limit, seed) -> datasets.Dataset:
|
||||
ds = datasets.load_dataset("cais/mmlu", self.config, split=self.split)
|
||||
ds = ds.add_column("_row_id", list(range(len(ds))))
|
||||
if limit:
|
||||
ds = ds.shuffle(seed=seed)
|
||||
ds = ds.select(range(min(limit, len(ds))))
|
||||
return ds
|
||||
|
||||
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||
ds = self.load(limit, seed)
|
||||
|
||||
for doc in ds:
|
||||
doc = cast(Mapping[str, Any], doc)
|
||||
|
||||
prompt, labels = format_multiple_choice(doc["question"], doc["choices"])
|
||||
yield Case(
|
||||
task=self.name,
|
||||
kind=self.kind,
|
||||
case_id=f"mmlu_{self.config}_{self.split}_{doc['subject']}_{doc['_row_id']}",
|
||||
prompt=prompt,
|
||||
gold=labels[int(doc["answer"])],
|
||||
meta_data={"subject": doc["subject"], "labels": labels},
|
||||
)
|
||||
|
||||
|
||||
class Hellaswag_Task(MCTaskSpec):
|
||||
|
||||
# Preprocess hellaswag
|
||||
@staticmethod
|
||||
def preprocess(text: str):
|
||||
text = text.strip()
|
||||
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
|
||||
text = text.replace(" [title]", ". ")
|
||||
text = re.sub("\\[.*?\\]", "", text)
|
||||
text = text.replace(" ", " ")
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def hellaswag_process_doc(doc: dict[str, str]):
|
||||
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
|
||||
question = Hellaswag_Task.preprocess(doc["activity_label"] + ": " + ctx)
|
||||
proc_answers = [Hellaswag_Task.preprocess(answer) for answer in doc["endings"]]
|
||||
prompt, labels = format_multiple_choice(question, proc_answers)
|
||||
out_doc = {
|
||||
"prompt": prompt,
|
||||
"gold": labels[int(doc["label"])],
|
||||
}
|
||||
return out_doc
|
||||
|
||||
def __init__(self):
|
||||
self.name = "hellaswag"
|
||||
self.kind = "mc"
|
||||
|
||||
def load(self, limit, seed) -> datasets.Dataset:
|
||||
ds = datasets.load_dataset("Rowan/hellaswag", split="validation")
|
||||
if limit:
|
||||
ds = ds.shuffle(seed=seed)
|
||||
ds = ds.select(range(min(limit, len(ds))))
|
||||
ds = ds.map(Hellaswag_Task.hellaswag_process_doc)
|
||||
|
||||
return ds
|
||||
|
||||
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||
ds = self.load(limit, seed)
|
||||
for doc in ds:
|
||||
doc = cast(Mapping[str, Any], doc)
|
||||
yield Case(
|
||||
task=self.name,
|
||||
kind=self.kind,
|
||||
case_id=f"hellaswag_{doc['split']}_{doc['ind']}",
|
||||
prompt=doc["prompt"],
|
||||
gold=doc["gold"],
|
||||
meta_data={},
|
||||
)
|
||||
|
||||
|
||||
class Aime_Task(MathTaskSpec):
|
||||
|
||||
def __init__(self):
|
||||
self.name = "aime"
|
||||
self.kind = "math"
|
||||
self.split = "train"
|
||||
|
||||
def load(self, limit, seed) -> datasets.Dataset:
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
|
||||
|
||||
if limit:
|
||||
ds = ds.shuffle(seed=seed)
|
||||
ds = ds.select(range(min(limit, len(ds))))
|
||||
|
||||
ds = ds.map(
|
||||
lambda ex: {
|
||||
"prompt": MATH_TEMPLATE.format(
|
||||
question=ex["problem"],
|
||||
)
|
||||
}
|
||||
)
|
||||
return ds
|
||||
|
||||
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||
ds = self.load(limit, seed)
|
||||
|
||||
for i, doc in enumerate(ds):
|
||||
doc = cast(Mapping[str, Any], doc)
|
||||
yield Case(
|
||||
task=self.name,
|
||||
kind=self.kind,
|
||||
case_id=f"aime_{self.split}_{doc['id']}",
|
||||
prompt=doc["prompt"],
|
||||
gold=doc["answer"],
|
||||
meta_data={"id": doc["id"]},
|
||||
)
|
||||
|
||||
|
||||
class Gsm8k_Task(MathTaskSpec):
|
||||
|
||||
def __init__(self):
|
||||
self.name = "gsm8k"
|
||||
self.kind = "math"
|
||||
self.config = "main"
|
||||
self.split = "test"
|
||||
|
||||
def load(self, limit, seed) -> datasets.Dataset:
|
||||
ds = datasets.load_dataset("openai/gsm8k", self.config, split=self.split)
|
||||
ds = ds.add_column("_row_id", list(range(len(ds))))
|
||||
if limit:
|
||||
ds = ds.shuffle(seed=seed)
|
||||
ds = ds.select(range(min(limit, len(ds))))
|
||||
|
||||
ds = ds.map(
|
||||
lambda k: {
|
||||
"prompt": MATH_TEMPLATE.format(
|
||||
question=k["question"],
|
||||
),
|
||||
"gold": k["answer"].split("### ")[-1].rstrip(),
|
||||
}
|
||||
)
|
||||
return ds
|
||||
|
||||
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||
ds = self.load(limit, seed)
|
||||
|
||||
for doc in ds:
|
||||
doc = cast(Mapping[str, Any], doc)
|
||||
yield Case(
|
||||
task=self.name,
|
||||
kind=self.kind,
|
||||
case_id=f"gsm8k_{self.config}_{self.split}:{doc['_row_id']}",
|
||||
prompt=doc["prompt"],
|
||||
gold=doc["gold"],
|
||||
meta_data={},
|
||||
)
|
||||
|
||||
|
||||
TASK_DICT: dict[str, type[TaskSpec]] = {
|
||||
"mmlu": MMLU_Task,
|
||||
"aime": Aime_Task,
|
||||
"gsm8k": Gsm8k_Task,
|
||||
"hellaswag": Hellaswag_Task,
|
||||
"arc": ARC_Task,
|
||||
"winogrande": WinoGrande_Task,
|
||||
}
|
||||
|
||||
|
||||
def build_request(case: Case, n_predict: int) -> dict[str, Any]:
|
||||
json_data = {
|
||||
"n_predict": n_predict,
|
||||
"max_tokens": n_predict,
|
||||
"temperature": 0,
|
||||
"prompt": case.prompt,
|
||||
}
|
||||
return json_data
|
||||
|
||||
|
||||
def write_checkpoint_line(
|
||||
checkpoint_file: Path,
|
||||
row: dict[str, Any],
|
||||
file_lock: threading.Lock,
|
||||
):
|
||||
with file_lock:
|
||||
with checkpoint_file.open(mode="a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(row) + "\n")
|
||||
|
||||
|
||||
def send_prompt(
|
||||
case: Case,
|
||||
data: dict,
|
||||
) -> dict[str, Union[str, int]]:
|
||||
result = {
|
||||
"task": case.task,
|
||||
"case_id": case.case_id,
|
||||
"status": "error",
|
||||
"correct": 0,
|
||||
"gold": case.gold,
|
||||
"pred": "",
|
||||
"error": "",
|
||||
}
|
||||
session: requests.Session = data["session"]
|
||||
server_address: str = data["server_address"]
|
||||
task = TASK_DICT.get(case.task)
|
||||
if task is None:
|
||||
result["error"] = f"unknown_task: {case.task}"
|
||||
return result
|
||||
logger.debug(case.prompt)
|
||||
|
||||
json_data = build_request(case, data["n_predict"])
|
||||
res_json = {}
|
||||
try:
|
||||
response = session.post(f"{server_address}/v1/completions", json=json_data)
|
||||
res_json = response.json()
|
||||
result["status"] = "ok"
|
||||
except Exception as e:
|
||||
result["error"] = f"http_exception: {e}"
|
||||
logger.warning(result["error"])
|
||||
|
||||
if result["status"] == "ok":
|
||||
result = TASK_DICT[case.task].grade(case, res_json)
|
||||
|
||||
write_checkpoint_line(
|
||||
data["checkpoint_file"],
|
||||
result.copy(),
|
||||
data["file_lock"],
|
||||
)
|
||||
return result
|
||||
|
||||
def aggregate_by_task(results: list[dict[str, Any]]) -> dict[str, dict[str, int]]:
|
||||
tmp = {
|
||||
"total": 0,
|
||||
"error": 0,
|
||||
"invalid": 0,
|
||||
"correct": 0,
|
||||
}
|
||||
agg: dict[str, dict[str, int]] = {}
|
||||
for row in results:
|
||||
d = agg.get(row["task"], tmp.copy())
|
||||
d["total"] += 1
|
||||
status = row["status"]
|
||||
if status == "ok":
|
||||
d["correct"] += row["correct"]
|
||||
elif status == "invalid":
|
||||
d["invalid"] += 1
|
||||
elif status == "error":
|
||||
d["error"] += 1
|
||||
|
||||
agg[row["task"]] = d
|
||||
return agg
|
||||
|
||||
|
||||
def print_summary(pertask_results: dict[str, dict[str, int]]):
|
||||
print("\n=== llama-eval suite summary ===")
|
||||
print(
|
||||
f"{'Task':<15} {'Acc':>8} {'Correct':>8} {'Total':>8} {'Invalid':>8} {'Error':>8}"
|
||||
)
|
||||
print("-" * 65)
|
||||
|
||||
suite_total = 0
|
||||
suite_correct = 0
|
||||
|
||||
for task in sorted(pertask_results.keys()):
|
||||
stats = pertask_results[task]
|
||||
total = stats["total"]
|
||||
correct = stats["correct"]
|
||||
invalid = stats["invalid"]
|
||||
error = stats["error"]
|
||||
|
||||
acc = (correct / total) if total > 0 else 0.0
|
||||
|
||||
print(
|
||||
f"{task:<15} "
|
||||
f"{acc:8.3f} "
|
||||
f"{correct:8d} "
|
||||
f"{total:8d} "
|
||||
f"{invalid:8d} "
|
||||
f"{error:8d}"
|
||||
)
|
||||
|
||||
suite_total += total
|
||||
suite_correct += correct
|
||||
|
||||
# Overall summary
|
||||
print("-" * 65)
|
||||
suite_acc = (suite_correct / suite_total) if suite_total > 0 else 0.0
|
||||
print(
|
||||
f"{'ALL':<15} " f"{suite_acc:8.3f} " f"{suite_correct:8d} " f"{suite_total:8d}"
|
||||
)
|
||||
|
||||
|
||||
def read_checkpoint(
|
||||
checkpoint_file: Path, resume_flag: bool
|
||||
) -> tuple[Set[str], Set[str], list[dict[str, Any]]]:
|
||||
done = set()
|
||||
errored = set()
|
||||
results = []
|
||||
if not resume_flag or not checkpoint_file.is_file():
|
||||
return done, errored, results
|
||||
|
||||
with checkpoint_file.open(mode="r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
row = json.loads(line)
|
||||
except Exception as e:
|
||||
logger.warning(f"WARNING: malformed checkpoint line {line}\n{e}")
|
||||
continue
|
||||
|
||||
case_id = row.get("case_id")
|
||||
if not case_id:
|
||||
continue
|
||||
|
||||
if row["status"] == "error":
|
||||
errored.add(case_id)
|
||||
else:
|
||||
done.add(case_id)
|
||||
results.append(row)
|
||||
errored -= done
|
||||
return done, errored, results
|
||||
|
||||
|
||||
def benchmark(
|
||||
path_server: str,
|
||||
prompt_source: str,
|
||||
n_prompts: int,
|
||||
n_predict: int,
|
||||
rng_seed: int,
|
||||
resume_flag: bool,
|
||||
checkpoint_file: Path,
|
||||
log_level: int,
|
||||
):
|
||||
logger.setLevel(log_level)
|
||||
done, errored, checkpoint_results = read_checkpoint(checkpoint_file, resume_flag)
|
||||
|
||||
if not path_server.startswith("http://") and not path_server.startswith("https://"):
|
||||
logger.error("ERROR: malformed server path")
|
||||
return
|
||||
|
||||
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
|
||||
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
|
||||
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
|
||||
|
||||
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
|
||||
|
||||
task_queue: set[TaskSpec] = set()
|
||||
for src in prompt_source.split(","):
|
||||
if src == "all":
|
||||
for v in TASK_DICT.values():
|
||||
task_queue.add(v())
|
||||
break
|
||||
task_queue.add(TASK_DICT[src]())
|
||||
|
||||
session = None
|
||||
try:
|
||||
server_address: str = path_server
|
||||
|
||||
adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
|
||||
session = requests.Session()
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
file_lock = threading.Lock()
|
||||
cases: list[Case] = []
|
||||
data: list[dict] = []
|
||||
for task in task_queue:
|
||||
for case in task.iter_cases(n_prompts, rng_seed):
|
||||
if case.case_id in done or case.case_id in errored:
|
||||
logger.debug(f"Skipping case_id {case.case_id} from checkpoint")
|
||||
continue
|
||||
|
||||
cases.append(case)
|
||||
data.append(
|
||||
{
|
||||
"prompt_source": prompt_source,
|
||||
"session": session,
|
||||
"server_address": server_address,
|
||||
"n_predict": n_predict,
|
||||
"file_lock": file_lock,
|
||||
"checkpoint_file": checkpoint_file,
|
||||
}
|
||||
)
|
||||
logger.info("Starting the benchmark...\n")
|
||||
t0 = time()
|
||||
results: list[dict[str, Union[str, int]]] = thread_map(
|
||||
send_prompt,
|
||||
cases,
|
||||
data,
|
||||
max_workers=parallel,
|
||||
chunksize=1,
|
||||
)
|
||||
finally:
|
||||
if session is not None:
|
||||
session.close()
|
||||
|
||||
t1 = time()
|
||||
logger.info(f"\nllama-eval duration: {t1-t0:.2f} s")
|
||||
results.extend(checkpoint_results)
|
||||
pertask_results = aggregate_by_task(results)
|
||||
print_summary(pertask_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
|
||||
"Results are printed to console and visualized as plots (saved to current working directory). "
|
||||
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). "
|
||||
"The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
|
||||
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--path_server",
|
||||
type=str,
|
||||
default="http://localhost:8033",
|
||||
help="llama-server url",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt_source",
|
||||
type=str,
|
||||
default="mmlu",
|
||||
help=f"Eval types supported: all,{list(TASK_DICT.keys())}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_prompts", type=int, default=None, help="Number of prompts to evaluate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rng_seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Number to see rng (Used to select prompts from datasource)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_predict",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Max. number of tokens to predict per prompt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
dest="resume_flag",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Enable resuming from last state stored in checkpoint file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-resume",
|
||||
dest="resume_flag",
|
||||
action="store_false",
|
||||
help="Disble resuming from last state stored in checkpoint file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
type=Path,
|
||||
dest="checkpoint_file",
|
||||
default="./llama-eval-checkpoint.jsonl",
|
||||
help="Checkpoint file to read last state from",
|
||||
)
|
||||
parser.set_defaults(log_level=logging.INFO)
|
||||
parser.add_argument(
|
||||
"--quiet", action="store_const", dest="log_level", const=logging.ERROR
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_const",
|
||||
default=True,
|
||||
dest="log_level",
|
||||
const=logging.DEBUG,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
benchmark(**vars(args))
|
||||
184
examples/llama-eval/llama-server-simulator-plan.md
Normal file
184
examples/llama-eval/llama-server-simulator-plan.md
Normal file
@@ -0,0 +1,184 @@
|
||||
# llama-server-simulator Implementation Plan
|
||||
|
||||
## Overview
|
||||
Create a standalone Python script that simulates a llama-server HTTP endpoint for testing the eval script.
|
||||
|
||||
## Goals
|
||||
1. Simulate llama-server's `/v1/chat/completions` endpoint
|
||||
2. Accept requests and respond with expected answers from AIME dataset
|
||||
3. Implement configurable success rate (sometimes right, sometimes wrong)
|
||||
4. Use regex matching to find questions in incoming requests
|
||||
5. Test with curl requests before integrating with eval script
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
### Phase 1: Basic Simulator Structure
|
||||
- Create `llama-server-simulator.py` script
|
||||
- Set up Flask/FastAPI HTTP server
|
||||
- Implement `/v1/chat/completions` endpoint
|
||||
- Handle basic request/response format
|
||||
|
||||
### Phase 2: AIME Dataset Integration
|
||||
- Load AIME dataset
|
||||
- Store questions and expected answers
|
||||
- Implement regex matching to find questions in incoming requests
|
||||
- Extract expected answer from matched question
|
||||
|
||||
### Phase 3: Response Generation
|
||||
- Implement success rate configuration
|
||||
- Randomly determine if response should be correct or incorrect
|
||||
- Generate appropriate response based on success determination
|
||||
- Format response in OpenAI-compatible format
|
||||
|
||||
### Phase 4: Testing
|
||||
- Write curl commands to test basic functionality
|
||||
- Test correct responses
|
||||
- Test incorrect responses
|
||||
- Test edge cases (no question found, etc.)
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Server Framework
|
||||
- Use Flask for simplicity
|
||||
- Listen on configurable port
|
||||
- Support JSON request/response format
|
||||
|
||||
### Request Format
|
||||
```json
|
||||
{
|
||||
"model": "llama",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Question text here"}
|
||||
],
|
||||
"temperature": 0,
|
||||
"max_tokens": 2048
|
||||
}
|
||||
```
|
||||
|
||||
### Response Format
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-xxx",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "llama",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Answer text here"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### AIME Dataset Integration
|
||||
- Load from HuggingFace: "AI-MO/aimo-validation-aime"
|
||||
- Store in memory for fast lookup
|
||||
- Regex pattern to find question text in request
|
||||
- Extract answer from matched question
|
||||
|
||||
### Success Rate Configuration
|
||||
- Command-line argument: `--success-rate 0.8` (80% success rate)
|
||||
- Randomly determine correctness based on rate
|
||||
- Log when responses are correct vs incorrect
|
||||
|
||||
### Testing Strategy
|
||||
1. Start simulator with default settings
|
||||
2. Send curl request with known question
|
||||
3. Verify response contains expected answer
|
||||
4. Test with different success rates
|
||||
5. Test edge cases
|
||||
|
||||
## Implementation Steps
|
||||
|
||||
### Step 1: Basic Server Setup
|
||||
```python
|
||||
from flask import Flask, request, jsonify
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route('/v1/chat/completions', methods=['POST'])
|
||||
def chat_completions():
|
||||
# Handle request
|
||||
return jsonify(response)
|
||||
```
|
||||
|
||||
### Step 2: Load AIME Dataset
|
||||
```python
|
||||
import datasets
|
||||
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split="train")
|
||||
# Store in memory
|
||||
```
|
||||
|
||||
### Step 3: Regex Matching
|
||||
```python
|
||||
import re
|
||||
|
||||
def find_question_in_request(request_text):
|
||||
# Regex pattern to find question
|
||||
pattern = r"question:\s*(.*?)\n"
|
||||
match = re.search(pattern, request_text, re.DOTALL)
|
||||
return match.group(1) if match else None
|
||||
```
|
||||
|
||||
### Step 4: Response Generation
|
||||
```python
|
||||
import random
|
||||
|
||||
def generate_response(question, success_rate):
|
||||
if random.random() < success_rate:
|
||||
return get_expected_answer(question)
|
||||
else:
|
||||
return get_wrong_answer(question)
|
||||
```
|
||||
|
||||
### Step 5: Testing with Curl
|
||||
```bash
|
||||
curl -X POST http://localhost:8033/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama",
|
||||
"messages": [{"role": "user", "content": "Question text"}]
|
||||
}'
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
- `--port`: Server port (default: 8033)
|
||||
- `--success-rate`: Success rate 0-1 (default: 0.8)
|
||||
- `--host`: Server host (default: localhost)
|
||||
- `--dataset-split`: AIME split to use (default: train)
|
||||
|
||||
## Expected Output
|
||||
```
|
||||
=== llama-server-simulator ===
|
||||
Server running on http://localhost:8033
|
||||
Success rate: 0.8
|
||||
AIME dataset loaded: 1000 questions
|
||||
```
|
||||
|
||||
## Testing Checklist
|
||||
- [ ] Server starts successfully
|
||||
- [ ] Basic request/response works
|
||||
- [ ] Correct answer returned when success rate allows
|
||||
- [ ] Wrong answer returned when success rate doesn't allow
|
||||
- [ ] No question found returns error
|
||||
- [ ] Multiple requests work correctly
|
||||
- [ ] Different success rates work as expected
|
||||
|
||||
## Next Steps
|
||||
1. Implement basic server structure
|
||||
2. Load AIME dataset
|
||||
3. Implement regex matching
|
||||
4. Add response generation with success rate
|
||||
5. Test with curl commands
|
||||
6. Integrate with eval script once simulator works
|
||||
283
examples/llama-eval/llama-server-simulator.py
Executable file
283
examples/llama-eval/llama-server-simulator.py
Executable file
@@ -0,0 +1,283 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
from flask import Flask, request, jsonify
|
||||
|
||||
# Set cache directory for HuggingFace datasets
|
||||
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
os.environ["HF_DATASETS_CACHE"] = str(cache_dir)
|
||||
|
||||
def dice(s1: str, s2: str) -> float:
|
||||
"""Calculate Dice coefficient between two strings based on bigram overlap."""
|
||||
if not s1 and not s2:
|
||||
return 1.0
|
||||
|
||||
def _bigrams(s: str):
|
||||
return [s[i : i + 2] for i in range(len(s) - 1)]
|
||||
|
||||
bigrams1 = _bigrams(s1)
|
||||
bigrams2 = _bigrams(s2)
|
||||
|
||||
if not bigrams1 and not bigrams2:
|
||||
return 1.0
|
||||
|
||||
from collections import Counter
|
||||
|
||||
freq1 = Counter(bigrams1)
|
||||
freq2 = Counter(bigrams2)
|
||||
|
||||
intersection = sum(min(freq1[bg], freq2[bg]) for bg in freq1)
|
||||
dice_coeff = 2 * intersection / (len(bigrams1) + len(bigrams2))
|
||||
return dice_coeff
|
||||
|
||||
def debug_log(message: str):
|
||||
"""Log debug messages to both stdout and a file"""
|
||||
print(message, file=sys.stderr)
|
||||
with open("/tmp/simulator-debug.log", "a") as f:
|
||||
f.write(message + "\n")
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@dataclass
|
||||
class EvalState:
|
||||
id: str
|
||||
tasks: List[str]
|
||||
task_states: Dict[str, Dict]
|
||||
sampling_config: Dict
|
||||
|
||||
def normalize_number(s: str) -> Optional[int]:
|
||||
match = re.match(r"\d+", s) # match digits from the start
|
||||
if not match:
|
||||
return None
|
||||
return int(match.group(0))
|
||||
|
||||
class AimeDataset:
|
||||
def __init__(self, split: str = "train"):
|
||||
self.split = split
|
||||
self.questions: List[Dict] = []
|
||||
self._load_dataset()
|
||||
|
||||
def _load_dataset(self):
|
||||
print(f"Loading AIME dataset (split: {self.split})...")
|
||||
|
||||
cache_path = Path.home() / ".cache" / "huggingface" / "datasets" / "AI-MO___aimo-validation-aime" / "default" / "0.0.0"
|
||||
if cache_path.exists():
|
||||
print(f"Using cached dataset from {cache_path}")
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split, cache_dir=str(cache_path))
|
||||
else:
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
|
||||
|
||||
self.questions = list(ds)
|
||||
print(f"AIME dataset loaded: {len(self.questions)} questions")
|
||||
|
||||
def find_question(self, request_text: str) -> Optional[Dict]:
|
||||
best_match = None
|
||||
best_distance = -1
|
||||
best_index = -1
|
||||
|
||||
for i, question in enumerate(self.questions):
|
||||
question_text = question["problem"]
|
||||
request_lower = request_text.lower()
|
||||
question_lower = question_text.lower()
|
||||
|
||||
# Exact match
|
||||
if question_lower == request_lower:
|
||||
debug_log(f"DEBUG: Found exact match at index {i}")
|
||||
return question
|
||||
|
||||
# Remove LaTeX formatting for more flexible matching
|
||||
question_no_latex = re.sub(r'\$[^$]+\$', '', question_text)
|
||||
if question_no_latex.lower() == request_lower:
|
||||
debug_log(f"DEBUG: Found match (no LaTeX) at index {i}")
|
||||
return question
|
||||
|
||||
# Calculate Levenshtein distance for partial matches
|
||||
# Only consider if request is at least 50% of question length
|
||||
if len(request_lower) >= len(question_lower) * 0.5:
|
||||
distance = dice(question_lower, request_lower)
|
||||
|
||||
if distance > best_distance:
|
||||
best_distance = distance
|
||||
best_match = question
|
||||
best_index = i
|
||||
|
||||
if best_match and best_distance > 0.3: # Threshold for partial match
|
||||
debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}")
|
||||
return best_match
|
||||
|
||||
debug_log(f"DEBUG: No matching question found for: {request_text[:100]}...")
|
||||
return None
|
||||
|
||||
def get_answer(self, question: Dict) -> str:
|
||||
answer = question["answer"]
|
||||
if isinstance(answer, str):
|
||||
normalized = normalize_number(answer)
|
||||
return str(normalized) if normalized is not None else answer
|
||||
return str(answer)
|
||||
|
||||
class Simulator:
|
||||
def __init__(
|
||||
self,
|
||||
port: int = 8033,
|
||||
host: str = "localhost",
|
||||
success_rate: float = 0.8,
|
||||
dataset_split: str = "train"
|
||||
):
|
||||
self.port = port
|
||||
self.host = host
|
||||
self.success_rate = success_rate
|
||||
self.dataset = AimeDataset(dataset_split)
|
||||
self.eval_state = EvalState(
|
||||
id="aime-2025",
|
||||
tasks=["aime"],
|
||||
task_states={},
|
||||
sampling_config={"temperature": 0, "max_tokens": 2048}
|
||||
)
|
||||
|
||||
def _generate_response(
|
||||
self,
|
||||
question: Dict,
|
||||
should_be_correct: bool
|
||||
) -> Dict:
|
||||
expected_answer = self.dataset.get_answer(question)
|
||||
|
||||
if should_be_correct:
|
||||
response_text = expected_answer
|
||||
else:
|
||||
response_text = self._generate_wrong_answer(question)
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{int(time.time())}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": "llama",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response_text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150
|
||||
}
|
||||
}
|
||||
|
||||
def _generate_wrong_answer(self, question: Dict) -> str:
|
||||
expected_answer = self.dataset.get_answer(question)
|
||||
|
||||
if expected_answer.isdigit():
|
||||
wrong_answer = str(int(expected_answer) + 1)
|
||||
else:
|
||||
wrong_answer = expected_answer + " (wrong)"
|
||||
|
||||
return wrong_answer
|
||||
|
||||
def _process_request(self, request_data: Dict) -> Dict:
|
||||
messages = request_data.get("messages", [])
|
||||
if not messages:
|
||||
return {"error": "No messages in request"}
|
||||
|
||||
request_text = messages[0].get("content", "")
|
||||
debug_log(f"DEBUG: Received request with content: {request_text[:150]}...")
|
||||
|
||||
question = self.dataset.find_question(request_text)
|
||||
if not question:
|
||||
debug_log(f"DEBUG: find_question returned None")
|
||||
return {"error": "No matching question found"}
|
||||
|
||||
should_be_correct = random.random() < self.success_rate
|
||||
|
||||
response = self._generate_response(question, should_be_correct)
|
||||
|
||||
task_id = "aime"
|
||||
self.eval_state.task_states[task_id] = {
|
||||
"correct": should_be_correct,
|
||||
"expected": self.dataset.get_answer(question),
|
||||
"predicted": response["choices"][0]["message"]["content"]
|
||||
}
|
||||
|
||||
return response
|
||||
|
||||
@app.route('/v1/chat/completions', methods=['POST'])
|
||||
def chat_completions():
|
||||
try:
|
||||
request_data = request.get_json()
|
||||
|
||||
if not request_data:
|
||||
return jsonify({"error": "Invalid JSON"}), 400
|
||||
|
||||
response = simulator._process_request(request_data)
|
||||
|
||||
return jsonify(response)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing request: {e}")
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="llama-server simulator for testing eval scripts"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8033,
|
||||
help="Server port (default: 8033)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Server host (default: localhost)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--success-rate",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="Success rate 0-1 (default: 0.8)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-split",
|
||||
type=str,
|
||||
default="train",
|
||||
help="AIME dataset split to use (default: train)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
global simulator
|
||||
simulator = Simulator(
|
||||
port=args.port,
|
||||
host=args.host,
|
||||
success_rate=args.success_rate,
|
||||
dataset_split=args.dataset_split
|
||||
)
|
||||
|
||||
print("\n=== llama-server-simulator ===")
|
||||
print(f"Server running on http://{args.host}:{args.port}")
|
||||
print(f"Success rate: {args.success_rate}")
|
||||
print(f"AIME dataset loaded: {len(simulator.dataset.questions)} questions")
|
||||
print("\nPress Ctrl+C to stop\n")
|
||||
|
||||
app.run(host=args.host, port=args.port, debug=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
135
examples/llama-eval/simulator-summary.md
Normal file
135
examples/llama-eval/simulator-summary.md
Normal file
@@ -0,0 +1,135 @@
|
||||
# llama-server-simulator Implementation Summary
|
||||
|
||||
## Overview
|
||||
Successfully implemented a standalone Python script that simulates a llama-server HTTP endpoint for testing the eval script.
|
||||
|
||||
## Features Implemented
|
||||
|
||||
### 1. HTTP Server
|
||||
- Flask-based `/v1/chat/completions` endpoint
|
||||
- OpenAI-compatible response format
|
||||
- Configurable port and host
|
||||
|
||||
### 2. AIME Dataset Integration
|
||||
- Loads AIME dataset from HuggingFace
|
||||
- In-memory storage for fast lookup
|
||||
- 90 questions loaded from train split
|
||||
|
||||
### 3. Intelligent Question Matching
|
||||
- **Exact matching**: Direct string comparison
|
||||
- **LaTeX removal**: Removes `$...$` formatting for flexible matching
|
||||
- **Levenshtein distance**: Calculates similarity between strings
|
||||
- **Partial matching**: Finds best match even with small differences
|
||||
|
||||
### 4. Response Generation
|
||||
- Configurable success rate (0-1)
|
||||
- Returns correct answers when success rate allows
|
||||
- Returns wrong answers when success rate doesn't allow
|
||||
- Wrong answers are generated by incrementing the expected answer
|
||||
|
||||
### 5. Debug Logging
|
||||
- Debug messages written to stderr
|
||||
- Logs request content, matching results, and distances
|
||||
- Helps troubleshoot matching issues
|
||||
|
||||
## Configuration Options
|
||||
|
||||
```bash
|
||||
python3 llama-server-simulator.py \
|
||||
--port 8034 \
|
||||
--host localhost \
|
||||
--success-rate 0.8 \
|
||||
--dataset-split train
|
||||
```
|
||||
|
||||
## Testing Results
|
||||
|
||||
### Test 1: Correct Answer
|
||||
- **Success rate**: 0.8
|
||||
- **Expected answer**: 116
|
||||
- **Result**: ✓ Correct (116)
|
||||
|
||||
### Test 2: Wrong Answer
|
||||
- **Success rate**: 0.0
|
||||
- **Expected answer**: 116
|
||||
- **Result**: ✓ Wrong (117)
|
||||
|
||||
### Test 3: No Matching Question
|
||||
- **Request**: "What is the capital of France?"
|
||||
- **Result**: ✓ Returns error "No matching question found"
|
||||
|
||||
### Test 4: Success Rate Verification
|
||||
- **Success rate**: 0.8
|
||||
- **Requests**: 10
|
||||
- **Correct answers**: 8/10 (80%)
|
||||
- **Result**: ✓ Success rate working as expected
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Matching Algorithm
|
||||
1. Try exact match (case-insensitive)
|
||||
2. Try match after removing LaTeX formatting
|
||||
3. Calculate Levenshtein distance for partial matches
|
||||
4. Return best match if distance < 0.3 (30% difference)
|
||||
|
||||
### Response Format
|
||||
```json
|
||||
{
|
||||
"id": "chatcmpl-1769864875",
|
||||
"object": "chat.completion",
|
||||
"created": 1769864875,
|
||||
"model": "llama",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "116"
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Files Created
|
||||
|
||||
1. `llama-server-simulator.py` - Main simulator script
|
||||
2. `test-simulator.sh` - Basic test script
|
||||
3. `test-simulator-comprehensive.sh` - Comprehensive test script
|
||||
4. `llama-server-simulator-plan.md` - Implementation plan
|
||||
5. `llama-eval-discussion.md` - Discussion notes
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. ✓ Basic simulator structure
|
||||
2. ✓ AIME dataset integration
|
||||
3. ✓ Question matching with Levenshtein distance
|
||||
4. ✓ Response generation with configurable success rate
|
||||
5. ✓ Testing with curl requests
|
||||
6. ⏭️ Integrate with eval script
|
||||
7. ⏭️ Implement eval state object
|
||||
8. ⏭️ Implement processor object
|
||||
9. ⏭️ Add real-time progress reporting
|
||||
|
||||
## Known Limitations
|
||||
|
||||
1. Only supports AIME dataset (train split)
|
||||
2. Matching is case-insensitive
|
||||
3. Wrong answers are simple increments (not realistic)
|
||||
4. No support for multiple endpoints
|
||||
5. No distributed evaluation
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
1. Support multiple datasets
|
||||
2. More sophisticated wrong answer generation
|
||||
3. Multiple endpoint support
|
||||
4. Distributed evaluation
|
||||
5. Real-time progress reporting
|
||||
6. Eval state serialization
|
||||
26
examples/llama-eval/test-grader.py
Executable file
26
examples/llama-eval/test-grader.py
Executable file
@@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Test grader script")
|
||||
parser.add_argument("--answer", type=str, required=True, help="Predicted answer")
|
||||
parser.add_argument("--expected", type=str, required=True, help="Expected answer")
|
||||
args = parser.parse_args()
|
||||
|
||||
pred = args.answer.strip()
|
||||
gold = args.expected.strip()
|
||||
|
||||
print(f"Gold: {gold}")
|
||||
print(f"Pred: {pred}")
|
||||
|
||||
if pred == gold:
|
||||
print("Correct!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("Incorrect")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
86
examples/llama-eval/test-simulator.sh
Executable file
86
examples/llama-eval/test-simulator.sh
Executable file
@@ -0,0 +1,86 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Get the directory where this script is located
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
echo "=== llama-server-simulator Test Script ==="
|
||||
echo ""
|
||||
|
||||
PORT=8033
|
||||
SUCCESS_RATE=0.8
|
||||
TEST_PORT=8034
|
||||
|
||||
echo "Starting simulator on port $PORT with success rate $SUCCESS_RATE..."
|
||||
source "$SCRIPT_DIR/venv/bin/activate"
|
||||
python3 "$SCRIPT_DIR/llama-server-simulator.py" --port $PORT --success-rate $SUCCESS_RATE > /tmp/simulator-test.log 2>&1 &
|
||||
SIMULATOR_PID=$!
|
||||
|
||||
echo "Waiting for simulator to start..."
|
||||
sleep 5
|
||||
|
||||
# Helper function to make a request and extract the answer
|
||||
make_request() {
|
||||
local question="$1"
|
||||
curl -s -X POST http://localhost:$PORT/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"model\": \"llama\",
|
||||
\"messages\": [
|
||||
{\"role\": \"user\", \"content\": \"$question\"}
|
||||
],
|
||||
\"temperature\": 0,
|
||||
\"max_tokens\": 2048
|
||||
}" | python3 -c "import sys, json; data = json.load(sys.stdin); print(data.get('choices', [{}])[0].get('message', {}).get('content', data.get('error', 'No response')))"
|
||||
}
|
||||
|
||||
# Test question (repeated in multiple tests)
|
||||
TEST_QUESTION="Quadratic polynomials P(x) and Q(x) have leading coefficients 2 and -2, respectively. The graphs of both polynomials pass through the two points (16,54) and (20,53). Find P(0) + Q(0)."
|
||||
|
||||
echo ""
|
||||
echo "=== Test 1: Correct Answer ==="
|
||||
echo "Sending request with known question..."
|
||||
answer=$(make_request "$TEST_QUESTION")
|
||||
echo "Answer: $answer"
|
||||
echo "Expected: 116"
|
||||
echo "Correct: $([ "$answer" == "116" ] && echo "Yes" || echo "No")"
|
||||
|
||||
echo ""
|
||||
echo "=== Test 2: Wrong Answer ==="
|
||||
echo "Sending request with known question (success rate 0.0)..."
|
||||
answer=$(make_request "$TEST_QUESTION")
|
||||
echo "Answer: $answer"
|
||||
echo "Expected: 116"
|
||||
echo "Correct: $([ "$answer" == "116" ] && echo "Yes" || echo "No")"
|
||||
|
||||
echo ""
|
||||
echo "=== Test 3: No Matching Question ==="
|
||||
echo "Sending request with non-matching text..."
|
||||
response=$(make_request "What is the capital of France?")
|
||||
echo "Response: $response"
|
||||
echo "Expected: No matching question found"
|
||||
echo "Correct: $([ "$response" == "No matching question found" ] && echo "Yes" || echo "No")"
|
||||
|
||||
echo ""
|
||||
echo "=== Test 4: Success Rate Verification ==="
|
||||
echo "Sending 10 requests to test success rate..."
|
||||
correct_count=0
|
||||
for i in {1..10}; do
|
||||
answer=$(make_request "$TEST_QUESTION")
|
||||
if [ "$answer" == "116" ]; then
|
||||
correct_count=$((correct_count + 1))
|
||||
fi
|
||||
echo " Request $i: Answer = $answer"
|
||||
done
|
||||
echo "Correct answers: $correct_count/10"
|
||||
echo "Expected: ~8/10 (80% success rate)"
|
||||
echo "Success rate: $(echo "scale=1; $correct_count * 10" | bc)%"
|
||||
|
||||
echo ""
|
||||
echo "=== Test Complete ==="
|
||||
echo "Stopping simulator..."
|
||||
kill $SIMULATOR_PID 2>/dev/null
|
||||
wait $SIMULATOR_PID 2>/dev/null || true
|
||||
|
||||
echo "Simulator stopped."
|
||||
Reference in New Issue
Block a user