Files
microsoft-SkillOpt/scripts/materialize_searchqa.py
2026-06-16 09:26:56 +08:00

149 lines
4.6 KiB
Python

"""Materialize runnable SearchQA splits from the released ID manifest."""
from __future__ import annotations
import argparse
import json
from collections.abc import Iterable, Mapping
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parent.parent
SPLITS = ("train", "val", "test")
REQUIRED_FIELDS = ("question", "context", "answers")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--manifest-dir",
type=Path,
default=PROJECT_ROOT / "data" / "searchqa_id_split",
help="Directory containing train/val/test ID manifests.",
)
parser.add_argument(
"--output-dir",
type=Path,
default=PROJECT_ROOT / "data" / "searchqa_split",
help="Directory to write runnable train/val/test splits.",
)
parser.add_argument(
"--dataset",
default="lucadiliello/searchqa",
help="Hugging Face dataset repository to load.",
)
return parser.parse_args()
def load_manifest_ids(manifest_dir: Path) -> dict[str, list[str]]:
split_ids = {}
for split in SPLITS:
path = manifest_dir / split / "items.json"
with path.open(encoding="utf-8") as file:
items = json.load(file)
split_ids[split] = [str(item["id"]) for item in items]
return split_ids
def _iter_dataset_rows(dataset: Mapping[str, Iterable[dict]]) -> Iterable[dict]:
for source_split in dataset.values():
yield from source_split
def _normalize_row(row: dict) -> dict:
try:
key = str(row["key"])
except KeyError as exc:
raise ValueError("SearchQA source row is missing required field: key") from exc
missing = [field for field in REQUIRED_FIELDS if field not in row]
if missing:
raise ValueError(f"SearchQA source row {key!r} is missing required fields: {', '.join(missing)}")
return {
"id": key,
"question": row["question"],
"context": row["context"],
"answers": row["answers"],
}
def materialize_searchqa_splits(
manifest_dir: Path,
output_dir: Path,
dataset: Mapping[str, Iterable[dict]],
*,
dataset_name: str,
) -> dict[str, int]:
"""Write runnable SearchQA train/val/test splits from a source dataset."""
manifest_dir = manifest_dir.resolve()
output_dir = output_dir.resolve()
split_ids = load_manifest_ids(manifest_dir)
wanted_ids = {item_id for ids in split_ids.values() for item_id in ids}
selected: dict[str, dict] = {}
duplicate_ids: set[str] = set()
for row in _iter_dataset_rows(dataset):
key = str(row.get("key", ""))
if key not in wanted_ids:
continue
if key in selected:
duplicate_ids.add(key)
continue
selected[key] = _normalize_row(row)
if duplicate_ids:
preview = ", ".join(sorted(duplicate_ids)[:5])
raise ValueError(f"SearchQA source dataset contains duplicate manifest IDs. First IDs: {preview}")
missing = sorted(wanted_ids - selected.keys())
if missing:
preview = ", ".join(missing[:5])
raise RuntimeError(f"SearchQA source dataset is missing {len(missing)} manifest IDs. First IDs: {preview}")
counts = {}
for split, ids in split_ids.items():
items = [selected[item_id] for item_id in ids]
split_dir = output_dir / split
split_dir.mkdir(parents=True, exist_ok=True)
with (split_dir / "items.json").open("w", encoding="utf-8") as file:
json.dump(items, file, ensure_ascii=False, indent=2)
counts[split] = len(items)
manifest = {
"source_manifest_dir": str(manifest_dir),
"source_dataset": dataset_name,
"counts": counts,
"item_fields": ["id", *REQUIRED_FIELDS],
}
with (output_dir / "split_manifest.json").open("w", encoding="utf-8") as file:
json.dump(manifest, file, ensure_ascii=False, indent=2)
return counts
def main() -> None:
args = parse_args()
try:
from datasets import load_dataset
except ImportError as exc:
raise SystemExit(
"Missing dependency 'datasets'. Install it with:\n"
" python -m pip install 'skillopt[searchqa]'\n"
"or:\n"
" python -m pip install datasets"
) from exc
print(f"Loading {args.dataset}...")
dataset = load_dataset(args.dataset)
counts = materialize_searchqa_splits(
args.manifest_dir,
args.output_dir,
dataset,
dataset_name=args.dataset,
)
print(f"Wrote SearchQA splits to {args.output_dir.resolve()}: {counts}")
if __name__ == "__main__":
main()