sim : fix answer matching

This commit is contained in:
Georgi Gerganov
2026-02-02 19:45:04 +02:00
parent 98e9eabbf4
commit c965abbe6e
2 changed files with 36 additions and 26 deletions

View File

@@ -28,8 +28,7 @@ GRADER_PATTERNS = {
}
TEMPLATE_REGISTRY = {
"aime": """
{question}
"aime": """{question}
Please reason step by step, and put your final answer within \\boxed{{}}.
""",
}

View File

@@ -19,25 +19,28 @@ cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
cache_dir.mkdir(parents=True, exist_ok=True)
os.environ["HF_DATASETS_CACHE"] = str(cache_dir)
def levenshtein_distance(s1: str, s2: str) -> int:
"""Calculate Levenshtein distance between two strings"""
if len(s1) < len(s2):
return levenshtein_distance(s2, s1)
def dice(s1: str, s2: str) -> float:
"""Calculate Dice coefficient between two strings based on bigram overlap."""
if not s1 and not s2:
return 1.0
if len(s2) == 0:
return len(s1)
def _bigrams(s: str):
return [s[i : i + 2] for i in range(len(s) - 1)]
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
for j, c2 in enumerate(s2):
insertions = previous_row[j + 1] + 1
deletions = current_row[j] + 1
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
bigrams1 = _bigrams(s1)
bigrams2 = _bigrams(s2)
return previous_row[-1]
if not bigrams1 and not bigrams2:
return 1.0
from collections import Counter
freq1 = Counter(bigrams1)
freq2 = Counter(bigrams2)
intersection = sum(min(freq1[bg], freq2[bg]) for bg in freq1)
dice_coeff = 2 * intersection / (len(bigrams1) + len(bigrams2))
return dice_coeff
def debug_log(message: str):
"""Log debug messages to both stdout and a file"""
@@ -54,6 +57,12 @@ class EvalState:
task_states: Dict[str, Dict]
sampling_config: Dict
def normalize_number(s: str) -> Optional[int]:
match = re.match(r"\d+", s) # match digits from the start
if not match:
return None
return int(match.group(0))
class AimeDataset:
def __init__(self, split: str = "train"):
self.split = split
@@ -75,7 +84,7 @@ class AimeDataset:
def find_question(self, request_text: str) -> Optional[Dict]:
best_match = None
best_distance = float('inf')
best_distance = -1
best_index = -1
for i, question in enumerate(self.questions):
@@ -97,16 +106,14 @@ class AimeDataset:
# Calculate Levenshtein distance for partial matches
# Only consider if request is at least 50% of question length
if len(request_lower) >= len(question_lower) * 0.5:
distance = levenshtein_distance(question_lower, request_lower)
# Normalize distance by length
normalized_distance = distance / len(question_lower)
distance = dice(question_lower, request_lower)
if normalized_distance < best_distance:
best_distance = normalized_distance
if distance > best_distance:
best_distance = distance
best_match = question
best_index = i
if best_match and best_distance < 0.3: # Threshold for partial match
if best_match and best_distance > 0.3: # Threshold for partial match
debug_log(f"DEBUG: Found best partial match at index {best_index} with distance {best_distance:.3f}")
return best_match
@@ -114,7 +121,11 @@ class AimeDataset:
return None
def get_answer(self, question: Dict) -> str:
return str(question["answer"])
answer = question["answer"]
if isinstance(answer, str):
normalized = normalize_number(answer)
return str(normalized) if normalized is not None else answer
return str(answer)
class Simulator:
def __init__(