From 0a02369ebe2177c28a54ec616d7b5f9f43383aa9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 6 May 2026 21:50:25 +0000 Subject: [PATCH] Make detect_archive_format/safe_extract_tarball public; add workflow add archive CLI tests Agent-Logs-Url: https://github.com/github/spec-kit/sessions/845e41d1-75e3-49fb-a580-a7fb805dd716 Co-authored-by: mnriem <15701806+mnriem@users.noreply.github.com> --- src/specify_cli/__init__.py | 28 ++-- src/specify_cli/extensions.py | 12 +- src/specify_cli/presets.py | 10 +- tests/test_extensions.py | 8 +- tests/test_workflows.py | 238 ++++++++++++++++++++++++++++++++++ 5 files changed, 267 insertions(+), 29 deletions(-) diff --git a/src/specify_cli/__init__.py b/src/specify_cli/__init__.py index 1da35136c..ce29f48ae 100644 --- a/src/specify_cli/__init__.py +++ b/src/specify_cli/__init__.py @@ -2629,7 +2629,7 @@ def preset_add( import urllib.request import urllib.error import tempfile - from .extensions import _detect_archive_format as _det_fmt + from .extensions import detect_archive_format as _det_fmt with tempfile.TemporaryDirectory() as tmpdir: archive_fmt = _det_fmt(from_url) @@ -3628,7 +3628,7 @@ def extension_add( import urllib.request import urllib.error from urllib.parse import urlparse - from .extensions import _detect_archive_format + from .extensions import detect_archive_format # Validate URL parsed = urlparse(from_url) @@ -3647,14 +3647,14 @@ def extension_add( # Download archive to temp location; detect format from URL or Content-Type. download_dir = project_root / ".specify" / "extensions" / ".cache" / "downloads" download_dir.mkdir(parents=True, exist_ok=True) - archive_fmt = _detect_archive_format(from_url) + archive_fmt = detect_archive_format(from_url) archive_path = None try: with urllib.request.urlopen(from_url, timeout=60) as response: if not archive_fmt: content_type = response.headers.get("Content-Type", "") - archive_fmt = _detect_archive_format(from_url, content_type) + archive_fmt = detect_archive_format(from_url, content_type) archive_data = response.read() if not archive_fmt: @@ -4331,9 +4331,9 @@ def extension_update( try: # 6. Validate extension ID from archive BEFORE modifying installation # Handle both root-level and nested extension.yml (GitHub auto-generated archives) - from .extensions import _detect_archive_format + from .extensions import detect_archive_format import tarfile - archive_fmt = _detect_archive_format(str(archive_path)) + archive_fmt = detect_archive_format(str(archive_path)) import yaml manifest_data = None @@ -5029,7 +5029,7 @@ def workflow_add( from ipaddress import ip_address from urllib.parse import urlparse from urllib.request import urlopen # noqa: S310 - from .extensions import _detect_archive_format + from .extensions import detect_archive_format parsed_src = urlparse(source) src_host = parsed_src.hostname or "" @@ -5062,10 +5062,10 @@ def workflow_add( raise typer.Exit(1) # Detect archive format from the final URL or Content-Type header. - archive_fmt = _detect_archive_format(final_url) + archive_fmt = detect_archive_format(final_url) if not archive_fmt: content_type = resp.headers.get("Content-Type", "") - archive_fmt = _detect_archive_format(final_url, content_type) + archive_fmt = detect_archive_format(final_url, content_type) raw_data = resp.read() except typer.Exit: @@ -5119,8 +5119,8 @@ def workflow_add( source.lower().endswith(".tar.gz") or source.lower().endswith(".tgz") or source.lower().endswith(".zip") ): # Local archive file containing workflow.yml - from .extensions import _detect_archive_format - local_fmt = _detect_archive_format(source) + from .extensions import detect_archive_format + local_fmt = detect_archive_format(source) try: wf_yaml = _extract_workflow_yml(source_path, local_fmt) except Exception as exc: @@ -5199,7 +5199,7 @@ def workflow_add( try: from urllib.request import urlopen # noqa: S310 — URL comes from catalog - from .extensions import _detect_archive_format + from .extensions import detect_archive_format workflow_dir.mkdir(parents=True, exist_ok=True) with urlopen(workflow_url, timeout=30) as response: # noqa: S310 @@ -5224,10 +5224,10 @@ def workflow_add( raise typer.Exit(1) # Detect archive format from the final URL or Content-Type header. - cat_archive_fmt = _detect_archive_format(final_url) + cat_archive_fmt = detect_archive_format(final_url) if not cat_archive_fmt: cat_ct = response.headers.get("Content-Type", "") - cat_archive_fmt = _detect_archive_format(final_url, cat_ct) + cat_archive_fmt = detect_archive_format(final_url, cat_ct) raw_response = response.read() diff --git a/src/specify_cli/extensions.py b/src/specify_cli/extensions.py index 3b25bbfbe..f8c4c3925 100644 --- a/src/specify_cli/extensions.py +++ b/src/specify_cli/extensions.py @@ -108,7 +108,7 @@ def normalize_priority(value: Any, default: int = 10) -> int: return priority if priority >= 1 else default -def _detect_archive_format(url: str, content_type: str = "") -> str: +def detect_archive_format(url: str, content_type: str = "") -> str: """Detect archive format from URL path extension or Content-Type header. Args: @@ -143,7 +143,7 @@ def _detect_archive_format(url: str, content_type: str = "") -> str: return "" -def _safe_extract_tarball( +def safe_extract_tarball( archive_path: Path, dest_dir: Path, error_class: "type[Exception]" = Exception, @@ -1340,11 +1340,11 @@ class ExtensionManager: with tempfile.TemporaryDirectory() as tmpdir: temp_path = Path(tmpdir) - archive_fmt = _detect_archive_format(str(zip_path)) + archive_fmt = detect_archive_format(str(zip_path)) if archive_fmt == "tar.gz": # Extract tarball safely (prevent tar slip attack) - _safe_extract_tarball(zip_path, temp_path, ValidationError) + safe_extract_tarball(zip_path, temp_path, ValidationError) else: # Extract ZIP safely (prevent Zip Slip attack) with zipfile.ZipFile(zip_path, 'r') as zf: @@ -2140,14 +2140,14 @@ class ExtensionCatalog: version = ext_info.get("version", "unknown") # Detect archive format from URL; resolve via Content-Type when needed. - archive_fmt = _detect_archive_format(download_url) + archive_fmt = detect_archive_format(download_url) # Download the archive try: with self._open_url(download_url, timeout=60) as response: if not archive_fmt: content_type = response.headers.get("Content-Type", "") - archive_fmt = _detect_archive_format(download_url, content_type) + archive_fmt = detect_archive_format(download_url, content_type) archive_data = response.read() except urllib.error.URLError as e: diff --git a/src/specify_cli/presets.py b/src/specify_cli/presets.py index e270bb504..16b63862c 100644 --- a/src/specify_cli/presets.py +++ b/src/specify_cli/presets.py @@ -27,7 +27,7 @@ import yaml from packaging import version as pkg_version from packaging.specifiers import SpecifierSet, InvalidSpecifier -from .extensions import ExtensionRegistry, normalize_priority, _detect_archive_format, _safe_extract_tarball +from .extensions import ExtensionRegistry, normalize_priority, detect_archive_format, safe_extract_tarball def _substitute_core_template( @@ -1626,11 +1626,11 @@ class PresetManager: with tempfile.TemporaryDirectory() as tmpdir: temp_path = Path(tmpdir) - archive_fmt = _detect_archive_format(str(zip_path)) + archive_fmt = detect_archive_format(str(zip_path)) if archive_fmt == "tar.gz": # Extract tarball safely (prevent tar slip attack) - _safe_extract_tarball(zip_path, temp_path, PresetValidationError) + safe_extract_tarball(zip_path, temp_path, PresetValidationError) else: with zipfile.ZipFile(zip_path, 'r') as zf: temp_path_resolved = temp_path.resolve() @@ -2314,13 +2314,13 @@ class PresetCatalog: version = pack_info.get("version", "unknown") # Detect archive format from URL; resolve via Content-Type when needed. - archive_fmt = _detect_archive_format(download_url) + archive_fmt = detect_archive_format(download_url) try: with self._open_url(download_url, timeout=60) as response: if not archive_fmt: content_type = response.headers.get("Content-Type", "") - archive_fmt = _detect_archive_format(download_url, content_type) + archive_fmt = detect_archive_format(download_url, content_type) archive_data = response.read() except urllib.error.URLError as e: diff --git a/tests/test_extensions.py b/tests/test_extensions.py index 631014007..5cb79472c 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -178,14 +178,14 @@ class TestNormalizePriority: assert normalize_priority("invalid", default=1) == 1 -# ===== _detect_archive_format Tests ===== +# ===== detect_archive_format Tests ===== class TestDetectArchiveFormat: - """Test the _detect_archive_format helper.""" + """Test the detect_archive_format helper.""" def _fmt(self, url, ct=""): - from specify_cli.extensions import _detect_archive_format - return _detect_archive_format(url, ct) + from specify_cli.extensions import detect_archive_format + return detect_archive_format(url, ct) def test_zip_url_extension(self): assert self._fmt("https://example.com/ext-1.0.0.zip") == "zip" diff --git a/tests/test_workflows.py b/tests/test_workflows.py index 4c042fc7d..2b6c87f04 100644 --- a/tests/test_workflows.py +++ b/tests/test_workflows.py @@ -1843,3 +1843,241 @@ steps: assert state.status == RunStatus.COMPLETED assert "do-plan" in state.step_results assert "do-specify" not in state.step_results + + +# ===== workflow add archive CLI tests ===== + +MINIMAL_WORKFLOW_YAML = """\ +schema_version: "1.0" +workflow: + id: "arc-workflow" + name: "Archive Workflow" + version: "1.0.0" + description: "Installed from archive" +steps: + - id: step-one + type: shell + run: "echo hello" +""" + + +class TestWorkflowAddArchive: + """CLI-level tests for `workflow add` with local archive files.""" + + @pytest.fixture + def project_dir(self, tmp_path): + """Create a minimal spec-kit project.""" + specify = tmp_path / ".specify" + specify.mkdir() + (specify / "workflows").mkdir() + return tmp_path + + def _runner_and_app(self): + from typer.testing import CliRunner + from specify_cli import app + return CliRunner(), app + + # -- Local ZIP archive -------------------------------------------------- + + def test_workflow_add_local_zip_flat(self, project_dir): + """workflow add installs from a local ZIP with workflow.yml at root.""" + import zipfile + runner, app = self._runner_and_app() + + archive = project_dir / "workflow.zip" + with zipfile.ZipFile(archive, "w") as zf: + zf.writestr("workflow.yml", MINIMAL_WORKFLOW_YAML) + + with __import__("unittest.mock", fromlist=["patch"]).patch.object( + __import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir + ): + result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=False) + + assert result.exit_code == 0, result.output + assert "arc-workflow" in result.output + installed = project_dir / ".specify" / "workflows" / "arc-workflow" / "workflow.yml" + assert installed.exists() + + def test_workflow_add_local_zip_nested(self, project_dir): + """workflow add installs from a local ZIP with workflow.yml in a subdirectory.""" + import zipfile + runner, app = self._runner_and_app() + + archive = project_dir / "workflow.zip" + with zipfile.ZipFile(archive, "w") as zf: + zf.writestr("repo-1.0/workflow.yml", MINIMAL_WORKFLOW_YAML) + + with __import__("unittest.mock", fromlist=["patch"]).patch.object( + __import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir + ): + result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=False) + + assert result.exit_code == 0, result.output + assert "arc-workflow" in result.output + + def test_workflow_add_local_zip_missing_workflow_yml(self, project_dir): + """workflow add exits with an error when the ZIP has no workflow.yml.""" + import zipfile + runner, app = self._runner_and_app() + + archive = project_dir / "empty.zip" + with zipfile.ZipFile(archive, "w") as zf: + zf.writestr("README.md", "nothing here") + + with __import__("unittest.mock", fromlist=["patch"]).patch.object( + __import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir + ): + result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=True) + + assert result.exit_code != 0 + assert "extract" in result.output.lower() or "workflow" in result.output.lower() + + # -- Local tar.gz archive ----------------------------------------------- + + def test_workflow_add_local_tar_gz_flat(self, project_dir): + """workflow add installs from a local .tar.gz with workflow.yml at root.""" + import tarfile, io + runner, app = self._runner_and_app() + + archive = project_dir / "workflow.tar.gz" + with tarfile.open(archive, "w:gz") as tf: + data = MINIMAL_WORKFLOW_YAML.encode() + info = tarfile.TarInfo(name="workflow.yml") + info.size = len(data) + tf.addfile(info, io.BytesIO(data)) + + with __import__("unittest.mock", fromlist=["patch"]).patch.object( + __import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir + ): + result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=False) + + assert result.exit_code == 0, result.output + assert "arc-workflow" in result.output + installed = project_dir / ".specify" / "workflows" / "arc-workflow" / "workflow.yml" + assert installed.exists() + + def test_workflow_add_local_tar_gz_nested(self, project_dir): + """workflow add installs from a local .tar.gz with workflow.yml in a subdirectory.""" + import tarfile, io + runner, app = self._runner_and_app() + + archive = project_dir / "workflow.tar.gz" + with tarfile.open(archive, "w:gz") as tf: + data = MINIMAL_WORKFLOW_YAML.encode() + info = tarfile.TarInfo(name="repo-1.0/workflow.yml") + info.size = len(data) + tf.addfile(info, io.BytesIO(data)) + + with __import__("unittest.mock", fromlist=["patch"]).patch.object( + __import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir + ): + result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=False) + + assert result.exit_code == 0, result.output + assert "arc-workflow" in result.output + + def test_workflow_add_local_tgz_flat(self, project_dir): + """workflow add recognises the .tgz extension as a gzipped tarball.""" + import tarfile, io + runner, app = self._runner_and_app() + + archive = project_dir / "workflow.tgz" + with tarfile.open(archive, "w:gz") as tf: + data = MINIMAL_WORKFLOW_YAML.encode() + info = tarfile.TarInfo(name="workflow.yml") + info.size = len(data) + tf.addfile(info, io.BytesIO(data)) + + with __import__("unittest.mock", fromlist=["patch"]).patch.object( + __import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir + ): + result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=False) + + assert result.exit_code == 0, result.output + assert "arc-workflow" in result.output + + def test_workflow_add_local_tar_gz_missing_workflow_yml(self, project_dir): + """workflow add exits with an error when the .tar.gz has no workflow.yml.""" + import tarfile, io + runner, app = self._runner_and_app() + + archive = project_dir / "empty.tar.gz" + with tarfile.open(archive, "w:gz") as tf: + data = b"nothing" + info = tarfile.TarInfo(name="README.md") + info.size = len(data) + tf.addfile(info, io.BytesIO(data)) + + with __import__("unittest.mock", fromlist=["patch"]).patch.object( + __import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir + ): + result = runner.invoke(app, ["workflow", "add", str(archive)], catch_exceptions=True) + + assert result.exit_code != 0 + assert "extract" in result.output.lower() or "workflow" in result.output.lower() + + # -- URL archive download ----------------------------------------------- + + def test_workflow_add_url_tar_gz(self, project_dir): + """workflow add downloads a .tar.gz from a URL and installs the workflow.""" + import tarfile, io + from unittest.mock import patch, MagicMock + runner, app = self._runner_and_app() + + # Build an in-memory tar.gz archive containing workflow.yml. + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w:gz") as tf: + data = MINIMAL_WORKFLOW_YAML.encode() + info = tarfile.TarInfo(name="workflow.yml") + info.size = len(data) + tf.addfile(info, io.BytesIO(data)) + raw_bytes = buf.getvalue() + + mock_resp = MagicMock() + mock_resp.geturl.return_value = "https://example.com/workflow.tar.gz" + mock_resp.headers.get.return_value = "application/gzip" + mock_resp.read.return_value = raw_bytes + mock_resp.__enter__ = lambda s: s + mock_resp.__exit__ = MagicMock(return_value=False) + + with patch("urllib.request.urlopen", return_value=mock_resp), \ + patch.object( + __import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir + ): + result = runner.invoke( + app, ["workflow", "add", "https://example.com/workflow.tar.gz"], + catch_exceptions=False, + ) + + assert result.exit_code == 0, result.output + assert "arc-workflow" in result.output + + def test_workflow_add_url_zip(self, project_dir): + """workflow add downloads a .zip from a URL and installs the workflow.""" + import zipfile, io + from unittest.mock import patch, MagicMock + runner, app = self._runner_and_app() + + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w") as zf: + zf.writestr("workflow.yml", MINIMAL_WORKFLOW_YAML) + raw_bytes = buf.getvalue() + + mock_resp = MagicMock() + mock_resp.geturl.return_value = "https://example.com/workflow.zip" + mock_resp.headers.get.return_value = "application/zip" + mock_resp.read.return_value = raw_bytes + mock_resp.__enter__ = lambda s: s + mock_resp.__exit__ = MagicMock(return_value=False) + + with patch("urllib.request.urlopen", return_value=mock_resp), \ + patch.object( + __import__("pathlib", fromlist=["Path"]).Path, "cwd", return_value=project_dir + ): + result = runner.invoke( + app, ["workflow", "add", "https://example.com/workflow.zip"], + catch_exceptions=False, + ) + + assert result.exit_code == 0, result.output + assert "arc-workflow" in result.output