Skip to content

Task Executor

Executor responsible for running individual DAG tasks.

Single-level DAG executor with global task-based scheduling.

TaskExecutor

Bases: Executor

Single-level DAG executor with global task-based scheduling.

Unlike the original two-level Executor, this implementation: - Uses a single global task queue instead of context queue - Controls concurrency with a single global semaphore - Executes one task at a time per worker - Provides priority-based scheduling at task level

Key benefits: - Simpler architecture with single-level concurrency control - Better resource utilization through global scheduling - Easier configuration (just max_concurrent_tasks)

Source code in shutils/dag/task_executor.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
class TaskExecutor(Executor):
    """Single-level DAG executor with global task-based scheduling.

    Unlike the original two-level Executor, this implementation:
    - Uses a single global task queue instead of context queue
    - Controls concurrency with a single global semaphore
    - Executes one task at a time per worker
    - Provides priority-based scheduling at task level

    Key benefits:
    - Simpler architecture with single-level concurrency control
    - Better resource utilization through global scheduling
    - Easier configuration (just max_concurrent_tasks)
    """

    def __init__(
        self,
        dag: DAG,
        runtime: Runtime | None = None,
        config: ExecutorConfig | None = None,
    ):
        """Initialize the TaskExecutor.

        Args:
            dag: The DAG to execute.
            runtime: Optional runtime for tracking active contexts.
            config: Executor configuration.
        """
        if config is None:
            config = ExecutorConfig()
        super().__init__(dag, runtime, config)

        # Task queue for global scheduling
        self._task_queue = TaskPriorityQueue()

        # Worker count defaults to max_concurrent_tasks
        self._num_workers = self._config.task_worker_num


    async def run(
        self,
        input_context: Context | list[Context] | None = None
    ) -> list[OutputContext]:
        """Execute the DAG with given input contexts.

        Args:
            input_context: Single context, list of contexts, or None for default

        Returns:
            List of output contexts
        """
        if input_context is None:
            input_context = [Context(self.runtime)]
        elif isinstance(input_context, Context):
            input_context = [input_context]
        elif isinstance(input_context, list):
            pass
        else:
            raise ValueError("context must be a Context or a list of Context")

        logger.info(f"[SimplifiedExecutor.run]: length: {len(input_context)}, input: {input_context}")

        # Initialize input contexts
        for context in input_context:
            await context.async_context.complete(self.dag.in_task)
            # Enqueue all available tasks with FIFO_HIGH priority
            await self._task_queue.async_put_context_tasks(
                context,
                TaskPriority.FIFO_HIGH
            )

        logger.info("[SimplifiedExecutor.run]: initial tasks enqueued")

        # Create environment
        env = Environment(self.runtime, self._process_pool, self.dag)

        # Start worker pool
        worker_tasks = [
            asyncio.create_task(self._worker_loop(idx, env))
            for idx in range(self._num_workers)
        ]

        # Wait for all workers to complete
        output = await asyncio.gather(*worker_tasks)

        # Collect all output contexts
        output_context = []
        for output_context_list in output:
            output_context.extend(output_context_list)

        # Shutdown tasks
        for task in self.dag.tasks.values():
            if isinstance(task, ShutdownTask):
                task.shutdown()
            elif isinstance(task, AsyncShutdownTask):
                await task.shutdown()

        logger.info(f"[SimplifiedExecutor.run]: execution complete, outputs: {len(output_context)}")
        return output_context

    @asynccontextmanager
    async def check_get_task(self, timeout: float | None = None, use_counter: bool = True) -> AsyncGenerator[TaskItem]:
        """Context manager that gets a task from the queue with optional timeout.

        Args:
            timeout: Seconds to wait for a task.
            use_counter: Whether to check the runtime counter for stop condition.
        """
        counter = self.runtime.counter
        if not use_counter or counter > 0:
            async with asyncio.timeout(timeout):
                task = await self._task_queue.async_get_task()
                yield task
        else:
            yield TaskItem(TaskPriority.FIFO_HIGH, 0, StopContext(), TaskBase(lambda ctx: None ))

    async def _worker_loop(
        self,
        worker_id: int,
        env: Environment
    ) -> list[OutputContext]:
        """Single worker that continuously fetches and executes tasks.

        Flow:
        1. Acquire global semaphore (limit total concurrent tasks)
        2. Get TaskItem from queue (with timeout)
        3. Execute the task
        4. Postprocess the context (GC, bypass)
        5. Check for newly available tasks in the same context
        6. If new tasks found, enqueue them with LIFO_HIGH priority
        7. Release semaphore
        8. Loop until stop condition
        """
        output_contexts: list[OutputContext] = []
        worker_storage = {}

        while True:
            try:
                # Get next task with timeout
                async with self.check_get_task(self._config.context_queue_timeout) as task_item:
                    in_context = task_item.context
                    task = task_item.task
                    logger.debug(f"[Worker{worker_id}]: get task[{task_item}] from context[{in_context}]")
                    if isinstance(in_context, StopContext):
                        logger.info(f"[Worker{worker_id}]: get StopContext, break")
                        break
                    if in_context.is_destory():
                        logger.error(f"[Worker{worker_id}]: Context {in_context} is destory, skip")
                        continue

                    avaliable_tasks = await in_context.async_task_state.avaliable_task()
                    if not avaliable_tasks:
                        logger.error(
                            f"[Worker{worker_id}]: Context {in_context} "
                            "do not have avaliable task, will destory"
                        )
                        await in_context.async_context.destory()
                        continue
            except TimeoutError:
                logger.debug(f"[Worker{worker_id}]: Queue timeout, waiting...")
                continue


            # Set worker local storage
            token = _worker_context_var.set(worker_storage)
            try:
                # Execute single task
                logger.debug(f"[Worker{worker_id}]: Executing {task} for {in_context}")
                output_contexts_list = await self._run_task(
                    worker_id,
                    task,
                    in_context,
                    env
                )
                output_contexts_list = [ context for context in output_contexts_list if not context.freezing ]

                # Handle task output
                for output_context in output_contexts_list:
                    if isinstance(output_context, OutputContext):
                        output_contexts.append(output_context)
                    elif output_context == in_context:
                        # Same context: check for more tasks and put them back with LIFO_HIGH
                        # This ensures the same DAG path is prioritized
                        logger.debug(f"[Worker{worker_id}]: Context {in_context} continues, requeueing with LIFO_HIGH")
                        count = await self._task_queue.async_put_context_tasks(
                            in_context,
                            TaskPriority.LIFO_HIGH
                        )
                        if count == 0:
                            # No more tasks in this context, destroy it and stop tracking
                            logger.debug(f"[Worker{worker_id}]: Context {in_context} has no more tasks, destroying")
                            await in_context.async_context.destory()
                    elif isinstance(output_context, LoopContext):
                        # Loop context: use FIFO_LOW priority to avoid starving other contexts
                        logger.debug(f"[Worker{worker_id}]: LoopContext {output_context}, enqueuing with FIFO_LOW")
                        await self._task_queue.async_put_context_tasks(
                            output_context,
                            TaskPriority.FIFO_LOW
                        )
                    else:
                        # New context: enqueue its tasks with FIFO_HIGH priority
                        logger.debug(f"[Worker{worker_id}]: New context {output_context}, enqueuing with FIFO_HIGH")
                        await self._task_queue.async_put_context_tasks(
                            output_context,
                            TaskPriority.FIFO_HIGH
                        )

                # Postprocess context (GC, bypass)
                await self._context_postprocess(in_context, output_contexts_list, task)

            except Exception as e:
                logger.error(f"[Worker{worker_id}]: Error processing task {task}: {e}")
            finally:
                _worker_context_var.reset(token)

        return output_contexts

    async def _run_task(
        self, idx: int, task: TaskBase, in_context: Context, env: Environment
    ) -> list[Context]:
        if task in in_context.awake_time:
            if in_context.awake_time[task] > time.time():
                logger.debug(f"{in_context} cannot awake now")
                await self._task_queue.async_put_task(in_context, task)
                return []
            logger.debug(f"{in_context} can awake now")
            in_context.awake_time.pop(task)

        context_list = []
        try:
            logger.debug(f"[Worker{idx}]: {in_context} begin running {task}")
            if isinstance(task, ForegroundTask):
                if isinstance(task, SyncTask):
                    context_list = task(in_context, env)
                else:
                    raise ValueError(f"[Worker{idx}]: Unknown task type in forground mode: {type(task)}")
            else:
                if isinstance(task, AsyncTask):
                    context_list = await task(in_context, env)
                elif isinstance(task, SyncTask):
                    if self.__thread_pool:
                        loop = asyncio.get_running_loop()
                        context_list = await loop.run_in_executor(self.__thread_pool, task, in_context, env)
                    else:
                        context_list = await asyncio.to_thread(task, in_context, env)
                else:
                    raise ValueError(f"[Worker{idx}]: Unknown task type: {type(task)}")
            logger.debug(f"[Worker{idx}]: {in_context} running {task} done")
        except Exception as e:
            if task.config.retry_times > 0 and await in_context.async_task_state.retry(task) <= task.config.retry_times:
                    if task.config.retry_interval != 0:
                        if callable(task.config.retry_interval):
                            interval = task.config.retry_interval(in_context)
                        else:
                            interval = task.config.retry_interval
                        in_context._awake_interval(interval, task)
                    await self._task_queue.async_put_task(in_context, task)
                    return []
            logger.error(f"[Worker{idx}]: {in_context} running {task} failed, error: {type(e).__name__}: {e}")
            traceback.print_exc()
            in_context.error_info = ErrorInfo(has_error=True, exception=e, error_node=task.id)
            await in_context.async_context.destory()

        for idx, out_context in enumerate(context_list):
            if not isinstance(out_context, LoopContext) and not isinstance(out_context, RateLimitContext):
                await out_context.async_context.complete(task)
            if isinstance(out_context, RateLimitContext):
                context_list[idx] = out_context.context

        return context_list

    async def _context_postprocess(
        self,
        input_context: Context,
        output_contexts: list[Context],
        running_task: TaskBase
    ) -> None:
        """Postprocess context after task execution.

        Handles:
        1. Context GC (destroy contexts no longer referenced)
        2. Bypass logic (skip tasks in new contexts)

        Simplified compared to original because we now handle single
        task execution per call.

        Args:
            input_context: The context before task execution
            output_contexts: Contexts produced by the task
            running_task: The task that just completed
        """
        if not self._config.enable_context_gc and not self._config.enable_context_bypass:
            return

        # Build sets for comparison
        input_context_set = {input_context}
        output_context_set = {
            ctx for ctx in output_contexts
            if not ctx.is_destory()
        }

        # Context GC logic
        if self._config.enable_context_gc:
            await self._collect_referenced_contexts(output_context_set)

            # Destroy contexts in input but not in output
            for context in input_context_set:
                if context not in output_context_set and not context.is_destory():
                    logger.debug(f"[ContextGC]: {context} no longer referenced, destroying")
                    await context.async_context.destory()

        # Bypass logic (simplified)
        if self._config.enable_context_bypass:
            # For each new context (not in input), apply bypass
            new_contexts = output_context_set - input_context_set
            for new_context in new_contexts:
                if isinstance(new_context, LoopContext):
                    continue

                # Get tasks that were completed in the input path
                completed_in_input = input_context._completed_tasks
                if running_task in completed_in_input:
                    bypass_tasks = self.dag._get_bypass_tasks({running_task})
                else:
                    bypass_tasks = self.dag._get_bypass_tasks(completed_in_input)

                logger.debug(
                    f"[ContextBypass]: {new_context} skipping bypass tasks: {bypass_tasks}"
                )
                for bypass_task in bypass_tasks:
                    await new_context.async_context.complete(bypass_task)

    async def _collect_referenced_contexts(self, context_set: set[Context]) -> None:
        """Helper to collect all parent and child contexts into the set.

        Used for GC and bypass logic.

        Args:
            context_set: Set of contexts to collect from, will be modified in place
        """
        additional_contexts: list[Context] = []

        for context in context_set:
            # Skip OutputContext for GC/bypass - it's a terminal context
            if isinstance(context, OutputContext):
                continue

            # Get parent
            parent_result = await context.async_context.parent_context()
            if parent_result is not None:
                parent_context = parent_result.context
                if parent_context not in context_set and not parent_context.is_destory():
                    additional_contexts.append(parent_context)

            # Get children
            async for child_wrapper in context.async_context.iter_child_context():
                child_context = child_wrapper.context
                if child_context not in context_set and not child_context.is_destory():
                    additional_contexts.append(child_context)

        # Recursively collect newly found contexts
        if additional_contexts:
            for ctx in additional_contexts:
                context_set.add(ctx)
            await self._collect_referenced_contexts(context_set)

