eval : support multiple dataset runs

This commit is contained in:
Georgi Gerganov
2026-02-02 22:34:25 +02:00
parent c965abbe6e
commit 3754239e43

View File

@@ -12,6 +12,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Any
import requests
from tqdm import tqdm
import random
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
cache_dir.mkdir(parents=True, exist_ok=True)
@@ -194,10 +195,10 @@ class Processor:
response.raise_for_status()
return response.json()
def _process_single_case(self, i: int) -> TaskState:
def _process_single_case(self, i: int, task_id: str) -> TaskState:
"""Process a single case (thread-safe)"""
question = self.dataset.get_question(i)
case_id = f"aime_{self.dataset.split}_{question['id']}"
dataset_id = f"aime_{self.dataset.split}_{question['id']}"
gold = self.dataset.get_answer(question)
# Apply template if available
@@ -207,7 +208,7 @@ class Processor:
prompt = question["problem"]
task_state = TaskState(
case_id=case_id,
case_id=task_id,
prompt=prompt,
gold=gold
)
@@ -223,7 +224,7 @@ class Processor:
return task_state
def process(self, n_cases: int = None, seed: int = 42):
def process(self, n_cases: int = None, seed: int = 1234):
"""Process cases and update eval state"""
if n_cases is None:
n_cases = len(self.dataset.questions)
@@ -234,26 +235,37 @@ class Processor:
print(f"Max tokens: {self.n_predict}")
print()
dataset_size = len(self.dataset.questions)
random.seed(seed)
task_list = []
for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size):
chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size)
indices = list(range(dataset_size))
random.shuffle(indices)
chunk_indices = indices[:chunk_size]
for i in chunk_indices:
task_id = f"aime_{self.eval_state.id}_{chunk_idx:03d}_{i:03d}"
task_list.append((i, task_id))
# Print task summary table
print("Tasks:")
print(" Task ID Dataset Prompt (first 40 chars) Expected Status")
for i in range(min(n_cases, len(self.dataset.questions))):
for i, task_id in task_list:
question = self.dataset.get_question(i)
case_id = f"aime_{self.dataset.split}_{question['id']}"
prompt = question["problem"]
gold = self.dataset.get_answer(question)
truncated_prompt = prompt[:40] + "..." if len(prompt) > 40 else prompt
print(f" {case_id:<15} AIME2025 {truncated_prompt:<40} {gold:<10} pending")
print(f" {task_id:<15} AIME2025 {truncated_prompt:<40} {gold:<10} pending")
print()
task_states: Dict[str, List[TaskState]] = {task: [] for task in self.eval_state.tasks}
total = 0
correct = 0
indices = list(range(min(n_cases, len(self.dataset.questions))))
with ThreadPoolExecutor(max_workers=self.threads) as executor:
futures = {executor.submit(self._process_single_case, i): i for i in indices}
futures = {executor.submit(self._process_single_case, i, task_id): (i, task_id) for i, task_id in task_list}
for future in as_completed(futures):
task_state = future.result()
@@ -309,6 +321,12 @@ def main():
default=None,
help="Number of cases to evaluate (default: all)"
)
parser.add_argument(
"--seed",
type=int,
default=1234,
help="Random seed for shuffling (default: 1234)"
)
parser.add_argument(
"--n_predict",
type=int,
@@ -376,7 +394,7 @@ def main():
model_name=args.model
)
eval_state = processor.process(n_cases=args.n_cases)
eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)
processor.dump_state(args.output)
if __name__ == "__main__":