forked from python-discord/bot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconverters.py
579 lines (447 loc) · 21.2 KB
/
converters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
from __future__ import annotations
import re
import typing as t
from datetime import datetime, timezone
from ssl import CertificateError
import dateutil.parser
import discord
from aiohttp import ClientConnectorError
from botcore.site_api import ResponseCodeError
from botcore.utils import unqualify
from botcore.utils.regex import DISCORD_INVITE
from dateutil.relativedelta import relativedelta
from discord.ext.commands import BadArgument, Bot, Context, Converter, IDConverter, MemberConverter, UserConverter
from discord.utils import escape_markdown, snowflake_time
from bot import exts, instance as bot_instance
from bot.constants import URLs
from bot.errors import InvalidInfraction
from bot.exts.info.doc import _inventory_parser
from bot.exts.info.tags import TagIdentifier
from bot.log import get_logger
from bot.utils import time
if t.TYPE_CHECKING:
from bot.exts.info.source import SourceType
log = get_logger(__name__)
DISCORD_EPOCH_DT = snowflake_time(0)
RE_USER_MENTION = re.compile(r"<@!?([0-9]+)>$")
class ValidDiscordServerInvite(Converter):
"""
A converter that validates whether a given string is a valid Discord server invite.
Raises 'BadArgument' if:
- The string is not a valid Discord server invite.
- The string is valid, but is an invite for a group DM.
- The string is valid, but is expired.
Returns a (partial) guild object if:
- The string is a valid vanity
- The string is a full invite URI
- The string contains the invite code (the stuff after discord.gg/)
See the Discord API docs for documentation on the guild object:
https://discord.com/developers/docs/resources/guild#guild-object
"""
async def convert(self, ctx: Context, server_invite: str) -> dict:
"""Check whether the string is a valid Discord server invite."""
invite_code = DISCORD_INVITE.match(server_invite)
if invite_code:
response = await ctx.bot.http_session.get(
f"{URLs.discord_invite_api}/{invite_code.group('invite')}"
)
if response.status != 404:
invite_data = await response.json()
return invite_data.get("guild")
id_converter = IDConverter()
if id_converter._get_id_match(server_invite):
raise BadArgument("Guild IDs are not supported, only invites.")
raise BadArgument("This does not appear to be a valid Discord server invite.")
class ValidFilterListType(Converter):
"""
A converter that checks whether the given string is a valid FilterList type.
Raises `BadArgument` if the argument is not a valid FilterList type, and simply
passes through the given argument otherwise.
"""
@staticmethod
async def get_valid_types(bot: Bot) -> list:
"""
Try to get a list of valid filter list types.
Raise a BadArgument if the API can't respond.
"""
try:
valid_types = await bot.api_client.get('bot/filter-lists/get-types')
except ResponseCodeError:
raise BadArgument("Cannot validate list_type: Unable to fetch valid types from API.")
return [enum for enum, classname in valid_types]
async def convert(self, ctx: Context, list_type: str) -> str:
"""Checks whether the given string is a valid FilterList type."""
valid_types = await self.get_valid_types(ctx.bot)
list_type = list_type.upper()
if list_type not in valid_types:
# Maybe the user is using the plural form of this type,
# e.g. "guild_invites" instead of "guild_invite".
#
# This code will support the simple plural form (a single 's' at the end),
# which works for all current list types, but if a list type is added in the future
# which has an irregular plural form (like 'ies'), this code will need to be
# refactored to support this.
if list_type.endswith("S") and list_type[:-1] in valid_types:
list_type = list_type[:-1]
else:
valid_types_list = '\n'.join([f"• {type_.lower()}" for type_ in valid_types])
raise BadArgument(
f"You have provided an invalid list type!\n\n"
f"Please provide one of the following: \n{valid_types_list}"
)
return list_type
class Extension(Converter):
"""
Fully qualify the name of an extension and ensure it exists.
The * and ** values bypass this when used with the reload command.
"""
async def convert(self, ctx: Context, argument: str) -> str:
"""Fully qualify the name of an extension and ensure it exists."""
# Special values to reload all extensions
if argument == "*" or argument == "**":
return argument
argument = argument.lower()
if argument in bot_instance.all_extensions:
return argument
elif (qualified_arg := f"{exts.__name__}.{argument}") in bot_instance.all_extensions:
return qualified_arg
matches = []
for ext in bot_instance.all_extensions:
if argument == unqualify(ext):
matches.append(ext)
if len(matches) > 1:
matches.sort()
names = "\n".join(matches)
raise BadArgument(
f":x: `{argument}` is an ambiguous extension name. "
f"Please use one of the following fully-qualified names.```\n{names}```"
)
elif matches:
return matches[0]
else:
raise BadArgument(f":x: Could not find the extension `{argument}`.")
class PackageName(Converter):
"""
A converter that checks whether the given string is a valid package name.
Package names are used for stats and are restricted to the a-z and _ characters.
"""
PACKAGE_NAME_RE = re.compile(r"[^a-z0-9_]")
@classmethod
async def convert(cls, ctx: Context, argument: str) -> str:
"""Checks whether the given string is a valid package name."""
if cls.PACKAGE_NAME_RE.search(argument):
raise BadArgument("The provided package name is not valid; please only use the _, 0-9, and a-z characters.")
return argument
class ValidURL(Converter):
"""
Represents a valid webpage URL.
This converter checks whether the given URL can be reached and requesting it returns a status
code of 200. If not, `BadArgument` is raised.
Otherwise, it simply passes through the given URL.
"""
@staticmethod
async def convert(ctx: Context, url: str) -> str:
"""This converter checks whether the given URL can be reached with a status code of 200."""
try:
async with ctx.bot.http_session.get(url) as resp:
if resp.status != 200:
raise BadArgument(
f"HTTP GET on `{url}` returned status `{resp.status}`, expected 200"
)
except CertificateError:
if url.startswith('https'):
raise BadArgument(
f"Got a `CertificateError` for URL `{url}`. Does it support HTTPS?"
)
raise BadArgument(f"Got a `CertificateError` for URL `{url}`.")
except ValueError:
raise BadArgument(f"`{url}` doesn't look like a valid hostname to me.")
except ClientConnectorError:
raise BadArgument(f"Cannot connect to host with URL `{url}`.")
return url
class Inventory(Converter):
"""
Represents an Intersphinx inventory URL.
This converter checks whether intersphinx accepts the given inventory URL, and raises
`BadArgument` if that is not the case or if the url is unreachable.
Otherwise, it returns the url and the fetched inventory dict in a tuple.
"""
@staticmethod
async def convert(ctx: Context, url: str) -> t.Tuple[str, _inventory_parser.InventoryDict]:
"""Convert url to Intersphinx inventory URL."""
await ctx.typing()
try:
inventory = await _inventory_parser.fetch_inventory(url)
except _inventory_parser.InvalidHeaderError:
raise BadArgument("Unable to parse inventory because of invalid header, check if URL is correct.")
else:
if inventory is None:
raise BadArgument(
f"Failed to fetch inventory file after {_inventory_parser.FAILED_REQUEST_ATTEMPTS} attempts."
)
return url, inventory
class Snowflake(IDConverter):
"""
Converts to an int if the argument is a valid Discord snowflake.
A snowflake is valid if:
* It consists of 15-21 digits (0-9)
* Its parsed datetime is after the Discord epoch
* Its parsed datetime is less than 1 day after the current time
"""
async def convert(self, ctx: Context, arg: str) -> int:
"""
Ensure `arg` matches the ID pattern and its timestamp is in range.
Return `arg` as an int if it's a valid snowflake.
"""
error = f"Invalid snowflake {arg!r}"
if not self._get_id_match(arg):
raise BadArgument(error)
snowflake = int(arg)
try:
time = snowflake_time(snowflake)
except (OverflowError, OSError) as e:
# Not sure if this can ever even happen, but let's be safe.
raise BadArgument(f"{error}: {e}")
if time < DISCORD_EPOCH_DT:
raise BadArgument(f"{error}: timestamp is before the Discord epoch.")
elif (datetime.now(timezone.utc) - time).days < -1:
raise BadArgument(f"{error}: timestamp is too far into the future.")
return snowflake
class SourceConverter(Converter):
"""Convert an argument into a help command, tag, command, or cog."""
@staticmethod
async def convert(ctx: Context, argument: str) -> SourceType:
"""Convert argument into source object."""
if argument.lower() == "help":
return ctx.bot.help_command
cog = ctx.bot.get_cog(argument)
if cog:
return cog
cmd = ctx.bot.get_command(argument)
if cmd:
return cmd
tags_cog = ctx.bot.get_cog("Tags")
show_tag = True
if not tags_cog:
show_tag = False
else:
identifier = TagIdentifier.from_string(argument.lower())
if identifier in tags_cog.tags:
return identifier
escaped_arg = escape_markdown(argument)
raise BadArgument(
f"Unable to convert '{escaped_arg}' to valid command{', tag,' if show_tag else ''} or Cog."
)
class DurationDelta(Converter):
"""Convert duration strings into dateutil.relativedelta.relativedelta objects."""
async def convert(self, ctx: Context, duration: str) -> relativedelta:
"""
Converts a `duration` string to a relativedelta object.
The converter supports the following symbols for each unit of time:
- years: `Y`, `y`, `year`, `years`
- months: `m`, `month`, `months`
- weeks: `w`, `W`, `week`, `weeks`
- days: `d`, `D`, `day`, `days`
- hours: `H`, `h`, `hour`, `hours`
- minutes: `M`, `minute`, `minutes`
- seconds: `S`, `s`, `second`, `seconds`
The units need to be provided in descending order of magnitude.
"""
if not (delta := time.parse_duration_string(duration)):
raise BadArgument(f"`{duration}` is not a valid duration string.")
return delta
class Duration(DurationDelta):
"""Convert duration strings into UTC datetime.datetime objects."""
async def convert(self, ctx: Context, duration: str) -> datetime:
"""
Converts a `duration` string to a datetime object that's `duration` in the future.
The converter supports the same symbols for each unit of time as its parent class.
"""
delta = await super().convert(ctx, duration)
now = datetime.now(timezone.utc)
try:
return now + delta
except (ValueError, OverflowError):
raise BadArgument(f"`{duration}` results in a datetime outside the supported range.")
class Age(DurationDelta):
"""Convert duration strings into UTC datetime.datetime objects."""
async def convert(self, ctx: Context, duration: str) -> datetime:
"""
Converts a `duration` string to a datetime object that's `duration` in the past.
The converter supports the same symbols for each unit of time as its parent class.
"""
delta = await super().convert(ctx, duration)
now = datetime.now(timezone.utc)
try:
return now - delta
except (ValueError, OverflowError):
raise BadArgument(f"`{duration}` results in a datetime outside the supported range.")
class OffTopicName(Converter):
"""A converter that ensures an added off-topic name is valid."""
ALLOWED_CHARACTERS = r"ABCDEFGHIJKLMNOPQRSTUVWXYZ!?'`-<>\/"
TRANSLATED_CHARACTERS = "𝖠𝖡𝖢𝖣𝖤𝖥𝖦𝖧𝖨𝖩𝖪𝖫𝖬𝖭𝖮𝖯𝖰𝖱𝖲𝖳𝖴𝖵𝖶𝖷𝖸𝖹ǃ?’’-<>⧹⧸"
@classmethod
def translate_name(cls, name: str, *, from_unicode: bool = True) -> str:
"""
Translates `name` into a format that is allowed in discord channel names.
If `from_unicode` is True, the name is translated from a discord-safe format, back to normalized text.
"""
if from_unicode:
table = str.maketrans(cls.ALLOWED_CHARACTERS, cls.TRANSLATED_CHARACTERS)
else:
table = str.maketrans(cls.TRANSLATED_CHARACTERS, cls.ALLOWED_CHARACTERS)
return name.translate(table)
async def convert(self, ctx: Context, argument: str) -> str:
"""Attempt to replace any invalid characters with their approximate Unicode equivalent."""
# Chain multiple words to a single one
argument = "-".join(argument.split())
if not (2 <= len(argument) <= 96):
raise BadArgument("Channel name must be between 2 and 96 chars long")
elif not all(c.isalnum() or c in self.ALLOWED_CHARACTERS for c in argument):
raise BadArgument(
"Channel name must only consist of "
"alphanumeric characters, minus signs or apostrophes."
)
# Replace invalid characters with unicode alternatives.
return self.translate_name(argument)
class ISODateTime(Converter):
"""Converts an ISO-8601 datetime string into a datetime.datetime."""
async def convert(self, ctx: Context, datetime_string: str) -> datetime:
"""
Converts a ISO-8601 `datetime_string` into a `datetime.datetime` object.
The converter is flexible in the formats it accepts, as it uses the `isoparse` method of
`dateutil.parser`. In general, it accepts datetime strings that start with a date,
optionally followed by a time. Specifying a timezone offset in the datetime string is
supported, but the `datetime` object will be converted to UTC. If no timezone is specified, the datetime will
be assumed to be in UTC already. In all cases, the returned object will have the UTC timezone.
See: https://dateutil.readthedocs.io/en/stable/parser.html#dateutil.parser.isoparse
Formats that are guaranteed to be valid by our tests are:
- `YYYY-mm-ddTHH:MM:SSZ` | `YYYY-mm-dd HH:MM:SSZ`
- `YYYY-mm-ddTHH:MM:SS±HH:MM` | `YYYY-mm-dd HH:MM:SS±HH:MM`
- `YYYY-mm-ddTHH:MM:SS±HHMM` | `YYYY-mm-dd HH:MM:SS±HHMM`
- `YYYY-mm-ddTHH:MM:SS±HH` | `YYYY-mm-dd HH:MM:SS±HH`
- `YYYY-mm-ddTHH:MM:SS` | `YYYY-mm-dd HH:MM:SS`
- `YYYY-mm-ddTHH:MM` | `YYYY-mm-dd HH:MM`
- `YYYY-mm-dd`
- `YYYY-mm`
- `YYYY`
Note: ISO-8601 specifies a `T` as the separator between the date and the time part of the
datetime string. The converter accepts both a `T` and a single space character.
"""
try:
dt = dateutil.parser.isoparse(datetime_string)
except ValueError:
raise BadArgument(f"`{datetime_string}` is not a valid ISO-8601 datetime string")
if dt.tzinfo:
dt = dt.astimezone(timezone.utc)
else: # Without a timezone, assume it represents UTC.
dt = dt.replace(tzinfo=timezone.utc)
return dt
class HushDurationConverter(Converter):
"""Convert passed duration to `int` minutes or `None`."""
MINUTES_RE = re.compile(r"(\d+)(?:M|m|$)")
async def convert(self, ctx: Context, argument: str) -> int:
"""
Convert `argument` to a duration that's max 15 minutes or None.
If `"forever"` is passed, -1 is returned; otherwise an int of the extracted time.
Accepted formats are:
* <duration>,
* <duration>m,
* <duration>M,
* forever.
"""
if argument == "forever":
return -1
match = self.MINUTES_RE.match(argument)
if not match:
raise BadArgument(f"{argument} is not a valid minutes duration.")
duration = int(match.group(1))
if duration > 15:
raise BadArgument("Duration must be at most 15 minutes.")
return duration
def _is_an_unambiguous_user_argument(argument: str) -> bool:
"""Check if the provided argument is a user mention, user id, or username (name#discrim)."""
has_id_or_mention = bool(IDConverter()._get_id_match(argument) or RE_USER_MENTION.match(argument))
# Check to see if the author passed a username (a discriminator exists)
argument = argument.removeprefix('@')
has_username = len(argument) > 5 and argument[-5] == '#'
return has_id_or_mention or has_username
AMBIGUOUS_ARGUMENT_MSG = ("`{argument}` is not a User mention, a User ID or a Username in the format"
" `name#discriminator`.")
class UnambiguousUser(UserConverter):
"""
Converts to a `discord.User`, but only if a mention, userID or a username (name#discrim) is provided.
Unlike the default `UserConverter`, it doesn't allow conversion from a name.
This is useful in cases where that lookup strategy would lead to too much ambiguity.
"""
async def convert(self, ctx: Context, argument: str) -> discord.User:
"""Convert the `argument` to a `discord.User`."""
if _is_an_unambiguous_user_argument(argument):
return await super().convert(ctx, argument)
else:
raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument))
class UnambiguousMember(MemberConverter):
"""
Converts to a `discord.Member`, but only if a mention, userID or a username (name#discrim) is provided.
Unlike the default `MemberConverter`, it doesn't allow conversion from a name or nickname.
This is useful in cases where that lookup strategy would lead to too much ambiguity.
"""
async def convert(self, ctx: Context, argument: str) -> discord.Member:
"""Convert the `argument` to a `discord.Member`."""
if _is_an_unambiguous_user_argument(argument):
return await super().convert(ctx, argument)
else:
raise BadArgument(AMBIGUOUS_ARGUMENT_MSG.format(argument=argument))
class Infraction(Converter):
"""
Attempts to convert a given infraction ID into an infraction.
Alternatively, `l`, `last`, or `recent` can be passed in order to
obtain the most recent infraction by the actor.
"""
async def convert(self, ctx: Context, arg: str) -> t.Optional[dict]:
"""Attempts to convert `arg` into an infraction `dict`."""
if arg in ("l", "last", "recent"):
params = {
"actor__id": ctx.author.id,
"ordering": "-inserted_at"
}
infractions = await ctx.bot.api_client.get("bot/infractions/expanded", params=params)
if not infractions:
raise BadArgument(
"Couldn't find most recent infraction; you have never given an infraction."
)
else:
return infractions[0]
else:
try:
return await ctx.bot.api_client.get(f"bot/infractions/{arg}/expanded")
except ResponseCodeError as e:
if e.status == 404:
raise InvalidInfraction(
converter=Infraction,
original=e,
infraction_arg=arg
)
raise e
if t.TYPE_CHECKING:
ValidDiscordServerInvite = dict # noqa: F811
ValidFilterListType = str # noqa: F811
Extension = str # noqa: F811
PackageName = str # noqa: F811
ValidURL = str # noqa: F811
Inventory = t.Tuple[str, _inventory_parser.InventoryDict] # noqa: F811
Snowflake = int # noqa: F811
SourceConverter = SourceType # noqa: F811
DurationDelta = relativedelta # noqa: F811
Duration = datetime # noqa: F811
Age = datetime # noqa: F811
OffTopicName = str # noqa: F811
ISODateTime = datetime # noqa: F811
HushDurationConverter = int # noqa: F811
UnambiguousUser = discord.User # noqa: F811
UnambiguousMember = discord.Member # noqa: F811
Infraction = t.Optional[dict] # noqa: F811
Expiry = t.Union[Duration, ISODateTime]
DurationOrExpiry = t.Union[DurationDelta, ISODateTime]
MemberOrUser = t.Union[discord.Member, discord.User]
UnambiguousMemberOrUser = t.Union[UnambiguousMember, UnambiguousUser]