mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2026-02-05 13:53:23 +02:00
eval : support multiple dataset runs
This commit is contained in:
@@ -12,6 +12,7 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -194,10 +195,10 @@ class Processor:
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _process_single_case(self, i: int) -> TaskState:
|
||||
def _process_single_case(self, i: int, task_id: str) -> TaskState:
|
||||
"""Process a single case (thread-safe)"""
|
||||
question = self.dataset.get_question(i)
|
||||
case_id = f"aime_{self.dataset.split}_{question['id']}"
|
||||
dataset_id = f"aime_{self.dataset.split}_{question['id']}"
|
||||
gold = self.dataset.get_answer(question)
|
||||
|
||||
# Apply template if available
|
||||
@@ -207,7 +208,7 @@ class Processor:
|
||||
prompt = question["problem"]
|
||||
|
||||
task_state = TaskState(
|
||||
case_id=case_id,
|
||||
case_id=task_id,
|
||||
prompt=prompt,
|
||||
gold=gold
|
||||
)
|
||||
@@ -223,7 +224,7 @@ class Processor:
|
||||
|
||||
return task_state
|
||||
|
||||
def process(self, n_cases: int = None, seed: int = 42):
|
||||
def process(self, n_cases: int = None, seed: int = 1234):
|
||||
"""Process cases and update eval state"""
|
||||
if n_cases is None:
|
||||
n_cases = len(self.dataset.questions)
|
||||
@@ -234,26 +235,37 @@ class Processor:
|
||||
print(f"Max tokens: {self.n_predict}")
|
||||
print()
|
||||
|
||||
dataset_size = len(self.dataset.questions)
|
||||
random.seed(seed)
|
||||
|
||||
task_list = []
|
||||
for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size):
|
||||
chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size)
|
||||
indices = list(range(dataset_size))
|
||||
random.shuffle(indices)
|
||||
chunk_indices = indices[:chunk_size]
|
||||
|
||||
for i in chunk_indices:
|
||||
task_id = f"aime_{self.eval_state.id}_{chunk_idx:03d}_{i:03d}"
|
||||
task_list.append((i, task_id))
|
||||
|
||||
# Print task summary table
|
||||
print("Tasks:")
|
||||
print(" Task ID Dataset Prompt (first 40 chars) Expected Status")
|
||||
for i in range(min(n_cases, len(self.dataset.questions))):
|
||||
for i, task_id in task_list:
|
||||
question = self.dataset.get_question(i)
|
||||
case_id = f"aime_{self.dataset.split}_{question['id']}"
|
||||
prompt = question["problem"]
|
||||
gold = self.dataset.get_answer(question)
|
||||
truncated_prompt = prompt[:40] + "..." if len(prompt) > 40 else prompt
|
||||
print(f" {case_id:<15} AIME2025 {truncated_prompt:<40} {gold:<10} pending")
|
||||
print(f" {task_id:<15} AIME2025 {truncated_prompt:<40} {gold:<10} pending")
|
||||
print()
|
||||
|
||||
task_states: Dict[str, List[TaskState]] = {task: [] for task in self.eval_state.tasks}
|
||||
total = 0
|
||||
correct = 0
|
||||
|
||||
indices = list(range(min(n_cases, len(self.dataset.questions))))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.threads) as executor:
|
||||
futures = {executor.submit(self._process_single_case, i): i for i in indices}
|
||||
futures = {executor.submit(self._process_single_case, i, task_id): (i, task_id) for i, task_id in task_list}
|
||||
|
||||
for future in as_completed(futures):
|
||||
task_state = future.result()
|
||||
@@ -309,6 +321,12 @@ def main():
|
||||
default=None,
|
||||
help="Number of cases to evaluate (default: all)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=1234,
|
||||
help="Random seed for shuffling (default: 1234)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_predict",
|
||||
type=int,
|
||||
@@ -376,7 +394,7 @@ def main():
|
||||
model_name=args.model
|
||||
)
|
||||
|
||||
eval_state = processor.process(n_cases=args.n_cases)
|
||||
eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)
|
||||
processor.dump_state(args.output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user