__init__(dag, runtime=None, config=None)

Initialize the TaskExecutor.

Parameters:

Name Type Description Default
dag DAG

The DAG to execute.

required
runtime Runtime | None

Optional runtime for tracking active contexts.

None
config ExecutorConfig | None

Executor configuration.

None
Source code in shutils/dag/task_executor.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def __init__(
    self,
    dag: DAG,
    runtime: Runtime | None = None,
    config: ExecutorConfig | None = None,
):
    """Initialize the TaskExecutor.

    Args:
        dag: The DAG to execute.
        runtime: Optional runtime for tracking active contexts.
        config: Executor configuration.
    """
    if config is None:
        config = ExecutorConfig()
    super().__init__(dag, runtime, config)

    # Task queue for global scheduling
    self._task_queue = TaskPriorityQueue()

    # Worker count defaults to max_concurrent_tasks
    self._num_workers = self._config.task_worker_num

check_get_task(timeout=None, use_counter=True) async

Context manager that gets a task from the queue with optional timeout.

Parameters:

Name Type Description Default
timeout float | None

Seconds to wait for a task.

None
use_counter bool

Whether to check the runtime counter for stop condition.

True
Source code in shutils/dag/task_executor.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@asynccontextmanager
async def check_get_task(self, timeout: float | None = None, use_counter: bool = True) -> AsyncGenerator[TaskItem]:
    """Context manager that gets a task from the queue with optional timeout.

    Args:
        timeout: Seconds to wait for a task.
        use_counter: Whether to check the runtime counter for stop condition.
    """
    counter = self.runtime.counter
    if not use_counter or counter > 0:
        async with asyncio.timeout(timeout):
            task = await self._task_queue.async_get_task()
            yield task
    else:
        yield TaskItem(TaskPriority.FIFO_HIGH, 0, StopContext(), TaskBase(lambda ctx: None ))

