diff --git a/data/README.md b/data/README.md index 8cb5fd7..a31c337 100644 --- a/data/README.md +++ b/data/README.md @@ -138,6 +138,20 @@ ALFWorld: `searchqa_id_split/` is an ID-only manifest. Each released `id` exactly matches the `key` field in `lucadiliello/searchqa`. +To materialize the runnable SearchQA split used by +`configs/searchqa/default.yaml`, install the optional dependency and run: + +```bash +python -m pip install 'skillopt[searchqa]' +python scripts/materialize_searchqa.py +``` + +This writes full examples to: + +```text +data/searchqa_split +``` + Materialized examples must include the fields consumed by the SearchQA environment, including: diff --git a/pyproject.toml b/pyproject.toml index e6a9021..48da25c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,8 @@ alfworld = ["alfworld>=0.4.0", "gymnasium>=0.29.0"] claude = ["claude-agent-sdk>=0.1.0"] # Qwen local model backend (via vLLM) qwen = ["vllm>=0.4.0"] +# SearchQA data materialization +searchqa = ["datasets>=2.18.0"] # Documentation site docs = ["mkdocs-material>=9.5.0", "mkdocstrings[python]>=0.24.0"] # WebUI dashboard diff --git a/scripts/materialize_searchqa.py b/scripts/materialize_searchqa.py new file mode 100644 index 0000000..30838ac --- /dev/null +++ b/scripts/materialize_searchqa.py @@ -0,0 +1,148 @@ +"""Materialize runnable SearchQA splits from the released ID manifest.""" + +from __future__ import annotations + +import argparse +import json +from collections.abc import Iterable, Mapping +from pathlib import Path + +PROJECT_ROOT = Path(__file__).resolve().parent.parent +SPLITS = ("train", "val", "test") +REQUIRED_FIELDS = ("question", "context", "answers") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--manifest-dir", + type=Path, + default=PROJECT_ROOT / "data" / "searchqa_id_split", + help="Directory containing train/val/test ID manifests.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=PROJECT_ROOT / "data" / "searchqa_split", + help="Directory to write runnable train/val/test splits.", + ) + parser.add_argument( + "--dataset", + default="lucadiliello/searchqa", + help="Hugging Face dataset repository to load.", + ) + return parser.parse_args() + + +def load_manifest_ids(manifest_dir: Path) -> dict[str, list[str]]: + split_ids = {} + for split in SPLITS: + path = manifest_dir / split / "items.json" + with path.open(encoding="utf-8") as file: + items = json.load(file) + split_ids[split] = [str(item["id"]) for item in items] + return split_ids + + +def _iter_dataset_rows(dataset: Mapping[str, Iterable[dict]]) -> Iterable[dict]: + for source_split in dataset.values(): + yield from source_split + + +def _normalize_row(row: dict) -> dict: + try: + key = str(row["key"]) + except KeyError as exc: + raise ValueError("SearchQA source row is missing required field: key") from exc + + missing = [field for field in REQUIRED_FIELDS if field not in row] + if missing: + raise ValueError(f"SearchQA source row {key!r} is missing required fields: {', '.join(missing)}") + + return { + "id": key, + "question": row["question"], + "context": row["context"], + "answers": row["answers"], + } + + +def materialize_searchqa_splits( + manifest_dir: Path, + output_dir: Path, + dataset: Mapping[str, Iterable[dict]], + *, + dataset_name: str, +) -> dict[str, int]: + """Write runnable SearchQA train/val/test splits from a source dataset.""" + manifest_dir = manifest_dir.resolve() + output_dir = output_dir.resolve() + split_ids = load_manifest_ids(manifest_dir) + wanted_ids = {item_id for ids in split_ids.values() for item_id in ids} + + selected: dict[str, dict] = {} + duplicate_ids: set[str] = set() + for row in _iter_dataset_rows(dataset): + key = str(row.get("key", "")) + if key not in wanted_ids: + continue + if key in selected: + duplicate_ids.add(key) + continue + selected[key] = _normalize_row(row) + + if duplicate_ids: + preview = ", ".join(sorted(duplicate_ids)[:5]) + raise ValueError(f"SearchQA source dataset contains duplicate manifest IDs. First IDs: {preview}") + + missing = sorted(wanted_ids - selected.keys()) + if missing: + preview = ", ".join(missing[:5]) + raise RuntimeError(f"SearchQA source dataset is missing {len(missing)} manifest IDs. First IDs: {preview}") + + counts = {} + for split, ids in split_ids.items(): + items = [selected[item_id] for item_id in ids] + split_dir = output_dir / split + split_dir.mkdir(parents=True, exist_ok=True) + with (split_dir / "items.json").open("w", encoding="utf-8") as file: + json.dump(items, file, ensure_ascii=False, indent=2) + counts[split] = len(items) + + manifest = { + "source_manifest_dir": str(manifest_dir), + "source_dataset": dataset_name, + "counts": counts, + "item_fields": ["id", *REQUIRED_FIELDS], + } + with (output_dir / "split_manifest.json").open("w", encoding="utf-8") as file: + json.dump(manifest, file, ensure_ascii=False, indent=2) + + return counts + + +def main() -> None: + args = parse_args() + try: + from datasets import load_dataset + except ImportError as exc: + raise SystemExit( + "Missing dependency 'datasets'. Install it with:\n" + " python -m pip install 'skillopt[searchqa]'\n" + "or:\n" + " python -m pip install datasets" + ) from exc + + print(f"Loading {args.dataset}...") + dataset = load_dataset(args.dataset) + counts = materialize_searchqa_splits( + args.manifest_dir, + args.output_dir, + dataset, + dataset_name=args.dataset, + ) + print(f"Wrote SearchQA splits to {args.output_dir.resolve()}: {counts}") + + +if __name__ == "__main__": + main() diff --git a/tests/test_materialize_searchqa.py b/tests/test_materialize_searchqa.py new file mode 100644 index 0000000..bbfb2a8 --- /dev/null +++ b/tests/test_materialize_searchqa.py @@ -0,0 +1,66 @@ +import json +from pathlib import Path + +import pytest + +from scripts.materialize_searchqa import materialize_searchqa_splits + + +def _write_manifest(root: Path, split_ids: dict[str, list[str]]) -> None: + for split, ids in split_ids.items(): + split_dir = root / split + split_dir.mkdir(parents=True) + (split_dir / "items.json").write_text( + json.dumps([{"id": item_id} for item_id in ids]), + encoding="utf-8", + ) + + +def _row(key: str) -> dict: + return { + "key": key, + "question": f"question {key}", + "context": f"context {key}", + "answers": [f"answer {key}"], + "ignored": "not written", + } + + +def test_materialize_searchqa_splits_preserves_manifest_order(tmp_path): + manifest_dir = tmp_path / "manifest" + output_dir = tmp_path / "out" + _write_manifest(manifest_dir, {"train": ["b", "a"], "val": ["c"], "test": ["d"]}) + + counts = materialize_searchqa_splits( + manifest_dir, + output_dir, + {"train": [_row("a"), _row("b")], "validation": [_row("c"), _row("d")]}, + dataset_name="example/searchqa", + ) + + assert counts == {"train": 2, "val": 1, "test": 1} + train_items = json.loads((output_dir / "train" / "items.json").read_text(encoding="utf-8")) + assert [item["id"] for item in train_items] == ["b", "a"] + assert train_items[0] == { + "id": "b", + "question": "question b", + "context": "context b", + "answers": ["answer b"], + } + + split_manifest = json.loads((output_dir / "split_manifest.json").read_text(encoding="utf-8")) + assert split_manifest["source_dataset"] == "example/searchqa" + assert split_manifest["counts"] == counts + + +def test_materialize_searchqa_splits_fails_on_missing_manifest_id(tmp_path): + manifest_dir = tmp_path / "manifest" + _write_manifest(manifest_dir, {"train": ["a"], "val": ["missing"], "test": []}) + + with pytest.raises(RuntimeError, match="missing"): + materialize_searchqa_splits( + manifest_dir, + tmp_path / "out", + {"train": [_row("a")]}, + dataset_name="example/searchqa", + )