mirror of
https://github.com/microsoft/SkillOpt.git
synced 2026-07-03 14:02:58 +08:00
313 lines
11 KiB
Python
313 lines
11 KiB
Python
"""SkillOpt-Sleep — Stage 2: mine.
|
|
|
|
Turn :class:`SessionDigest` objects into :class:`TaskRecord` training units.
|
|
|
|
Two miners:
|
|
* heuristic_mine — deterministic, no API. Detects retry chains (a prompt
|
|
re-asked after negative feedback => the early attempt failed), extracts
|
|
the user's recurring intents, and labels outcomes from feedback signals.
|
|
* llm_mine — optional; uses an optimizer backend to produce richer
|
|
TaskRecords with checkable references. Falls back to heuristic on error.
|
|
|
|
The heuristic miner is what makes the whole cycle runnable offline and is the
|
|
basis of the deterministic experiment.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import os
|
|
import re
|
|
from collections import Counter
|
|
from typing import Any, Callable, List, Optional, Set, Tuple
|
|
|
|
from skillopt_sleep.types import SessionDigest, TaskRecord
|
|
|
|
|
|
def _tid(project: str, intent: str) -> str:
|
|
h = hashlib.sha256((project + "::" + intent).encode("utf-8")).hexdigest()[:12]
|
|
return "task_" + h
|
|
|
|
|
|
def _short(text: str, n: int = 600) -> str:
|
|
text = (text or "").strip()
|
|
return text if len(text) <= n else text[:n] + " …"
|
|
|
|
|
|
def _looks_negative(signals: List[str]) -> bool:
|
|
return any(s.startswith("neg:") for s in signals)
|
|
|
|
|
|
def _looks_positive(signals: List[str]) -> bool:
|
|
return any(s.startswith("pos:") for s in signals)
|
|
|
|
|
|
_TARGET_STOPWORDS = {
|
|
"about", "after", "again", "agent", "agents", "all", "also", "always",
|
|
"and", "any", "are", "before", "being", "but", "can", "codex",
|
|
"current", "default", "docs", "does", "done", "each", "file", "files",
|
|
"for", "from", "have", "into", "keep", "must", "not", "only", "path",
|
|
"paths", "project", "read", "repo", "request", "requests", "rule",
|
|
"rules", "same", "should", "skill", "skills", "source", "start",
|
|
"task", "tasks", "that", "the", "their", "then", "this", "unless",
|
|
"update", "user", "users", "when", "with", "work", "workflow",
|
|
}
|
|
|
|
|
|
def _target_tokens(text: str) -> List[str]:
|
|
tokens: List[str] = []
|
|
for raw in re.findall(r"[\w][\w.-]*", (text or "").lower(), flags=re.UNICODE):
|
|
parts = [raw] + re.split(r"[\W_]+", raw, flags=re.UNICODE)
|
|
for part in parts:
|
|
if len(part) < 3 or part.isdigit() or part in _TARGET_STOPWORDS:
|
|
continue
|
|
tokens.append(part)
|
|
return tokens
|
|
|
|
|
|
def _expand_target_keywords(keywords: Set[str]) -> None:
|
|
if "mcp" in keywords:
|
|
keywords.update({
|
|
"configure", "configuration", "connect", "connected", "enable",
|
|
"enabled", "install", "installed", "server", "servers",
|
|
"настрой", "настроить", "подключи", "подключить",
|
|
})
|
|
if {"conflict", "conflicts"} & keywords:
|
|
keywords.update({
|
|
"cherry", "conflict", "conflicts", "git", "merge", "rebase",
|
|
"unmerged", "конфликт", "конфликты",
|
|
})
|
|
|
|
|
|
def target_task_keywords(
|
|
target_skill_text: str,
|
|
target_skill_path: str = "",
|
|
*,
|
|
limit: int = 180,
|
|
) -> Tuple[Set[str], Set[str]]:
|
|
"""Return (strong, weak) keywords that describe a target skill."""
|
|
path_text = (target_skill_path or "").replace(os.sep, " ")
|
|
headings = "\n".join(re.findall(r"(?m)^#+\s+(.+)$", target_skill_text or ""))
|
|
strong = set(_target_tokens(path_text + "\n" + headings))
|
|
weak = set(strong)
|
|
counts = Counter(_target_tokens(target_skill_text or ""))
|
|
for token, _count in counts.most_common(limit):
|
|
weak.add(token)
|
|
_expand_target_keywords(strong)
|
|
_expand_target_keywords(weak)
|
|
return strong, weak
|
|
|
|
|
|
def _task_search_text(task: TaskRecord) -> str:
|
|
return "\n".join([
|
|
task.intent or "",
|
|
task.context_excerpt or "",
|
|
" ".join(task.tags or []),
|
|
])
|
|
|
|
|
|
def filter_tasks_for_target(
|
|
tasks: List[TaskRecord],
|
|
target_skill_text: str,
|
|
target_skill_path: str = "",
|
|
) -> List[TaskRecord]:
|
|
"""Prefer tasks whose language overlaps the explicit target skill.
|
|
|
|
If nothing matches, return the original list. This keeps a target run useful
|
|
even when transcripts are too sparse or the skill is too generic.
|
|
"""
|
|
strong, weak = target_task_keywords(target_skill_text, target_skill_path)
|
|
if not tasks or not (strong or weak):
|
|
return tasks
|
|
|
|
ranked = []
|
|
for idx, task in enumerate(tasks):
|
|
tokens = set(_target_tokens(_task_search_text(task)))
|
|
strong_hits = tokens & strong
|
|
weak_hits = tokens & weak
|
|
if not strong_hits and len(weak_hits) < 2:
|
|
continue
|
|
score = len(strong_hits) * 3 + len(weak_hits)
|
|
ranked.append((score, idx, task))
|
|
if not ranked:
|
|
return tasks
|
|
ranked.sort(key=lambda item: (-item[0], item[1]))
|
|
return [task for _score, _idx, task in ranked]
|
|
|
|
|
|
def heuristic_mine(
|
|
digests: List[SessionDigest],
|
|
*,
|
|
max_tasks: int = 40,
|
|
) -> List[TaskRecord]:
|
|
"""Deterministic miner — no API calls.
|
|
|
|
Strategy:
|
|
* Each session with >=1 real user prompt yields one TaskRecord whose
|
|
intent is the FIRST substantive prompt (the original ask).
|
|
* Outcome is inferred:
|
|
- negative feedback present and no later positive -> "fail"
|
|
- positive feedback present -> "success"
|
|
- re-asks (multiple user turns) without resolution -> "mixed"
|
|
- otherwise -> "unknown"
|
|
* attempted_solution = the last assistant final (what was produced).
|
|
* reference_kind defaults to "none"; the consolidation step will use a
|
|
rubric judge for these. (Exact refs are added by the experiment data
|
|
or by the LLM miner when it can derive a checkable answer.)
|
|
"""
|
|
tasks: List[TaskRecord] = []
|
|
for d in digests:
|
|
if not d.user_prompts:
|
|
continue
|
|
intent = d.user_prompts[0]
|
|
if len(intent.strip()) < 8:
|
|
continue
|
|
if _looks_positive(d.feedback_signals) and not _looks_negative(d.feedback_signals):
|
|
outcome = "success"
|
|
elif _looks_negative(d.feedback_signals):
|
|
outcome = "fail"
|
|
elif d.n_user_turns >= 3:
|
|
outcome = "mixed"
|
|
else:
|
|
outcome = "unknown"
|
|
|
|
attempted = d.assistant_finals[-1] if d.assistant_finals else ""
|
|
context = ""
|
|
if len(d.user_prompts) > 1:
|
|
# later prompts often carry the corrective detail / real constraints
|
|
context = "Follow-up constraints from the same session:\n- " + "\n- ".join(
|
|
_short(p, 200) for p in d.user_prompts[1:4]
|
|
)
|
|
tags = []
|
|
if d.tools_used:
|
|
tags.append("tools:" + "+".join(d.tools_used[:4]))
|
|
if d.git_branch:
|
|
tags.append("branch:" + d.git_branch)
|
|
|
|
tasks.append(
|
|
TaskRecord(
|
|
id=_tid(d.project, intent),
|
|
project=d.project,
|
|
intent=_short(intent, 800),
|
|
context_excerpt=_short(context, 600),
|
|
attempted_solution=_short(attempted, 600),
|
|
outcome=outcome,
|
|
reference_kind="none",
|
|
reference="",
|
|
tags=tags,
|
|
source_sessions=[d.session_id],
|
|
)
|
|
)
|
|
if len(tasks) >= max_tasks:
|
|
break
|
|
return tasks
|
|
|
|
|
|
def dedup_tasks(tasks: List[TaskRecord]) -> List[TaskRecord]:
|
|
"""Merge tasks sharing an id (same project+intent across sessions)."""
|
|
by_id: dict = {}
|
|
for t in tasks:
|
|
if t.id in by_id:
|
|
ex = by_id[t.id]
|
|
ex.source_sessions = list(dict.fromkeys(ex.source_sessions + t.source_sessions))
|
|
# prefer a resolved outcome if either session resolved it
|
|
order = {"success": 3, "fail": 2, "mixed": 1, "unknown": 0}
|
|
if order.get(t.outcome, 0) > order.get(ex.outcome, 0):
|
|
ex.outcome = t.outcome
|
|
else:
|
|
by_id[t.id] = t
|
|
return list(by_id.values())
|
|
|
|
|
|
def assign_splits(
|
|
tasks: List[TaskRecord],
|
|
*,
|
|
val_fraction: float = 0.34,
|
|
test_fraction: float = 0.0,
|
|
holdout_fraction: float | None = None, # legacy alias for val_fraction
|
|
seed: int = 42,
|
|
) -> List[TaskRecord]:
|
|
"""Deterministically split tasks into train / val / test.
|
|
|
|
Anti-overfitting contract (the user's design):
|
|
* ``val`` and ``test`` are drawn ONLY from REAL mined tasks (origin=='real')
|
|
and never overlap. val gates updates; test is the final held-out measure.
|
|
* ``train`` may include DREAM-augmented tasks (origin=='dream'); those are
|
|
NEVER placed in val/test.
|
|
|
|
A stable hash of the task id keeps the same real task in the same split across
|
|
nights (a fixed held-out gate, like SkillOpt's D_sel/D_test).
|
|
|
|
Back-compat: if ``test_fraction`` is 0 (default), this behaves like the old
|
|
two-way replay/holdout split — real tasks divide into train + val, no test.
|
|
``holdout_fraction`` is accepted as an alias for ``val_fraction``.
|
|
"""
|
|
if holdout_fraction is not None:
|
|
val_fraction = holdout_fraction
|
|
|
|
dream = [t for t in tasks if t.origin == "dream"]
|
|
real = [t for t in tasks if t.origin != "dream"]
|
|
|
|
# all dream tasks go to train, unconditionally
|
|
for t in dream:
|
|
t.split = "train"
|
|
|
|
val_cut = int(round(val_fraction * 100))
|
|
test_cut = val_cut + int(round(test_fraction * 100))
|
|
for t in real:
|
|
bucket = int(hashlib.sha256((str(seed) + t.id).encode()).hexdigest(), 16) % 100
|
|
if bucket < val_cut:
|
|
t.split = "val"
|
|
elif bucket < test_cut:
|
|
t.split = "test"
|
|
else:
|
|
t.split = "train"
|
|
|
|
# guarantee val (the gate) is non-empty when we have >=2 real tasks
|
|
real_splits = {t.split for t in real}
|
|
if len(real) >= 2 and "val" not in real_splits:
|
|
real[-1].split = "val"
|
|
# guarantee a train pool exists (dream or real) when possible
|
|
if not any(t.split == "train" for t in tasks) and len(real) >= 2:
|
|
real[0].split = "train"
|
|
# if test was requested but ended up empty with >=3 real tasks, carve one
|
|
if test_fraction > 0 and len(real) >= 3 and not any(t.split == "test" for t in real):
|
|
for t in real:
|
|
if t.split == "train":
|
|
t.split = "test"
|
|
break
|
|
return tasks
|
|
|
|
|
|
def normalize_legacy_split(value: str) -> str:
|
|
"""Map old split names to the new vocabulary."""
|
|
return {"replay": "train", "holdout": "val"}.get(value, value)
|
|
|
|
|
|
def mine(
|
|
digests: List[SessionDigest],
|
|
*,
|
|
max_tasks: int = 40,
|
|
candidate_limit: int = 0,
|
|
holdout_fraction: float = 0.34,
|
|
seed: int = 42,
|
|
llm_miner: Optional[Callable[[List[SessionDigest]], List[TaskRecord]]] = None,
|
|
target_skill_text: str = "",
|
|
target_skill_path: str = "",
|
|
) -> List[TaskRecord]:
|
|
"""Top-level miner. Uses ``llm_miner`` if provided, else heuristic."""
|
|
candidate_limit = candidate_limit or max_tasks
|
|
tasks: List[TaskRecord] = []
|
|
if llm_miner is not None:
|
|
try:
|
|
tasks = llm_miner(digests) or []
|
|
except Exception:
|
|
tasks = []
|
|
if not tasks:
|
|
tasks = heuristic_mine(digests, max_tasks=candidate_limit)
|
|
tasks = dedup_tasks(tasks)
|
|
if target_skill_text or target_skill_path:
|
|
tasks = filter_tasks_for_target(tasks, target_skill_text, target_skill_path)
|
|
tasks = tasks[:max_tasks]
|
|
tasks = assign_splits(tasks, holdout_fraction=holdout_fraction, seed=seed)
|
|
return tasks
|