Merge pull request #65 from summerview1997/codex/searchqa-materialize-splits

Add SearchQA split materialization helper
This commit is contained in:
Yifan Yang
2026-06-17 23:50:38 +08:00
committed by GitHub
4 changed files with 230 additions and 0 deletions

View File

@@ -138,6 +138,20 @@ ALFWorld:
`searchqa_id_split/` is an ID-only manifest. Each released `id` exactly matches
the `key` field in `lucadiliello/searchqa`.
To materialize the runnable SearchQA split used by
`configs/searchqa/default.yaml`, install the optional dependency and run:
```bash
python -m pip install 'skillopt[searchqa]'
python scripts/materialize_searchqa.py
```
This writes full examples to:
```text
data/searchqa_split
```
Materialized examples must include the fields consumed by the SearchQA
environment, including:

View File

@@ -40,6 +40,8 @@ alfworld = ["alfworld>=0.4.0", "gymnasium>=0.29.0"]
claude = ["claude-agent-sdk>=0.1.0"]
# Qwen local model backend (via vLLM)
qwen = ["vllm>=0.4.0"]
# SearchQA data materialization
searchqa = ["datasets>=2.18.0"]
# Documentation site
docs = ["mkdocs-material>=9.5.0", "mkdocstrings[python]>=0.24.0"]
# WebUI dashboard

View File

@@ -0,0 +1,148 @@
"""Materialize runnable SearchQA splits from the released ID manifest."""
from __future__ import annotations
import argparse
import json
from collections.abc import Iterable, Mapping
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
SPLITS = ("train", "val", "test")
REQUIRED_FIELDS = ("question", "context", "answers")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--manifest-dir",
type=Path,
default=PROJECT_ROOT / "data" / "searchqa_id_split",
help="Directory containing train/val/test ID manifests.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=PROJECT_ROOT / "data" / "searchqa_split",
help="Directory to write runnable train/val/test splits.",
)
parser.add_argument(
"--dataset",
default="lucadiliello/searchqa",
help="Hugging Face dataset repository to load.",
)
return parser.parse_args()
def load_manifest_ids(manifest_dir: Path) -> dict[str, list[str]]:
split_ids = {}
for split in SPLITS:
path = manifest_dir / split / "items.json"
with path.open(encoding="utf-8") as file:
items = json.load(file)
split_ids[split] = [str(item["id"]) for item in items]
return split_ids
def _iter_dataset_rows(dataset: Mapping[str, Iterable[dict]]) -> Iterable[dict]:
for source_split in dataset.values():
yield from source_split
def _normalize_row(row: dict) -> dict:
try:
key = str(row["key"])
except KeyError as exc:
raise ValueError("SearchQA source row is missing required field: key") from exc
missing = [field for field in REQUIRED_FIELDS if field not in row]
if missing:
raise ValueError(f"SearchQA source row {key!r} is missing required fields: {', '.join(missing)}")
return {
"id": key,
"question": row["question"],
"context": row["context"],
"answers": row["answers"],
}
def materialize_searchqa_splits(
manifest_dir: Path,
output_dir: Path,
dataset: Mapping[str, Iterable[dict]],
*,
dataset_name: str,
) -> dict[str, int]:
"""Write runnable SearchQA train/val/test splits from a source dataset."""
manifest_dir = manifest_dir.resolve()
output_dir = output_dir.resolve()
split_ids = load_manifest_ids(manifest_dir)
wanted_ids = {item_id for ids in split_ids.values() for item_id in ids}
selected: dict[str, dict] = {}
duplicate_ids: set[str] = set()
for row in _iter_dataset_rows(dataset):
key = str(row.get("key", ""))
if key not in wanted_ids:
continue
if key in selected:
duplicate_ids.add(key)
continue
selected[key] = _normalize_row(row)
if duplicate_ids:
preview = ", ".join(sorted(duplicate_ids)[:5])
raise ValueError(f"SearchQA source dataset contains duplicate manifest IDs. First IDs: {preview}")
missing = sorted(wanted_ids - selected.keys())
if missing:
preview = ", ".join(missing[:5])
raise RuntimeError(f"SearchQA source dataset is missing {len(missing)} manifest IDs. First IDs: {preview}")
counts = {}
for split, ids in split_ids.items():
items = [selected[item_id] for item_id in ids]
split_dir = output_dir / split
split_dir.mkdir(parents=True, exist_ok=True)
with (split_dir / "items.json").open("w", encoding="utf-8") as file:
json.dump(items, file, ensure_ascii=False, indent=2)
counts[split] = len(items)
manifest = {
"source_manifest_dir": str(manifest_dir),
"source_dataset": dataset_name,
"counts": counts,
"item_fields": ["id", *REQUIRED_FIELDS],
}
with (output_dir / "split_manifest.json").open("w", encoding="utf-8") as file:
json.dump(manifest, file, ensure_ascii=False, indent=2)
return counts
def main() -> None:
args = parse_args()
try:
from datasets import load_dataset
except ImportError as exc:
raise SystemExit(
"Missing dependency 'datasets'. Install it with:\n"
" python -m pip install 'skillopt[searchqa]'\n"
"or:\n"
" python -m pip install datasets"
) from exc
print(f"Loading {args.dataset}...")
dataset = load_dataset(args.dataset)
counts = materialize_searchqa_splits(
args.manifest_dir,
args.output_dir,
dataset,
dataset_name=args.dataset,
)
print(f"Wrote SearchQA splits to {args.output_dir.resolve()}: {counts}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,66 @@
import json
from pathlib import Path
import pytest
from scripts.materialize_searchqa import materialize_searchqa_splits
def _write_manifest(root: Path, split_ids: dict[str, list[str]]) -> None:
for split, ids in split_ids.items():
split_dir = root / split
split_dir.mkdir(parents=True)
(split_dir / "items.json").write_text(
json.dumps([{"id": item_id} for item_id in ids]),
encoding="utf-8",
)
def _row(key: str) -> dict:
return {
"key": key,
"question": f"question {key}",
"context": f"context {key}",
"answers": [f"answer {key}"],
"ignored": "not written",
}
def test_materialize_searchqa_splits_preserves_manifest_order(tmp_path):
manifest_dir = tmp_path / "manifest"
output_dir = tmp_path / "out"
_write_manifest(manifest_dir, {"train": ["b", "a"], "val": ["c"], "test": ["d"]})
counts = materialize_searchqa_splits(
manifest_dir,
output_dir,
{"train": [_row("a"), _row("b")], "validation": [_row("c"), _row("d")]},
dataset_name="example/searchqa",
)
assert counts == {"train": 2, "val": 1, "test": 1}
train_items = json.loads((output_dir / "train" / "items.json").read_text(encoding="utf-8"))
assert [item["id"] for item in train_items] == ["b", "a"]
assert train_items[0] == {
"id": "b",
"question": "question b",
"context": "context b",
"answers": ["answer b"],
}
split_manifest = json.loads((output_dir / "split_manifest.json").read_text(encoding="utf-8"))
assert split_manifest["source_dataset"] == "example/searchqa"
assert split_manifest["counts"] == counts
def test_materialize_searchqa_splits_fails_on_missing_manifest_id(tmp_path):
manifest_dir = tmp_path / "manifest"
_write_manifest(manifest_dir, {"train": ["a"], "val": ["missing"], "test": []})
with pytest.raises(RuntimeError, match="missing"):
materialize_searchqa_splits(
manifest_dir,
tmp_path / "out",
{"train": [_row("a")]},
dataset_name="example/searchqa",
)