diff --git a/skillopt/envs/alfworld/rollout.py b/skillopt/envs/alfworld/rollout.py
index 8c3b4ac..18264c3 100644
--- a/skillopt/envs/alfworld/rollout.py
+++ b/skillopt/envs/alfworld/rollout.py
@@ -7,12 +7,10 @@ Provides:
"""
from __future__ import annotations
+import concurrent.futures
import json
import os
import re
-import sys
-import concurrent.futures
-import numpy as np
from skillopt.model import chat_target
@@ -65,6 +63,25 @@ def _append_diagnostic_instruction(prompt: str, diagnostic_instruction: str) ->
return f"{prompt}\n\n## Training Readout\n{diagnostic_instruction.strip()}\n"
+def _resolve_alfworld_gamefile(gamefile: str) -> str:
+ path = os.path.expanduser(os.path.expandvars(str(gamefile)))
+ if os.path.isabs(path):
+ return path
+
+ data_root = os.environ.get("ALFWORLD_DATA", "").strip()
+ if not data_root:
+ return path
+
+ root = os.path.expanduser(os.path.expandvars(data_root))
+ return os.path.abspath(os.path.join(root, path))
+
+
+def _resolve_alfworld_gamefiles(gamefiles: list[str] | None) -> list[str] | None:
+ if gamefiles is None:
+ return None
+ return [_resolve_alfworld_gamefile(gamefile) for gamefile in gamefiles]
+
+
# ── Environment builder ──────────────────────────────────────────────────────
@@ -86,9 +103,10 @@ def build_alfworld_env(
Returns:
env_manager: AlfWorldEnvironmentManager instance
"""
- from omegaconf import OmegaConf
from functools import partial
+ from omegaconf import OmegaConf
+
from skillopt.envs.alfworld.vendor.alfworld_envs import build_alfworld_envs
from skillopt.envs.alfworld.vendor.alfworld_projection import alfworld_projection
from skillopt.envs.alfworld.vendor.env_manager import AlfWorldEnvironmentManager
@@ -97,6 +115,7 @@ def build_alfworld_env(
alf_config_path = os.path.join(HERE, "vendor", "config_tw.yaml")
env_kwargs = {"eval_dataset": eval_dataset}
+ resolved_gamefiles = _resolve_alfworld_gamefiles(specific_gamefiles)
envs = build_alfworld_envs(
alf_config_path,
@@ -106,7 +125,7 @@ def build_alfworld_env(
is_train=is_train,
env_kwargs=env_kwargs,
resources_per_worker=None,
- gamefiles=specific_gamefiles,
+ gamefiles=resolved_gamefiles,
)
config = OmegaConf.create(
@@ -222,7 +241,7 @@ def run_alfworld_batch(
if _extract_action(response) is None:
return idx, "missing action taglook"
return idx, response
- except Exception as e:
+ except Exception:
return idx, "errorlook"
executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_api_workers)
diff --git a/tests/test_alfworld_paths.py b/tests/test_alfworld_paths.py
new file mode 100644
index 0000000..eb0229b
--- /dev/null
+++ b/tests/test_alfworld_paths.py
@@ -0,0 +1,33 @@
+import os
+
+from skillopt.envs.alfworld.rollout import _resolve_alfworld_gamefile, _resolve_alfworld_gamefiles
+
+
+def test_resolve_alfworld_gamefile_uses_alfworld_data_for_relative_paths(monkeypatch, tmp_path):
+ data_root = tmp_path / "alfworld_data"
+ monkeypatch.setenv("ALFWORLD_DATA", str(data_root))
+
+ resolved = _resolve_alfworld_gamefile("json_2.1.1/valid_seen/task/game.tw-pddl")
+
+ assert resolved == os.path.join(str(data_root), "json_2.1.1/valid_seen/task/game.tw-pddl")
+
+
+def test_resolve_alfworld_gamefile_keeps_absolute_paths(monkeypatch, tmp_path):
+ monkeypatch.setenv("ALFWORLD_DATA", str(tmp_path / "alfworld_data"))
+ absolute = tmp_path / "elsewhere" / "game.tw-pddl"
+
+ assert _resolve_alfworld_gamefile(str(absolute)) == str(absolute)
+
+
+def test_resolve_alfworld_gamefile_keeps_relative_path_without_alfworld_data(monkeypatch):
+ monkeypatch.delenv("ALFWORLD_DATA", raising=False)
+
+ assert _resolve_alfworld_gamefile("json_2.1.1/train/task/game.tw-pddl") == (
+ "json_2.1.1/train/task/game.tw-pddl"
+ )
+
+
+def test_resolve_alfworld_gamefiles_handles_none(monkeypatch):
+ monkeypatch.setenv("ALFWORLD_DATA", "/tmp/alfworld_data")
+
+ assert _resolve_alfworld_gamefiles(None) is None