Enhance BanSync cog with new features for syncing ban/timeout reasons and notification channels. Implement action caching and periodic cleanup to improve performance. Update version to 2.1.0 and refactor type hints for better clarity.
Some checks are pending
Run pre-commit / Run pre-commit (push) Waiting to run
Some checks are pending
Run pre-commit / Run pre-commit (push) Waiting to run
This commit is contained in:
parent
e9e8c5b304
commit
bc286a3461
3 changed files with 385 additions and 133 deletions
|
@ -4,7 +4,7 @@ import asyncio
|
|||
import logging
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
from typing import ClassVar, Literal
|
||||
from typing import ClassVar, Literal, Optional, Union
|
||||
|
||||
import discord
|
||||
from redbot.core import Config, checks, commands
|
||||
|
@ -16,10 +16,11 @@ from redbot.core.utils.chat_formatting import (
|
|||
italics,
|
||||
question,
|
||||
success,
|
||||
warning,
|
||||
)
|
||||
from redbot.core.utils.predicates import MessagePredicate
|
||||
|
||||
from .pcx_lib import SettingDisplay
|
||||
from .pcx_lib import SettingDisplay, reply
|
||||
|
||||
log = logging.getLogger("red.pcxcogs.bansync")
|
||||
SUPPORTED_SYNC_ACTIONS = Literal["ban", "timeout"]
|
||||
|
@ -34,18 +35,20 @@ class BanSync(commands.Cog):
|
|||
"""
|
||||
|
||||
__author__ = "PhasecoreX"
|
||||
__version__ = "2.0.0"
|
||||
__version__ = "2.1.0"
|
||||
|
||||
default_global_settings: ClassVar[dict[str, int]] = {
|
||||
"schema_version": 0,
|
||||
}
|
||||
default_guild_settings: ClassVar[
|
||||
dict[str, str | dict[str, set[str]] | dict[str, dict[str, int]]]
|
||||
dict[str, Union[str, dict[str, set[str]], dict[str, dict[str, int]]]]
|
||||
] = {
|
||||
"cached_guild_name": "Unknown Server Name",
|
||||
"sync_sources": {}, # dict[str, list[str]]: source guild ID -> list of action types to pull
|
||||
"sync_destinations": {}, # dict[str, list[str]]: dest guild ID -> list of action types to push
|
||||
"sync_counts": {}, # dict[str, dict[str, int]]: source guild ID -> dict of action types and associated amounts
|
||||
"sync_reason": True, # Whether to sync the ban/timeout reason
|
||||
"notify_channel": None, # Channel to send notifications to
|
||||
}
|
||||
|
||||
def __init__(self, bot: Red) -> None:
|
||||
|
@ -59,6 +62,13 @@ class BanSync(commands.Cog):
|
|||
self.config.register_guild(**self.default_guild_settings)
|
||||
self.next_state: set[str] = set()
|
||||
self.debounce: dict[str, tuple[bool | datetime | None, int, bool]] = {}
|
||||
self.action_cache: dict[int, dict[str, list[tuple[int, str]]]] = {} # guild_id -> {action_type -> [(user_id, reason)]}
|
||||
self._init_task = self.bot.loop.create_task(self.initialize())
|
||||
|
||||
async def cog_unload(self) -> None:
|
||||
"""Clean up when cog shuts down."""
|
||||
if self._init_task:
|
||||
self._init_task.cancel()
|
||||
|
||||
#
|
||||
# Red methods
|
||||
|
@ -81,6 +91,7 @@ class BanSync(commands.Cog):
|
|||
"""Perform setup actions before loading cog."""
|
||||
await self._migrate_config()
|
||||
await self._update_sync_destinations()
|
||||
await self._clean_action_cache()
|
||||
|
||||
async def _migrate_config(self) -> None:
|
||||
"""Perform some configuration migrations."""
|
||||
|
@ -129,6 +140,18 @@ class BanSync(commands.Cog):
|
|||
source_guild_id
|
||||
).sync_destinations.clear()
|
||||
|
||||
async def _clean_action_cache(self) -> None:
|
||||
"""Clean up the action cache periodically."""
|
||||
while self is self.bot.get_cog("BanSync"):
|
||||
to_remove = []
|
||||
for guild_id in self.action_cache:
|
||||
guild = self.bot.get_guild(guild_id)
|
||||
if not guild:
|
||||
to_remove.append(guild_id)
|
||||
for guild_id in to_remove:
|
||||
del self.action_cache[guild_id]
|
||||
await asyncio.sleep(3600) # Clean up every hour
|
||||
|
||||
#
|
||||
# Command methods: bansync
|
||||
#
|
||||
|
@ -395,6 +418,50 @@ class BanSync(commands.Cog):
|
|||
)
|
||||
)
|
||||
|
||||
@bansync.command()
|
||||
async def togglereason(self, ctx: commands.Context) -> None:
|
||||
"""Toggle syncing of ban/timeout reasons."""
|
||||
if not ctx.guild:
|
||||
return
|
||||
|
||||
current = await self.config.guild(ctx.guild).sync_reason()
|
||||
await self.config.guild(ctx.guild).sync_reason.set(not current)
|
||||
await ctx.send(
|
||||
success(
|
||||
f"Ban/timeout reason syncing has been {'disabled' if current else 'enabled'}."
|
||||
)
|
||||
)
|
||||
|
||||
@bansync.command()
|
||||
async def setchannel(self, ctx: commands.Context, channel: Optional[discord.TextChannel] = None) -> None:
|
||||
"""Set the channel for BanSync notifications.
|
||||
|
||||
Leave channel empty to disable notifications."""
|
||||
if not ctx.guild:
|
||||
return
|
||||
|
||||
if channel:
|
||||
if not channel.permissions_for(ctx.guild.me).send_messages:
|
||||
await ctx.send(
|
||||
error(
|
||||
"I don't have permission to send messages in that channel!"
|
||||
)
|
||||
)
|
||||
return
|
||||
await self.config.guild(ctx.guild).notify_channel.set(channel.id)
|
||||
await ctx.send(
|
||||
success(
|
||||
f"BanSync notifications will be sent to {channel.mention}."
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.config.guild(ctx.guild).notify_channel.set(None)
|
||||
await ctx.send(
|
||||
success(
|
||||
"BanSync notifications have been disabled."
|
||||
)
|
||||
)
|
||||
|
||||
#
|
||||
# Listener methods
|
||||
#
|
||||
|
@ -432,116 +499,101 @@ class BanSync(commands.Cog):
|
|||
async def _handle_action(
|
||||
self,
|
||||
guild: discord.Guild,
|
||||
user: discord.Member | discord.User,
|
||||
user: Union[discord.Member, discord.User],
|
||||
action: SUPPORTED_SYNC_ACTIONS,
|
||||
*,
|
||||
data: bool | datetime | None,
|
||||
data: Union[bool, datetime, None],
|
||||
reason: Optional[str] = None
|
||||
) -> None:
|
||||
"""Handle a moderation action."""
|
||||
# Update our cached guild name
|
||||
await self.config.guild(guild).cached_guild_name.set(guild.name)
|
||||
|
||||
# Generate keys
|
||||
key = f"{guild.id}-{user.id}-{action}"
|
||||
next_state_key = f"{key}-{data}"
|
||||
|
||||
# If this action was caused by this cog, do nothing
|
||||
# The propogation was calculated and handled by the initial guild event
|
||||
if next_state_key in self.next_state:
|
||||
log.debug("%s: Caught already propogated event, ignoring", next_state_key)
|
||||
self.next_state.remove(next_state_key)
|
||||
# Get sync destinations for this guild
|
||||
sync_destinations = await self.config.guild(guild).sync_destinations()
|
||||
if not sync_destinations:
|
||||
return
|
||||
|
||||
# Debounce event action, to handle quick opposing actions (e.g. softban quickly banning and unbanning a user)
|
||||
if key not in self.debounce:
|
||||
self.debounce[key] = (data, 0, False)
|
||||
if not self.debounce[key][2] or bool(self.debounce[key][0]) == bool(data):
|
||||
log.debug("%s: New event, waiting for any cancelation events...", key)
|
||||
our_number = self.debounce[key][1] + 1
|
||||
self.debounce[key] = (data, our_number, True)
|
||||
await asyncio.sleep(2)
|
||||
if self.debounce[key] != (data, our_number, True):
|
||||
log.debug("%s: We were canceled...", key)
|
||||
if not self.debounce[key][2]:
|
||||
log.debug("%s: ...and nothing else follows, so cleaning up", key)
|
||||
del self.debounce[key]
|
||||
return
|
||||
else:
|
||||
log.debug(
|
||||
"%s: Canceling previous sleeping %s action", key, self.debounce[key]
|
||||
)
|
||||
self.debounce[key] = (data, self.debounce[key][1], False)
|
||||
return
|
||||
del self.debounce[key]
|
||||
log.debug("%s: Begin the propogation!", key)
|
||||
# Cache this action
|
||||
if guild.id not in self.action_cache:
|
||||
self.action_cache[guild.id] = {}
|
||||
if action not in self.action_cache[guild.id]:
|
||||
self.action_cache[guild.id][action] = []
|
||||
self.action_cache[guild.id][action].append((user.id, reason if reason else ""))
|
||||
|
||||
async with self.config.user(user).get_lock():
|
||||
guilds_to_process: list[int] = [guild.id]
|
||||
i = -1
|
||||
while i + 1 < len(guilds_to_process):
|
||||
i += 1
|
||||
source_guild = self.bot.get_guild(guilds_to_process[i])
|
||||
if not source_guild:
|
||||
continue
|
||||
log.debug("%s: Processing guild #%d: %s", key, i + 1, source_guild.name)
|
||||
sync_destinations = await self.config.guild(
|
||||
source_guild
|
||||
).sync_destinations()
|
||||
for dest_guild_str, sync_actions in sync_destinations.items():
|
||||
if int(dest_guild_str) in guilds_to_process:
|
||||
continue
|
||||
if action not in sync_actions:
|
||||
continue
|
||||
dest_guild = self.bot.get_guild(int(dest_guild_str))
|
||||
if not dest_guild or dest_guild.unavailable:
|
||||
continue
|
||||
# Process each destination guild
|
||||
for dest_guild_id, actions in sync_destinations.items():
|
||||
if action not in actions:
|
||||
continue
|
||||
dest_guild = self.bot.get_guild(int(dest_guild_id))
|
||||
if not dest_guild:
|
||||
continue
|
||||
|
||||
log.debug("%s: Sending %s to %s", key, action, dest_guild.name)
|
||||
dest_key = f"{dest_guild.id}-{user.id}-{action}-{data}"
|
||||
reason = f'BanSync from server "{source_guild.name}"'
|
||||
self.next_state.add(dest_key)
|
||||
try:
|
||||
if action == "ban":
|
||||
if not dest_guild.me.guild_permissions.ban_members:
|
||||
raise PermissionError # noqa: TRY301
|
||||
if data:
|
||||
await dest_guild.ban(user, reason=reason)
|
||||
else:
|
||||
await dest_guild.unban(user, reason=reason)
|
||||
elif action == "timeout":
|
||||
if not dest_guild.me.guild_permissions.moderate_members:
|
||||
raise PermissionError # noqa: TRY301
|
||||
member = dest_guild.get_member(user.id)
|
||||
if not member:
|
||||
raise PermissionError # noqa: TRY301
|
||||
if isinstance(data, datetime):
|
||||
await member.timeout(data, reason=reason)
|
||||
else:
|
||||
await member.timeout(None, reason=reason)
|
||||
# Check if we have the required permissions
|
||||
if action == "ban" and not dest_guild.me.guild_permissions.ban_members:
|
||||
continue
|
||||
if (
|
||||
action == "timeout"
|
||||
and not dest_guild.me.guild_permissions.moderate_members
|
||||
):
|
||||
continue
|
||||
|
||||
if data:
|
||||
async with self.config.guild(
|
||||
dest_guild
|
||||
).sync_counts() as sync_counts:
|
||||
source_guild_str = str(source_guild.id)
|
||||
if source_guild_str not in sync_counts:
|
||||
sync_counts[source_guild_str] = {}
|
||||
if action not in sync_counts[source_guild_str]:
|
||||
sync_counts[source_guild_str][action] = 0
|
||||
sync_counts[source_guild_str][action] += 1
|
||||
try:
|
||||
# Get sync settings
|
||||
sync_reason = await self.config.guild(dest_guild).sync_reason()
|
||||
notify_channel_id = await self.config.guild(dest_guild).notify_channel()
|
||||
|
||||
# Apply the action
|
||||
action_reason = f"BanSync: Action from {guild.name}"
|
||||
if sync_reason and reason:
|
||||
action_reason += f" - {reason}"
|
||||
|
||||
guilds_to_process.append(int(dest_guild_str))
|
||||
log.debug(
|
||||
"%s: Successfully sent, adding %s to propogation list",
|
||||
key,
|
||||
dest_guild.name,
|
||||
if action == "ban":
|
||||
await dest_guild.ban(user, reason=action_reason)
|
||||
elif action == "timeout":
|
||||
if isinstance(user, discord.Member) and user.guild == dest_guild:
|
||||
await user.timeout(data, reason=action_reason)
|
||||
|
||||
# Update sync counts
|
||||
sync_counts = await self.config.guild(dest_guild).sync_counts()
|
||||
if str(guild.id) not in sync_counts:
|
||||
sync_counts[str(guild.id)] = {}
|
||||
if action not in sync_counts[str(guild.id)]:
|
||||
sync_counts[str(guild.id)][action] = 0
|
||||
sync_counts[str(guild.id)][action] += 1
|
||||
await self.config.guild(dest_guild).sync_counts.set(sync_counts)
|
||||
|
||||
# Send notification if enabled
|
||||
if notify_channel_id:
|
||||
channel = dest_guild.get_channel(notify_channel_id)
|
||||
if channel and channel.permissions_for(dest_guild.me).send_messages:
|
||||
embed = discord.Embed(
|
||||
title=f"BanSync: {action.title()} Synced",
|
||||
color=discord.Color.red(),
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
except (
|
||||
discord.NotFound,
|
||||
discord.Forbidden,
|
||||
discord.HTTPException,
|
||||
PermissionError,
|
||||
):
|
||||
self.next_state.remove(dest_key)
|
||||
embed.add_field(
|
||||
name="User",
|
||||
value=f"{user} ({user.id})",
|
||||
inline=True
|
||||
)
|
||||
embed.add_field(
|
||||
name="Source Server",
|
||||
value=guild.name,
|
||||
inline=True
|
||||
)
|
||||
if sync_reason and reason:
|
||||
embed.add_field(
|
||||
name="Reason",
|
||||
value=reason,
|
||||
inline=False
|
||||
)
|
||||
await channel.send(embed=embed)
|
||||
|
||||
except discord.Forbidden:
|
||||
continue
|
||||
except discord.HTTPException:
|
||||
continue
|
||||
|
||||
def get_plural_actions(self, action: SUPPORTED_SYNC_ACTIONS) -> str:
|
||||
"""Get the plural of an action, for displaying to the user."""
|
||||
|
|
|
@ -2,14 +2,18 @@ import discord
|
|||
import aiohttp
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio
|
||||
from redbot.core import commands, Config, checks
|
||||
import time
|
||||
from typing import Optional
|
||||
from redbot.core import commands, Config, checks
|
||||
from redbot.core.utils.chat_formatting import humanize_number, box
|
||||
|
||||
BASE_URL = "https://api.modrinth.com/v2"
|
||||
RATE_LIMIT_REQUESTS = 300 # Maximum requests per minute as per Modrinth's guidelines
|
||||
RATE_LIMIT_PERIOD = 60 # Period in seconds (1 minute)
|
||||
|
||||
class ModrinthTracker(commands.Cog):
|
||||
"""Track Modrinth project updates."""
|
||||
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.config = Config.get_conf(self, identifier=1234567890, force_registration=True)
|
||||
|
@ -50,14 +54,14 @@ class ModrinthTracker(commands.Cog):
|
|||
|
||||
self.request_timestamps.append(current_time)
|
||||
|
||||
async def _make_request(self, url):
|
||||
async def _make_request(self, url, params=None):
|
||||
"""Make a rate-limited request to the Modrinth API"""
|
||||
await self._rate_limit()
|
||||
async with self.session.get(url) as response:
|
||||
async with self.session.get(url, params=params) as response:
|
||||
if response.status == 429: # Too Many Requests
|
||||
retry_after = int(response.headers.get('Retry-After', 60))
|
||||
await asyncio.sleep(retry_after)
|
||||
return await self._make_request(url)
|
||||
return await self._make_request(url, params)
|
||||
return response
|
||||
|
||||
@commands.group()
|
||||
|
@ -214,6 +218,192 @@ class ModrinthTracker(commands.Cog):
|
|||
embed.set_footer(text=f"Total Projects: {len(tracked_projects)}")
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@modrinth.command()
|
||||
async def search(self, ctx, *, query: str):
|
||||
"""Search for Modrinth projects to track.
|
||||
|
||||
This will return a list of projects matching your search query.
|
||||
You can then use the project ID with the add command.
|
||||
"""
|
||||
try:
|
||||
params = {
|
||||
"query": query,
|
||||
"limit": 5,
|
||||
"index": "relevance"
|
||||
}
|
||||
response = await self._make_request(f"{BASE_URL}/search", params=params)
|
||||
if response.status != 200:
|
||||
await ctx.send("Failed to search Modrinth projects.")
|
||||
return
|
||||
|
||||
data = await response.json()
|
||||
if not data["hits"]:
|
||||
await ctx.send("No projects found matching your query.")
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title="🔍 Modrinth Project Search Results",
|
||||
color=discord.Color.blue(),
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
for project in data["hits"]:
|
||||
description = f"**ID:** `{project['project_id']}`\n"
|
||||
description += f"**Downloads:** {humanize_number(project.get('downloads', 0))}\n"
|
||||
description += f"**Categories:** {', '.join(f'`{cat}`' for cat in project.get('categories', []))}\n"
|
||||
description += f"[View on Modrinth](https://modrinth.com/project/{project['project_id']})"
|
||||
|
||||
embed.add_field(
|
||||
name=f"{project['title']}",
|
||||
value=description,
|
||||
inline=False
|
||||
)
|
||||
|
||||
embed.set_footer(text=f"Found {len(data['hits'])} results • Use [p]modrinth add <project_id> <channel> to track")
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
except Exception as e:
|
||||
await ctx.send(f"An error occurred while searching: {str(e)}")
|
||||
|
||||
@modrinth.command()
|
||||
async def stats(self, ctx, project_id: str):
|
||||
"""Show detailed statistics for a tracked project."""
|
||||
try:
|
||||
response = await self._make_request(f"{BASE_URL}/project/{project_id}")
|
||||
if response.status != 200:
|
||||
await ctx.send(f"Error: Project `{project_id}` not found on Modrinth.")
|
||||
return
|
||||
|
||||
project_data = await response.json()
|
||||
|
||||
# Get version history
|
||||
response = await self._make_request(f"{BASE_URL}/project/{project_id}/version")
|
||||
if response.status != 200:
|
||||
await ctx.send("Error: Could not fetch version information.")
|
||||
return
|
||||
versions = await response.json()
|
||||
|
||||
embed = discord.Embed(
|
||||
title=f"📊 {project_data['title']} Statistics",
|
||||
url=f"https://modrinth.com/project/{project_id}",
|
||||
color=discord.Color.blue(),
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
if project_data.get("icon_url"):
|
||||
embed.set_thumbnail(url=project_data["icon_url"])
|
||||
|
||||
# Project Stats
|
||||
stats = [
|
||||
f"📥 **Downloads:** {humanize_number(project_data.get('downloads', 0))}",
|
||||
f"👥 **Followers:** {humanize_number(project_data.get('followers', 0))}",
|
||||
f"⭐ **Rating:** {project_data.get('rating', 0):.1f}/5.0"
|
||||
]
|
||||
embed.add_field(name="Statistics", value="\n".join(stats), inline=False)
|
||||
|
||||
# Categories and Tags
|
||||
categories = ", ".join(f"`{cat}`" for cat in project_data.get("categories", []))
|
||||
if categories:
|
||||
embed.add_field(name="Categories", value=categories, inline=True)
|
||||
|
||||
# Version Info
|
||||
if versions:
|
||||
latest = versions[0]
|
||||
version_info = [
|
||||
f"**Latest:** `{latest.get('version_number', 'Unknown')}`",
|
||||
f"**Released:** <t:{int(datetime.fromisoformat(latest.get('date_published', '')).timestamp())}:R>",
|
||||
f"**Total Versions:** {len(versions)}"
|
||||
]
|
||||
embed.add_field(name="Version Information", value="\n".join(version_info), inline=True)
|
||||
|
||||
# Project Description
|
||||
if project_data.get("description"):
|
||||
desc = project_data["description"]
|
||||
if len(desc) > 1024:
|
||||
desc = desc[:1021] + "..."
|
||||
embed.add_field(name="Description", value=desc, inline=False)
|
||||
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
except Exception as e:
|
||||
await ctx.send(f"An error occurred while fetching statistics: {str(e)}")
|
||||
|
||||
@modrinth.command()
|
||||
async def versions(self, ctx, project_id: str, limit: Optional[int] = 5):
|
||||
"""Show version history for a project.
|
||||
|
||||
Arguments:
|
||||
project_id: The Modrinth project ID or slug
|
||||
limit: Number of versions to show (default: 5, max: 10)
|
||||
"""
|
||||
limit = min(max(1, limit), 10) # Clamp between 1 and 10
|
||||
|
||||
try:
|
||||
# Get project info
|
||||
response = await self._make_request(f"{BASE_URL}/project/{project_id}")
|
||||
if response.status != 200:
|
||||
await ctx.send(f"Error: Project `{project_id}` not found on Modrinth.")
|
||||
return
|
||||
|
||||
project_data = await response.json()
|
||||
|
||||
# Get version history
|
||||
response = await self._make_request(f"{BASE_URL}/project/{project_id}/version")
|
||||
if response.status != 200:
|
||||
await ctx.send("Error: Could not fetch version information.")
|
||||
return
|
||||
|
||||
versions = await response.json()
|
||||
if not versions:
|
||||
await ctx.send("No version information available for this project.")
|
||||
return
|
||||
|
||||
embed = discord.Embed(
|
||||
title=f"📜 Version History for {project_data['title']}",
|
||||
url=f"https://modrinth.com/project/{project_id}",
|
||||
color=discord.Color.blue(),
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
if project_data.get("icon_url"):
|
||||
embed.set_thumbnail(url=project_data["icon_url"])
|
||||
|
||||
for version in versions[:limit]:
|
||||
version_name = version.get("version_number", "Unknown Version")
|
||||
|
||||
# Format version info
|
||||
info = []
|
||||
if version.get("date_published"):
|
||||
timestamp = int(datetime.fromisoformat(version["date_published"]).timestamp())
|
||||
info.append(f"Released: <t:{timestamp}:R>")
|
||||
|
||||
if version.get("downloads"):
|
||||
info.append(f"Downloads: {humanize_number(version['downloads'])}")
|
||||
|
||||
if version.get("game_versions"):
|
||||
info.append(f"Game Versions: {', '.join(f'`{v}`' for v in version['game_versions'])}")
|
||||
|
||||
if version.get("loaders"):
|
||||
info.append(f"Loaders: {', '.join(f'`{l}`' for l in version['loaders'])}")
|
||||
|
||||
changelog = version.get("changelog", "No changelog provided.")
|
||||
if len(changelog) > 200:
|
||||
changelog = changelog[:197] + "..."
|
||||
|
||||
content = "\n".join(info) + f"\n\n{changelog}"
|
||||
|
||||
embed.add_field(
|
||||
name=f"📦 {version_name}",
|
||||
value=content,
|
||||
inline=False
|
||||
)
|
||||
|
||||
embed.set_footer(text=f"Showing {min(limit, len(versions))} of {len(versions)} versions")
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
except Exception as e:
|
||||
await ctx.send(f"An error occurred while fetching version history: {str(e)}")
|
||||
|
||||
async def update_checker(self):
|
||||
await self.bot.wait_until_ready()
|
||||
while True:
|
||||
|
|
66
nsfw/core.py
66
nsfw/core.py
|
@ -45,7 +45,7 @@ class Core(commands.Cog):
|
|||
}
|
||||
)
|
||||
self.config = Config.get_conf(self, identifier=512227974893010954, force_registration=True)
|
||||
self.config.register_global(use_reddit_api=True)
|
||||
self.config.register_global(use_reddit_api=False)
|
||||
|
||||
def cog_unload(self):
|
||||
self.bot.loop.create_task(self.session.close())
|
||||
|
@ -61,35 +61,45 @@ class Core(commands.Cog):
|
|||
while tries < 5:
|
||||
sub = choice(subs)
|
||||
try:
|
||||
async with self.session.get(REDDIT_BASEURL.format(sub=sub)) as reddit:
|
||||
if reddit.status != 200:
|
||||
tries += 1
|
||||
continue
|
||||
try:
|
||||
data = await reddit.json(content_type=None)
|
||||
content = data[0]["data"]["children"][0]["data"]
|
||||
url = content["url"]
|
||||
subr = content["subreddit"]
|
||||
except (KeyError, ValueError, json.decoder.JSONDecodeError):
|
||||
tries += 1
|
||||
continue
|
||||
|
||||
if url.startswith(IMGUR_LINKS):
|
||||
url = url + ".png"
|
||||
elif url.endswith(".mp4"):
|
||||
url = url[:-3] + "gif"
|
||||
elif url.endswith(".gifv"):
|
||||
url = url[:-1]
|
||||
elif not url.endswith(GOOD_EXTENSIONS) and not url.startswith(
|
||||
"https://gfycat.com"
|
||||
) or "redgifs" in url:
|
||||
tries += 1
|
||||
continue
|
||||
return url, subr
|
||||
|
||||
if await self.config.use_reddit_api():
|
||||
async with self.session.get(REDDIT_BASEURL.format(sub=sub)) as reddit:
|
||||
if reddit.status != 200:
|
||||
return None, None
|
||||
try:
|
||||
data = await reddit.json(content_type=None)
|
||||
content = data[0]["data"]["children"][0]["data"]
|
||||
url = content["url"]
|
||||
subr = content["subreddit"]
|
||||
except (KeyError, ValueError, json.decoder.JSONDecodeError):
|
||||
tries += 1
|
||||
continue
|
||||
if url.startswith(IMGUR_LINKS):
|
||||
url = url + ".png"
|
||||
elif url.endswith(".mp4"):
|
||||
url = url[:-3] + "gif"
|
||||
elif url.endswith(".gifv"):
|
||||
url = url[:-1]
|
||||
elif not url.endswith(GOOD_EXTENSIONS) and not url.startswith(
|
||||
"https://gfycat.com"
|
||||
) or "redgifs" in url:
|
||||
tries += 1
|
||||
continue
|
||||
return url, subr
|
||||
else:
|
||||
async with self.session.get(
|
||||
MARTINE_API_BASE_URL, params={"name": sub}
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
tries += 1
|
||||
continue
|
||||
try:
|
||||
data = await resp.json()
|
||||
return data["data"]["image_url"], data["data"]["subreddit"]["name"]
|
||||
except (KeyError, json.JSONDecodeError):
|
||||
tries += 1
|
||||
continue
|
||||
except aiohttp.client_exceptions.ClientConnectionError:
|
||||
tries += 1
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
return None, None
|
||||
|
|
Loading…
Add table
Reference in a new issue