run(input_context=None) async

Execute the DAG with given input contexts.

Parameters:

Name Type Description Default
input_context Context | list[Context] | None

Single context, list of contexts, or None for default

None

Returns:

Type Description
list[OutputContext]

List of output contexts

Source code in shutils/dag/task_executor.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
async def run(
    self,
    input_context: Context | list[Context] | None = None
) -> list[OutputContext]:
    """Execute the DAG with given input contexts.

    Args:
        input_context: Single context, list of contexts, or None for default

    Returns:
        List of output contexts
    """
    if input_context is None:
        input_context = [Context(self.runtime)]
    elif isinstance(input_context, Context):
        input_context = [input_context]
    elif isinstance(input_context, list):
        pass
    else:
        raise ValueError("context must be a Context or a list of Context")

    logger.info(f"[SimplifiedExecutor.run]: length: {len(input_context)}, input: {input_context}")

    # Initialize input contexts
    for context in input_context:
        await context.async_context.complete(self.dag.in_task)
        # Enqueue all available tasks with FIFO_HIGH priority
        await self._task_queue.async_put_context_tasks(
            context,
            TaskPriority.FIFO_HIGH
        )

    logger.info("[SimplifiedExecutor.run]: initial tasks enqueued")

    # Create environment
    env = Environment(self.runtime, self._process_pool, self.dag)

    # Start worker pool
    worker_tasks = [
        asyncio.create_task(self._worker_loop(idx, env))
        for idx in range(self._num_workers)
    ]

    # Wait for all workers to complete
    output = await asyncio.gather(*worker_tasks)

    # Collect all output contexts
    output_context = []
    for output_context_list in output:
        output_context.extend(output_context_list)

    # Shutdown tasks
    for task in self.dag.tasks.values():
        if isinstance(task, ShutdownTask):
            task.shutdown()
        elif isinstance(task, AsyncShutdownTask):
            await task.shutdown()

    logger.info(f"[SimplifiedExecutor.run]: execution complete, outputs: {len(output_context)}")
    return output_context