Source code for xoscar.api

# Copyright 2022-2023 XProbe Inc.
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import asyncio
import functools
import inspect
import logging
import threading
import uuid
from collections import defaultdict
from numbers import Number
from typing import (
    TYPE_CHECKING,
    Any,
    Awaitable,
    Dict,
    Generic,
    List,
    Optional,
    Tuple,
    Type,
    TypeVar,
    Union,
)
from urllib.parse import urlparse

from .aio import AioFileObject
from .backend import get_backend
from .context import get_context
from .core import ActorRef, BufferRef, FileObjectRef, _Actor, _StatelessActor

if TYPE_CHECKING:
    from .backends.config import ActorPoolConfig
    from .backends.pool import MainActorPoolType

logger = logging.getLogger(__name__)


[docs] async def create_actor( actor_cls: Type, *args, uid=None, address=None, **kwargs ) -> ActorRef: # TODO: explain default values. """ Create an actor. Parameters ---------- actor_cls : Actor Actor class. args : tuple Positional arguments for ``actor_cls.__init__``. uid : identifier, default=None Actor identifier. address : str, default=None Address to locate the actor. kwargs : dict Keyword arguments for ``actor_cls.__init__``. Returns ------- ActorRef """ ctx = get_context() return await ctx.create_actor(actor_cls, *args, uid=uid, address=address, **kwargs)
async def has_actor(actor_ref: ActorRef) -> bool: """ Check if the given actor exists. Parameters ---------- actor_ref : ActorRef Reference to an actor. Returns ------- bool """ ctx = get_context() return await ctx.has_actor(actor_ref)
[docs] async def destroy_actor(actor_ref: ActorRef): """ Destroy an actor by its reference. Parameters ---------- actor_ref : ActorRef Reference to an actor. Returns ------- bool """ ctx = get_context() return await ctx.destroy_actor(actor_ref)
[docs] async def actor_ref(*args, **kwargs) -> ActorRef: """ Create a reference to an actor. Returns ------- ActorRef """ # TODO: refine the argument list for better user experience. ctx = get_context() return await ctx.actor_ref(*args, **kwargs)
[docs] async def kill_actor(actor_ref: ActorRef): # TODO: explain the meaning of 'kill' """ Forcefully kill an actor. It's important to note that this operation is potentially dangerous as it may result in the termination of other associated actors. Only proceed if you understand the potential impact on associated actors and can handle any resulting consequences. Parameters ---------- actor_ref : ActorRef Reference to an actor. Returns ------- bool """ ctx = get_context() return await ctx.kill_actor(actor_ref)
[docs] async def create_actor_pool( address: str, n_process: int | None = None, **kwargs ) -> "MainActorPoolType": # TODO: explain default values. """ Create an actor pool. Parameters ---------- address: str Address of the actor pool. n_process: Optional[int], default=None Number of processes. kwargs : dict Other keyword arguments for the actor pool. Returns ------- MainActorPoolType """ if address is None: raise ValueError("address has to be provided") if "://" not in address: scheme = None else: scheme = urlparse(address).scheme or None return await get_backend(scheme).create_actor_pool( address, n_process=n_process, **kwargs )
async def wait_for(fut: Awaitable[Any], timeout: int | float | None = None) -> Any: # asyncio.wait_for() on Xoscar actor call cannot work as expected, # because when time out, the future will be cancelled, but an actor call will catch this error, # and send a CancelMessage to the dest pool, if the CancelMessage cannot be processed correctly(e.g. the dest pool hangs), # the time out will never happen. Thus this PR added a new API so that no matter the CancelMessage delivered or not, # the timeout will happen as expected. loop = asyncio.get_running_loop() new_fut = loop.create_future() task = asyncio.ensure_future(fut) def on_done(f: asyncio.Future): if new_fut.done(): return if f.cancelled(): new_fut.cancel() elif f.exception(): new_fut.set_exception(f.exception()) # type: ignore else: new_fut.set_result(f.result()) task.add_done_callback(on_done) try: return await asyncio.wait_for(new_fut, timeout) except asyncio.TimeoutError: if not task.done(): try: task.cancel() # Try to cancel without waiting except Exception: logger.warning("Failed to cancel task", exc_info=True) raise def buffer_ref(address: str, buffer: Any) -> BufferRef: """ Init buffer ref according address and buffer. Parameters ---------- address The address of the buffer. buffer CPU / GPU buffer. Need to support for slicing and retrieving the length. Returns ---------- BufferRef obj. """ ctx = get_context() return ctx.buffer_ref(address, buffer) def file_object_ref(address: str, fileobj: AioFileObject) -> FileObjectRef: """ Init file object ref according to address and aio file obj. Parameters ---------- address The address of the file obj. fileobj Aio file object. Returns ---------- FileObjectRef obj. """ ctx = get_context() return ctx.file_object_ref(address, fileobj) async def copy_to( local_buffers_or_fileobjs: list, remote_refs: List[Union[BufferRef, FileObjectRef]], block_size: Optional[int] = None, ): """ Copy data from local buffers to remote buffers or copy local file objects to remote file objects. Parameters ---------- local_buffers_or_fileobjs Local buffers or file objects. remote_refs Remote buffer refs or file object refs. block_size Transfer block size when non-ucx """ ctx = get_context() return await ctx.copy_to(local_buffers_or_fileobjs, remote_refs, block_size)
[docs] async def wait_actor_pool_recovered(address: str, main_pool_address: str | None = None): """ Wait until the specified actor pool has recovered from failure. Parameters ---------- address: str Address of the actor pool. main_pool_address: Optional[str], default=None Address of corresponding main actor pool. Returns ------- """ ctx = get_context() return await ctx.wait_actor_pool_recovered(address, main_pool_address)
[docs] async def get_pool_config(address: str) -> "ActorPoolConfig": """ Get the configuration of specified actor pool. Parameters ---------- address: str Address of the actor pool. Returns ------- ActorPoolConfig """ ctx = get_context() return await ctx.get_pool_config(address)
def setup_cluster(address_to_resources: Dict[str, Dict[str, Number]]): scheme_to_address_resources: defaultdict[str | None, dict] = defaultdict(dict) for address, resources in address_to_resources.items(): if address is None: raise ValueError("address has to be provided") if "://" not in address: scheme = None else: scheme = urlparse(address).scheme or None scheme_to_address_resources[scheme][address] = resources for scheme, address_resources in scheme_to_address_resources.items(): get_backend(scheme).get_driver_cls().setup_cluster(address_resources) T = TypeVar("T") class IteratorWrapper(Generic[T]): def __init__(self, uid: str, actor_addr: str, actor_uid: str): self._uid = uid self._actor_addr = actor_addr self._actor_uid = actor_uid self._actor_ref = None self._gc_destroy = True async def destroy(self): if self._actor_ref is None: self._actor_ref = await actor_ref( address=self._actor_addr, uid=self._actor_uid ) assert self._actor_ref is not None return await self._actor_ref.__xoscar_destroy_generator__(self._uid) def __del__(self): # It's not a good idea to spawn a new thread and join in __del__, # but currently it's the only way to GC the generator. # TODO(codingl2k1): This __del__ may hangs if the program is exiting. if self._gc_destroy: thread = threading.Thread( target=asyncio.run, args=(self.destroy(),), daemon=True ) thread.start() thread.join() def __aiter__(self): return self def __getstate__(self): # Transfer gc destroy during serialization. state = self.__dict__.copy() state["_gc_destroy"] = True self._gc_destroy = False return state async def __anext__(self) -> T: if self._actor_ref is None: self._actor_ref = await actor_ref( address=self._actor_addr, uid=self._actor_uid ) try: assert self._actor_ref is not None return await self._actor_ref.__xoscar_next__(self._uid) except Exception as e: if "StopIteration" in str(e): raise StopAsyncIteration else: raise class AsyncActorMixin: @classmethod def default_uid(cls): return cls.__name__ def __new__(cls, *args, **kwargs): try: return _actor_implementation[cls](*args, **kwargs) except KeyError: return super().__new__(cls, *args, **kwargs) def __init__(self, *args, **kwargs) -> None: super().__init__() self._generators: Dict[str, IteratorWrapper] = {} async def __post_create__(self): """ Method called after actor creation """ return await super().__post_create__() async def __pre_destroy__(self): """ Method called before actor destroy """ return await super().__pre_destroy__() async def __on_receive__(self, message: Tuple[Any]): """ Handle message from other actors and dispatch them to user methods Parameters ---------- message : tuple Message shall be (method_name,) + args + (kwargs,) """ return await super().__on_receive__(message) # type: ignore async def __xoscar_next__(self, generator_uid: str) -> Any: """ Iter the next of generator. Parameters ---------- generator_uid: str The uid of generator Returns ------- The next value of generator """ def _wrapper(_gen): try: return next(_gen) except StopIteration: return stop async def _async_wrapper(_gen): try: # anext is only available for Python >= 3.10 return await _gen.__anext__() # noqa: F821 except StopAsyncIteration: return stop if gen := self._generators.get(generator_uid): stop = object() try: if inspect.isgenerator(gen): r = await asyncio.to_thread(_wrapper, gen) elif inspect.isasyncgen(gen): r = await asyncio.create_task(_async_wrapper(gen)) else: raise Exception( f"The generator {generator_uid} should be a generator or an async generator, " f"but a {type(gen)} is got." ) except Exception as e: logger.exception( f"Destroy generator {generator_uid} due to an error encountered." ) await self.__xoscar_destroy_generator__(generator_uid) del gen # Avoid exception hold generator reference. raise e if r is stop: await self.__xoscar_destroy_generator__(generator_uid) del gen # Avoid exception hold generator reference. raise Exception("StopIteration") else: return r else: raise RuntimeError(f"No iterator with id: {generator_uid}") async def __xoscar_destroy_generator__(self, generator_uid: str): """ Destroy the generator. Parameters ---------- generator_uid: str The uid of generator """ logger.debug("Destroy generator: %s", generator_uid) self._generators.pop(generator_uid, None) def generator(func): need_to_thread = not asyncio.iscoroutinefunction(func) @functools.wraps(func) async def _wrapper(self, *args, **kwargs): if need_to_thread: r = await asyncio.to_thread(func, self, *args, **kwargs) else: r = await func(self, *args, **kwargs) if inspect.isgenerator(r) or inspect.isasyncgen(r): gen_uid = uuid.uuid1().hex logger.debug("Create generator: %s", gen_uid) self._generators[gen_uid] = r return IteratorWrapper(gen_uid, self.address, self.uid) else: return r return _wrapper class Actor(AsyncActorMixin, _Actor): pass class StatelessActor(AsyncActorMixin, _StatelessActor): pass _actor_implementation: Dict[Type[Actor], Type[Actor]] = dict() def register_actor_implementation(actor_cls: Type[Actor], impl_cls: Type[Actor]): _actor_implementation[actor_cls] = impl_cls def unregister_actor_implementation(actor_cls: Type[Actor]): try: del _actor_implementation[actor_cls] except KeyError: pass