mirror of
https://github.com/microsoft/SkillOpt.git
synced 2026-07-03 14:02:58 +08:00
Merge pull request #65 from summerview1997/codex/searchqa-materialize-splits
Add SearchQA split materialization helper
This commit is contained in:
@@ -138,6 +138,20 @@ ALFWorld:
|
||||
`searchqa_id_split/` is an ID-only manifest. Each released `id` exactly matches
|
||||
the `key` field in `lucadiliello/searchqa`.
|
||||
|
||||
To materialize the runnable SearchQA split used by
|
||||
`configs/searchqa/default.yaml`, install the optional dependency and run:
|
||||
|
||||
```bash
|
||||
python -m pip install 'skillopt[searchqa]'
|
||||
python scripts/materialize_searchqa.py
|
||||
```
|
||||
|
||||
This writes full examples to:
|
||||
|
||||
```text
|
||||
data/searchqa_split
|
||||
```
|
||||
|
||||
Materialized examples must include the fields consumed by the SearchQA
|
||||
environment, including:
|
||||
|
||||
|
||||
@@ -40,6 +40,8 @@ alfworld = ["alfworld>=0.4.0", "gymnasium>=0.29.0"]
|
||||
claude = ["claude-agent-sdk>=0.1.0"]
|
||||
# Qwen local model backend (via vLLM)
|
||||
qwen = ["vllm>=0.4.0"]
|
||||
# SearchQA data materialization
|
||||
searchqa = ["datasets>=2.18.0"]
|
||||
# Documentation site
|
||||
docs = ["mkdocs-material>=9.5.0", "mkdocstrings[python]>=0.24.0"]
|
||||
# WebUI dashboard
|
||||
|
||||
148
scripts/materialize_searchqa.py
Normal file
148
scripts/materialize_searchqa.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Materialize runnable SearchQA splits from the released ID manifest."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections.abc import Iterable, Mapping
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
SPLITS = ("train", "val", "test")
|
||||
REQUIRED_FIELDS = ("question", "context", "answers")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=PROJECT_ROOT / "data" / "searchqa_id_split",
|
||||
help="Directory containing train/val/test ID manifests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=PROJECT_ROOT / "data" / "searchqa_split",
|
||||
help="Directory to write runnable train/val/test splits.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
default="lucadiliello/searchqa",
|
||||
help="Hugging Face dataset repository to load.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_manifest_ids(manifest_dir: Path) -> dict[str, list[str]]:
|
||||
split_ids = {}
|
||||
for split in SPLITS:
|
||||
path = manifest_dir / split / "items.json"
|
||||
with path.open(encoding="utf-8") as file:
|
||||
items = json.load(file)
|
||||
split_ids[split] = [str(item["id"]) for item in items]
|
||||
return split_ids
|
||||
|
||||
|
||||
def _iter_dataset_rows(dataset: Mapping[str, Iterable[dict]]) -> Iterable[dict]:
|
||||
for source_split in dataset.values():
|
||||
yield from source_split
|
||||
|
||||
|
||||
def _normalize_row(row: dict) -> dict:
|
||||
try:
|
||||
key = str(row["key"])
|
||||
except KeyError as exc:
|
||||
raise ValueError("SearchQA source row is missing required field: key") from exc
|
||||
|
||||
missing = [field for field in REQUIRED_FIELDS if field not in row]
|
||||
if missing:
|
||||
raise ValueError(f"SearchQA source row {key!r} is missing required fields: {', '.join(missing)}")
|
||||
|
||||
return {
|
||||
"id": key,
|
||||
"question": row["question"],
|
||||
"context": row["context"],
|
||||
"answers": row["answers"],
|
||||
}
|
||||
|
||||
|
||||
def materialize_searchqa_splits(
|
||||
manifest_dir: Path,
|
||||
output_dir: Path,
|
||||
dataset: Mapping[str, Iterable[dict]],
|
||||
*,
|
||||
dataset_name: str,
|
||||
) -> dict[str, int]:
|
||||
"""Write runnable SearchQA train/val/test splits from a source dataset."""
|
||||
manifest_dir = manifest_dir.resolve()
|
||||
output_dir = output_dir.resolve()
|
||||
split_ids = load_manifest_ids(manifest_dir)
|
||||
wanted_ids = {item_id for ids in split_ids.values() for item_id in ids}
|
||||
|
||||
selected: dict[str, dict] = {}
|
||||
duplicate_ids: set[str] = set()
|
||||
for row in _iter_dataset_rows(dataset):
|
||||
key = str(row.get("key", ""))
|
||||
if key not in wanted_ids:
|
||||
continue
|
||||
if key in selected:
|
||||
duplicate_ids.add(key)
|
||||
continue
|
||||
selected[key] = _normalize_row(row)
|
||||
|
||||
if duplicate_ids:
|
||||
preview = ", ".join(sorted(duplicate_ids)[:5])
|
||||
raise ValueError(f"SearchQA source dataset contains duplicate manifest IDs. First IDs: {preview}")
|
||||
|
||||
missing = sorted(wanted_ids - selected.keys())
|
||||
if missing:
|
||||
preview = ", ".join(missing[:5])
|
||||
raise RuntimeError(f"SearchQA source dataset is missing {len(missing)} manifest IDs. First IDs: {preview}")
|
||||
|
||||
counts = {}
|
||||
for split, ids in split_ids.items():
|
||||
items = [selected[item_id] for item_id in ids]
|
||||
split_dir = output_dir / split
|
||||
split_dir.mkdir(parents=True, exist_ok=True)
|
||||
with (split_dir / "items.json").open("w", encoding="utf-8") as file:
|
||||
json.dump(items, file, ensure_ascii=False, indent=2)
|
||||
counts[split] = len(items)
|
||||
|
||||
manifest = {
|
||||
"source_manifest_dir": str(manifest_dir),
|
||||
"source_dataset": dataset_name,
|
||||
"counts": counts,
|
||||
"item_fields": ["id", *REQUIRED_FIELDS],
|
||||
}
|
||||
with (output_dir / "split_manifest.json").open("w", encoding="utf-8") as file:
|
||||
json.dump(manifest, file, ensure_ascii=False, indent=2)
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
except ImportError as exc:
|
||||
raise SystemExit(
|
||||
"Missing dependency 'datasets'. Install it with:\n"
|
||||
" python -m pip install 'skillopt[searchqa]'\n"
|
||||
"or:\n"
|
||||
" python -m pip install datasets"
|
||||
) from exc
|
||||
|
||||
print(f"Loading {args.dataset}...")
|
||||
dataset = load_dataset(args.dataset)
|
||||
counts = materialize_searchqa_splits(
|
||||
args.manifest_dir,
|
||||
args.output_dir,
|
||||
dataset,
|
||||
dataset_name=args.dataset,
|
||||
)
|
||||
print(f"Wrote SearchQA splits to {args.output_dir.resolve()}: {counts}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
66
tests/test_materialize_searchqa.py
Normal file
66
tests/test_materialize_searchqa.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.materialize_searchqa import materialize_searchqa_splits
|
||||
|
||||
|
||||
def _write_manifest(root: Path, split_ids: dict[str, list[str]]) -> None:
|
||||
for split, ids in split_ids.items():
|
||||
split_dir = root / split
|
||||
split_dir.mkdir(parents=True)
|
||||
(split_dir / "items.json").write_text(
|
||||
json.dumps([{"id": item_id} for item_id in ids]),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def _row(key: str) -> dict:
|
||||
return {
|
||||
"key": key,
|
||||
"question": f"question {key}",
|
||||
"context": f"context {key}",
|
||||
"answers": [f"answer {key}"],
|
||||
"ignored": "not written",
|
||||
}
|
||||
|
||||
|
||||
def test_materialize_searchqa_splits_preserves_manifest_order(tmp_path):
|
||||
manifest_dir = tmp_path / "manifest"
|
||||
output_dir = tmp_path / "out"
|
||||
_write_manifest(manifest_dir, {"train": ["b", "a"], "val": ["c"], "test": ["d"]})
|
||||
|
||||
counts = materialize_searchqa_splits(
|
||||
manifest_dir,
|
||||
output_dir,
|
||||
{"train": [_row("a"), _row("b")], "validation": [_row("c"), _row("d")]},
|
||||
dataset_name="example/searchqa",
|
||||
)
|
||||
|
||||
assert counts == {"train": 2, "val": 1, "test": 1}
|
||||
train_items = json.loads((output_dir / "train" / "items.json").read_text(encoding="utf-8"))
|
||||
assert [item["id"] for item in train_items] == ["b", "a"]
|
||||
assert train_items[0] == {
|
||||
"id": "b",
|
||||
"question": "question b",
|
||||
"context": "context b",
|
||||
"answers": ["answer b"],
|
||||
}
|
||||
|
||||
split_manifest = json.loads((output_dir / "split_manifest.json").read_text(encoding="utf-8"))
|
||||
assert split_manifest["source_dataset"] == "example/searchqa"
|
||||
assert split_manifest["counts"] == counts
|
||||
|
||||
|
||||
def test_materialize_searchqa_splits_fails_on_missing_manifest_id(tmp_path):
|
||||
manifest_dir = tmp_path / "manifest"
|
||||
_write_manifest(manifest_dir, {"train": ["a"], "val": ["missing"], "test": []})
|
||||
|
||||
with pytest.raises(RuntimeError, match="missing"):
|
||||
materialize_searchqa_splits(
|
||||
manifest_dir,
|
||||
tmp_path / "out",
|
||||
{"train": [_row("a")]},
|
||||
dataset_name="example/searchqa",
|
||||
)
|
||||
Reference in New Issue
Block a user