Skip to content

Serve Executor

Executor for serving DAG tasks in online environments.

Server-style DAG executor that supports submitting tasks and retrieving results asynchronously.

ContextStatus

Bases: Enum

Status of a submitted context in the ServeExecutor.

Source code in shutils/dag/serve_executor.py
26
27
28
29
30
31
class ContextStatus(Enum):
    """Status of a submitted context in the ServeExecutor."""

    INIT = "INIT"
    RUNNING = "RUNNING"
    FINISH = "FINISH"

ServeExecutor

Bases: Executor

Long-running DAG executor that supports submitting contexts and awaiting results.

Source code in shutils/dag/serve_executor.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
class ServeExecutor(Executor):
    """Long-running DAG executor that supports submitting contexts and awaiting results."""

    @override
    def __init__(
        self,
        dag: DAG,
        runtime: Runtime | None = None,
        config: ExecutorConfig | None = None,
    ):
        """Initialize the serve executor.

        Args:
            dag: The DAG to execute.
            runtime: Optional runtime for tracking active contexts.
            config: Executor configuration.
        """
        if config is None:
            config = ExecutorConfig()
        super().__init__(dag, runtime, config)
        self.context_status_dict: dict[str, ContextStatus] = {}
        self.context_result_trackers: dict[str, asyncio.Future] = {}
        self.lock = AsyncRWLock()


    async def run(self) -> None:    # type: ignore[override]
        """Start the worker pool. Does not return until shutdown."""
        env = Environment(self.runtime, self._process_pool, self.dag)
        worker_tasks = [
            asyncio.create_task(self._worker_loop(idx, env))
            for idx in range(self._config.context_worker_num)
        ]
        await asyncio.gather(*worker_tasks)

        for task in self.dag.tasks.values():
            if isinstance(task, ShutdownTask):
                task.shutdown()
            elif isinstance(task, AsyncShutdownTask):
                await task.shutdown()

    @override
    async def _worker_loop(self, idx: int, env: Environment) -> list[OutputContext]:
        worker_storage = {}
        while True:
            try:
                async with self.check_get_context(self._config.context_queue_timeout, False) as in_context:
                    logger.debug(f"[Worker{idx}]: get context[{in_context}] from async queue done")
                    if isinstance(in_context, StopContext):
                        logger.error(
                            f"[Worker{idx}]: get unexpected StopContext, "
                            "ServeExecutor will not stop, please check your code, skip"
                        )
                        continue
                    if in_context.is_destory():
                        logger.error(f"[Worker{idx}]: Context {in_context} is destory, skip")
                        continue

                    async with self.lock.write():
                        if in_context.id in self.context_status_dict:
                            self.context_status_dict[in_context.id] = ContextStatus.RUNNING

                    avaliable_tasks = await in_context.async_task_state.avaliable_task()
                    if not avaliable_tasks:
                        logger.error(f"[Worker{idx}]: Context {in_context} do not have avaliable task, will destory")
                        await in_context.async_context.destory()
                        continue

                tasks = [
                    self._run_task(idx, sub_idx, task, in_context, env) for sub_idx, task in enumerate(avaliable_tasks)
                ]
                if self._config.task_worker_num > 0:
                    semaphore = asyncio.Semaphore(self._config.task_worker_num)
                    tasks = [self._async_limit(semaphore, task) for task in tasks]
                token = _worker_context_var.set(worker_storage)
                context_list_list = await asyncio.gather(*tasks)
                _worker_context_var.reset(token)
                context_list = [context for context_list in context_list_list for context in context_list]
                if len(context_list_list) > 1:
                    # need deduplicate
                    context_list = list(set(context_list))
                context_list = [ context for context in context_list if not context.freezing ]
                await self._context_postprocess(in_context, context_list, avaliable_tasks)
                for out_context in context_list:
                    if isinstance(out_context, OutputContext):
                        if out_context.id not in self.context_status_dict:
                            logger.error(
                                f"[Worker{idx}]: OutputContext {out_context} id {out_context.id} "
                                "not in context_status_dict, please check your code, skip"
                            )
                            continue

                        async with self.lock.write():
                            context_status= self.context_status_dict[out_context.id]
                            if context_status == ContextStatus.FINISH:
                                logger.error(
                                    f"[Worker{idx}]: OutputContext {out_context} id {out_context.id} "
                                    "already FINISH, please check your code and "
                                    "comfirm only one output for one input, skip"
                                )
                                continue
                            self.context_status_dict[out_context.id] = ContextStatus.FINISH
                            self.context_result_trackers[out_context.id].set_result(out_context.asdit())
                    else:
                        if out_context == in_context:
                            await self._context_queue.async_queue.put(out_context, ContextPriority.LIFO)
                        elif isinstance(out_context, LoopContext):
                            await self._context_queue.async_queue.put(out_context, ContextPriority.FIFO_LOW)
                        else:
                            await self._context_queue.async_queue.put(out_context, ContextPriority.FIFO_HIGH)
            except TimeoutError:
                logger.debug(f"[Worker{idx}]: context queue get timeout, skip")
                continue

        return []

    async def submit_task(self, context: Context) -> str:
        """Submit a context for execution and return its ID.

        Args:
            context: The input context to process.

        Returns:
            The context ID for tracking status and results.

        Raises:
            ValueError: If a context with the same ID has already been submitted.
        """
        async with self.lock.write():
            if context.id in self.context_status_dict:
                raise ValueError(f"Context {context} already submitted.")
            self.context_status_dict[context.id] = ContextStatus.INIT
            future = asyncio.Future()
            self.context_result_trackers[context.id] = future

        await context.async_context.complete(self.dag.in_task)
        await self._context_queue.async_queue.put(context, ContextPriority.FIFO_HIGH)
        return context.id

    async def get_task_status(self, task_id: str):
        """Get the current status of a submitted context.

        Args:
            task_id: The context ID returned by submit_task.

        Returns:
            The ContextStatus enum value.

        Raises:
            ValueError: If the task_id is not found.
        """
        async with self.lock.read():
            if task_id not in self.context_status_dict:
                raise ValueError(f"Task id {task_id} not found.")
            return self.context_status_dict[task_id]

    async def get_task_result(self, task_id: str) -> dict:
        """Wait for and return the result of a submitted context.

        Args:
            task_id: The context ID returned by submit_task.

        Returns:
            The output data dictionary.

        Raises:
            ValueError: If the task_id is not found.
        """
        async with self.lock.read():
            if task_id not in self.context_result_trackers:
                raise ValueError(f"Task id {task_id} not found.")
            future = self.context_result_trackers[task_id]

        result = await future
        async with self.lock.write():
            # clean up
            del self.context_status_dict[task_id]
            del self.context_result_trackers[task_id]

        return result

