Skip to content
This repository has been archived by the owner on Mar 13, 2023. It is now read-only.

Commit

Permalink
Merge pull request #675 from NAFTeam/dev
Browse files Browse the repository at this point in the history
NAFF 1.12.0
  • Loading branch information
LordOfPolls authored Oct 19, 2022
2 parents 9b5d70b + 80bd83e commit 1171097
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 83 deletions.
1 change: 0 additions & 1 deletion .github/workflows/pytest-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ jobs:
- .[speedup]
- .[voice]
- .[all]
- .[docs]

steps:
- name: Create check run
Expand Down
1 change: 1 addition & 0 deletions docs/src/API Reference/.pages
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
nav:
- Client.md
- AutoShardClient.md
- const.md
- errors.md
- API_Communication
Expand Down
69 changes: 40 additions & 29 deletions naff/api/http/http_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""This file handles the interaction with discords http endpoints."""
import asyncio
import time
from typing import Any, cast
from urllib.parse import quote as _uriquote
from weakref import WeakValueDictionary

import aiohttp
import discord_typings
from aiohttp import BaseConnector, ClientSession, ClientWebSocketResponse, FormData
from multidict import CIMultiDictProxy

Expand Down Expand Up @@ -34,38 +36,49 @@
from naff.client.errors import DiscordError, Forbidden, GatewayNotFound, HTTPException, NotFound, LoginError
from naff.client.utils.input_utils import response_decode, OverriddenJson
from naff.client.utils.serializer import dict_filter
from naff.models import CooldownSystem
from naff.models.discord.file import UPLOADABLE_TYPE
from .route import Route
import discord_typings

__all__ = ("HTTPClient",)


class GlobalLock:
"""Manages the global ratelimit"""

def __init__(self) -> None:
self.cooldown_system: CooldownSystem = CooldownSystem(
45, 1
) # global rate-limit is 50 per second, conservatively we use 45
self._lock: asyncio.Lock = asyncio.Lock()
self._lock = asyncio.Lock()
self.max_requests = 45
self._calls = 0
self._reset_time = 0

async def rate_limit(self) -> None:
async with self._lock:
while not self.cooldown_system.acquire_token():
await asyncio.sleep(self.cooldown_system.get_cooldown_time())
@property
def calls_remaining(self) -> int:
"""Returns the amount of calls remaining."""
return self.max_requests - self._calls

def reset_calls(self) -> None:
"""Resets the calls to the max amount."""
self._calls = self.max_requests
self._reset_time = time.perf_counter() + 1

async def lock(self, delta: float) -> None:
def set_reset_time(self, delta: float) -> None:
"""
Lock the global lock for a given duration.
Sets the reset time to the current time + delta.
To be called if a 429 is received.
Args:
delta: The time to keep the lock acquired
delta: The time to wait before resetting the calls.
"""
await self._lock.acquire()
await asyncio.sleep(delta)
self._lock.release()
self._reset_time = time.perf_counter() + delta
self._calls = 0

async def wait(self) -> None:
"""Throttles calls to prevent hitting the global rate limit."""
async with self._lock:
if self._reset_time <= time.perf_counter():
self.reset_calls()
elif self._calls <= 0:
await asyncio.sleep(self._reset_time - time.perf_counter())
self.reset_calls()
self._calls -= 1


class BucketLock:
Expand Down Expand Up @@ -272,21 +285,17 @@ async def request(
for attempt in range(self._max_attempts):
async with lock:
try:
await self.global_lock.rate_limit()
# prevent us exceeding the global rate limit by throttling http requests

if cast(ClientSession, self.__session).closed:
if self.__session.closed:
await self.login(cast(str, self.token))

processed_data = self._process_payload(payload, files)
if isinstance(processed_data, FormData):
kwargs["data"] = processed_data # pyright: ignore
else:
kwargs["json"] = processed_data # pyright: ignore
await self.global_lock.wait()

async with cast(ClientSession, self.__session).request(
route.method, route.url, **kwargs
) as response:
async with self.__session.request(route.method, route.url, **kwargs) as response:
result = await response_decode(response)
self.ingest_ratelimit(route, response.headers, lock)

Expand All @@ -299,7 +308,7 @@ async def request(
logger.error(
f"Bot has exceeded global ratelimit, locking REST API for {result['retry_after']} seconds"
)
await self.global_lock.lock(float(result["retry_after"]))
self.global_lock.set_reset_time(float(result["retry_after"]))
continue
elif result.get("message") == "The resource is being rate limited.":
# resource ratelimit is reached
Expand Down Expand Up @@ -361,7 +370,7 @@ async def _raise_exception(self, response, route, result) -> None:

async def request_cdn(self, url, asset) -> bytes: # pyright: ignore [reportGeneralTypeIssues]
logger.debug(f"{asset} requests {url} from CDN")
async with cast(ClientSession, self.__session).get(url) as response:
async with self.__session.get(url) as response:
if response.status == 200:
return await response.read()
await self._raise_exception(response, asset, await response_decode(response))
Expand All @@ -377,7 +386,9 @@ async def login(self, token: str) -> dict[str, Any]:
The currently logged in bot's data
"""
self.__session = ClientSession(connector=self.connector)
self.__session = ClientSession(
connector=self.connector if self.connector else aiohttp.TCPConnector(limit=self.global_lock.max_requests),
)
self.token = token
try:
result = await self.request(Route("GET", "/users/@me"))
Expand Down Expand Up @@ -422,6 +433,6 @@ async def websocket_connect(self, url: str) -> ClientWebSocketResponse:
url: the url to connect to
"""
return await cast(ClientSession, self.__session).ws_connect(
return await self.__session.ws_connect(
url, timeout=30, max_msg_size=0, autoclose=False, headers={"User-Agent": self.user_agent}, compress=0
)
41 changes: 39 additions & 2 deletions naff/client/auto_shard_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import time
from collections import defaultdict
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import naff.api.events as events
from naff.api.gateway.state import ConnectionState
Expand All @@ -12,7 +12,8 @@
to_snowflake,
)
from naff.models.naff.listener import Listener
from ..api.events import ShardConnect
from naff.models.discord import Status, Activity
from naff.api.events import ShardConnect

if TYPE_CHECKING:
from naff.models import Snowflake_Type
Expand Down Expand Up @@ -104,6 +105,18 @@ def get_shards_guild(self, shard_id: int) -> list[Guild]:
"""
return [guild for key, guild in self.cache.guild_cache.items() if ((key >> 22) % self.total_shards) == shard_id]

