mirror of
https://github.com/github/spec-kit.git
synced 2026-07-03 12:28:06 +08:00
feat(workflows): honor max_concurrency in fan-out via a bounded thread pool (#3224)
* feat(workflows): honor max_concurrency in fan-out via a bounded thread pool * feat(workflows): address review — sliding-window fan-out, locked output, faithful halt Address the reviewer feedback on the bounded fan-out concurrency: - Sliding submission window: keep at most `workers` items in flight and stop launching new items once the run is halting, instead of submitting all items up front (which let the pool keep starting queued work after a halt). - Faithful halt prefix: attribute a halt to the specific item whose own recorded result halted the run (replaying the sequential break condition, honoring continue_on_error/aborted), not the shared run status a later concurrent item may have flipped. The returned prefix now includes the actual halting item, matching the sequential path. An item that fails before recording a result (e.g. an unknown step type) is attributed too, since every item runs the same template. - Lock the parent fan-out output mutation: route the post-fan-out step_results[...]['output'] update through a new RunState.set_step_output() under the run lock, so it cannot race a concurrent save(). - Docstring: describe int() coercion accurately (numeric strings / floats are honored; only non-coercible or <= 1 runs sequentially). Tests: add concurrent halt-includes-halting-item, continue_on_error-does-not- truncate, and unknown-template-type-matches-sequential coverage; make the timing test use a monotonic clock with a looser threshold to avoid CI flakiness. * feat(workflows): address second review pass — concurrency hardening - append_log: serialize the log_entries append + log.jsonl write under a dedicated RunState._log_lock so concurrent fan-out workers can't interleave or corrupt log lines (kept separate from the state lock; never nested). - _run_fan_out.run_item: read the item output back through the item_ctx it executed against rather than the outer context closure — clearer and robust if StepContext ever stops sharing the steps dict by reference. - StepBase: document the thread-safety contract — STEP_REGISTRY holds one shared instance per type, so concurrent fan-out invokes execute() on the same object; implementations must be stateless/thread-safe (the built-ins already are). - test_concurrency_is_real: prove parallelism deterministically with a threading.Barrier (sequential execution can't clear it) instead of a wall-clock timing assertion. * feat(workflows): address review — stamp updated_at under lock, clarify cancel semantics - RunState.save(): move the updated_at timestamp assignment inside the run lock so the timestamp matches the snapshot the thread serializes and concurrent savers don't race on it. - _run_fan_out docstring: clarify that on a halt only not-yet-started items are cancelled; items already running finish but their outputs are ignored (Future.cancel() can't stop running work, and the pool joins on exit). * feat(workflows): serialize on_step_start callback under a lock The concurrent fan-out path invokes _execute_steps from worker threads, which calls the engine's on_step_start callback (the CLI sets it to a console.print lambda). Concurrent invocation could interleave/garble progress output. Guard the call with a WorkflowEngine._callback_lock so callbacks are serialized; the lock is uncontended for sequential runs. * feat(workflows): re-raise worker exceptions in-place to preserve traceback In _run_fan_out's concurrent path, a worker exception was stashed in first_exc and re-raised after the loop. Re-raise it from within the except block with a bare `raise` (after cancelling outstanding futures) so the original traceback is preserved, and drop the now-unneeded first_exc variable. The ThreadPoolExecutor __exit__ still joins any already-running workers before the exception escapes. * feat(workflows): lock final fan-out status, drop redundant output write, bound workers Address third review pass: - Remove the unlocked `context.steps[step_id]["output"] = …` writes in the fan-out parent update. context.steps[step_id] is the same dict object that set_step_output() updates under the run lock, so the direct (unsynchronized) mutation was redundant. - Preserve sequential halt semantics under concurrency: a later in-flight item could overwrite state.status after the halting item was identified. _run_fan_out now derives the halting item's run status (item_halt_status, replacing the bool item_halted) and restores it after the pool joins, so the final status is the first halting item's outcome. - Bound the pool: workers = min(max_concurrency, len(items)) and early-return for empty items, so a user-controlled max_concurrency can't over-allocate threads. Add coverage that an earlier PAUSED item's status wins over a later concurrent FAILED item. * feat(workflows): avoid unlocked context.steps writes when it aliases step_results On a resume run, StepContext is built with steps=state.step_results, so the two direct `context.steps[...] = ...` writes mutated the shared dict outside the run lock and could race save(). Route both through a new _record_result helper that mirrors into context.steps only when it is a distinct object (a fresh run) and otherwise relies solely on record_step_result's locked write.
This commit is contained in:
@@ -97,6 +97,13 @@ class StepBase(ABC):
|
||||
|
||||
Every step type — built-in or extension-provided — implements this
|
||||
interface and registers in ``STEP_REGISTRY``.
|
||||
|
||||
Thread-safety: ``STEP_REGISTRY`` holds a single shared instance per type, so
|
||||
a concurrent ``fan-out`` (``max_concurrency > 1``) can invoke ``execute`` on
|
||||
the same instance from several threads at once. Implementations must be
|
||||
stateless / thread-safe — derive all per-run state from the ``config`` and
|
||||
``context`` arguments and never mutate ``self`` in ``execute``. The built-in
|
||||
steps follow this rule.
|
||||
"""
|
||||
|
||||
#: Matches the ``type:`` value in workflow YAML.
|
||||
|
||||
@@ -10,10 +10,14 @@ The engine is the orchestrator that:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -412,6 +416,15 @@ class RunState:
|
||||
self.current_step_index = 0
|
||||
self.current_step_id: str | None = None
|
||||
self.step_results: dict[str, dict[str, Any]] = {}
|
||||
# Guards step_results mutation and save() so a concurrent fan-out cannot
|
||||
# mutate the dict while save() is serializing it (which would raise
|
||||
# "dictionary changed size during iteration").
|
||||
self._lock = threading.Lock()
|
||||
# Serializes append_log's list append + log.jsonl write so concurrent
|
||||
# fan-out workers cannot interleave or corrupt log lines. Kept separate
|
||||
# from _lock so frequent logging never contends with state saves; since
|
||||
# append_log is never called while _lock is held, the two never nest.
|
||||
self._log_lock = threading.Lock()
|
||||
self.inputs: dict[str, Any] = {}
|
||||
self.created_at = datetime.now(timezone.utc).isoformat()
|
||||
self.updated_at = self.created_at
|
||||
@@ -421,28 +434,72 @@ class RunState:
|
||||
def runs_dir(self) -> Path:
|
||||
return self.project_root / ".specify" / "workflows" / "runs" / self.run_id
|
||||
|
||||
def record_step_result(self, step_id: str, data: dict[str, Any]) -> None:
|
||||
"""Record one step's result under the run lock.
|
||||
|
||||
Routing the mutation through the lock keeps it from racing a concurrent
|
||||
``save()`` that is iterating ``step_results`` (e.g. during a concurrent
|
||||
fan-out). For a sequential run this is an uncontended lock.
|
||||
"""
|
||||
with self._lock:
|
||||
self.step_results[step_id] = data
|
||||
|
||||
def set_step_output(self, step_id: str, output: Any) -> None:
|
||||
"""Replace an already-recorded step's ``output`` under the run lock.
|
||||
|
||||
Fan-out updates its parent step's output after the items have run;
|
||||
routing that nested mutation through the lock keeps it from racing a
|
||||
``save()`` serializing ``step_results`` — the same invariant
|
||||
``record_step_result`` provides for the top-level assignment.
|
||||
"""
|
||||
with self._lock:
|
||||
if step_id in self.step_results:
|
||||
self.step_results[step_id]["output"] = output
|
||||
|
||||
def save(self) -> None:
|
||||
"""Persist current state to disk."""
|
||||
self.updated_at = datetime.now(timezone.utc).isoformat()
|
||||
"""Persist current state to disk.
|
||||
|
||||
Held under the run lock and written atomically (temp file + ``os.replace``)
|
||||
so a concurrent fan-out can neither mutate ``step_results`` mid-serialization
|
||||
nor leave a reader observing a half-written file. Racing writers only
|
||||
contend to be last; they never corrupt.
|
||||
"""
|
||||
runs_dir = self.runs_dir
|
||||
runs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
state_data = {
|
||||
"run_id": self.run_id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"status": self.status.value,
|
||||
"current_step_index": self.current_step_index,
|
||||
"current_step_id": self.current_step_id,
|
||||
"step_results": self.step_results,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
with open(runs_dir / "state.json", "w", encoding="utf-8") as f:
|
||||
json.dump(state_data, f, indent=2)
|
||||
with self._lock:
|
||||
# Stamp updated_at inside the lock so the timestamp matches the
|
||||
# snapshot this thread serializes (concurrent savers don't race it).
|
||||
self.updated_at = datetime.now(timezone.utc).isoformat()
|
||||
state_data = {
|
||||
"run_id": self.run_id,
|
||||
"workflow_id": self.workflow_id,
|
||||
"status": self.status.value,
|
||||
"current_step_index": self.current_step_index,
|
||||
"current_step_id": self.current_step_id,
|
||||
"step_results": self.step_results,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
self._atomic_write_json(runs_dir / "state.json", state_data)
|
||||
self._atomic_write_json(runs_dir / "inputs.json", {"inputs": self.inputs})
|
||||
|
||||
inputs_data = {"inputs": self.inputs}
|
||||
with open(runs_dir / "inputs.json", "w", encoding="utf-8") as f:
|
||||
json.dump(inputs_data, f, indent=2)
|
||||
@staticmethod
|
||||
def _atomic_write_json(path: Path, data: dict[str, Any]) -> None:
|
||||
"""Write *data* as indented JSON to *path* atomically (temp + ``os.replace``)."""
|
||||
fd, tmp = tempfile.mkstemp(
|
||||
dir=str(path.parent), prefix=f".{path.name}.", suffix=".tmp"
|
||||
)
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
os.replace(tmp, path)
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def load(cls, run_id: str, project_root: Path) -> RunState:
|
||||
@@ -490,14 +547,18 @@ class RunState:
|
||||
return state
|
||||
|
||||
def append_log(self, entry: dict[str, Any]) -> None:
|
||||
"""Append a log entry to the run log."""
|
||||
entry["timestamp"] = datetime.now(timezone.utc).isoformat()
|
||||
self.log_entries.append(entry)
|
||||
"""Append a log entry to the run log.
|
||||
|
||||
Held under ``_log_lock`` so concurrent fan-out workers serialize their
|
||||
list append and ``log.jsonl`` write rather than interleaving lines.
|
||||
"""
|
||||
entry["timestamp"] = datetime.now(timezone.utc).isoformat()
|
||||
runs_dir = self.runs_dir
|
||||
runs_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(runs_dir / "log.jsonl", "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
with self._log_lock:
|
||||
self.log_entries.append(entry)
|
||||
with open(runs_dir / "log.jsonl", "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
|
||||
|
||||
# -- Workflow Engine ------------------------------------------------------
|
||||
@@ -509,6 +570,10 @@ class WorkflowEngine:
|
||||
def __init__(self, project_root: Path | None = None) -> None:
|
||||
self.project_root = project_root or Path(".")
|
||||
self.on_step_start: Any = None # Callable[[str, str], None] | None
|
||||
# Serializes on_step_start so a concurrent fan-out can't interleave the
|
||||
# callback's output (the CLI sets it to a console.print lambda). Uncontended
|
||||
# for sequential runs.
|
||||
self._callback_lock = threading.Lock()
|
||||
|
||||
def load_workflow(self, source: str | Path) -> WorkflowDefinition:
|
||||
"""Load a workflow from an installed ID or a local YAML path.
|
||||
@@ -712,6 +777,22 @@ class WorkflowEngine:
|
||||
state.save()
|
||||
return state
|
||||
|
||||
@staticmethod
|
||||
def _record_result(
|
||||
context: StepContext, state: RunState, step_id: str, data: dict[str, Any]
|
||||
) -> None:
|
||||
"""Record a step result into both the live context and persistent state.
|
||||
|
||||
``record_step_result`` writes ``state.step_results`` under the run lock.
|
||||
On a resume run ``context.steps`` *is* that same dict, so that locked
|
||||
write is the only one needed; mirror into ``context.steps`` separately
|
||||
only when it is a distinct object (a fresh run), to avoid an unlocked
|
||||
mutation of the shared dict that could race a concurrent ``save()``.
|
||||
"""
|
||||
if context.steps is not state.step_results:
|
||||
context.steps[step_id] = data
|
||||
state.record_step_result(step_id, data)
|
||||
|
||||
def _execute_steps(
|
||||
self,
|
||||
steps: list[dict[str, Any]],
|
||||
@@ -739,7 +820,8 @@ class WorkflowEngine:
|
||||
# otherwise stay silent (library-safe default).
|
||||
label = step_config.get("command", "") or step_type
|
||||
if self.on_step_start is not None:
|
||||
self.on_step_start(step_id, label)
|
||||
with self._callback_lock:
|
||||
self.on_step_start(step_id, label)
|
||||
|
||||
step_impl = registry.get(step_type)
|
||||
if not step_impl:
|
||||
@@ -772,8 +854,7 @@ class WorkflowEngine:
|
||||
"output": result.output,
|
||||
"status": result.status.value,
|
||||
}
|
||||
context.steps[step_id] = step_data
|
||||
state.step_results[step_id] = step_data
|
||||
self._record_result(context, state, step_id, step_data)
|
||||
|
||||
state.append_log(
|
||||
{
|
||||
@@ -900,40 +981,32 @@ class WorkflowEngine:
|
||||
):
|
||||
return
|
||||
if orig and ns_copy["id"] in context.steps:
|
||||
context.steps[orig] = context.steps[ns_copy["id"]]
|
||||
state.step_results[orig] = context.steps[ns_copy["id"]]
|
||||
self._record_result(
|
||||
context, state, orig,
|
||||
context.steps[ns_copy["id"]],
|
||||
)
|
||||
|
||||
# Fan-out: execute nested step template per item with unique IDs
|
||||
# Fan-out: execute the nested step template once per item. Honors
|
||||
# max_concurrency — <=1 runs sequentially (default, historical
|
||||
# behavior); >1 runs up to that many items concurrently. Either way
|
||||
# results are assembled in item order under the
|
||||
# parentId:templateId:index id grammar.
|
||||
if step_type == "fan-out":
|
||||
items = result.output.get("items", [])
|
||||
template = result.output.get("step_template", {})
|
||||
if template and items:
|
||||
fan_out_results = []
|
||||
for item_idx, item_val in enumerate(result.output["items"]):
|
||||
context.item = item_val
|
||||
# Per-item ID: parentId:templateId:index
|
||||
item_step = dict(template)
|
||||
base_id = item_step.get("id", "item")
|
||||
item_step["id"] = f"{step_id}:{base_id}:{item_idx}"
|
||||
self._execute_steps(
|
||||
[item_step], context, state, registry,
|
||||
step_offset=-1,
|
||||
)
|
||||
# Collect per-item result for fan-in
|
||||
item_result = context.steps.get(item_step["id"], {})
|
||||
fan_out_results.append(item_result.get("output", {}))
|
||||
if state.status in (
|
||||
RunStatus.PAUSED,
|
||||
RunStatus.FAILED,
|
||||
RunStatus.ABORTED,
|
||||
):
|
||||
break
|
||||
fan_out_results = self._run_fan_out(
|
||||
items, template, step_id, context, state, registry,
|
||||
result.output.get("max_concurrency", 1),
|
||||
)
|
||||
context.item = None
|
||||
# Preserve original output and add collected results
|
||||
fan_out_output = dict(result.output)
|
||||
fan_out_output["results"] = fan_out_results
|
||||
context.steps[step_id]["output"] = fan_out_output
|
||||
state.step_results[step_id]["output"] = fan_out_output
|
||||
# set_step_output updates the recorded dict under the run lock;
|
||||
# context.steps[step_id] is that same object, so it reflects the
|
||||
# change too — no separate (unlocked) context mutation needed.
|
||||
state.set_step_output(step_id, fan_out_output)
|
||||
if state.status in (
|
||||
RunStatus.PAUSED,
|
||||
RunStatus.FAILED,
|
||||
@@ -943,8 +1016,170 @@ class WorkflowEngine:
|
||||
else:
|
||||
# Empty items or no template — normalize output
|
||||
result.output["results"] = []
|
||||
context.steps[step_id]["output"] = result.output
|
||||
state.step_results[step_id]["output"] = result.output
|
||||
state.set_step_output(step_id, result.output)
|
||||
|
||||
def _run_fan_out(
|
||||
self,
|
||||
items: list[Any],
|
||||
template: dict[str, Any],
|
||||
step_id: str,
|
||||
context: StepContext,
|
||||
state: RunState,
|
||||
registry: dict[str, Any],
|
||||
max_concurrency: Any,
|
||||
) -> list[Any]:
|
||||
"""Run a fan-out template once per item; return per-item outputs in item order.
|
||||
|
||||
``max_concurrency`` <= 1 (the default) runs items sequentially, identical
|
||||
to the historical fan-out behavior. ``max_concurrency`` > 1 runs items on a
|
||||
bounded thread pool using a sliding submission window of that size: at most
|
||||
that many items are ever in flight, and no new item is launched once the run
|
||||
has reached a halting status, so a halt cannot keep starting queued work.
|
||||
|
||||
Results are always returned in item order (never completion order). On a
|
||||
halt (PAUSED/FAILED/ABORTED) the returned prefix is the items up to and
|
||||
including the first item *in item order* whose own execution halted the run
|
||||
— identical to the sequential path. Later items that have not yet started
|
||||
are cancelled; any already running are allowed to finish but their outputs
|
||||
are ignored. Halt is attributed per item from that item's recorded result
|
||||
(not the shared run status, which a concurrently-running later item may have
|
||||
already flipped), so the prefix never drops the actual halting item.
|
||||
|
||||
``max_concurrency`` is coerced with ``int()``; a value that cannot be
|
||||
coerced (``None``, a non-numeric string, …) or that coerces to <= 1 runs
|
||||
sequentially, while a numeric string like ``"4"`` or a float like ``4.0``
|
||||
is honored.
|
||||
"""
|
||||
if not items:
|
||||
return []
|
||||
|
||||
halting = (RunStatus.PAUSED, RunStatus.FAILED, RunStatus.ABORTED)
|
||||
try:
|
||||
workers = max(1, int(max_concurrency))
|
||||
except (TypeError, ValueError):
|
||||
workers = 1
|
||||
# Never spin up more workers than there is work — bounds a user-controlled
|
||||
# max_concurrency from over-allocating threads.
|
||||
workers = min(workers, len(items))
|
||||
|
||||
base_id = template.get("id", "item")
|
||||
|
||||
def item_id(idx: int) -> str:
|
||||
# Per-item ID grammar: parentId:templateId:index.
|
||||
return f"{step_id}:{base_id}:{idx}"
|
||||
|
||||
def run_item(idx: int, item_ctx: StepContext) -> Any:
|
||||
item_step = dict(template)
|
||||
item_step["id"] = item_id(idx)
|
||||
self._execute_steps(
|
||||
[item_step], item_ctx, state, registry, step_offset=-1,
|
||||
)
|
||||
# Read back through the context that was actually executed against,
|
||||
# not the outer closure — clearer and robust if StepContext copying
|
||||
# ever stops sharing the steps dict by reference.
|
||||
return item_ctx.steps.get(item_step["id"], {}).get("output", {})
|
||||
|
||||
# Sequential path — identical to the historical behavior.
|
||||
if workers <= 1:
|
||||
results: list[Any] = []
|
||||
for item_idx, item_val in enumerate(items):
|
||||
context.item = item_val
|
||||
results.append(run_item(item_idx, context))
|
||||
if state.status in halting:
|
||||
break
|
||||
return results
|
||||
|
||||
# Concurrent path — bounded sliding window; results assembled in item order.
|
||||
n = len(items)
|
||||
slots: list[Any] = [None] * n
|
||||
|
||||
def run_isolated(idx: int) -> Any:
|
||||
# Each item runs against its own context copy so context.item is not
|
||||
# clobbered across threads; the shared steps dict is written only on the
|
||||
# disjoint parentId:templateId:index key (GIL-safe on distinct keys).
|
||||
return run_item(idx, dataclasses.replace(context, item=items[idx]))
|
||||
|
||||
def item_halt_status(idx: int) -> RunStatus | None:
|
||||
# If THIS item's own execution halted the run, return the resulting run
|
||||
# status; else None. Decided from the item's own recorded result, not
|
||||
# the shared run status, so a later item's concurrent halt is never
|
||||
# misattributed here. Mirrors the sequential mapping: PAUSED -> PAUSED;
|
||||
# FAILED -> ABORTED when aborted, else FAILED, unless continue_on_error
|
||||
# routes around it.
|
||||
rec = context.steps.get(item_id(idx))
|
||||
if rec is None:
|
||||
# Ran but recorded nothing — only when the item failed before
|
||||
# record_step_result (e.g. an unknown step type returns early).
|
||||
# Every item runs the same template, so the shared run status is
|
||||
# this item's own outcome; attribute the halt to it.
|
||||
return state.status if state.status in halting else None
|
||||
status = rec.get("status")
|
||||
if status == StepStatus.PAUSED.value:
|
||||
return RunStatus.PAUSED
|
||||
if status == StepStatus.FAILED.value:
|
||||
out = rec.get("output") or {}
|
||||
if out.get("aborted"):
|
||||
return RunStatus.ABORTED
|
||||
if template.get("continue_on_error") is not True:
|
||||
return RunStatus.FAILED
|
||||
return None
|
||||
|
||||
# (halting item index, its run status) once a halt is attributed.
|
||||
halt: tuple[int, RunStatus] | None = None
|
||||
collected = 0
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
futures: dict[int, Future] = {}
|
||||
next_submit = 0
|
||||
for idx in range(n):
|
||||
# Refill the window: keep <= workers in flight, and stop launching
|
||||
# new items once the run is halting so a halt cannot keep starting
|
||||
# queued work. Already-submitted futures are still collected in
|
||||
# item order below.
|
||||
while (
|
||||
next_submit < n
|
||||
and len(futures) < workers
|
||||
and state.status not in halting
|
||||
):
|
||||
futures[next_submit] = pool.submit(run_isolated, next_submit)
|
||||
next_submit += 1
|
||||
|
||||
fut = futures.pop(idx, None)
|
||||
if fut is None:
|
||||
# Safety net: the window submits indices in order and the loop
|
||||
# breaks at the first halting item, so every collected index has
|
||||
# an in-flight future. Stop cleanly rather than raise if a future
|
||||
# change ever breaks that invariant.
|
||||
break
|
||||
try:
|
||||
slots[idx] = fut.result()
|
||||
except Exception:
|
||||
# A genuine exception escaping a step (not a normal step
|
||||
# FAILED, which sets state.status) must not be masked: cancel
|
||||
# outstanding work and re-raise — with a bare ``raise`` so the
|
||||
# original traceback is preserved — so the engine marks the run
|
||||
# failed instead of reporting a vacuous completion. The pool's
|
||||
# __exit__ still joins any already-running workers.
|
||||
for other in futures.values():
|
||||
other.cancel()
|
||||
raise
|
||||
collected = idx + 1
|
||||
halt_status = item_halt_status(idx)
|
||||
if halt_status is not None:
|
||||
# First halting item in item order: include it (slots[idx] is
|
||||
# already set), record its status, and cancel everything pending.
|
||||
halt = (idx, halt_status)
|
||||
for other in futures.values():
|
||||
other.cancel()
|
||||
break
|
||||
|
||||
if halt is not None:
|
||||
halted_at, halted_status = halt
|
||||
# A later in-flight item may have overwritten state.status before the
|
||||
# pool joined; restore the halting item's own outcome so the final run
|
||||
# status matches the sequential semantics.
|
||||
state.status = halted_status
|
||||
return slots[: halted_at + 1]
|
||||
return slots[:collected]
|
||||
|
||||
def _resolve_inputs(
|
||||
self,
|
||||
|
||||
@@ -2045,6 +2045,210 @@ class TestFanInStep:
|
||||
assert any("non-empty list" in e for e in errors)
|
||||
|
||||
|
||||
class TestFanOutConcurrency:
|
||||
"""Fan-out honors max_concurrency (WorkflowEngine._run_fan_out)."""
|
||||
|
||||
@staticmethod
|
||||
def _build(tmp_path, on_item=None):
|
||||
"""Wire an engine + run state to a probe step that echoes context.item.
|
||||
|
||||
Per-item output is ``{"seen": <item>}`` so order and per-thread item
|
||||
isolation are checkable. ``on_item(item)`` may run a side effect and
|
||||
optionally return a StepStatus to override COMPLETED (or raise).
|
||||
"""
|
||||
from specify_cli.workflows.base import (
|
||||
RunStatus,
|
||||
StepBase,
|
||||
StepContext,
|
||||
StepResult,
|
||||
StepStatus,
|
||||
)
|
||||
from specify_cli.workflows.engine import RunState, WorkflowEngine
|
||||
|
||||
class _ProbeStep(StepBase):
|
||||
type_key = "probe"
|
||||
|
||||
def execute(self, config, context):
|
||||
status = StepStatus.COMPLETED
|
||||
if on_item is not None:
|
||||
override = on_item(context.item)
|
||||
if override is not None:
|
||||
status = override
|
||||
return StepResult(status=status, output={"seen": context.item})
|
||||
|
||||
engine = WorkflowEngine(project_root=tmp_path)
|
||||
context = StepContext()
|
||||
state = RunState(run_id="r", workflow_id="w", project_root=tmp_path)
|
||||
state.status = RunStatus.RUNNING
|
||||
template = {"id": "impl", "type": "probe"}
|
||||
return engine, context, state, {"probe": _ProbeStep()}, template
|
||||
|
||||
def _run(self, tmp_path, items, max_concurrency, on_item=None):
|
||||
engine, context, state, registry, template = self._build(tmp_path, on_item)
|
||||
results = engine._run_fan_out(
|
||||
items, template, "fan", context, state, registry, max_concurrency
|
||||
)
|
||||
return results, state
|
||||
|
||||
def test_sequential_default_preserves_order(self, tmp_path):
|
||||
results, _ = self._run(tmp_path, list(range(5)), 1)
|
||||
assert results == [{"seen": i} for i in range(5)]
|
||||
|
||||
def test_concurrent_runs_all_items_in_item_order(self, tmp_path):
|
||||
results, _ = self._run(tmp_path, list(range(10)), 4)
|
||||
assert results == [{"seen": i} for i in range(10)]
|
||||
|
||||
def test_sequential_and_concurrent_agree(self, tmp_path):
|
||||
items = [{"n": i} for i in range(8)]
|
||||
seq, _ = self._run(tmp_path, items, 1)
|
||||
con, _ = self._run(tmp_path, items, 4)
|
||||
assert seq == con == [{"seen": {"n": i}} for i in range(8)]
|
||||
|
||||
def test_shuffled_completion_preserves_item_order(self, tmp_path):
|
||||
# Determinism keystone: completion order is forced to the exact REVERSE of
|
||||
# item order by an event chain (no sleeps) — item i blocks until item i+1
|
||||
# has finished, so item 0 completes LAST — yet results must still be in
|
||||
# item order. K == len(items) so all workers are in flight together.
|
||||
import threading
|
||||
|
||||
n = 4
|
||||
done = [threading.Event() for _ in range(n)]
|
||||
completion: list[int] = []
|
||||
clock = threading.Lock()
|
||||
|
||||
def on_item(item):
|
||||
if item + 1 < n:
|
||||
assert done[item + 1].wait(2.0), f"item {item + 1} never finished"
|
||||
with clock:
|
||||
completion.append(item)
|
||||
done[item].set()
|
||||
return None
|
||||
|
||||
results, _ = self._run(tmp_path, list(range(n)), n, on_item)
|
||||
assert results == [{"seen": i} for i in range(n)]
|
||||
assert completion == list(reversed(range(n)))
|
||||
|
||||
def test_concurrency_is_real(self, tmp_path):
|
||||
import threading
|
||||
|
||||
# Deterministic proof of real parallelism (no wall-clock threshold to
|
||||
# tune or flake): every item must reach the barrier before any may pass.
|
||||
# Sequential execution would block the first item forever — the barrier
|
||||
# times out, raises BrokenBarrierError, and fails the test.
|
||||
n = 4
|
||||
barrier = threading.Barrier(n, timeout=5)
|
||||
|
||||
def on_item(item):
|
||||
barrier.wait()
|
||||
return None
|
||||
|
||||
results, _ = self._run(tmp_path, list(range(n)), n, on_item)
|
||||
assert results == [{"seen": i} for i in range(n)]
|
||||
|
||||
@pytest.mark.parametrize("bad", [0, -1, None, "abc", 1.0])
|
||||
def test_invalid_max_concurrency_coerces_to_sequential(self, tmp_path, bad):
|
||||
results, _ = self._run(tmp_path, list(range(4)), bad)
|
||||
assert results == [{"seen": i} for i in range(4)]
|
||||
|
||||
def test_string_max_concurrency_is_honored(self, tmp_path):
|
||||
results, _ = self._run(tmp_path, list(range(4)), "2")
|
||||
assert results == [{"seen": i} for i in range(4)]
|
||||
|
||||
def test_context_item_isolation_across_threads(self, tmp_path):
|
||||
items = [{"id": f"x{i}"} for i in range(6)]
|
||||
results, _ = self._run(tmp_path, items, 6)
|
||||
assert [r["seen"]["id"] for r in results] == [f"x{i}" for i in range(6)]
|
||||
|
||||
def test_empty_items(self, tmp_path):
|
||||
results, _ = self._run(tmp_path, [], 4)
|
||||
assert results == []
|
||||
|
||||
def test_concurrent_halt_status_not_clobbered_by_later_item(self, tmp_path):
|
||||
# Item 1 PAUSES (first halting item in order); item 3 FAILS while in
|
||||
# flight. The final run status must be the halting item's (PAUSED), never
|
||||
# a later item's (FAILED) that raced after it — matching sequential.
|
||||
from specify_cli.workflows.base import RunStatus, StepStatus
|
||||
|
||||
def on_item(item):
|
||||
if item == 1:
|
||||
return StepStatus.PAUSED
|
||||
if item == 3:
|
||||
return StepStatus.FAILED
|
||||
return None
|
||||
|
||||
results, state = self._run(tmp_path, list(range(4)), 4, on_item)
|
||||
assert results == [{"seen": 0}, {"seen": 1}]
|
||||
assert state.status == RunStatus.PAUSED
|
||||
|
||||
def test_halt_on_failure_sequential_returns_prefix(self, tmp_path):
|
||||
from specify_cli.workflows.base import RunStatus, StepStatus
|
||||
|
||||
def on_item(item):
|
||||
return StepStatus.FAILED if item == 2 else None
|
||||
|
||||
results, state = self._run(tmp_path, list(range(5)), 1, on_item)
|
||||
assert len(results) == 3 # items 0,1,2 ran; 3,4 never dispatched
|
||||
assert results[2] == {"seen": 2}
|
||||
assert state.status == RunStatus.FAILED
|
||||
|
||||
def test_halt_on_failure_concurrent_includes_halting_item(self, tmp_path):
|
||||
# The concurrent prefix must match the sequential one: items up to and
|
||||
# INCLUDING the failing item (2), never a short prefix that drops it just
|
||||
# because a later in-flight item flipped the shared run status first.
|
||||
from specify_cli.workflows.base import RunStatus, StepStatus
|
||||
|
||||
def on_item(item):
|
||||
return StepStatus.FAILED if item == 2 else None
|
||||
|
||||
results, state = self._run(tmp_path, list(range(6)), 4, on_item)
|
||||
assert results == [{"seen": 0}, {"seen": 1}, {"seen": 2}]
|
||||
assert state.status == RunStatus.FAILED
|
||||
|
||||
def test_continue_on_error_item_does_not_halt_concurrent(self, tmp_path):
|
||||
# A failing item whose template sets continue_on_error must NOT truncate
|
||||
# the fan-out: every item still runs and is returned in order.
|
||||
from specify_cli.workflows.base import StepStatus
|
||||
|
||||
def on_item(item):
|
||||
return StepStatus.FAILED if item == 2 else None
|
||||
|
||||
engine, context, state, registry, template = self._build(tmp_path, on_item)
|
||||
template["continue_on_error"] = True
|
||||
results = engine._run_fan_out(
|
||||
list(range(5)), template, "fan", context, state, registry, 4
|
||||
)
|
||||
assert results == [{"seen": i} for i in range(5)]
|
||||
|
||||
def test_unknown_template_type_halts_concurrent_like_sequential(self, tmp_path):
|
||||
# A template whose type isn't registered fails fast and records no result;
|
||||
# the concurrent path must still attribute the halt to the first item and
|
||||
# return the same prefix as sequential — never run on as if completed.
|
||||
from specify_cli.workflows.base import RunStatus, StepContext
|
||||
from specify_cli.workflows.engine import RunState, WorkflowEngine
|
||||
|
||||
def fresh():
|
||||
state = RunState(run_id="r", workflow_id="w", project_root=tmp_path)
|
||||
state.status = RunStatus.RUNNING
|
||||
return WorkflowEngine(project_root=tmp_path), StepContext(), state
|
||||
|
||||
template = {"id": "impl", "type": "does-not-exist"}
|
||||
e1, c1, s1 = fresh()
|
||||
seq = e1._run_fan_out(list(range(5)), template, "fan", c1, s1, {}, 1)
|
||||
e2, c2, s2 = fresh()
|
||||
con = e2._run_fan_out(list(range(5)), template, "fan", c2, s2, {}, 4)
|
||||
assert seq == con == [{}] # halted at the first item; rest never returned
|
||||
assert s1.status == s2.status == RunStatus.FAILED
|
||||
|
||||
def test_first_exception_cancels_and_reraises(self, tmp_path):
|
||||
def on_item(item):
|
||||
if item == 0:
|
||||
raise ValueError("boom")
|
||||
return None
|
||||
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
self._run(tmp_path, list(range(4)), 2, on_item)
|
||||
|
||||
|
||||
class TestFanInWaitForValidation:
|
||||
"""fan-in wait_for must reference a declared step (no silent empty join)."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user