mirror of
https://github.com/github/spec-kit.git
synced 2026-07-03 12:28:06 +08:00
Make detect_archive_format/safe_extract_tarball public; add workflow add archive CLI tests
Agent-Logs-Url: https://github.com/github/spec-kit/sessions/845e41d1-75e3-49fb-a580-a7fb805dd716 Co-authored-by: mnriem <15701806+mnriem@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
e0495ebc38
commit
0a02369ebe
@@ -2629,7 +2629,7 @@ def preset_add(
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import tempfile
|
||||
from .extensions import _detect_archive_format as _det_fmt
|
||||
from .extensions import detect_archive_format as _det_fmt
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
archive_fmt = _det_fmt(from_url)
|
||||
@@ -3628,7 +3628,7 @@ def extension_add(
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from urllib.parse import urlparse
|
||||
from .extensions import _detect_archive_format
|
||||
from .extensions import detect_archive_format
|
||||
|
||||
# Validate URL
|
||||
parsed = urlparse(from_url)
|
||||
@@ -3647,14 +3647,14 @@ def extension_add(
|
||||
# Download archive to temp location; detect format from URL or Content-Type.
|
||||
download_dir = project_root / ".specify" / "extensions" / ".cache" / "downloads"
|
||||
download_dir.mkdir(parents=True, exist_ok=True)
|
||||
archive_fmt = _detect_archive_format(from_url)
|
||||
archive_fmt = detect_archive_format(from_url)
|
||||
archive_path = None
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(from_url, timeout=60) as response:
|
||||
if not archive_fmt:
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
archive_fmt = _detect_archive_format(from_url, content_type)
|
||||
archive_fmt = detect_archive_format(from_url, content_type)
|
||||
archive_data = response.read()
|
||||
|
||||
if not archive_fmt:
|
||||
@@ -4331,9 +4331,9 @@ def extension_update(
|
||||
try:
|
||||
# 6. Validate extension ID from archive BEFORE modifying installation
|
||||
# Handle both root-level and nested extension.yml (GitHub auto-generated archives)
|
||||
from .extensions import _detect_archive_format
|
||||
from .extensions import detect_archive_format
|
||||
import tarfile
|
||||
archive_fmt = _detect_archive_format(str(archive_path))
|
||||
archive_fmt = detect_archive_format(str(archive_path))
|
||||
import yaml
|
||||
manifest_data = None
|
||||
|
||||
@@ -5029,7 +5029,7 @@ def workflow_add(
|
||||
from ipaddress import ip_address
|
||||
from urllib.parse import urlparse
|
||||
from urllib.request import urlopen # noqa: S310
|
||||
from .extensions import _detect_archive_format
|
||||
from .extensions import detect_archive_format
|
||||
|
||||
parsed_src = urlparse(source)
|
||||
src_host = parsed_src.hostname or ""
|
||||
@@ -5062,10 +5062,10 @@ def workflow_add(
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Detect archive format from the final URL or Content-Type header.
|
||||
archive_fmt = _detect_archive_format(final_url)
|
||||
archive_fmt = detect_archive_format(final_url)
|
||||
if not archive_fmt:
|
||||
content_type = resp.headers.get("Content-Type", "")
|
||||
archive_fmt = _detect_archive_format(final_url, content_type)
|
||||
archive_fmt = detect_archive_format(final_url, content_type)
|
||||
|
||||
raw_data = resp.read()
|
||||
except typer.Exit:
|
||||
@@ -5119,8 +5119,8 @@ def workflow_add(
|
||||
source.lower().endswith(".tar.gz") or source.lower().endswith(".tgz") or source.lower().endswith(".zip")
|
||||
):
|
||||
# Local archive file containing workflow.yml
|
||||
from .extensions import _detect_archive_format
|
||||
local_fmt = _detect_archive_format(source)
|
||||
from .extensions import detect_archive_format
|
||||
local_fmt = detect_archive_format(source)
|
||||
try:
|
||||
wf_yaml = _extract_workflow_yml(source_path, local_fmt)
|
||||
except Exception as exc:
|
||||
@@ -5199,7 +5199,7 @@ def workflow_add(
|
||||
|
||||
try:
|
||||
from urllib.request import urlopen # noqa: S310 — URL comes from catalog
|
||||
from .extensions import _detect_archive_format
|
||||
from .extensions import detect_archive_format
|
||||
|
||||
workflow_dir.mkdir(parents=True, exist_ok=True)
|
||||
with urlopen(workflow_url, timeout=30) as response: # noqa: S310
|
||||
@@ -5224,10 +5224,10 @@ def workflow_add(
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Detect archive format from the final URL or Content-Type header.
|
||||
cat_archive_fmt = _detect_archive_format(final_url)
|
||||
cat_archive_fmt = detect_archive_format(final_url)
|
||||
if not cat_archive_fmt:
|
||||
cat_ct = response.headers.get("Content-Type", "")
|
||||
cat_archive_fmt = _detect_archive_format(final_url, cat_ct)
|
||||
cat_archive_fmt = detect_archive_format(final_url, cat_ct)
|
||||
|
||||
raw_response = response.read()
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ def normalize_priority(value: Any, default: int = 10) -> int:
|
||||
return priority if priority >= 1 else default
|
||||
|
||||
|
||||
def _detect_archive_format(url: str, content_type: str = "") -> str:
|
||||
def detect_archive_format(url: str, content_type: str = "") -> str:
|
||||
"""Detect archive format from URL path extension or Content-Type header.
|
||||
|
||||
Args:
|
||||
@@ -143,7 +143,7 @@ def _detect_archive_format(url: str, content_type: str = "") -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _safe_extract_tarball(
|
||||
def safe_extract_tarball(
|
||||
archive_path: Path,
|
||||
dest_dir: Path,
|
||||
error_class: "type[Exception]" = Exception,
|
||||
@@ -1340,11 +1340,11 @@ class ExtensionManager:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
temp_path = Path(tmpdir)
|
||||
|
||||
archive_fmt = _detect_archive_format(str(zip_path))
|
||||
archive_fmt = detect_archive_format(str(zip_path))
|
||||
|
||||
if archive_fmt == "tar.gz":
|
||||
# Extract tarball safely (prevent tar slip attack)
|
||||
_safe_extract_tarball(zip_path, temp_path, ValidationError)
|
||||
safe_extract_tarball(zip_path, temp_path, ValidationError)
|
||||
else:
|
||||
# Extract ZIP safely (prevent Zip Slip attack)
|
||||
with zipfile.ZipFile(zip_path, 'r') as zf:
|
||||
@@ -2140,14 +2140,14 @@ class ExtensionCatalog:
|
||||
version = ext_info.get("version", "unknown")
|
||||
|
||||
# Detect archive format from URL; resolve via Content-Type when needed.
|
||||
archive_fmt = _detect_archive_format(download_url)
|
||||
archive_fmt = detect_archive_format(download_url)
|
||||
|
||||
# Download the archive
|
||||
try:
|
||||
with self._open_url(download_url, timeout=60) as response:
|
||||
if not archive_fmt:
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
archive_fmt = _detect_archive_format(download_url, content_type)
|
||||
archive_fmt = detect_archive_format(download_url, content_type)
|
||||
archive_data = response.read()
|
||||
|
||||
except urllib.error.URLError as e:
|
||||
|
||||
@@ -27,7 +27,7 @@ import yaml
|
||||
from packaging import version as pkg_version
|
||||
from packaging.specifiers import SpecifierSet, InvalidSpecifier
|
||||
|
||||
from .extensions import ExtensionRegistry, normalize_priority, _detect_archive_format, _safe_extract_tarball
|
||||
from .extensions import ExtensionRegistry, normalize_priority, detect_archive_format, safe_extract_tarball
|
||||
|
||||
|
||||
def _substitute_core_template(
|
||||
@@ -1626,11 +1626,11 @@ class PresetManager:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
temp_path = Path(tmpdir)
|
||||
|
||||
archive_fmt = _detect_archive_format(str(zip_path))
|
||||
archive_fmt = detect_archive_format(str(zip_path))
|
||||
|
||||
if archive_fmt == "tar.gz":
|
||||
# Extract tarball safely (prevent tar slip attack)
|
||||
_safe_extract_tarball(zip_path, temp_path, PresetValidationError)
|
||||
safe_extract_tarball(zip_path, temp_path, PresetValidationError)
|
||||
else:
|
||||
with zipfile.ZipFile(zip_path, 'r') as zf:
|
||||
temp_path_resolved = temp_path.resolve()
|
||||
@@ -2314,13 +2314,13 @@ class PresetCatalog:
|
||||
version = pack_info.get("version", "unknown")
|
||||
|
||||
# Detect archive format from URL; resolve via Content-Type when needed.
|
||||
archive_fmt = _detect_archive_format(download_url)
|
||||
archive_fmt = detect_archive_format(download_url)
|
||||
|
||||
try:
|
||||
with self._open_url(download_url, timeout=60) as response:
|
||||
if not archive_fmt:
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
archive_fmt = _detect_archive_format(download_url, content_type)
|
||||
archive_fmt = detect_archive_format(download_url, content_type)
|
||||
archive_data = response.read()
|
||||
|
||||
except urllib.error.URLError as e:
|
||||
|
||||
@@ -178,14 +178,14 @@ class TestNormalizePriority:
|
||||
assert normalize_priority("invalid", default=1) == 1
|
||||
|
||||
|
||||
# ===== _detect_archive_format Tests =====
|
||||
# ===== detect_archive_format Tests =====
|
||||
|
||||
class TestDetectArchiveFormat:
|
||||
"""Test the _detect_archive_format helper."""
|
||||
"""Test the detect_archive_format helper."""
|
||||
|
||||
def _fmt(self, url, ct=""):
|
||||
from specify_cli.extensions import _detect_archive_format
|
||||
return _detect_archive_format(url, ct)
|
||||
from specify_cli.extensions import detect_archive_format
|
||||
return detect_archive_format(url, ct)
|
||||
|
||||
def test_zip_url_extension(self):
|
||||
assert self._fmt("https://example.com/ext-1.0.0.zip") == "zip"
|
||||
|
||||
@@ -1843,3 +1843,241 @@ steps:
|
||||
assert state.status == RunStatus.COMPLETED
|
||||
assert "do-plan" in state.step_results
|
||||
assert "do-specify" not in state.step_results
|
||||
|
||||
|
||||
# ===== workflow add archive CLI tests =====
|
||||
|
||||
MINIMAL_WORKFLOW_YAML = """\
|
||||
schema_version: "1.0"
|
||||
workflow:
|
||||
id: "arc-workflow"
|
||||
name: "Archive Workflow"
|
||||
version: "1.0.0"
|
||||
description: "Installed from archive"
|
||||
steps:
|
||||
- id: step-one
|
||||
type: shell
|
||||
run: "echo hello"
|
||||
"""
|
||||
|
||||
|
||||
class TestWorkflowAddArchive:
|
||||
"""CLI-level tests for `workflow add` with local archive files."""
|
||||
|
||||
@pytest.fixture
|
||||
def project_dir(self, tmp_path):
|
||||
"""Create a minimal spec-kit project."""
|
||||
specify = tmp_path / ".specify"
|
||||
specify.mkdir()
|
||||
(specify / "workflows").mkdir()
|
||||
return tmp_path
|
||||
|
||||
def _runner_and_app(self):
|
||||
from typer.testing import CliRunner
|
||||
from specify_cli import app
|
||||
return CliRunner(), app
|
||||
|
||||
# -- Local ZIP archive --------------------------------------------------
|
||||
|
||||
def test_workflow_add_local_zip_flat(self, project_dir):
|
||||
"""workflow add installs from a local ZIP with workflow.yml at root."""
|
||||
import zipfile
|
||||
runner, app = self._runner_and_app()
|
||||
|
||||
archive = project_dir / "workflow.zip"
|
||||
with zipfile.ZipFile(archive, "w") as zf:
|
||||
zf.writestr("workflow.yml", MINIMAL_WORKFLOW_YAML)
|
||||
|
||||
with __import__("unittest.mock", fromlist=["patch"]).patch.object(
|
||||
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
|
||||
):
|
||||
result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=False)
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "arc-workflow" in result.output
|
||||
installed = project_dir / ".specify" / "workflows" / "arc-workflow" / "workflow.yml"
|
||||
assert installed.exists()
|
||||
|
||||
def test_workflow_add_local_zip_nested(self, project_dir):
|
||||
"""workflow add installs from a local ZIP with workflow.yml in a subdirectory."""
|
||||
import zipfile
|
||||
runner, app = self._runner_and_app()
|
||||
|
||||
archive = project_dir / "workflow.zip"
|
||||
with zipfile.ZipFile(archive, "w") as zf:
|
||||
zf.writestr("repo-1.0/workflow.yml", MINIMAL_WORKFLOW_YAML)
|
||||
|
||||
with __import__("unittest.mock", fromlist=["patch"]).patch.object(
|
||||
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
|
||||
):
|
||||
result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=False)
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "arc-workflow" in result.output
|
||||
|
||||
def test_workflow_add_local_zip_missing_workflow_yml(self, project_dir):
|
||||
"""workflow add exits with an error when the ZIP has no workflow.yml."""
|
||||
import zipfile
|
||||
runner, app = self._runner_and_app()
|
||||
|
||||
archive = project_dir / "empty.zip"
|
||||
with zipfile.ZipFile(archive, "w") as zf:
|
||||
zf.writestr("README.md", "nothing here")
|
||||
|
||||
with __import__("unittest.mock", fromlist=["patch"]).patch.object(
|
||||
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
|
||||
):
|
||||
result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=True)
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "extract" in result.output.lower() or "workflow" in result.output.lower()
|
||||
|
||||
# -- Local tar.gz archive -----------------------------------------------
|
||||
|
||||
def test_workflow_add_local_tar_gz_flat(self, project_dir):
|
||||
"""workflow add installs from a local .tar.gz with workflow.yml at root."""
|
||||
import tarfile, io
|
||||
runner, app = self._runner_and_app()
|
||||
|
||||
archive = project_dir / "workflow.tar.gz"
|
||||
with tarfile.open(archive, "w:gz") as tf:
|
||||
data = MINIMAL_WORKFLOW_YAML.encode()
|
||||
info = tarfile.TarInfo(name="workflow.yml")
|
||||
info.size = len(data)
|
||||
tf.addfile(info, io.BytesIO(data))
|
||||
|
||||
with __import__("unittest.mock", fromlist=["patch"]).patch.object(
|
||||
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
|
||||
):
|
||||
result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=False)
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "arc-workflow" in result.output
|
||||
installed = project_dir / ".specify" / "workflows" / "arc-workflow" / "workflow.yml"
|
||||
assert installed.exists()
|
||||
|
||||
def test_workflow_add_local_tar_gz_nested(self, project_dir):
|
||||
"""workflow add installs from a local .tar.gz with workflow.yml in a subdirectory."""
|
||||
import tarfile, io
|
||||
runner, app = self._runner_and_app()
|
||||
|
||||
archive = project_dir / "workflow.tar.gz"
|
||||
with tarfile.open(archive, "w:gz") as tf:
|
||||
data = MINIMAL_WORKFLOW_YAML.encode()
|
||||
info = tarfile.TarInfo(name="repo-1.0/workflow.yml")
|
||||
info.size = len(data)
|
||||
tf.addfile(info, io.BytesIO(data))
|
||||
|
||||
with __import__("unittest.mock", fromlist=["patch"]).patch.object(
|
||||
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
|
||||
):
|
||||
result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=False)
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "arc-workflow" in result.output
|
||||
|
||||
def test_workflow_add_local_tgz_flat(self, project_dir):
|
||||
"""workflow add recognises the .tgz extension as a gzipped tarball."""
|
||||
import tarfile, io
|
||||
runner, app = self._runner_and_app()
|
||||
|
||||
archive = project_dir / "workflow.tgz"
|
||||
with tarfile.open(archive, "w:gz") as tf:
|
||||
data = MINIMAL_WORKFLOW_YAML.encode()
|
||||
info = tarfile.TarInfo(name="workflow.yml")
|
||||
info.size = len(data)
|
||||
tf.addfile(info, io.BytesIO(data))
|
||||
|
||||
with __import__("unittest.mock", fromlist=["patch"]).patch.object(
|
||||
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
|
||||
):
|
||||
result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=False)
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "arc-workflow" in result.output
|
||||
|
||||
def test_workflow_add_local_tar_gz_missing_workflow_yml(self, project_dir):
|
||||
"""workflow add exits with an error when the .tar.gz has no workflow.yml."""
|
||||
import tarfile, io
|
||||
runner, app = self._runner_and_app()
|
||||
|
||||
archive = project_dir / "empty.tar.gz"
|
||||
with tarfile.open(archive, "w:gz") as tf:
|
||||
data = b"nothing"
|
||||
info = tarfile.TarInfo(name="README.md")
|
||||
info.size = len(data)
|
||||
tf.addfile(info, io.BytesIO(data))
|
||||
|
||||
with __import__("unittest.mock", fromlist=["patch"]).patch.object(
|
||||
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
|
||||
):
|
||||
result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=True)
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "extract" in result.output.lower() or "workflow" in result.output.lower()
|
||||
|
||||
# -- URL archive download -----------------------------------------------
|
||||
|
||||
def test_workflow_add_url_tar_gz(self, project_dir):
|
||||
"""workflow add downloads a .tar.gz from a URL and installs the workflow."""
|
||||
import tarfile, io
|
||||
from unittest.mock import patch, MagicMock
|
||||
runner, app = self._runner_and_app()
|
||||
|
||||
# Build an in-memory tar.gz archive containing workflow.yml.
|
||||
buf = io.BytesIO()
|
||||
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
|
||||
data = MINIMAL_WORKFLOW_YAML.encode()
|
||||
info = tarfile.TarInfo(name="workflow.yml")
|
||||
info.size = len(data)
|
||||
tf.addfile(info, io.BytesIO(data))
|
||||
raw_bytes = buf.getvalue()
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.geturl.return_value = "https://example.com/workflow.tar.gz"
|
||||
mock_resp.headers.get.return_value = "application/gzip"
|
||||
mock_resp.read.return_value = raw_bytes
|
||||
mock_resp.__enter__ = lambda s: s
|
||||
mock_resp.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("urllib.request.urlopen", return_value=mock_resp), \
|
||||
patch.object(
|
||||
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
|
||||
):
|
||||
result = runner.invoke(
|
||||
app, ["workflow", "add", "https://example.com/workflow.tar.gz"],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "arc-workflow" in result.output
|
||||
|
||||
def test_workflow_add_url_zip(self, project_dir):
|
||||
"""workflow add downloads a .zip from a URL and installs the workflow."""
|
||||
import zipfile, io
|
||||
from unittest.mock import patch, MagicMock
|
||||
runner, app = self._runner_and_app()
|
||||
|
||||
buf = io.BytesIO()
|
||||
with zipfile.ZipFile(buf, "w") as zf:
|
||||
zf.writestr("workflow.yml", MINIMAL_WORKFLOW_YAML)
|
||||
raw_bytes = buf.getvalue()
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.geturl.return_value = "https://example.com/workflow.zip"
|
||||
mock_resp.headers.get.return_value = "application/zip"
|
||||
mock_resp.read.return_value = raw_bytes
|
||||
mock_resp.__enter__ = lambda s: s
|
||||
mock_resp.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with patch("urllib.request.urlopen", return_value=mock_resp), \
|
||||
patch.object(
|
||||
__import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir
|
||||
):
|
||||
result = runner.invoke(
|
||||
app, ["workflow", "add", "https://example.com/workflow.zip"],
|
||||
catch_exceptions=False,
|
||||
)
|
||||
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "arc-workflow" in result.output
|
||||
|
||||
Reference in New Issue
Block a user