Enhance BanSync cog with new features for syncing ban/timeout reasons and notification channels. Implement action caching and periodic cleanup to improve performance. Update version to 2.1.0 and refactor type hints for better clarity.
Some checks are pending
Run pre-commit / Run pre-commit (push) Waiting to run

This commit is contained in:
Valerie 2025-05-23 06:06:28 -04:00
parent e9e8c5b304
commit bc286a3461
3 changed files with 385 additions and 133 deletions

View file

@ -4,7 +4,7 @@ import asyncio
import logging
from contextlib import suppress
from datetime import datetime
from typing import ClassVar, Literal
from typing import ClassVar, Literal, Optional, Union
import discord
from redbot.core import Config, checks, commands
@ -16,10 +16,11 @@ from redbot.core.utils.chat_formatting import (
italics,
question,
success,
warning,
)
from redbot.core.utils.predicates import MessagePredicate
from .pcx_lib import SettingDisplay
from .pcx_lib import SettingDisplay, reply
log = logging.getLogger("red.pcxcogs.bansync")
SUPPORTED_SYNC_ACTIONS = Literal["ban", "timeout"]
@ -34,18 +35,20 @@ class BanSync(commands.Cog):
"""
__author__ = "PhasecoreX"
__version__ = "2.0.0"
__version__ = "2.1.0"
default_global_settings: ClassVar[dict[str, int]] = {
"schema_version": 0,
}
default_guild_settings: ClassVar[
dict[str, str | dict[str, set[str]] | dict[str, dict[str, int]]]
dict[str, Union[str, dict[str, set[str]], dict[str, dict[str, int]]]]
] = {
"cached_guild_name": "Unknown Server Name",
"sync_sources": {}, # dict[str, list[str]]: source guild ID -> list of action types to pull
"sync_destinations": {}, # dict[str, list[str]]: dest guild ID -> list of action types to push
"sync_counts": {}, # dict[str, dict[str, int]]: source guild ID -> dict of action types and associated amounts
"sync_reason": True, # Whether to sync the ban/timeout reason
"notify_channel": None, # Channel to send notifications to
}
def __init__(self, bot: Red) -> None:
@ -59,6 +62,13 @@ class BanSync(commands.Cog):
self.config.register_guild(**self.default_guild_settings)
self.next_state: set[str] = set()
self.debounce: dict[str, tuple[bool | datetime | None, int, bool]] = {}
self.action_cache: dict[int, dict[str, list[tuple[int, str]]]] = {} # guild_id -> {action_type -> [(user_id, reason)]}
self._init_task = self.bot.loop.create_task(self.initialize())
async def cog_unload(self) -> None:
"""Clean up when cog shuts down."""
if self._init_task:
self._init_task.cancel()
#
# Red methods
@ -81,6 +91,7 @@ class BanSync(commands.Cog):
"""Perform setup actions before loading cog."""
await self._migrate_config()
await self._update_sync_destinations()
await self._clean_action_cache()
async def _migrate_config(self) -> None:
"""Perform some configuration migrations."""
@ -129,6 +140,18 @@ class BanSync(commands.Cog):
source_guild_id
).sync_destinations.clear()
async def _clean_action_cache(self) -> None:
"""Clean up the action cache periodically."""
while self is self.bot.get_cog("BanSync"):
to_remove = []
for guild_id in self.action_cache:
guild = self.bot.get_guild(guild_id)
if not guild:
to_remove.append(guild_id)
for guild_id in to_remove:
del self.action_cache[guild_id]
await asyncio.sleep(3600) # Clean up every hour
#
# Command methods: bansync
#
@ -395,6 +418,50 @@ class BanSync(commands.Cog):
)
)
@bansync.command()
async def togglereason(self, ctx: commands.Context) -> None:
"""Toggle syncing of ban/timeout reasons."""
if not ctx.guild:
return
current = await self.config.guild(ctx.guild).sync_reason()
await self.config.guild(ctx.guild).sync_reason.set(not current)
await ctx.send(
success(
f"Ban/timeout reason syncing has been {'disabled' if current else 'enabled'}."
)
)
@bansync.command()
async def setchannel(self, ctx: commands.Context, channel: Optional[discord.TextChannel] = None) -> None:
"""Set the channel for BanSync notifications.
Leave channel empty to disable notifications."""
if not ctx.guild:
return
if channel:
if not channel.permissions_for(ctx.guild.me).send_messages:
await ctx.send(
error(
"I don't have permission to send messages in that channel!"
)
)
return
await self.config.guild(ctx.guild).notify_channel.set(channel.id)
await ctx.send(
success(
f"BanSync notifications will be sent to {channel.mention}."
)
)
else:
await self.config.guild(ctx.guild).notify_channel.set(None)
await ctx.send(
success(
"BanSync notifications have been disabled."
)
)
#
# Listener methods
#
@ -432,116 +499,101 @@ class BanSync(commands.Cog):
async def _handle_action(
self,
guild: discord.Guild,
user: discord.Member | discord.User,
user: Union[discord.Member, discord.User],
action: SUPPORTED_SYNC_ACTIONS,
*,
data: bool | datetime | None,
data: Union[bool, datetime, None],
reason: Optional[str] = None
) -> None:
"""Handle a moderation action."""
# Update our cached guild name
await self.config.guild(guild).cached_guild_name.set(guild.name)
# Generate keys
key = f"{guild.id}-{user.id}-{action}"
next_state_key = f"{key}-{data}"
# If this action was caused by this cog, do nothing
# The propogation was calculated and handled by the initial guild event
if next_state_key in self.next_state:
log.debug("%s: Caught already propogated event, ignoring", next_state_key)
self.next_state.remove(next_state_key)
# Get sync destinations for this guild
sync_destinations = await self.config.guild(guild).sync_destinations()
if not sync_destinations:
return
# Debounce event action, to handle quick opposing actions (e.g. softban quickly banning and unbanning a user)
if key not in self.debounce:
self.debounce[key] = (data, 0, False)
if not self.debounce[key][2] or bool(self.debounce[key][0]) == bool(data):
log.debug("%s: New event, waiting for any cancelation events...", key)
our_number = self.debounce[key][1] + 1
self.debounce[key] = (data, our_number, True)
await asyncio.sleep(2)
if self.debounce[key] != (data, our_number, True):
log.debug("%s: We were canceled...", key)
if not self.debounce[key][2]:
log.debug("%s: ...and nothing else follows, so cleaning up", key)
del self.debounce[key]
return
else:
log.debug(
"%s: Canceling previous sleeping %s action", key, self.debounce[key]
)
self.debounce[key] = (data, self.debounce[key][1], False)
return
del self.debounce[key]
log.debug("%s: Begin the propogation!", key)
# Cache this action
if guild.id not in self.action_cache:
self.action_cache[guild.id] = {}
if action not in self.action_cache[guild.id]:
self.action_cache[guild.id][action] = []
self.action_cache[guild.id][action].append((user.id, reason if reason else ""))
async with self.config.user(user).get_lock():
guilds_to_process: list[int] = [guild.id]
i = -1
while i + 1 < len(guilds_to_process):
i += 1
source_guild = self.bot.get_guild(guilds_to_process[i])
if not source_guild:
continue
log.debug("%s: Processing guild #%d: %s", key, i + 1, source_guild.name)
sync_destinations = await self.config.guild(
source_guild
).sync_destinations()
for dest_guild_str, sync_actions in sync_destinations.items():
if int(dest_guild_str) in guilds_to_process:
continue
if action not in sync_actions:
continue
dest_guild = self.bot.get_guild(int(dest_guild_str))
if not dest_guild or dest_guild.unavailable:
continue
# Process each destination guild
for dest_guild_id, actions in sync_destinations.items():
if action not in actions:
continue
dest_guild = self.bot.get_guild(int(dest_guild_id))
if not dest_guild:
continue
log.debug("%s: Sending %s to %s", key, action, dest_guild.name)
dest_key = f"{dest_guild.id}-{user.id}-{action}-{data}"
reason = f'BanSync from server "{source_guild.name}"'
self.next_state.add(dest_key)
try:
if action == "ban":
if not dest_guild.me.guild_permissions.ban_members:
raise PermissionError # noqa: TRY301
if data:
await dest_guild.ban(user, reason=reason)
else:
await dest_guild.unban(user, reason=reason)
elif action == "timeout":
if not dest_guild.me.guild_permissions.moderate_members:
raise PermissionError # noqa: TRY301
member = dest_guild.get_member(user.id)
if not member:
raise PermissionError # noqa: TRY301
if isinstance(data, datetime):
await member.timeout(data, reason=reason)
else:
await member.timeout(None, reason=reason)
# Check if we have the required permissions
if action == "ban" and not dest_guild.me.guild_permissions.ban_members:
continue
if (
action == "timeout"
and not dest_guild.me.guild_permissions.moderate_members
):
continue
if data:
async with self.config.guild(
dest_guild
).sync_counts() as sync_counts:
source_guild_str = str(source_guild.id)
if source_guild_str not in sync_counts:
sync_counts[source_guild_str] = {}
if action not in sync_counts[source_guild_str]:
sync_counts[source_guild_str][action] = 0
sync_counts[source_guild_str][action] += 1
try:
# Get sync settings
sync_reason = await self.config.guild(dest_guild).sync_reason()
notify_channel_id = await self.config.guild(dest_guild).notify_channel()
# Apply the action
action_reason = f"BanSync: Action from {guild.name}"
if sync_reason and reason:
action_reason += f" - {reason}"
guilds_to_process.append(int(dest_guild_str))
log.debug(
"%s: Successfully sent, adding %s to propogation list",
key,
dest_guild.name,
if action == "ban":
await dest_guild.ban(user, reason=action_reason)
elif action == "timeout":
if isinstance(user, discord.Member) and user.guild == dest_guild:
await user.timeout(data, reason=action_reason)
# Update sync counts
sync_counts = await self.config.guild(dest_guild).sync_counts()
if str(guild.id) not in sync_counts:
sync_counts[str(guild.id)] = {}
if action not in sync_counts[str(guild.id)]:
sync_counts[str(guild.id)][action] = 0
sync_counts[str(guild.id)][action] += 1
await self.config.guild(dest_guild).sync_counts.set(sync_counts)
# Send notification if enabled
if notify_channel_id:
channel = dest_guild.get_channel(notify_channel_id)
if channel and channel.permissions_for(dest_guild.me).send_messages:
embed = discord.Embed(
title=f"BanSync: {action.title()} Synced",
color=discord.Color.red(),
timestamp=datetime.now()
)
except (
discord.NotFound,
discord.Forbidden,
discord.HTTPException,
PermissionError,
):
self.next_state.remove(dest_key)
embed.add_field(
name="User",
value=f"{user} ({user.id})",
inline=True
)
embed.add_field(
name="Source Server",
value=guild.name,
inline=True
)
if sync_reason and reason:
embed.add_field(
name="Reason",
value=reason,
inline=False
)
await channel.send(embed=embed)
except discord.Forbidden:
continue
except discord.HTTPException:
continue
def get_plural_actions(self, action: SUPPORTED_SYNC_ACTIONS) -> str:
"""Get the plural of an action, for displaying to the user."""

View file

@ -2,14 +2,18 @@ import discord
import aiohttp
from datetime import datetime, timedelta
import asyncio
from redbot.core import commands, Config, checks
import time
from typing import Optional
from redbot.core import commands, Config, checks
from redbot.core.utils.chat_formatting import humanize_number, box
BASE_URL = "https://api.modrinth.com/v2"
RATE_LIMIT_REQUESTS = 300 # Maximum requests per minute as per Modrinth's guidelines
RATE_LIMIT_PERIOD = 60 # Period in seconds (1 minute)
class ModrinthTracker(commands.Cog):
"""Track Modrinth project updates."""
def __init__(self, bot):
self.bot = bot
self.config = Config.get_conf(self, identifier=1234567890, force_registration=True)
@ -50,14 +54,14 @@ class ModrinthTracker(commands.Cog):
self.request_timestamps.append(current_time)
async def _make_request(self, url):
async def _make_request(self, url, params=None):
"""Make a rate-limited request to the Modrinth API"""
await self._rate_limit()
async with self.session.get(url) as response:
async with self.session.get(url, params=params) as response:
if response.status == 429: # Too Many Requests
retry_after = int(response.headers.get('Retry-After', 60))
await asyncio.sleep(retry_after)
return await self._make_request(url)
return await self._make_request(url, params)
return response
@commands.group()
@ -214,6 +218,192 @@ class ModrinthTracker(commands.Cog):
embed.set_footer(text=f"Total Projects: {len(tracked_projects)}")
await ctx.send(embed=embed)
@modrinth.command()
async def search(self, ctx, *, query: str):
"""Search for Modrinth projects to track.
This will return a list of projects matching your search query.
You can then use the project ID with the add command.
"""
try:
params = {
"query": query,
"limit": 5,
"index": "relevance"
}
response = await self._make_request(f"{BASE_URL}/search", params=params)
if response.status != 200:
await ctx.send("Failed to search Modrinth projects.")
return
data = await response.json()
if not data["hits"]:
await ctx.send("No projects found matching your query.")
return
embed = discord.Embed(
title="🔍 Modrinth Project Search Results",
color=discord.Color.blue(),
timestamp=datetime.now()
)
for project in data["hits"]:
description = f"**ID:** `{project['project_id']}`\n"
description += f"**Downloads:** {humanize_number(project.get('downloads', 0))}\n"
description += f"**Categories:** {', '.join(f'`{cat}`' for cat in project.get('categories', []))}\n"
description += f"[View on Modrinth](https://modrinth.com/project/{project['project_id']})"
embed.add_field(
name=f"{project['title']}",
value=description,
inline=False
)
embed.set_footer(text=f"Found {len(data['hits'])} results • Use [p]modrinth add <project_id> <channel> to track")
await ctx.send(embed=embed)
except Exception as e:
await ctx.send(f"An error occurred while searching: {str(e)}")
@modrinth.command()
async def stats(self, ctx, project_id: str):
"""Show detailed statistics for a tracked project."""
try:
response = await self._make_request(f"{BASE_URL}/project/{project_id}")
if response.status != 200:
await ctx.send(f"Error: Project `{project_id}` not found on Modrinth.")
return
project_data = await response.json()
# Get version history
response = await self._make_request(f"{BASE_URL}/project/{project_id}/version")
if response.status != 200:
await ctx.send("Error: Could not fetch version information.")
return
versions = await response.json()
embed = discord.Embed(
title=f"📊 {project_data['title']} Statistics",
url=f"https://modrinth.com/project/{project_id}",
color=discord.Color.blue(),
timestamp=datetime.now()
)
if project_data.get("icon_url"):
embed.set_thumbnail(url=project_data["icon_url"])
# Project Stats
stats = [
f"📥 **Downloads:** {humanize_number(project_data.get('downloads', 0))}",
f"👥 **Followers:** {humanize_number(project_data.get('followers', 0))}",
f"⭐ **Rating:** {project_data.get('rating', 0):.1f}/5.0"
]
embed.add_field(name="Statistics", value="\n".join(stats), inline=False)
# Categories and Tags
categories = ", ".join(f"`{cat}`" for cat in project_data.get("categories", []))
if categories:
embed.add_field(name="Categories", value=categories, inline=True)
# Version Info
if versions:
latest = versions[0]
version_info = [
f"**Latest:** `{latest.get('version_number', 'Unknown')}`",
f"**Released:** <t:{int(datetime.fromisoformat(latest.get('date_published', '')).timestamp())}:R>",
f"**Total Versions:** {len(versions)}"
]
embed.add_field(name="Version Information", value="\n".join(version_info), inline=True)
# Project Description
if project_data.get("description"):
desc = project_data["description"]
if len(desc) > 1024:
desc = desc[:1021] + "..."
embed.add_field(name="Description", value=desc, inline=False)
await ctx.send(embed=embed)
except Exception as e:
await ctx.send(f"An error occurred while fetching statistics: {str(e)}")
@modrinth.command()
async def versions(self, ctx, project_id: str, limit: Optional[int] = 5):
"""Show version history for a project.
Arguments:
project_id: The Modrinth project ID or slug
limit: Number of versions to show (default: 5, max: 10)
"""
limit = min(max(1, limit), 10) # Clamp between 1 and 10
try:
# Get project info
response = await self._make_request(f"{BASE_URL}/project/{project_id}")
if response.status != 200:
await ctx.send(f"Error: Project `{project_id}` not found on Modrinth.")
return
project_data = await response.json()
# Get version history
response = await self._make_request(f"{BASE_URL}/project/{project_id}/version")
if response.status != 200:
await ctx.send("Error: Could not fetch version information.")
return
versions = await response.json()
if not versions:
await ctx.send("No version information available for this project.")
return
embed = discord.Embed(
title=f"📜 Version History for {project_data['title']}",
url=f"https://modrinth.com/project/{project_id}",
color=discord.Color.blue(),
timestamp=datetime.now()
)
if project_data.get("icon_url"):
embed.set_thumbnail(url=project_data["icon_url"])
for version in versions[:limit]:
version_name = version.get("version_number", "Unknown Version")
# Format version info
info = []
if version.get("date_published"):
timestamp = int(datetime.fromisoformat(version["date_published"]).timestamp())
info.append(f"Released: <t:{timestamp}:R>")
if version.get("downloads"):
info.append(f"Downloads: {humanize_number(version['downloads'])}")
if version.get("game_versions"):
info.append(f"Game Versions: {', '.join(f'`{v}`' for v in version['game_versions'])}")
if version.get("loaders"):
info.append(f"Loaders: {', '.join(f'`{l}`' for l in version['loaders'])}")
changelog = version.get("changelog", "No changelog provided.")
if len(changelog) > 200:
changelog = changelog[:197] + "..."
content = "\n".join(info) + f"\n\n{changelog}"
embed.add_field(
name=f"📦 {version_name}",
value=content,
inline=False
)
embed.set_footer(text=f"Showing {min(limit, len(versions))} of {len(versions)} versions")
await ctx.send(embed=embed)
except Exception as e:
await ctx.send(f"An error occurred while fetching version history: {str(e)}")
async def update_checker(self):
await self.bot.wait_until_ready()
while True:

View file

@ -45,7 +45,7 @@ class Core(commands.Cog):
}
)
self.config = Config.get_conf(self, identifier=512227974893010954, force_registration=True)
self.config.register_global(use_reddit_api=True)
self.config.register_global(use_reddit_api=False)
def cog_unload(self):
self.bot.loop.create_task(self.session.close())
@ -61,35 +61,45 @@ class Core(commands.Cog):
while tries < 5:
sub = choice(subs)
try:
async with self.session.get(REDDIT_BASEURL.format(sub=sub)) as reddit:
if reddit.status != 200:
tries += 1
continue
try:
data = await reddit.json(content_type=None)
content = data[0]["data"]["children"][0]["data"]
url = content["url"]
subr = content["subreddit"]
except (KeyError, ValueError, json.decoder.JSONDecodeError):
tries += 1
continue
if url.startswith(IMGUR_LINKS):
url = url + ".png"
elif url.endswith(".mp4"):
url = url[:-3] + "gif"
elif url.endswith(".gifv"):
url = url[:-1]
elif not url.endswith(GOOD_EXTENSIONS) and not url.startswith(
"https://gfycat.com"
) or "redgifs" in url:
tries += 1
continue
return url, subr
if await self.config.use_reddit_api():
async with self.session.get(REDDIT_BASEURL.format(sub=sub)) as reddit:
if reddit.status != 200:
return None, None
try:
data = await reddit.json(content_type=None)
content = data[0]["data"]["children"][0]["data"]
url = content["url"]
subr = content["subreddit"]
except (KeyError, ValueError, json.decoder.JSONDecodeError):
tries += 1
continue
if url.startswith(IMGUR_LINKS):
url = url + ".png"
elif url.endswith(".mp4"):
url = url[:-3] + "gif"
elif url.endswith(".gifv"):
url = url[:-1]
elif not url.endswith(GOOD_EXTENSIONS) and not url.startswith(
"https://gfycat.com"
) or "redgifs" in url:
tries += 1
continue
return url, subr
else:
async with self.session.get(
MARTINE_API_BASE_URL, params={"name": sub}
) as resp:
if resp.status != 200:
tries += 1
continue
try:
data = await resp.json()
return data["data"]["image_url"], data["data"]["subreddit"]["name"]
except (KeyError, json.JSONDecodeError):
tries += 1
continue
except aiohttp.client_exceptions.ClientConnectionError:
tries += 1
await asyncio.sleep(1)
continue
return None, None