Source code for acord.client.client

# A simple base client for handling responses from discord
from typing import AsyncIterator, Coroutine, Dict, Iterator, List, Union, Callable, Optional

import asyncio
import logging
import sys
import traceback

from acord.bases.presence import Presence

from acord.core.abc import Route, API_VERSION
from acord.core.http import HTTPClient
from acord.errors import *
from acord.payloads import (
    StageInstanceCreatePayload,
)
from acord.ext.application_commands import ApplicationCommand, UDAppCommand
from acord.bases import Intents, _C
from acord.models import Message, Snowflake, User, Channel, Guild, StageInstance
from acord.utils import _d_to_channel

from .shard import Shard
from .caches.cache import Cache
from .caches.default import DefaultCache
from .ratelimiter import GatewayRatelimiter, DefaultGatewayRatelimiter

logger = logging.getLogger(__name__)


[docs]class Client(object): """ Client for interacting with the discord API Parameters ---------- loop: :class:`~asyncio.AbstractEventLoop` An existing loop to run the client off of token: :class:`str` Your API Token which can be generated at the developer portal intents: Union[:class:`Intents`, :class:`int`] Intents to be passed through when connecting to gateway, defaults to ``0`` presence: :class:`Presence` Presence to be sent in the identity packet encoding: :class:`str` Any of ``ETF`` and ``JSON`` are allowed to be chosen, controls data recieved by discord, defaults to ``False``. compress: :class:`bool` Whether to read compressed stream when receiving requests, defaults to ``False`` dispatch_on_recv: :class:`bool` Whether on_socket_recv should be dispatched cache: :class:`Cache` Cache for the client to use .. versionadded:: 0.2.3a0 gateway_ratelimiter: :class:`GatewayRatelimiter` Gateway ratelimiter for client to use .. versionadded:: 0.2.3a0 Attributes ---------- loop: :class:`~asyncio.AbstractEventLoop` Loop client uses token: :class:`str` Token set when initialising class dispatch_on_recv: :class:`bool` Whether the on_socket_recv should be dispatched intents: Union[:class:`Intents`, :class:`int`] Intents set when intialising class presence: :class:`Presence` Presence to be sent in the identity packet encoding: :class:`str` Encoding set when initialising class compress: :class:`bool` Whether to read compressed stream when receiving requests session_id: :class:`str` Session ID recieved when connecting to gateway gateway_version: :class:`str` Selected gateway version, available after connecting to gateway. In form ``v[0-9]``. user: :class:`User` Client user object application_commands: Dict[str, List[:class:`UDAppCommand`]] Mapping of registered application commands shards: Dict[:class:`int`, :class:`Shard`] Mapping of shards client is handling .. versionadded:: 0.2.3a0 cache: :class:`Cache` Cache of gateway objects, recommended to fetch using built in methods, e.g. :meth:`Client.get_user`. .. versionadded:: 0.2.3a0 gateway_ratelimiter: :class:`GatewayRatelimiter` Gateway ratelimiter for client to use, we recommend not playing around with this as it may lead to unexpected errors .. versionadded:: 0.2.3a0 max_concurrency: :class:`int` Number of identity requests client is allowed per 5 seconds .. versionadded:: 0.2.3a0 .. versionchanged:: 0.2.3a0 Name changed from MAX_CONC to max_concurrency num_shards: Optional[:class:`int`] Number of shards client is using, if None it means client has not been ran yet. .. note:: This value can be overwritten but only before client is ran. We dont recommend you do but oh well. .. versionadded:: 0.2.3a0 guilds: List[:class:`Guilds`] List of guilds client has access to rest: :class:`RestApi` An instance of the Rest API object """ cache: Cache def __init__( self, *, token: Optional[str] = None, dispatch_on_recv: bool = False, intents: Optional[Union[Intents, int]] = 0, presence: Presence = None, loop: Optional[asyncio.AbstractEventLoop] = asyncio.get_event_loop(), encoding: Optional[str] = "JSON", compress: Optional[bool] = False, cache: Cache = DefaultCache(), gateway_ratelimiter: GatewayRatelimiter = DefaultGatewayRatelimiter() ) -> None: self.loop = loop self.token = token self.dispatch_on_recv = dispatch_on_recv self.intents = intents self.presence = presence self._events = dict() # Gateway connection stuff self.encoding = encoding self.compress = compress # Others self.session_id = None self.gateway_version = None self.user = None self.http = None self.rest = None # When connecting to VC, temporarily stores session_id self.awaiting_voice_connections = dict() self.voice_connections = dict() if not isinstance(cache, Cache): raise TypeError("Cache must be a subclass of Cache") self.cache = cache self.gateway_ratelimiter = gateway_ratelimiter self.shards = dict() self.max_concurrency = 0 self.num_shards = None
[docs] def on(self, name: str, *, once: bool = False) -> Optional[_C]: """Register an event to be dispatched on call. This is a decorator, if you do not want to use the decorator consider trying: .. code-block:: py from acord import Client from xyz import some_event_handler client = Client(...) client.on("message")(some_event_handler) Parameters ---------- name: :class:`str` Name of event, consider checking out all `events <../events.html>`_ once: :class:`bool` Whether the event should be ran once before being removed. check: Callable[..., :class:`bool`] Check to be ran before dispatching event """ def inner(func): data = {"func": func, "once": once} if name in self._events: self._events[name].append(data) else: self._events.update({name: [data]}) # Tuples from wait_for if callable(func): try: func.__event_name__ = name except AttributeError: func.__dict__.update(__event_name__=name) return func return inner
[docs] def dispatch(self, event_name: str, *args, **kwargs) -> None: """Dispatch a registered event Parameters ---------- event_name: :class:`str` Name of event *args, **kwargs Additional args or kwargs to be passed through """ if not event_name.startswith("on_"): func_name = "on_" + event_name events = self._events.get(event_name, list()) func: Callable[..., Coroutine] = getattr(self, func_name, None) to_rmv: List[Dict] = list() tsk = None if func: try: tsk = self.loop.create_task( func(*args, **kwargs), name=f"Acord event dispatch: {event_name}" ) except Exception as exc: self.on_error(f"{func} ({func_name})", tsk) to_rmv: List[Dict] = list() for event in events: func = event["func"] try: # Handle wait for events try: fut, check = func except (ValueError, TypeError): tsk = self.loop.create_task( func(*args, **kwargs), name=f"Acord event dispatch: {event_name}", ) else: if check(*args, **kwargs) is True: res = tuple(args) + tuple(kwargs.values()) fut.set_result(res) to_rmv.append(event) except Exception: self.on_error(f"{func} ({func_name})", tsk) else: if event.get("once", False): to_rmv.append(event) for x in to_rmv: events.remove(x) if events: self._events[event_name] = events else: if event_name in self._events: self._events.pop(event_name) logger.info("Dispatched event: {}".format(event_name))
[docs] def wait_for( self, event: str, *, check: Callable[..., bool] = None, timeout: int = None ) -> _C: """|coro| Wait for a specific gateway event to occur. .. rubric:: Examples .. code-block:: py # Simple Greeting data = Client.wait_for( "message", check=lambda message: message.content == "Hello", timeout=30.0 ) message = data[0] return message.reply(content=f"Hello, {message.author}") Parameters ---------- event: :class:`str` Gateway event to wait for. check: Callable[..., :class:`bool`] Validate the gateway event recieved timeout: :class:`int` Time to wait for event to be recieved """ if not check: check = lambda *args, **kwargs: True fut = self.loop.create_future() self.on(event)((fut, check)) return asyncio.wait_for(fut, timeout=timeout)
[docs] async def create_stage_instance(self, *, reason: str = None, **data) -> StageInstance: """|coro| Creates a stage instance Parameters ---------- channel_id: :class:`Snowflake` ID of channel to create stage instance, channel type must be :attr:`ChannelTypes.GUILD_STAGE_VOICE` topic: :class:`str` The topic of the Stage instance (1-120 characters) privacy_level: :class:`StagePrivacyLevel` The privacy level of the Stage instance (default GUILD_ONLY) """ payload = StageInstanceCreatePayload(**data) bucket = dict(channel_id=payload.channel_id) headers = {"Content-Type": "application/json"} if reason is not None: headers["X-Audit-Log-Reason"] = reason r = await self.http.request( Route("POST", path=f"/stage-instances", bucket=bucket), data=payload.json(), headers=headers, ) instance = StageInstance(conn=self.http, **(await r.json())) self.cache.add_stage_instance(instance) return instance
[docs] def get_shard(self, guild_id: Snowflake) -> Optional[Shard]: """Gets a shard from :attr:`Client.shards` using a guild id Parameters ---------- guild_id: :class:`Snowflake` Guild ID to use """ shard_id = ((guild_id >> 22) % self.num_shards) shard = self.shards.get(shard_id) return shard
[docs] def register_application_command( self, command: UDAppCommand, *, guild_ids: Union[List[int], None] = None, extend: bool = True, ) -> None: """Registers application command internally before client is ran, after client is ran this method is redundant. Consider using :meth:`Client.create_application_command`. Parameters ---------- command: :class:`UDAppCommand` .. note:: :class:`UDAppCommand` represents any class which inherits it, this includes SlashBase. Command to register internally, to be dispatched. guild_ids: Union[List[:class:`int`], None] Additional guild IDs to restrict command to, if value is set to: * ``None``: Reads from class (Default option) * ``[]`` (Empty List): Makes it global .. note:: If final value is false, command will be registered globally extend: :class:`bool` Whether to extend current guild ids from the command class """ self.rest.register_application_command( command=command, guild_ids=guild_ids, extend=extend )
[docs] async def create_application_command( self, command: UDAppCommand, *, guild_ids: Union[List[int], None] = None, extend: bool = True, ) -> Union[ApplicationCommand, List[ApplicationCommand]]: """|coro| Creates an application command from a :class:`UDAppCommand` class. .. note:: It can take up to an hour for discord to process the command! Parameters ---------- same as :meth:`Client.register_application_command` """ return await self.rest.create_application_command( command=command, guild_ids=guild_ids, extend=extend )
[docs] async def bulk_update_global_app_commands( self, commands: List[UDAppCommand] ) -> None: """|coro| Updates global application commands in bulk Parameters ---------- commands: List[:class:`UDAppCommand`] List of application commands to update """ await self.rest.bulk_update_global_app_commands(commands=commands)
[docs] async def bulk_update_guild_app_commands( self, guild_id: Snowflake, commands: List[UDAppCommand], ) -> None: """|coro| Updates application commands for a guild in bulk Parameters ---------- guild_id: :class:`Snowflake` ID of target guild commands: List[:class:`UDAppCommand`] List of application commands to update """ await self.rest.bulk_update_guild_app_commands(guild_id, commands=commands)
async def _bulk_write_app_commands(self, exclude: set) -> None: cmds = [] for name, commands in self.application_commands.items(): if name in exclude: continue if isinstance(commands, list): cmds.extend(*set(commands)) else: cmds.append(commands) partitioned = {"global": []} for command in cmds: if not command.guild_ids: partitioned["global"].append(command) else: for guild_id in command.guild_ids: if guild_id not in partitioned: partitioned[guild_id] = [] partitioned[guild_id].append(command) global_ = partitioned.pop("global") if global_: await self.bulk_update_global_app_commands(global_) for guild_id, commands in partitioned.items(): await self.bulk_update_guild_app_commands(guild_id, commands)
[docs] async def shard_handler(self, *ready_scripts) -> None: """|coro| Default shard handler, doesn't do much except intiate shards and waits till all have been disconnected. """ gateway = await self.http.fetch_gateway() GATEWAY_WEBHOOK_URL = gateway["url"] GATEWAY_WEBHOOK_URL += f"?v={API_VERSION}" if self.compress: GATEWAY_WEBHOOK_URL += "&compress=zlib-stream" GATEWAY_WEBHOOK_URL += f"&encoding={self.encoding}" if not self.num_shards: self.num_shards = gateway["shards"] self.max_concurrency = gateway["session_start_limit"]["max_concurrency"] TASK_LIST = [] c = 0 for i in range(self.num_shards): # shard_id = i shard = Shard( url=GATEWAY_WEBHOOK_URL, shard_id=i, num_shards=self.num_shards, client=self ) await shard.connect() await shard.receive_hello() await shard.send_identity( self.token, self.intents, self.presence ) task = shard.listen(shard=shard) TASK_LIST.append(task) self.shards.update({i: shard}) c += 1 if c == self.max_concurrency: await asyncio.sleep(5) c = 0 for script in ready_scripts: await script await asyncio.gather(*TASK_LIST)
[docs] def run( self, token: str = None, reconnect: bool = True, update_app_commands: bool = True, exclude_app_cmds: set = set(), asyncio_debug: bool = False ): """Runs client, loop blocking. Parameters ---------- token: :class:`str` Token to log into discord with, if :class:`Client.token` is not None it will be used as a fallback, just in case this token fails reconnect: :class:`bool` Whether to reconnect it first connection fails, defaults to ``True``. update_add_commands: :class:`bool` Whether to update app commands, *in bulk*. exclude_app_cmds: :class:`set` A set of app names to stop being updated/created asyncio_debug: :class:`bool` Whether to enable debugging on the current loop. """ from acord.rest import RestApi if asyncio_debug: self.loop.set_debug(True) token = token or self.token if not token: raise ValueError("No token provided") if not self.http: self.http = HTTPClient(self, loop=self.loop, token=self.token) self.http.client = self # Login to create session # Also validates token try: self.loop.run_until_complete(self.http.login(token=token)) except HTTPException: if reconnect: # Prevent recursion # If cannot login, tries to revert token # Logins in again # If fails again raises error logger.info("Failed to login, reconnecting") return self.run(token=token, reconnect=False) raise self.token = token if not self.rest: self.rest = RestApi( token=self.token, loop=self.loop, cache=self.cache, http_client=self.http ) if self.rest._set_up: logger.info("Client.rest has already been set, skipping") else: self.loop.run_until_complete(self.rest.setup( exclude=exclude_app_cmds, update_commands=update_app_commands )) logger.info("Finished setting up rest api for client") self.loop.run_until_complete(self.shard_handler( self._bulk_write_app_commands(exclude_app_cmds) if update_app_commands else None ))
[docs] async def disconnect(self): """|coro| Disconnects client from discord, ends: * Client spawned sessions * Shards * Voice Connections """ logger.info("Disconnected from API, closing any open connections") await self.http.disconnect() for shard in self.shards: await shard.disconnect() for _, vc in self.voice_connections.items(): await vc.disconnect()
# NOTE: Fetch from cache:
[docs] def get_message(self, channel_id: int, message_id: int) -> Optional[Message]: """Returns the message stored in the internal cache, may be outdated""" return self.rest.get_message(channel_id, message_id)
[docs] def get_user(self, user_id: int) -> Optional[User]: """Returns the user stored in the internal cache, may be outdated""" return self.rest.get_user(user_id)
[docs] def get_guild(self, guild_id: int) -> Optional[Guild]: """Returns the guild stored in the internal cache, may be outdated""" return self.rest.get_guild(guild_id)
[docs] def get_channel(self, channel_id: int) -> Optional[Channel]: """Returns the channel stored in the internal cache, may be outdated""" return self.rest.get_channel(channel_id)
# NOTE: Fetch from API:
[docs] async def fetch_user(self, user_id: int) -> Optional[User]: """Fetches user from API and caches it""" return await self.rest.fetch_user(user_id)
[docs] async def fetch_channel(self, channel_id: int) -> Optional[Channel]: """Fetches channel from API and caches it""" return await self.rest.fetch_channel(channel_id)
[docs] async def fetch_message( self, channel_id: int, message_id: int ) -> Optional[Message]: """Fetches message from API and caches it""" return await self.rest.fetch_message( channel_id, message_id )
[docs] async def fetch_guild( self, guild_id: int, *, with_counts: bool = False ) -> Optional[Guild]: """Fetches guild from API and caches it. .. note:: If with_counts is set to ``True``, it will allow fields ``approximate_presence_count``, ``approximate_member_count`` to be used. """ return await self.rest.fetch_guild( guild_id, with_counts=with_counts )
[docs] async def fetch_glob_app_commands(self) -> AsyncIterator[ApplicationCommand]: """|coro| Fetches all global application commands registered by the client """ return await self.rest.fetch_glob_app_commands()
[docs] async def fetch_glob_app_command(self, command_id: Snowflake) -> ApplicationCommand: """|coro| Fetches a global application command registered by the client Parameters ---------- command_id: :class:`Snowflake` ID of command to fetch """ return await self.rest.fetch_glob_app_command(command_id)
@property def application_commands(self) -> dict: return self.rest.application_commands # NOTE: default event handlers async def on_voice_server_update(self, vc) -> None: """:meta private:""" # handles dispatch when client joins VC # no need to worry about tasks and threads since this is run as a task await vc.connect() await vc.listen()
[docs] def on_error(self, event_method, task: asyncio.Task = None, err = None): """|coro| Built in base error handler for events""" err = err or sys.exc_info() if task is not None: _err = task._exception if _err is not None and isinstance(_err, Exception): err = (type(_err), _err, _err.__traceback__) logger.error('Failed to run event "{}".'.format(event_method), exc_info=err) print(f"Ignoring exception in {event_method}", file=sys.stderr) traceback.print_exception(*err)