__init__(dag, runtime=None, config=None)

Initialize the serve executor.

Parameters:

Name Type Description Default
dag DAG

The DAG to execute.

required
runtime Runtime | None

Optional runtime for tracking active contexts.

None
config ExecutorConfig | None

Executor configuration.

None
Source code in shutils/dag/serve_executor.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@override
def __init__(
    self,
    dag: DAG,
    runtime: Runtime | None = None,
    config: ExecutorConfig | None = None,
):
    """Initialize the serve executor.

    Args:
        dag: The DAG to execute.
        runtime: Optional runtime for tracking active contexts.
        config: Executor configuration.
    """
    if config is None:
        config = ExecutorConfig()
    super().__init__(dag, runtime, config)
    self.context_status_dict: dict[str, ContextStatus] = {}
    self.context_result_trackers: dict[str, asyncio.Future] = {}
    self.lock = AsyncRWLock()

get_task_result(task_id) async

Wait for and return the result of a submitted context.

Parameters:

Name Type Description Default
task_id str

The context ID returned by submit_task.

required

Returns:

Type Description
dict

The output data dictionary.

Raises:

Type Description
ValueError

If the task_id is not found.

Source code in shutils/dag/serve_executor.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
async def get_task_result(self, task_id: str) -> dict:
    """Wait for and return the result of a submitted context.

    Args:
        task_id: The context ID returned by submit_task.

    Returns:
        The output data dictionary.

    Raises:
        ValueError: If the task_id is not found.
    """
    async with self.lock.read():
        if task_id not in self.context_result_trackers:
            raise ValueError(f"Task id {task_id} not found.")
        future = self.context_result_trackers[task_id]

    result = await future
    async with self.lock.write():
        # clean up
        del self.context_status_dict[task_id]
        del self.context_result_trackers[task_id]

    return result

get_task_status(task_id) async

Get the current status of a submitted context.

Parameters:

Name Type Description Default
task_id str

The context ID returned by submit_task.

required

Returns:

Type Description

The ContextStatus enum value.

Raises:

Type Description
ValueError

If the task_id is not found.

Source code in shutils/dag/serve_executor.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
async def get_task_status(self, task_id: str):
    """Get the current status of a submitted context.

    Args:
        task_id: The context ID returned by submit_task.

    Returns:
        The ContextStatus enum value.

    Raises:
        ValueError: If the task_id is not found.
    """
    async with self.lock.read():
        if task_id not in self.context_status_dict:
            raise ValueError(f"Task id {task_id} not found.")
        return self.context_status_dict[task_id]

run() async

Start the worker pool. Does not return until shutdown.

Source code in shutils/dag/serve_executor.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
async def run(self) -> None:    # type: ignore[override]
    """Start the worker pool. Does not return until shutdown."""
    env = Environment(self.runtime, self._process_pool, self.dag)
    worker_tasks = [
        asyncio.create_task(self._worker_loop(idx, env))
        for idx in range(self._config.context_worker_num)
    ]
    await asyncio.gather(*worker_tasks)

    for task in self.dag.tasks.values():
        if isinstance(task, ShutdownTask):
            task.shutdown()
        elif isinstance(task, AsyncShutdownTask):
            await task.shutdown()

submit_task(context) async

Submit a context for execution and return its ID.

Parameters:

Name Type Description Default
context Context

The input context to process.

required

Returns:

Type Description
str

The context ID for tracking status and results.

Raises:

Type Description
ValueError

If a context with the same ID has already been submitted.

Source code in shutils/dag/serve_executor.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
async def submit_task(self, context: Context) -> str:
    """Submit a context for execution and return its ID.

    Args:
        context: The input context to process.

    Returns:
        The context ID for tracking status and results.

    Raises:
        ValueError: If a context with the same ID has already been submitted.
    """
    async with self.lock.write():
        if context.id in self.context_status_dict:
            raise ValueError(f"Context {context} already submitted.")
        self.context_status_dict[context.id] = ContextStatus.INIT
        future = asyncio.Future()
        self.context_result_trackers[context.id] = future

    await context.async_context.complete(self.dag.in_task)
    await self._context_queue.async_queue.put(context, ContextPriority.FIFO_HIGH)
    return context.id