def get_shard_id(self, guild_id: "Snowflake_Type") -> int:
"""
Get the shard ID for a given guild.
Args:
guild_id: The ID of the guild
Returns:
The shard ID for the guild
"""
return (int(guild_id) >> 22) % self.total_shards

@Listener.create()
async def _on_websocket_ready(self, event: events.RawGatewayEvent) -> None:
"""
Expand Down Expand Up @@ -228,3 +241,27 @@ async def login(self, token) -> None:
self._connection_states: list[ConnectionState] = [
ConnectionState(self, self.intents, shard_id) for shard_id in range(self.total_shards)
]

async def change_presence(
self,
status: Optional[str | Status] = Status.ONLINE,
activity: Optional[str | Activity] = None,
*,
shard_id: int | None = None,
) -> None:
"""
Change the bot's presence.
Args:
status: The status for the bot to be. i.e. online, afk, etc.
activity: The activity for the bot to be displayed as doing.
shard_id: The shard to change the presence on. If not specified, the presence will be changed on all shards.
!!! note
Bots may only be `playing` `streaming` `listening` `watching` or `competing`, other activity types are likely to fail.
"""
if shard_id is None:
await asyncio.gather(*[shard.change_presence(status, activity) for shard in self._connection_states])
else:
await self._connection_states[shard_id].change_presence(status, activity)
17 changes: 5 additions & 12 deletions naff/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
Coroutine,
Dict,
List,
Mapping,
NoReturn,
Optional,
Sequence,
Expand Down Expand Up @@ -211,13 +210,7 @@ class Client(
delete_unused_application_cmds: Delete any commands from discord that aren't implemented in this client
enforce_interaction_perms: Enforce discord application command permissions, locally
fetch_members: Should the client fetch members from guilds upon startup (this will delay the client being ready)
auto_defer: A system to automatically defer commands after a set duration
interaction_context: The object to instantiate for Interaction Context
prefixed_context: The object to instantiate for Prefixed Context
component_context: The object to instantiate for Component Context
autocomplete_context: The object to instantiate for Autocomplete Context
modal_context: The object to instantiate for Modal Context
send_command_tracebacks: Automatically send uncaught tracebacks if a command throws an exception
auto_defer: AutoDefer: A system to automatically defer commands after a set duration
interaction_context: Type[InteractionContext]: InteractionContext: The object to instantiate for Interaction Context
Expand Down Expand Up @@ -1767,7 +1760,7 @@ def get_ext(self, name: str) -> Extension | None:
return ext[0]
return None

def load_extension(self, name: str, package: str | None = None, **load_kwargs: Mapping[str, Any]) -> None:
def load_extension(self, name: str, package: str | None = None, **load_kwargs: Any) -> None:
"""
Load an extension with given arguments.
Expand Down Expand Up @@ -1806,7 +1799,7 @@ def load_extension(self, name: str, package: str | None = None, **load_kwargs: M
return
asyncio.create_task(self.synchronise_interactions())

def unload_extension(self, name: str, package: str | None = None, **unload_kwargs: Mapping[str, Any]) -> None:
def unload_extension(self, name: str, package: str | None = None, **unload_kwargs: Any) -> None:
"""
Unload an extension with given arguments.
Expand Down Expand Up @@ -1847,8 +1840,8 @@ def reload_extension(
name: str,
package: str | None = None,
*,
load_kwargs: Mapping[str, Any] = None,
unload_kwargs: Mapping[str, Any] = None,
load_kwargs: Any = None,
unload_kwargs: Any = None,
) -> None:
"""
Helper method to reload an extension. Simply unloads, then loads the extension with given arguments.
Expand Down
4 changes: 2 additions & 2 deletions naff/client/mixins/send.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Iterable, List, Optional, Union
from typing import TYPE_CHECKING, Iterable, Optional, Union

import naff.models as models

Expand Down Expand Up @@ -34,7 +34,7 @@ async def send(
dict,
]
] = None,
stickers: Optional[Union[List[Union["Sticker", "Snowflake_Type"]], "Sticker", "Snowflake_Type"]] = None,
stickers: Optional[Union[Iterable[Union["Sticker", "Snowflake_Type"]], "Sticker", "Snowflake_Type"]] = None,
allowed_mentions: Optional[Union["AllowedMentions", dict]] = None,
reply_to: Optional[Union["MessageReference", "Message", dict, "Snowflake_Type"]] = None,
files: Optional[Union["UPLOADABLE_TYPE", Iterable["UPLOADABLE_TYPE"]]] = None,
Expand Down
4 changes: 3 additions & 1 deletion naff/models/discord/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,9 @@ class Member(DiscordObject, _SendDMMixin):
nick: Optional[str] = field(repr=True, default=None, metadata=docs("The user's nickname in this guild'"))
deaf: bool = field(default=False, metadata=docs("Has this user been deafened in voice channels?"))
mute: bool = field(default=False, metadata=docs("Has this user been muted in voice channels?"))
joined_at: "Timestamp" = field(converter=timestamp_converter, metadata=docs("When the user joined this guild"))
joined_at: "Timestamp" = field(
default=MISSING, converter=optional(timestamp_converter), metadata=docs("When the user joined this guild")
)
premium_since: Optional["Timestamp"] = field(
default=None,
converter=optional_c(timestamp_converter),
Expand Down
29 changes: 15 additions & 14 deletions naff/models/discord/voice_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,22 @@ def channel(self) -> "TYPE_VOICE_CHANNEL":
"""The channel the user is connected to."""
channel: "TYPE_VOICE_CHANNEL" = self._client.cache.get_channel(self._channel_id)

# make sure the member is showing up as a part of the channel
# this is relevant for VoiceStateUpdate.before
# noinspection PyProtectedMember
if self._member_id not in channel._voice_member_ids:
# the list of voice members need to be deepcopied, otherwise the cached obj will be updated
if channel:
# make sure the member is showing up as a part of the channel
# this is relevant for VoiceStateUpdate.before
# noinspection PyProtectedMember
voice_member_ids = copy.deepcopy(channel._voice_member_ids)

# create a copy of the obj
channel = copy.copy(channel)
channel._voice_member_ids = voice_member_ids

# add the member to that list
# noinspection PyProtectedMember
channel._voice_member_ids.append(self._member_id)
if self._member_id not in channel._voice_member_ids:
# the list of voice members need to be deepcopied, otherwise the cached obj will be updated
# noinspection PyProtectedMember
voice_member_ids = copy.deepcopy(channel._voice_member_ids)

# create a copy of the obj
channel = copy.copy(channel)
channel._voice_member_ids = voice_member_ids

# add the member to that list
# noinspection PyProtectedMember
channel._voice_member_ids.append(self._member_id)

return channel

Expand Down
Loading

0 comments on commit 1171097

Please sign in to comment.