607 lines
24 KiB
Python
607 lines
24 KiB
Python
"""BanSync cog for Red-DiscordBot by PhasecoreX."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from contextlib import suppress
|
|
from datetime import datetime
|
|
from typing import ClassVar, Literal, Optional, Union
|
|
|
|
import discord
|
|
from redbot.core import Config, checks, commands
|
|
from redbot.core.bot import Red
|
|
from redbot.core.utils.chat_formatting import (
|
|
bold,
|
|
error,
|
|
info,
|
|
italics,
|
|
question,
|
|
success,
|
|
warning,
|
|
)
|
|
from redbot.core.utils.predicates import MessagePredicate
|
|
|
|
from .pcx_lib import SettingDisplay, reply
|
|
|
|
log = logging.getLogger("red.pcxcogs.bansync")
|
|
SUPPORTED_SYNC_ACTIONS = Literal["ban", "timeout"]
|
|
|
|
|
|
class BanSync(commands.Cog):
|
|
"""Automatically sync moderation actions across servers.
|
|
|
|
This cog allows server admins to have moderation actions
|
|
automatically applied to members on their server when those
|
|
actions are performed on another server that the bot is in.
|
|
"""
|
|
|
|
__author__ = "PhasecoreX"
|
|
__version__ = "2.1.0"
|
|
|
|
default_global_settings: ClassVar[dict[str, int]] = {
|
|
"schema_version": 0,
|
|
}
|
|
default_guild_settings: ClassVar[
|
|
dict[str, Union[str, dict[str, set[str]], dict[str, dict[str, int]]]]
|
|
] = {
|
|
"cached_guild_name": "Unknown Server Name",
|
|
"sync_sources": {}, # dict[str, list[str]]: source guild ID -> list of action types to pull
|
|
"sync_destinations": {}, # dict[str, list[str]]: dest guild ID -> list of action types to push
|
|
"sync_counts": {}, # dict[str, dict[str, int]]: source guild ID -> dict of action types and associated amounts
|
|
"sync_reason": True, # Whether to sync the ban/timeout reason
|
|
"notify_channel": None, # Channel to send notifications to
|
|
}
|
|
|
|
def __init__(self, bot: Red) -> None:
|
|
"""Set up the cog."""
|
|
super().__init__()
|
|
self.bot = bot
|
|
self.config = Config.get_conf(
|
|
self, identifier=1224364860, force_registration=True
|
|
)
|
|
self.config.register_global(**self.default_global_settings)
|
|
self.config.register_guild(**self.default_guild_settings)
|
|
self.next_state: set[str] = set()
|
|
self.debounce: dict[str, tuple[bool | datetime | None, int, bool]] = {}
|
|
self.action_cache: dict[int, dict[str, list[tuple[int, str]]]] = {} # guild_id -> {action_type -> [(user_id, reason)]}
|
|
self._init_task = self.bot.loop.create_task(self.initialize())
|
|
|
|
async def cog_unload(self) -> None:
|
|
"""Clean up when cog shuts down."""
|
|
if self._init_task:
|
|
self._init_task.cancel()
|
|
|
|
#
|
|
# Red methods
|
|
#
|
|
|
|
def format_help_for_context(self, ctx: commands.Context) -> str:
|
|
"""Show version in help."""
|
|
pre_processed = super().format_help_for_context(ctx)
|
|
return f"{pre_processed}\n\nCog Version: {self.__version__}"
|
|
|
|
async def red_delete_data_for_user(self, *, _requester: str, _user_id: int) -> None:
|
|
"""Nothing to delete."""
|
|
return
|
|
|
|
#
|
|
# Initialization methods
|
|
#
|
|
|
|
async def initialize(self) -> None:
|
|
"""Perform setup actions before loading cog."""
|
|
await self._migrate_config()
|
|
await self._update_sync_destinations()
|
|
await self._clean_action_cache()
|
|
|
|
async def _migrate_config(self) -> None:
|
|
"""Perform some configuration migrations."""
|
|
schema_version = await self.config.schema_version()
|
|
|
|
if schema_version < 1:
|
|
# Support multiple action types
|
|
guild_dict = await self.config.all_guilds()
|
|
for guild_id, guild_info in guild_dict.items():
|
|
ban_sources = guild_info.get("ban_sources", [])
|
|
if ban_sources:
|
|
sync_sources = {}
|
|
for ban_source in ban_sources:
|
|
sync_sources[ban_source] = ["ban"]
|
|
await self.config.guild_from_id(guild_id).sync_sources.set(
|
|
sync_sources
|
|
)
|
|
await self.config.guild_from_id(guild_id).clear_raw("ban_sources")
|
|
ban_counts = guild_info.get("ban_count", {})
|
|
if ban_counts:
|
|
sync_counts = {}
|
|
for pull_server_id, ban_count in ban_counts.items():
|
|
sync_counts[pull_server_id] = {"ban": ban_count}
|
|
await self.config.guild_from_id(guild_id).sync_counts.set(
|
|
sync_counts
|
|
)
|
|
await self.config.guild_from_id(guild_id).clear_raw("ban_count")
|
|
await self.config.schema_version.set(1)
|
|
|
|
async def _update_sync_destinations(self) -> None:
|
|
"""Update all guilds sync destinations to reflect which guilds are pulling from them (directed graph reversal)."""
|
|
reversed_graph: dict[int, dict[str, list[str]]] = {}
|
|
guild_dict = await self.config.all_guilds()
|
|
for dest_guild_id, dest_guild_info in guild_dict.items():
|
|
for source_guild_id, actions in dest_guild_info["sync_sources"].items():
|
|
if source_guild_id not in reversed_graph:
|
|
reversed_graph[int(source_guild_id)] = {}
|
|
reversed_graph[int(source_guild_id)][str(dest_guild_id)] = actions
|
|
for source_guild_id in guild_dict:
|
|
if source_guild_id in reversed_graph:
|
|
await self.config.guild_from_id(source_guild_id).sync_destinations.set(
|
|
reversed_graph[source_guild_id]
|
|
)
|
|
else:
|
|
await self.config.guild_from_id(
|
|
source_guild_id
|
|
).sync_destinations.clear()
|
|
|
|
async def _clean_action_cache(self) -> None:
|
|
"""Clean up the action cache periodically."""
|
|
while self is self.bot.get_cog("BanSync"):
|
|
to_remove = []
|
|
for guild_id in self.action_cache:
|
|
guild = self.bot.get_guild(guild_id)
|
|
if not guild:
|
|
to_remove.append(guild_id)
|
|
for guild_id in to_remove:
|
|
del self.action_cache[guild_id]
|
|
await asyncio.sleep(3600) # Clean up every hour
|
|
|
|
#
|
|
# Command methods: bansync
|
|
#
|
|
|
|
@commands.group()
|
|
@commands.guild_only()
|
|
@checks.admin_or_permissions(manage_guild=True)
|
|
async def bansync(self, ctx: commands.Context) -> None:
|
|
"""Configure BanSync for this server."""
|
|
|
|
@bansync.command(aliases=["info"])
|
|
async def settings(self, ctx: commands.Context) -> None:
|
|
"""Display current settings."""
|
|
if not ctx.guild:
|
|
return
|
|
|
|
check_ban_members = False
|
|
check_moderate_members = False
|
|
pull_servers = SettingDisplay()
|
|
unknown_servers = []
|
|
sync_counts = await self.config.guild(ctx.guild).sync_counts()
|
|
sync_sources = await self.config.guild(ctx.guild).sync_sources()
|
|
for source_guild_str, actions in sync_sources.items():
|
|
count_info = ""
|
|
if "ban" in actions:
|
|
check_ban_members = True
|
|
ban_count = sync_counts.get(source_guild_str, {}).get("ban", 0)
|
|
count_info += f"{ban_count} ban{'' if ban_count == 1 else 's'}"
|
|
if "timeout" in actions:
|
|
check_moderate_members = True
|
|
if count_info:
|
|
count_info += ", "
|
|
timeout_count = sync_counts.get(source_guild_str, {}).get("timeout", 0)
|
|
count_info += (
|
|
f"{timeout_count} timeout{'' if timeout_count == 1 else 's'}"
|
|
)
|
|
|
|
guild_source = self.bot.get_guild(int(source_guild_str))
|
|
if guild_source:
|
|
# Update their cached guild name
|
|
await self.config.guild_from_id(
|
|
int(source_guild_str)
|
|
).cached_guild_name.set(guild_source.name)
|
|
pull_servers.add(guild_source.name, count_info)
|
|
else:
|
|
unknown_servers.append(
|
|
f'`{source_guild_str}` - Last known as "{await self.config.guild_from_id(int(source_guild_str)).cached_guild_name()}", {count_info}'
|
|
)
|
|
|
|
info_text = ""
|
|
|
|
if check_ban_members and not ctx.guild.me.guild_permissions.ban_members:
|
|
info_text += error(
|
|
"I do not have the Ban Members permission in this server!\nSyncing bans from other servers into this one will not work!\n\n"
|
|
)
|
|
if (
|
|
check_moderate_members
|
|
and not ctx.guild.me.guild_permissions.moderate_members
|
|
):
|
|
info_text += error(
|
|
"I do not have the Timeout Members permission in this server!\nSyncing timeouts from other servers into this one will not work!\n\n"
|
|
)
|
|
|
|
if not pull_servers:
|
|
info_text += (
|
|
info(bold("No servers are enabled for pulling!\n"))
|
|
+ "Use `[p]bansync enable` to add some.\n\n"
|
|
)
|
|
else:
|
|
info_text += (
|
|
":down_arrow: "
|
|
+ bold("Pulling actions from these servers:")
|
|
+ "\n"
|
|
+ pull_servers.display()
|
|
+ "\n"
|
|
)
|
|
|
|
if unknown_servers:
|
|
info_text += (
|
|
error(bold("These servers are no longer available:"))
|
|
+ "\n"
|
|
+ italics("(I am not in them anymore)")
|
|
+ "\n"
|
|
+ "\n".join(unknown_servers)
|
|
+ "\n\n"
|
|
)
|
|
|
|
totals = {}
|
|
for guild_pulls in sync_counts.values():
|
|
for action, count in guild_pulls.items():
|
|
if action not in totals:
|
|
totals[action] = 0
|
|
totals[action] += count
|
|
|
|
total_bans = totals.get("ban", 0)
|
|
total_timeouts = totals.get("timeout", 0)
|
|
info_text += italics(
|
|
f"Pulled a total of {total_bans} ban{'' if total_bans == 1 else 's'} and {total_timeouts} timeout{'' if total_timeouts == 1 else 's'} from {len(sync_counts)} server{'' if len(sync_counts) == 1 else 's'} into this server."
|
|
)
|
|
|
|
await ctx.send(info_text)
|
|
|
|
@bansync.command(aliases=["add", "pull"])
|
|
async def enable(
|
|
self,
|
|
ctx: commands.Context,
|
|
action: SUPPORTED_SYNC_ACTIONS,
|
|
*,
|
|
server: discord.Guild | str,
|
|
) -> None:
|
|
"""Enable pulling actions from a server."""
|
|
if not ctx.guild:
|
|
return
|
|
if action == "ban" and not ctx.guild.me.guild_permissions.ban_members:
|
|
await ctx.send(
|
|
error(
|
|
"I do not have the Ban Members permission in this server! Syncing bans from other servers into this one will not work!"
|
|
)
|
|
)
|
|
return
|
|
if action == "timeout" and not ctx.guild.me.guild_permissions.moderate_members:
|
|
await ctx.send(
|
|
error(
|
|
"I do not have the Timeout Members permission in this server! Syncing timeouts from other servers into this one will not work!"
|
|
)
|
|
)
|
|
return
|
|
if isinstance(server, str):
|
|
await ctx.send(
|
|
error(
|
|
"I could not find that server. I can only pull actions from other servers that I am in."
|
|
)
|
|
)
|
|
return
|
|
if server == ctx.guild:
|
|
await ctx.send(
|
|
error("You can only pull actions in from other servers, not this one.")
|
|
)
|
|
return
|
|
pull_server_str = str(server.id)
|
|
plural_actions = self.get_plural_actions(action)
|
|
sync_sources = await self.config.guild(ctx.guild).sync_sources()
|
|
if pull_server_str in sync_sources and action in sync_sources[pull_server_str]:
|
|
# Update our and their cached guild name
|
|
await self.config.guild(ctx.guild).cached_guild_name.set(ctx.guild.name)
|
|
await self.config.guild_from_id(server.id).cached_guild_name.set(
|
|
server.name
|
|
)
|
|
await ctx.send(
|
|
success(
|
|
f"We are already pulling {plural_actions} from {server.name} into this server."
|
|
)
|
|
)
|
|
return
|
|
|
|
# You really want to do this?
|
|
pred = MessagePredicate.yes_or_no(ctx)
|
|
await ctx.send(
|
|
question(
|
|
f'Are you **sure** you want to pull new {plural_actions} from the server "{server.name}" into this server? (yes/no)\n\n'
|
|
f"Be sure to only do this for servers that you trust, as all {plural_actions} that occur there will be mirrored into this server."
|
|
)
|
|
)
|
|
with suppress(asyncio.TimeoutError):
|
|
await ctx.bot.wait_for("message", check=pred, timeout=30)
|
|
if pred.result:
|
|
pass
|
|
else:
|
|
await ctx.send(info("Cancelled adding server as an action source."))
|
|
return
|
|
|
|
# Update our and their cached guild name
|
|
await self.config.guild(ctx.guild).cached_guild_name.set(ctx.guild.name)
|
|
await self.config.guild_from_id(server.id).cached_guild_name.set(server.name)
|
|
# Add their server to our pull list and save
|
|
if pull_server_str not in sync_sources:
|
|
sync_sources[pull_server_str] = []
|
|
if action not in sync_sources[pull_server_str]:
|
|
sync_sources[pull_server_str].append(action)
|
|
await self.config.guild(ctx.guild).sync_sources.set(sync_sources)
|
|
# Add our server to their push list and save
|
|
sync_destinations = await self.config.guild_from_id(
|
|
server.id
|
|
).sync_destinations()
|
|
push_server_str = str(ctx.guild.id)
|
|
if push_server_str not in sync_destinations:
|
|
sync_destinations[push_server_str] = []
|
|
if action not in sync_destinations[push_server_str]:
|
|
sync_destinations[push_server_str].append(action)
|
|
await self.config.guild_from_id(server.id).sync_destinations.set(
|
|
sync_destinations
|
|
)
|
|
# Return
|
|
await ctx.send(
|
|
success(
|
|
f'New {plural_actions} from "{server.name}" will now be pulled into this server.'
|
|
)
|
|
)
|
|
|
|
@bansync.command(aliases=["remove", "del", "delete"])
|
|
async def disable(
|
|
self,
|
|
ctx: commands.Context,
|
|
action: SUPPORTED_SYNC_ACTIONS,
|
|
*,
|
|
server: discord.Guild | str,
|
|
) -> None:
|
|
"""Disable pulling actions from a server."""
|
|
if not ctx.guild:
|
|
return
|
|
server_id: str | None = None
|
|
sync_sources = await self.config.guild(ctx.guild).sync_sources()
|
|
|
|
if isinstance(server, discord.Guild):
|
|
# Given arg was converted to a guild, nice!
|
|
server_id = str(server.id)
|
|
elif server in sync_sources:
|
|
server_id = server
|
|
else:
|
|
# Given arg was the name of a guild (str), or an ID not in the sync source list (str)
|
|
# (could be a guild with a name of just numbers?)
|
|
all_guild_dict = await self.config.all_guilds()
|
|
for dest_guild_id, dest_guild_settings in all_guild_dict.items():
|
|
if dest_guild_settings.get("cached_guild_name") == server:
|
|
server_id = str(dest_guild_id)
|
|
break
|
|
|
|
plural_actions = self.get_plural_actions(action)
|
|
if not server_id:
|
|
await ctx.send(error("I could not find that server."))
|
|
elif server_id in sync_sources and action in sync_sources[server_id]:
|
|
# Remove their server from our pull list and save
|
|
sync_sources[server_id] = [
|
|
item for item in sync_sources[server_id] if item != action
|
|
]
|
|
if not sync_sources[server_id]:
|
|
del sync_sources[server_id]
|
|
await self.config.guild(ctx.guild).sync_sources.set(sync_sources)
|
|
# Remove our server from their push list and save
|
|
sync_destinations = await self.config.guild_from_id(
|
|
int(server_id)
|
|
).sync_destinations()
|
|
push_server_str = str(ctx.guild.id)
|
|
if push_server_str in sync_destinations:
|
|
sync_destinations[push_server_str] = [
|
|
item
|
|
for item in sync_destinations[push_server_str]
|
|
if item != action
|
|
]
|
|
if not sync_destinations[push_server_str]:
|
|
del sync_destinations[push_server_str]
|
|
await self.config.guild_from_id(int(server_id)).sync_destinations.set(
|
|
sync_destinations
|
|
)
|
|
await ctx.send(
|
|
success(
|
|
f"New {plural_actions} will no longer be pulled from that server."
|
|
)
|
|
)
|
|
else:
|
|
await ctx.send(
|
|
info(
|
|
f"It doesn't seem like we were pulling {plural_actions} from that server in the first place."
|
|
)
|
|
)
|
|
|
|
@bansync.command()
|
|
async def togglereason(self, ctx: commands.Context) -> None:
|
|
"""Toggle syncing of ban/timeout reasons."""
|
|
if not ctx.guild:
|
|
return
|
|
|
|
current = await self.config.guild(ctx.guild).sync_reason()
|
|
await self.config.guild(ctx.guild).sync_reason.set(not current)
|
|
await ctx.send(
|
|
success(
|
|
f"Ban/timeout reason syncing has been {'disabled' if current else 'enabled'}."
|
|
)
|
|
)
|
|
|
|
@bansync.command()
|
|
async def setchannel(self, ctx: commands.Context, channel: Optional[discord.TextChannel] = None) -> None:
|
|
"""Set the channel for BanSync notifications.
|
|
|
|
Leave channel empty to disable notifications."""
|
|
if not ctx.guild:
|
|
return
|
|
|
|
if channel:
|
|
if not channel.permissions_for(ctx.guild.me).send_messages:
|
|
await ctx.send(
|
|
error(
|
|
"I don't have permission to send messages in that channel!"
|
|
)
|
|
)
|
|
return
|
|
await self.config.guild(ctx.guild).notify_channel.set(channel.id)
|
|
await ctx.send(
|
|
success(
|
|
f"BanSync notifications will be sent to {channel.mention}."
|
|
)
|
|
)
|
|
else:
|
|
await self.config.guild(ctx.guild).notify_channel.set(None)
|
|
await ctx.send(
|
|
success(
|
|
"BanSync notifications have been disabled."
|
|
)
|
|
)
|
|
|
|
#
|
|
# Listener methods
|
|
#
|
|
|
|
@commands.Cog.listener()
|
|
async def on_member_ban(
|
|
self, source_guild: discord.Guild, user: discord.Member | discord.User
|
|
) -> None:
|
|
"""When a user is banned, propogate that ban to other servers that are subscribed."""
|
|
await self._handle_action(source_guild, user, "ban", data=True)
|
|
|
|
@commands.Cog.listener()
|
|
async def on_member_unban(
|
|
self, source_guild: discord.Guild, user: discord.User
|
|
) -> None:
|
|
"""When a user is unbanned, propogate that unban to other servers that are subscribed."""
|
|
await self._handle_action(source_guild, user, "ban", data=False)
|
|
|
|
@commands.Cog.listener()
|
|
async def on_member_update(
|
|
self, before: discord.Member, after: discord.Member
|
|
) -> None:
|
|
"""When a user is timed out, propogate that timeout to other servers that are subscribed."""
|
|
if after.timed_out_until and after.timed_out_until != before.timed_out_until:
|
|
await self._handle_action(
|
|
after.guild, after, "timeout", data=after.timed_out_until
|
|
)
|
|
elif not after.timed_out_until and before.timed_out_until:
|
|
await self._handle_action(after.guild, after, "timeout", data=None)
|
|
|
|
#
|
|
# Private methods
|
|
#
|
|
|
|
async def _handle_action(
|
|
self,
|
|
guild: discord.Guild,
|
|
user: Union[discord.Member, discord.User],
|
|
action: SUPPORTED_SYNC_ACTIONS,
|
|
*,
|
|
data: Union[bool, datetime, None],
|
|
reason: Optional[str] = None
|
|
) -> None:
|
|
"""Handle a moderation action."""
|
|
# Update our cached guild name
|
|
await self.config.guild(guild).cached_guild_name.set(guild.name)
|
|
|
|
# Get sync destinations for this guild
|
|
sync_destinations = await self.config.guild(guild).sync_destinations()
|
|
if not sync_destinations:
|
|
return
|
|
|
|
# Cache this action
|
|
if guild.id not in self.action_cache:
|
|
self.action_cache[guild.id] = {}
|
|
if action not in self.action_cache[guild.id]:
|
|
self.action_cache[guild.id][action] = []
|
|
self.action_cache[guild.id][action].append((user.id, reason if reason else ""))
|
|
|
|
# Process each destination guild
|
|
for dest_guild_id, actions in sync_destinations.items():
|
|
if action not in actions:
|
|
continue
|
|
dest_guild = self.bot.get_guild(int(dest_guild_id))
|
|
if not dest_guild:
|
|
continue
|
|
|
|
# Check if we have the required permissions
|
|
if action == "ban" and not dest_guild.me.guild_permissions.ban_members:
|
|
continue
|
|
if (
|
|
action == "timeout"
|
|
and not dest_guild.me.guild_permissions.moderate_members
|
|
):
|
|
continue
|
|
|
|
try:
|
|
# Get sync settings
|
|
sync_reason = await self.config.guild(dest_guild).sync_reason()
|
|
notify_channel_id = await self.config.guild(dest_guild).notify_channel()
|
|
|
|
# Apply the action
|
|
action_reason = f"BanSync: Action from {guild.name}"
|
|
if sync_reason and reason:
|
|
action_reason += f" - {reason}"
|
|
|
|
if action == "ban":
|
|
await dest_guild.ban(user, reason=action_reason)
|
|
elif action == "timeout":
|
|
if isinstance(user, discord.Member) and user.guild == dest_guild:
|
|
await user.timeout(data, reason=action_reason)
|
|
|
|
# Update sync counts
|
|
sync_counts = await self.config.guild(dest_guild).sync_counts()
|
|
if str(guild.id) not in sync_counts:
|
|
sync_counts[str(guild.id)] = {}
|
|
if action not in sync_counts[str(guild.id)]:
|
|
sync_counts[str(guild.id)][action] = 0
|
|
sync_counts[str(guild.id)][action] += 1
|
|
await self.config.guild(dest_guild).sync_counts.set(sync_counts)
|
|
|
|
# Send notification if enabled
|
|
if notify_channel_id:
|
|
channel = dest_guild.get_channel(notify_channel_id)
|
|
if channel and channel.permissions_for(dest_guild.me).send_messages:
|
|
embed = discord.Embed(
|
|
title=f"BanSync: {action.title()} Synced",
|
|
color=discord.Color.red(),
|
|
timestamp=datetime.now()
|
|
)
|
|
embed.add_field(
|
|
name="User",
|
|
value=f"{user} ({user.id})",
|
|
inline=True
|
|
)
|
|
embed.add_field(
|
|
name="Source Server",
|
|
value=guild.name,
|
|
inline=True
|
|
)
|
|
if sync_reason and reason:
|
|
embed.add_field(
|
|
name="Reason",
|
|
value=reason,
|
|
inline=False
|
|
)
|
|
await channel.send(embed=embed)
|
|
|
|
except discord.Forbidden:
|
|
continue
|
|
except discord.HTTPException:
|
|
continue
|
|
|
|
def get_plural_actions(self, action: SUPPORTED_SYNC_ACTIONS) -> str:
|
|
"""Get the plural of an action, for displaying to the user."""
|
|
plural_actions = f"{action}s"
|
|
if action == "ban":
|
|
plural_actions = "bans and unbans"
|
|
return plural_actions
|
|
|
|
#
|
|
# Public methods
|
|
#
|