Enhance BanSync cog with new features for syncing ban/timeout reasons and notification channels. Implement action caching and periodic cleanup to improve performance. Update version to 2.1.0 and refactor type hints for better clarity.
Some checks are pending
Run pre-commit / Run pre-commit (push) Waiting to run

This commit is contained in:
Valerie 2025-05-23 06:06:28 -04:00
parent e9e8c5b304
commit bc286a3461
3 changed files with 385 additions and 133 deletions

View file

@ -4,7 +4,7 @@ import asyncio
import logging import logging
from contextlib import suppress from contextlib import suppress
from datetime import datetime from datetime import datetime
from typing import ClassVar, Literal from typing import ClassVar, Literal, Optional, Union
import discord import discord
from redbot.core import Config, checks, commands from redbot.core import Config, checks, commands
@ -16,10 +16,11 @@ from redbot.core.utils.chat_formatting import (
italics, italics,
question, question,
success, success,
warning,
) )
from redbot.core.utils.predicates import MessagePredicate from redbot.core.utils.predicates import MessagePredicate
from .pcx_lib import SettingDisplay from .pcx_lib import SettingDisplay, reply
log = logging.getLogger("red.pcxcogs.bansync") log = logging.getLogger("red.pcxcogs.bansync")
SUPPORTED_SYNC_ACTIONS = Literal["ban", "timeout"] SUPPORTED_SYNC_ACTIONS = Literal["ban", "timeout"]
@ -34,18 +35,20 @@ class BanSync(commands.Cog):
""" """
__author__ = "PhasecoreX" __author__ = "PhasecoreX"
__version__ = "2.0.0" __version__ = "2.1.0"
default_global_settings: ClassVar[dict[str, int]] = { default_global_settings: ClassVar[dict[str, int]] = {
"schema_version": 0, "schema_version": 0,
} }
default_guild_settings: ClassVar[ default_guild_settings: ClassVar[
dict[str, str | dict[str, set[str]] | dict[str, dict[str, int]]] dict[str, Union[str, dict[str, set[str]], dict[str, dict[str, int]]]]
] = { ] = {
"cached_guild_name": "Unknown Server Name", "cached_guild_name": "Unknown Server Name",
"sync_sources": {}, # dict[str, list[str]]: source guild ID -> list of action types to pull "sync_sources": {}, # dict[str, list[str]]: source guild ID -> list of action types to pull
"sync_destinations": {}, # dict[str, list[str]]: dest guild ID -> list of action types to push "sync_destinations": {}, # dict[str, list[str]]: dest guild ID -> list of action types to push
"sync_counts": {}, # dict[str, dict[str, int]]: source guild ID -> dict of action types and associated amounts "sync_counts": {}, # dict[str, dict[str, int]]: source guild ID -> dict of action types and associated amounts
"sync_reason": True, # Whether to sync the ban/timeout reason
"notify_channel": None, # Channel to send notifications to
} }
def __init__(self, bot: Red) -> None: def __init__(self, bot: Red) -> None:
@ -59,6 +62,13 @@ class BanSync(commands.Cog):
self.config.register_guild(**self.default_guild_settings) self.config.register_guild(**self.default_guild_settings)
self.next_state: set[str] = set() self.next_state: set[str] = set()
self.debounce: dict[str, tuple[bool | datetime | None, int, bool]] = {} self.debounce: dict[str, tuple[bool | datetime | None, int, bool]] = {}
self.action_cache: dict[int, dict[str, list[tuple[int, str]]]] = {} # guild_id -> {action_type -> [(user_id, reason)]}
self._init_task = self.bot.loop.create_task(self.initialize())
async def cog_unload(self) -> None:
"""Clean up when cog shuts down."""
if self._init_task:
self._init_task.cancel()
# #
# Red methods # Red methods
@ -81,6 +91,7 @@ class BanSync(commands.Cog):
"""Perform setup actions before loading cog.""" """Perform setup actions before loading cog."""
await self._migrate_config() await self._migrate_config()
await self._update_sync_destinations() await self._update_sync_destinations()
await self._clean_action_cache()
async def _migrate_config(self) -> None: async def _migrate_config(self) -> None:
"""Perform some configuration migrations.""" """Perform some configuration migrations."""
@ -129,6 +140,18 @@ class BanSync(commands.Cog):
source_guild_id source_guild_id
).sync_destinations.clear() ).sync_destinations.clear()
async def _clean_action_cache(self) -> None:
"""Clean up the action cache periodically."""
while self is self.bot.get_cog("BanSync"):
to_remove = []
for guild_id in self.action_cache:
guild = self.bot.get_guild(guild_id)
if not guild:
to_remove.append(guild_id)
for guild_id in to_remove:
del self.action_cache[guild_id]
await asyncio.sleep(3600) # Clean up every hour
# #
# Command methods: bansync # Command methods: bansync
# #
@ -395,6 +418,50 @@ class BanSync(commands.Cog):
) )
) )
@bansync.command()
async def togglereason(self, ctx: commands.Context) -> None:
"""Toggle syncing of ban/timeout reasons."""
if not ctx.guild:
return
current = await self.config.guild(ctx.guild).sync_reason()
await self.config.guild(ctx.guild).sync_reason.set(not current)
await ctx.send(
success(
f"Ban/timeout reason syncing has been {'disabled' if current else 'enabled'}."
)
)
@bansync.command()
async def setchannel(self, ctx: commands.Context, channel: Optional[discord.TextChannel] = None) -> None:
"""Set the channel for BanSync notifications.
Leave channel empty to disable notifications."""
if not ctx.guild:
return
if channel:
if not channel.permissions_for(ctx.guild.me).send_messages:
await ctx.send(
error(
"I don't have permission to send messages in that channel!"
)
)
return
await self.config.guild(ctx.guild).notify_channel.set(channel.id)
await ctx.send(
success(
f"BanSync notifications will be sent to {channel.mention}."
)
)
else:
await self.config.guild(ctx.guild).notify_channel.set(None)
await ctx.send(
success(
"BanSync notifications have been disabled."
)
)
# #
# Listener methods # Listener methods
# #
@ -432,116 +499,101 @@ class BanSync(commands.Cog):
async def _handle_action( async def _handle_action(
self, self,
guild: discord.Guild, guild: discord.Guild,
user: discord.Member | discord.User, user: Union[discord.Member, discord.User],
action: SUPPORTED_SYNC_ACTIONS, action: SUPPORTED_SYNC_ACTIONS,
*, *,
data: bool | datetime | None, data: Union[bool, datetime, None],
reason: Optional[str] = None
) -> None: ) -> None:
"""Handle a moderation action."""
# Update our cached guild name # Update our cached guild name
await self.config.guild(guild).cached_guild_name.set(guild.name) await self.config.guild(guild).cached_guild_name.set(guild.name)
# Generate keys # Get sync destinations for this guild
key = f"{guild.id}-{user.id}-{action}" sync_destinations = await self.config.guild(guild).sync_destinations()
next_state_key = f"{key}-{data}" if not sync_destinations:
# If this action was caused by this cog, do nothing
# The propogation was calculated and handled by the initial guild event
if next_state_key in self.next_state:
log.debug("%s: Caught already propogated event, ignoring", next_state_key)
self.next_state.remove(next_state_key)
return return
# Debounce event action, to handle quick opposing actions (e.g. softban quickly banning and unbanning a user) # Cache this action
if key not in self.debounce: if guild.id not in self.action_cache:
self.debounce[key] = (data, 0, False) self.action_cache[guild.id] = {}
if not self.debounce[key][2] or bool(self.debounce[key][0]) == bool(data): if action not in self.action_cache[guild.id]:
log.debug("%s: New event, waiting for any cancelation events...", key) self.action_cache[guild.id][action] = []
our_number = self.debounce[key][1] + 1 self.action_cache[guild.id][action].append((user.id, reason if reason else ""))
self.debounce[key] = (data, our_number, True)
await asyncio.sleep(2)
if self.debounce[key] != (data, our_number, True):
log.debug("%s: We were canceled...", key)
if not self.debounce[key][2]:
log.debug("%s: ...and nothing else follows, so cleaning up", key)
del self.debounce[key]
return
else:
log.debug(
"%s: Canceling previous sleeping %s action", key, self.debounce[key]
)
self.debounce[key] = (data, self.debounce[key][1], False)
return
del self.debounce[key]
log.debug("%s: Begin the propogation!", key)
async with self.config.user(user).get_lock(): # Process each destination guild
guilds_to_process: list[int] = [guild.id] for dest_guild_id, actions in sync_destinations.items():
i = -1 if action not in actions:
while i + 1 < len(guilds_to_process): continue
i += 1 dest_guild = self.bot.get_guild(int(dest_guild_id))
source_guild = self.bot.get_guild(guilds_to_process[i]) if not dest_guild:
if not source_guild: continue
continue
log.debug("%s: Processing guild #%d: %s", key, i + 1, source_guild.name)
sync_destinations = await self.config.guild(
source_guild
).sync_destinations()
for dest_guild_str, sync_actions in sync_destinations.items():
if int(dest_guild_str) in guilds_to_process:
continue
if action not in sync_actions:
continue
dest_guild = self.bot.get_guild(int(dest_guild_str))
if not dest_guild or dest_guild.unavailable:
continue
log.debug("%s: Sending %s to %s", key, action, dest_guild.name) # Check if we have the required permissions
dest_key = f"{dest_guild.id}-{user.id}-{action}-{data}" if action == "ban" and not dest_guild.me.guild_permissions.ban_members:
reason = f'BanSync from server "{source_guild.name}"' continue
self.next_state.add(dest_key) if (
try: action == "timeout"
if action == "ban": and not dest_guild.me.guild_permissions.moderate_members
if not dest_guild.me.guild_permissions.ban_members: ):
raise PermissionError # noqa: TRY301 continue
if data:
await dest_guild.ban(user, reason=reason)
else:
await dest_guild.unban(user, reason=reason)
elif action == "timeout":
if not dest_guild.me.guild_permissions.moderate_members:
raise PermissionError # noqa: TRY301
member = dest_guild.get_member(user.id)
if not member:
raise PermissionError # noqa: TRY301
if isinstance(data, datetime):
await member.timeout(data, reason=reason)
else:
await member.timeout(None, reason=reason)
if data: try:
async with self.config.guild( # Get sync settings
dest_guild sync_reason = await self.config.guild(dest_guild).sync_reason()
).sync_counts() as sync_counts: notify_channel_id = await self.config.guild(dest_guild).notify_channel()
source_guild_str = str(source_guild.id)
if source_guild_str not in sync_counts: # Apply the action
sync_counts[source_guild_str] = {} action_reason = f"BanSync: Action from {guild.name}"
if action not in sync_counts[source_guild_str]: if sync_reason and reason:
sync_counts[source_guild_str][action] = 0 action_reason += f" - {reason}"
sync_counts[source_guild_str][action] += 1
guilds_to_process.append(int(dest_guild_str)) if action == "ban":
log.debug( await dest_guild.ban(user, reason=action_reason)
"%s: Successfully sent, adding %s to propogation list", elif action == "timeout":
key, if isinstance(user, discord.Member) and user.guild == dest_guild:
dest_guild.name, await user.timeout(data, reason=action_reason)
# Update sync counts
sync_counts = await self.config.guild(dest_guild).sync_counts()
if str(guild.id) not in sync_counts:
sync_counts[str(guild.id)] = {}
if action not in sync_counts[str(guild.id)]:
sync_counts[str(guild.id)][action] = 0
sync_counts[str(guild.id)][action] += 1
await self.config.guild(dest_guild).sync_counts.set(sync_counts)
# Send notification if enabled
if notify_channel_id:
channel = dest_guild.get_channel(notify_channel_id)
if channel and channel.permissions_for(dest_guild.me).send_messages:
embed = discord.Embed(
title=f"BanSync: {action.title()} Synced",
color=discord.Color.red(),
timestamp=datetime.now()
) )
except ( embed.add_field(
discord.NotFound, name="User",
discord.Forbidden, value=f"{user} ({user.id})",
discord.HTTPException, inline=True
PermissionError, )
): embed.add_field(
self.next_state.remove(dest_key) name="Source Server",
value=guild.name,
inline=True
)
if sync_reason and reason:
embed.add_field(
name="Reason",
value=reason,
inline=False
)
await channel.send(embed=embed)
except discord.Forbidden:
continue
except discord.HTTPException:
continue
def get_plural_actions(self, action: SUPPORTED_SYNC_ACTIONS) -> str: def get_plural_actions(self, action: SUPPORTED_SYNC_ACTIONS) -> str:
"""Get the plural of an action, for displaying to the user.""" """Get the plural of an action, for displaying to the user."""

