Make detect_archive_format/safe_extract_tarball public; add workflow add archive CLI tests

Agent-Logs-Url: https://github.com/github/spec-kit/sessions/845e41d1-75e3-49fb-a580-a7fb805dd716

Co-authored-by: mnriem <15701806+mnriem@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2026-05-06 21:50:25 +00:00
committed by GitHub
parent e0495ebc38
commit 0a02369ebe
5 changed files with 267 additions and 29 deletions

View File

@@ -2629,7 +2629,7 @@ def preset_add(
import urllib.request
import urllib.error
import tempfile
from .extensions import _detect_archive_format as _det_fmt
from .extensions import detect_archive_format as _det_fmt
with tempfile.TemporaryDirectory() as tmpdir:
archive_fmt = _det_fmt(from_url)
@@ -3628,7 +3628,7 @@ def extension_add(
import urllib.request
import urllib.error
from urllib.parse import urlparse
from .extensions import _detect_archive_format
from .extensions import detect_archive_format
# Validate URL
parsed = urlparse(from_url)
@@ -3647,14 +3647,14 @@ def extension_add(
# Download archive to temp location; detect format from URL or Content-Type.
download_dir = project_root / ".specify" / "extensions" / ".cache" / "downloads"
download_dir.mkdir(parents=True, exist_ok=True)
archive_fmt = _detect_archive_format(from_url)
archive_fmt = detect_archive_format(from_url)
archive_path = None
try:
with urllib.request.urlopen(from_url, timeout=60) as response:
if not archive_fmt:
content_type = response.headers.get("Content-Type", "")
archive_fmt = _detect_archive_format(from_url, content_type)
archive_fmt = detect_archive_format(from_url, content_type)
archive_data = response.read()
if not archive_fmt:
@@ -4331,9 +4331,9 @@ def extension_update(
try:
# 6. Validate extension ID from archive BEFORE modifying installation
# Handle both root-level and nested extension.yml (GitHub auto-generated archives)
from .extensions import _detect_archive_format
from .extensions import detect_archive_format
import tarfile
archive_fmt = _detect_archive_format(str(archive_path))
archive_fmt = detect_archive_format(str(archive_path))
import yaml
manifest_data = None
@@ -5029,7 +5029,7 @@ def workflow_add(
from ipaddress import ip_address
from urllib.parse import urlparse
from urllib.request import urlopen # noqa: S310
from .extensions import _detect_archive_format
from .extensions import detect_archive_format
parsed_src = urlparse(source)
src_host = parsed_src.hostname or ""
@@ -5062,10 +5062,10 @@ def workflow_add(
raise typer.Exit(1)
# Detect archive format from the final URL or Content-Type header.
archive_fmt = _detect_archive_format(final_url)
archive_fmt = detect_archive_format(final_url)
if not archive_fmt:
content_type = resp.headers.get("Content-Type", "")
archive_fmt = _detect_archive_format(final_url, content_type)
archive_fmt = detect_archive_format(final_url, content_type)
raw_data = resp.read()
except typer.Exit:
@@ -5119,8 +5119,8 @@ def workflow_add(
source.lower().endswith(".tar.gz") or source.lower().endswith(".tgz") or source.lower().endswith(".zip")
):
# Local archive file containing workflow.yml
from .extensions import _detect_archive_format
local_fmt = _detect_archive_format(source)
from .extensions import detect_archive_format
local_fmt = detect_archive_format(source)
try:
wf_yaml = _extract_workflow_yml(source_path, local_fmt)
except Exception as exc:
@@ -5199,7 +5199,7 @@ def workflow_add(
try:
from urllib.request import urlopen # noqa: S310 — URL comes from catalog
from .extensions import _detect_archive_format
from .extensions import detect_archive_format
workflow_dir.mkdir(parents=True, exist_ok=True)
with urlopen(workflow_url, timeout=30) as response: # noqa: S310
@@ -5224,10 +5224,10 @@ def workflow_add(
raise typer.Exit(1)
# Detect archive format from the final URL or Content-Type header.
cat_archive_fmt = _detect_archive_format(final_url)
cat_archive_fmt = detect_archive_format(final_url)
if not cat_archive_fmt:
cat_ct = response.headers.get("Content-Type", "")
cat_archive_fmt = _detect_archive_format(final_url, cat_ct)
cat_archive_fmt = detect_archive_format(final_url, cat_ct)
raw_response = response.read()

View File

@@ -108,7 +108,7 @@ def normalize_priority(value: Any, default: int = 10) -> int:
return priority if priority >= 1 else default
def _detect_archive_format(url: str, content_type: str = "") -> str:
def detect_archive_format(url: str, content_type: str = "") -> str:
"""Detect archive format from URL path extension or Content-Type header.
Args:
@@ -143,7 +143,7 @@ def _detect_archive_format(url: str, content_type: str = "") -> str:
return ""
def _safe_extract_tarball(
def safe_extract_tarball(
archive_path: Path,
dest_dir: Path,
error_class: "type[Exception]" = Exception,
@@ -1340,11 +1340,11 @@ class ExtensionManager:
with tempfile.TemporaryDirectory() as tmpdir:
temp_path = Path(tmpdir)
archive_fmt = _detect_archive_format(str(zip_path))
archive_fmt = detect_archive_format(str(zip_path))
if archive_fmt == "tar.gz":
# Extract tarball safely (prevent tar slip attack)
_safe_extract_tarball(zip_path, temp_path, ValidationError)
safe_extract_tarball(zip_path, temp_path, ValidationError)
else:
# Extract ZIP safely (prevent Zip Slip attack)
with zipfile.ZipFile(zip_path, 'r') as zf:
@@ -2140,14 +2140,14 @@ class ExtensionCatalog:
version = ext_info.get("version", "unknown")
# Detect archive format from URL; resolve via Content-Type when needed.
archive_fmt = _detect_archive_format(download_url)
archive_fmt = detect_archive_format(download_url)
# Download the archive
try:
with self._open_url(download_url, timeout=60) as response:
if not archive_fmt:
content_type = response.headers.get("Content-Type", "")
archive_fmt = _detect_archive_format(download_url, content_type)
archive_fmt = detect_archive_format(download_url, content_type)
archive_data = response.read()
except urllib.error.URLError as e:

View File

@@ -27,7 +27,7 @@ import yaml
from packaging import version as pkg_version
from packaging.specifiers import SpecifierSet, InvalidSpecifier
from .extensions import ExtensionRegistry, normalize_priority, _detect_archive_format, _safe_extract_tarball
from .extensions import ExtensionRegistry, normalize_priority, detect_archive_format, safe_extract_tarball
def _substitute_core_template(
@@ -1626,11 +1626,11 @@ class PresetManager:
with tempfile.TemporaryDirectory() as tmpdir:
temp_path = Path(tmpdir)
archive_fmt = _detect_archive_format(str(zip_path))
archive_fmt = detect_archive_format(str(zip_path))
if archive_fmt == "tar.gz":
# Extract tarball safely (prevent tar slip attack)
_safe_extract_tarball(zip_path, temp_path, PresetValidationError)
safe_extract_tarball(zip_path, temp_path, PresetValidationError)
else:
with zipfile.ZipFile(zip_path, 'r') as zf:
temp_path_resolved = temp_path.resolve()
@@ -2314,13 +2314,13 @@ class PresetCatalog:
version = pack_info.get("version", "unknown")
# Detect archive format from URL; resolve via Content-Type when needed.
archive_fmt = _detect_archive_format(download_url)
archive_fmt = detect_archive_format(download_url)
try:
with self._open_url(download_url, timeout=60) as response:
if not archive_fmt:
content_type = response.headers.get("Content-Type", "")
archive_fmt = _detect_archive_format(download_url, content_type)
archive_fmt = detect_archive_format(download_url, content_type)
archive_data = response.read()
except urllib.error.URLError as e:

View File

@@ -178,14 +178,14 @@ class TestNormalizePriority:
assert normalize_priority("invalid", default=1) == 1
# ===== _detect_archive_format Tests =====
# ===== detect_archive_format Tests =====
class TestDetectArchiveFormat:
"""Test the _detect_archive_format helper."""
"""Test the detect_archive_format helper."""
def _fmt(self, url, ct=""):
from specify_cli.extensions import _detect_archive_format
return _detect_archive_format(url, ct)
from specify_cli.extensions import detect_archive_format
return detect_archive_format(url, ct)
def test_zip_url_extension(self):
assert self._fmt("https://example.com/ext-1.0.0.zip") == "zip"

View File

@@ -1843,3 +1843,241 @@ steps:
assert state.status == RunStatus.COMPLETED
assert "do-plan" in state.step_results
assert "do-specify" not in state.step_results
# ===== workflow add archive CLI tests =====
MINIMAL_WORKFLOW_YAML = """\
schema_version: "1.0"
workflow:
id: "arc-workflow"
name: "Archive Workflow"
version: "1.0.0"
description: "Installed from archive"
steps:
- id: step-one
type: shell
run: "echo hello"
"""
class TestWorkflowAddArchive:
"""CLI-level tests for `workflow add` with local archive files."""
@pytest.fixture
def project_dir(self, tmp_path):
"""Create a minimal spec-kit project."""
specify = tmp_path / ".specify"
specify.mkdir()
(specify / "workflows").mkdir()
return tmp_path
def _runner_and_app(self):
from typer.testing import CliRunner
from specify_cli import app
return CliRunner(), app
# -- Local ZIP archive --------------------------------------------------
def test_workflow_add_local_zip_flat(self, project_dir):
"""workflow add installs from a local ZIP with workflow.yml at root."""
import zipfile
runner, app = self._runner_and_app()
archive = project_dir / "workflow.zip"
with zipfile.ZipFile(archive, "w") as zf:
zf.writestr("workflow.yml", MINIMAL_WORKFLOW_YAML)
with __import__("unittest.mock", fromlist=["patch"]).patch.object(
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
):
result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=False)
assert result.exit_code == 0, result.output
assert "arc-workflow" in result.output
installed = project_dir / ".specify" / "workflows" / "arc-workflow" / "workflow.yml"
assert installed.exists()
def test_workflow_add_local_zip_nested(self, project_dir):
"""workflow add installs from a local ZIP with workflow.yml in a subdirectory."""
import zipfile
runner, app = self._runner_and_app()
archive = project_dir / "workflow.zip"
with zipfile.ZipFile(archive, "w") as zf:
zf.writestr("repo-1.0/workflow.yml", MINIMAL_WORKFLOW_YAML)
with __import__("unittest.mock", fromlist=["patch"]).patch.object(
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
):
result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=False)
assert result.exit_code == 0, result.output
assert "arc-workflow" in result.output
def test_workflow_add_local_zip_missing_workflow_yml(self, project_dir):
"""workflow add exits with an error when the ZIP has no workflow.yml."""
import zipfile
runner, app = self._runner_and_app()
archive = project_dir / "empty.zip"
with zipfile.ZipFile(archive, "w") as zf:
zf.writestr("README.md", "nothing here")
with __import__("unittest.mock", fromlist=["patch"]).patch.object(
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
):
result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=True)
assert result.exit_code != 0
assert "extract" in result.output.lower() or "workflow" in result.output.lower()
# -- Local tar.gz archive -----------------------------------------------
def test_workflow_add_local_tar_gz_flat(self, project_dir):
"""workflow add installs from a local .tar.gz with workflow.yml at root."""
import tarfile, io
runner, app = self._runner_and_app()
archive = project_dir / "workflow.tar.gz"
with tarfile.open(archive, "w:gz") as tf:
data = MINIMAL_WORKFLOW_YAML.encode()
info = tarfile.TarInfo(name="workflow.yml")
info.size = len(data)
tf.addfile(info, io.BytesIO(data))
with __import__("unittest.mock", fromlist=["patch"]).patch.object(
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
):
result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=False)
assert result.exit_code == 0, result.output
assert "arc-workflow" in result.output
installed = project_dir / ".specify" / "workflows" / "arc-workflow" / "workflow.yml"
assert installed.exists()
def test_workflow_add_local_tar_gz_nested(self, project_dir):
"""workflow add installs from a local .tar.gz with workflow.yml in a subdirectory."""
import tarfile, io
runner, app = self._runner_and_app()
archive = project_dir / "workflow.tar.gz"
with tarfile.open(archive, "w:gz") as tf:
data = MINIMAL_WORKFLOW_YAML.encode()
info = tarfile.TarInfo(name="repo-1.0/workflow.yml")
info.size = len(data)
tf.addfile(info, io.BytesIO(data))
with __import__("unittest.mock", fromlist=["patch"]).patch.object(
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
):
result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=False)
assert result.exit_code == 0, result.output
assert "arc-workflow" in result.output
def test_workflow_add_local_tgz_flat(self, project_dir):
"""workflow add recognises the .tgz extension as a gzipped tarball."""
import tarfile, io
runner, app = self._runner_and_app()
archive = project_dir / "workflow.tgz"
with tarfile.open(archive, "w:gz") as tf:
data = MINIMAL_WORKFLOW_YAML.encode()
info = tarfile.TarInfo(name="workflow.yml")
info.size = len(data)
tf.addfile(info, io.BytesIO(data))
with __import__("unittest.mock", fromlist=["patch"]).patch.object(
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
):
result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=False)
assert result.exit_code == 0, result.output
assert "arc-workflow" in result.output
def test_workflow_add_local_tar_gz_missing_workflow_yml(self, project_dir):
"""workflow add exits with an error when the .tar.gz has no workflow.yml."""
import tarfile, io
runner, app = self._runner_and_app()
archive = project_dir / "empty.tar.gz"
with tarfile.open(archive, "w:gz") as tf:
data = b"nothing"
info = tarfile.TarInfo(name="README.md")
info.size = len(data)
tf.addfile(info, io.BytesIO(data))
with __import__("unittest.mock", fromlist=["patch"]).patch.object(
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
):
result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=True)
assert result.exit_code != 0
assert "extract" in result.output.lower() or "workflow" in result.output.lower()
# -- URL archive download -----------------------------------------------
def test_workflow_add_url_tar_gz(self, project_dir):
"""workflow add downloads a .tar.gz from a URL and installs the workflow."""
import tarfile, io
from unittest.mock import patch, MagicMock
runner, app = self._runner_and_app()
# Build an in-memory tar.gz archive containing workflow.yml.
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
data = MINIMAL_WORKFLOW_YAML.encode()
info = tarfile.TarInfo(name="workflow.yml")
info.size = len(data)
tf.addfile(info, io.BytesIO(data))
raw_bytes = buf.getvalue()
mock_resp = MagicMock()
mock_resp.geturl.return_value = "https://example.com/workflow.tar.gz"
mock_resp.headers.get.return_value = "application/gzip"
mock_resp.read.return_value = raw_bytes
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
with patch("urllib.request.urlopen", return_value=mock_resp), \
patch.object(
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
):
result = runner.invoke(
app, ["workflow", "add", "https://example.com/workflow.tar.gz"],
catch_exceptions=False,
)
assert result.exit_code == 0, result.output
assert "arc-workflow" in result.output
def test_workflow_add_url_zip(self, project_dir):
"""workflow add downloads a .zip from a URL and installs the workflow."""
import zipfile, io
from unittest.mock import patch, MagicMock
runner, app = self._runner_and_app()
buf = io.BytesIO()
with zipfile.ZipFile(buf, "w") as zf:
zf.writestr("workflow.yml", MINIMAL_WORKFLOW_YAML)
raw_bytes = buf.getvalue()
mock_resp = MagicMock()
mock_resp.geturl.return_value = "https://example.com/workflow.zip"
mock_resp.headers.get.return_value = "application/zip"
mock_resp.read.return_value = raw_bytes
mock_resp.__enter__ = lambda s: s
mock_resp.__exit__ = MagicMock(return_value=False)
with patch("urllib.request.urlopen", return_value=mock_resp), \
patch.object(
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
):
result = runner.invoke(
app, ["workflow", "add", "https://example.com/workflow.zip"],
catch_exceptions=False,
)
assert result.exit_code == 0, result.output
assert "arc-workflow" in result.output