mirror of
https://github.com/microsoft/SkillOpt.git
synced 2026-07-03 14:02:58 +08:00
Fix ALFWorld gamefile paths relative to ALFWORLD_DATA
This commit is contained in:
@@ -7,12 +7,10 @@ Provides:
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import concurrent.futures
|
||||
import numpy as np
|
||||
|
||||
from skillopt.model import chat_target
|
||||
|
||||
@@ -65,6 +63,25 @@ def _append_diagnostic_instruction(prompt: str, diagnostic_instruction: str) ->
|
||||
return f"{prompt}\n\n## Training Readout\n{diagnostic_instruction.strip()}\n"
|
||||
|
||||
|
||||
def _resolve_alfworld_gamefile(gamefile: str) -> str:
|
||||
path = os.path.expanduser(os.path.expandvars(str(gamefile)))
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
|
||||
data_root = os.environ.get("ALFWORLD_DATA", "").strip()
|
||||
if not data_root:
|
||||
return path
|
||||
|
||||
root = os.path.expanduser(os.path.expandvars(data_root))
|
||||
return os.path.abspath(os.path.join(root, path))
|
||||
|
||||
|
||||
def _resolve_alfworld_gamefiles(gamefiles: list[str] | None) -> list[str] | None:
|
||||
if gamefiles is None:
|
||||
return None
|
||||
return [_resolve_alfworld_gamefile(gamefile) for gamefile in gamefiles]
|
||||
|
||||
|
||||
# ── Environment builder ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -86,9 +103,10 @@ def build_alfworld_env(
|
||||
Returns:
|
||||
env_manager: AlfWorldEnvironmentManager instance
|
||||
"""
|
||||
from omegaconf import OmegaConf
|
||||
from functools import partial
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from skillopt.envs.alfworld.vendor.alfworld_envs import build_alfworld_envs
|
||||
from skillopt.envs.alfworld.vendor.alfworld_projection import alfworld_projection
|
||||
from skillopt.envs.alfworld.vendor.env_manager import AlfWorldEnvironmentManager
|
||||
@@ -97,6 +115,7 @@ def build_alfworld_env(
|
||||
|
||||
alf_config_path = os.path.join(HERE, "vendor", "config_tw.yaml")
|
||||
env_kwargs = {"eval_dataset": eval_dataset}
|
||||
resolved_gamefiles = _resolve_alfworld_gamefiles(specific_gamefiles)
|
||||
|
||||
envs = build_alfworld_envs(
|
||||
alf_config_path,
|
||||
@@ -106,7 +125,7 @@ def build_alfworld_env(
|
||||
is_train=is_train,
|
||||
env_kwargs=env_kwargs,
|
||||
resources_per_worker=None,
|
||||
gamefiles=specific_gamefiles,
|
||||
gamefiles=resolved_gamefiles,
|
||||
)
|
||||
|
||||
config = OmegaConf.create(
|
||||
@@ -222,7 +241,7 @@ def run_alfworld_batch(
|
||||
if _extract_action(response) is None:
|
||||
return idx, "<think>missing action tag</think><action>look</action>"
|
||||
return idx, response
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return idx, "<think>error</think><action>look</action>"
|
||||
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_api_workers)
|
||||
|
||||
Reference in New Issue
Block a user