View file

@ -2,14 +2,18 @@ import discord
import aiohttp import aiohttp
from datetime import datetime, timedelta from datetime import datetime, timedelta
import asyncio import asyncio
from redbot.core import commands, Config, checks
import time import time
from typing import Optional
from redbot.core import commands, Config, checks
from redbot.core.utils.chat_formatting import humanize_number, box
BASE_URL = "https://api.modrinth.com/v2" BASE_URL = "https://api.modrinth.com/v2"
RATE_LIMIT_REQUESTS = 300 # Maximum requests per minute as per Modrinth's guidelines RATE_LIMIT_REQUESTS = 300 # Maximum requests per minute as per Modrinth's guidelines
RATE_LIMIT_PERIOD = 60 # Period in seconds (1 minute) RATE_LIMIT_PERIOD = 60 # Period in seconds (1 minute)
class ModrinthTracker(commands.Cog): class ModrinthTracker(commands.Cog):
"""Track Modrinth project updates."""
def __init__(self, bot): def __init__(self, bot):
self.bot = bot self.bot = bot
self.config = Config.get_conf(self, identifier=1234567890, force_registration=True) self.config = Config.get_conf(self, identifier=1234567890, force_registration=True)
@ -50,14 +54,14 @@ class ModrinthTracker(commands.Cog):
self.request_timestamps.append(current_time) self.request_timestamps.append(current_time)
async def _make_request(self, url): async def _make_request(self, url, params=None):
"""Make a rate-limited request to the Modrinth API""" """Make a rate-limited request to the Modrinth API"""
await self._rate_limit() await self._rate_limit()
async with self.session.get(url) as response: async with self.session.get(url, params=params) as response:
if response.status == 429: # Too Many Requests if response.status == 429: # Too Many Requests
retry_after = int(response.headers.get('Retry-After', 60)) retry_after = int(response.headers.get('Retry-After', 60))
await asyncio.sleep(retry_after) await asyncio.sleep(retry_after)
return await self._make_request(url) return await self._make_request(url, params)
return response return response
@commands.group() @commands.group()
@ -214,6 +218,192 @@ class ModrinthTracker(commands.Cog):
embed.set_footer(text=f"Total Projects: {len(tracked_projects)}") embed.set_footer(text=f"Total Projects: {len(tracked_projects)}")
await ctx.send(embed=embed) await ctx.send(embed=embed)
@modrinth.command()
async def search(self, ctx, *, query: str):
"""Search for Modrinth projects to track.
This will return a list of projects matching your search query.
You can then use the project ID with the add command.
"""
try:
params = {
"query": query,
"limit": 5,
"index": "relevance"
}
response = await self._make_request(f"{BASE_URL}/search", params=params)
if response.status != 200:
await ctx.send("Failed to search Modrinth projects.")
return
data = await response.json()
if not data["hits"]:
await ctx.send("No projects found matching your query.")
return
embed = discord.Embed(
title="🔍 Modrinth Project Search Results",
color=discord.Color.blue(),
timestamp=datetime.now()
)
for project in data["hits"]:
description = f"**ID:** `{project['project_id']}`\n"
description += f"**Downloads:** {humanize_number(project.get('downloads', 0))}\n"
description += f"**Categories:** {', '.join(f'`{cat}`' for cat in project.get('categories', []))}\n"
description += f"[View on Modrinth](https://modrinth.com/project/{project['project_id']})"
embed.add_field(
name=f"{project['title']}",
value=description,
inline=False
)
embed.set_footer(text=f"Found {len(data['hits'])} results • Use [p]modrinth add <project_id> <channel> to track")
await ctx.send(embed=embed)
except Exception as e:
await ctx.send(f"An error occurred while searching: {str(e)}")
@modrinth.command()
async def stats(self, ctx, project_id: str):
"""Show detailed statistics for a tracked project."""
try:
response = await self._make_request(f"{BASE_URL}/project/{project_id}")
if response.status != 200:
await ctx.send(f"Error: Project `{project_id}` not found on Modrinth.")
return
project_data = await response.json()
# Get version history
response = await self._make_request(f"{BASE_URL}/project/{project_id}/version")
if response.status != 200:
await ctx.send("Error: Could not fetch version information.")
return
versions = await response.json()
embed = discord.Embed(
title=f"📊 {project_data['title']} Statistics",
url=f"https://modrinth.com/project/{project_id}",
color=discord.Color.blue(),
timestamp=datetime.now()
)
if project_data.get("icon_url"):
embed.set_thumbnail(url=project_data["icon_url"])
# Project Stats
stats = [
f"📥 **Downloads:** {humanize_number(project_data.get('downloads', 0))}",
f"👥 **Followers:** {humanize_number(project_data.get('followers', 0))}",
f"⭐ **Rating:** {project_data.get('rating', 0):.1f}/5.0"
]
embed.add_field(name="Statistics", value="\n".join(stats), inline=False)
# Categories and Tags
categories = ", ".join(f"`{cat}`" for cat in project_data.get("categories", []))
if categories:
embed.add_field(name="Categories", value=categories, inline=True)
# Version Info
if versions:
latest = versions[0]
version_info = [
f"**Latest:** `{latest.get('version_number', 'Unknown')}`",
f"**Released:** <t:{int(datetime.fromisoformat(latest.get('date_published', '')).timestamp())}:R>",
f"**Total Versions:** {len(versions)}"
]
embed.add_field(name="Version Information", value="\n".join(version_info), inline=True)
# Project Description
if project_data.get("description"):
desc = project_data["description"]
if len(desc) > 1024:
desc = desc[:1021] + "..."
embed.add_field(name="Description", value=desc, inline=False)
await ctx.send(embed=embed)
except Exception as e:
await ctx.send(f"An error occurred while fetching statistics: {str(e)}")
@modrinth.command()
async def versions(self, ctx, project_id: str, limit: Optional[int] = 5):
"""Show version history for a project.
Arguments:
project_id: The Modrinth project ID or slug
limit: Number of versions to show (default: 5, max: 10)
"""
limit = min(max(1, limit), 10) # Clamp between 1 and 10
try:
# Get project info
response = await self._make_request(f"{BASE_URL}/project/{project_id}")
if response.status != 200:
await ctx.send(f"Error: Project `{project_id}` not found on Modrinth.")
return
project_data = await response.json()
# Get version history
response = await self._make_request(f"{BASE_URL}/project/{project_id}/version")
if response.status != 200:
await ctx.send("Error: Could not fetch version information.")
return
versions = await response.json()
if not versions:
await ctx.send("No version information available for this project.")
return
embed = discord.Embed(
title=f"📜 Version History for {project_data['title']}",
url=f"https://modrinth.com/project/{project_id}",
color=discord.Color.blue(),
timestamp=datetime.now()
)
if project_data.get("icon_url"):
embed.set_thumbnail(url=project_data["icon_url"])
for version in versions[:limit]:
version_name = version.get("version_number", "Unknown Version")
# Format version info
info = []
if version.get("date_published"):
timestamp = int(datetime.fromisoformat(version["date_published"]).timestamp())
info.append(f"Released: <t:{timestamp}:R>")
if version.get("downloads"):
info.append(f"Downloads: {humanize_number(version['downloads'])}")
if version.get("game_versions"):
info.append(f"Game Versions: {', '.join(f'`{v}`' for v in version['game_versions'])}")
if version.get("loaders"):
info.append(f"Loaders: {', '.join(f'`{l}`' for l in version['loaders'])}")
changelog = version.get("changelog", "No changelog provided.")
if len(changelog) > 200:
changelog = changelog[:197] + "..."
content = "\n".join(info) + f"\n\n{changelog}"
embed.add_field(
name=f"📦 {version_name}",
value=content,
inline=False
)
embed.set_footer(text=f"Showing {min(limit, len(versions))} of {len(versions)} versions")
await ctx.send(embed=embed)
except Exception as e:
await ctx.send(f"An error occurred while fetching version history: {str(e)}")
async def update_checker(self): async def update_checker(self):
await self.bot.wait_until_ready() await self.bot.wait_until_ready()
while True: while True:

