diff --git a/skillopt/envs/alfworld/rollout.py b/skillopt/envs/alfworld/rollout.py index 8c3b4ac..18264c3 100644 --- a/skillopt/envs/alfworld/rollout.py +++ b/skillopt/envs/alfworld/rollout.py @@ -7,12 +7,10 @@ Provides: """ from __future__ import annotations +import concurrent.futures import json import os import re -import sys -import concurrent.futures -import numpy as np from skillopt.model import chat_target @@ -65,6 +63,25 @@ def _append_diagnostic_instruction(prompt: str, diagnostic_instruction: str) -> return f"{prompt}\n\n## Training Readout\n{diagnostic_instruction.strip()}\n" +def _resolve_alfworld_gamefile(gamefile: str) -> str: + path = os.path.expanduser(os.path.expandvars(str(gamefile))) + if os.path.isabs(path): + return path + + data_root = os.environ.get("ALFWORLD_DATA", "").strip() + if not data_root: + return path + + root = os.path.expanduser(os.path.expandvars(data_root)) + return os.path.abspath(os.path.join(root, path)) + + +def _resolve_alfworld_gamefiles(gamefiles: list[str] | None) -> list[str] | None: + if gamefiles is None: + return None + return [_resolve_alfworld_gamefile(gamefile) for gamefile in gamefiles] + + # ── Environment builder ────────────────────────────────────────────────────── @@ -86,9 +103,10 @@ def build_alfworld_env( Returns: env_manager: AlfWorldEnvironmentManager instance """ - from omegaconf import OmegaConf from functools import partial + from omegaconf import OmegaConf + from skillopt.envs.alfworld.vendor.alfworld_envs import build_alfworld_envs from skillopt.envs.alfworld.vendor.alfworld_projection import alfworld_projection from skillopt.envs.alfworld.vendor.env_manager import AlfWorldEnvironmentManager @@ -97,6 +115,7 @@ def build_alfworld_env( alf_config_path = os.path.join(HERE, "vendor", "config_tw.yaml") env_kwargs = {"eval_dataset": eval_dataset} + resolved_gamefiles = _resolve_alfworld_gamefiles(specific_gamefiles) envs = build_alfworld_envs( alf_config_path, @@ -106,7 +125,7 @@ def build_alfworld_env( is_train=is_train, env_kwargs=env_kwargs, resources_per_worker=None, - gamefiles=specific_gamefiles, + gamefiles=resolved_gamefiles, ) config = OmegaConf.create( @@ -222,7 +241,7 @@ def run_alfworld_batch( if _extract_action(response) is None: return idx, "missing action taglook" return idx, response - except Exception as e: + except Exception: return idx, "errorlook" executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_api_workers) diff --git a/tests/test_alfworld_paths.py b/tests/test_alfworld_paths.py new file mode 100644 index 0000000..eb0229b --- /dev/null +++ b/tests/test_alfworld_paths.py @@ -0,0 +1,33 @@ +import os + +from skillopt.envs.alfworld.rollout import _resolve_alfworld_gamefile, _resolve_alfworld_gamefiles + + +def test_resolve_alfworld_gamefile_uses_alfworld_data_for_relative_paths(monkeypatch, tmp_path): + data_root = tmp_path / "alfworld_data" + monkeypatch.setenv("ALFWORLD_DATA", str(data_root)) + + resolved = _resolve_alfworld_gamefile("json_2.1.1/valid_seen/task/game.tw-pddl") + + assert resolved == os.path.join(str(data_root), "json_2.1.1/valid_seen/task/game.tw-pddl") + + +def test_resolve_alfworld_gamefile_keeps_absolute_paths(monkeypatch, tmp_path): + monkeypatch.setenv("ALFWORLD_DATA", str(tmp_path / "alfworld_data")) + absolute = tmp_path / "elsewhere" / "game.tw-pddl" + + assert _resolve_alfworld_gamefile(str(absolute)) == str(absolute) + + +def test_resolve_alfworld_gamefile_keeps_relative_path_without_alfworld_data(monkeypatch): + monkeypatch.delenv("ALFWORLD_DATA", raising=False) + + assert _resolve_alfworld_gamefile("json_2.1.1/train/task/game.tw-pddl") == ( + "json_2.1.1/train/task/game.tw-pddl" + ) + + +def test_resolve_alfworld_gamefiles_handles_none(monkeypatch): + monkeypatch.setenv("ALFWORLD_DATA", "/tmp/alfworld_data") + + assert _resolve_alfworld_gamefiles(None) is None