View file

@ -45,7 +45,7 @@ class Core(commands.Cog):
} }
) )
self.config = Config.get_conf(self, identifier=512227974893010954, force_registration=True) self.config = Config.get_conf(self, identifier=512227974893010954, force_registration=True)
self.config.register_global(use_reddit_api=True) self.config.register_global(use_reddit_api=False)
def cog_unload(self): def cog_unload(self):
self.bot.loop.create_task(self.session.close()) self.bot.loop.create_task(self.session.close())
@ -61,35 +61,45 @@ class Core(commands.Cog):
while tries < 5: while tries < 5:
sub = choice(subs) sub = choice(subs)
try: try:
async with self.session.get(REDDIT_BASEURL.format(sub=sub)) as reddit: if await self.config.use_reddit_api():
if reddit.status != 200: async with self.session.get(REDDIT_BASEURL.format(sub=sub)) as reddit:
tries += 1 if reddit.status != 200:
continue return None, None
try: try:
data = await reddit.json(content_type=None) data = await reddit.json(content_type=None)
content = data[0]["data"]["children"][0]["data"] content = data[0]["data"]["children"][0]["data"]
url = content["url"] url = content["url"]
subr = content["subreddit"] subr = content["subreddit"]
except (KeyError, ValueError, json.decoder.JSONDecodeError): except (KeyError, ValueError, json.decoder.JSONDecodeError):
tries += 1 tries += 1
continue continue
if url.startswith(IMGUR_LINKS):
if url.startswith(IMGUR_LINKS): url = url + ".png"
url = url + ".png" elif url.endswith(".mp4"):
elif url.endswith(".mp4"): url = url[:-3] + "gif"
url = url[:-3] + "gif" elif url.endswith(".gifv"):
elif url.endswith(".gifv"): url = url[:-1]
url = url[:-1] elif not url.endswith(GOOD_EXTENSIONS) and not url.startswith(
elif not url.endswith(GOOD_EXTENSIONS) and not url.startswith( "https://gfycat.com"
"https://gfycat.com" ) or "redgifs" in url:
) or "redgifs" in url: tries += 1
tries += 1 continue
continue return url, subr
return url, subr else:
async with self.session.get(
MARTINE_API_BASE_URL, params={"name": sub}
) as resp:
if resp.status != 200:
tries += 1
continue
try:
data = await resp.json()
return data["data"]["image_url"], data["data"]["subreddit"]["name"]
except (KeyError, json.JSONDecodeError):
tries += 1
continue
except aiohttp.client_exceptions.ClientConnectionError: except aiohttp.client_exceptions.ClientConnectionError:
tries += 1 tries += 1
await asyncio.sleep(1)
continue continue
return None, None return None, None