diff --git a/cli.py b/cli.py index b3be7fd..349a1a2 100644 --- a/cli.py +++ b/cli.py @@ -133,17 +133,28 @@ async def list_servers(): await database.close() -async def show_user_servers(user_id: str): +async def show_user_servers(query: str): """Show servers a user is in.""" config = Config() database = await create_database(mariadb_config=config.get_mariadb_config()) try: - user_id_int = int(user_id) - user = await database.get_user(user_id_int) + user = None + + # Try to parse as user ID first + try: + user_id_int = int(query) + user = await database.get_user(user_id_int) + except ValueError: + # If not a number, search by username + all_users = await database.get_all_users() + for u in all_users: + if query.lower() in u.username.lower() or query.lower() in (u.display_name or "").lower(): + user = u + break if not user: - print(f"User {user_id} not found in database.") + print(f"User '{query}' not found in database.") return print(f"\n=== User: {user.username}#{user.discriminator} ===") @@ -157,12 +168,16 @@ async def show_user_servers(user_id: str): if user.activity: print(f"Activity: {user.activity}") - print(f"\nServers ({len(user.servers)}):") - for server_id in user.servers: - print(f" - {server_id}") + # Get server names for display + if user.servers: + server_names = await database.get_server_names(user.servers) + print(f"\nServers ({len(user.servers)}):") + for server_id in user.servers: + server_name = server_names.get(server_id, f"Unknown Server ({server_id})") + print(f" - {server_name}") + else: + print(f"\nServers: None") - except ValueError: - print("Invalid user ID. Please provide a numeric user ID.") finally: await database.close() @@ -282,7 +297,7 @@ def main(): servers_parser = subparsers.add_parser("servers", help="List all servers with user counts") user_servers_parser = subparsers.add_parser("user-servers", help="Show servers a user is in") - user_servers_parser.add_argument("user_id", help="User ID to lookup") + user_servers_parser.add_argument("query", help="User ID or username to lookup") server_users_parser = subparsers.add_parser("server-users", help="Show users in a server") server_users_parser.add_argument("server_id", help="Server ID to lookup") @@ -312,7 +327,7 @@ def main(): elif args.command == "servers": asyncio.run(list_servers()) elif args.command == "user-servers": - asyncio.run(show_user_servers(args.user_id)) + asyncio.run(show_user_servers(args.query)) elif args.command == "server-users": asyncio.run(show_server_users(args.server_id)) elif args.command == "backup": diff --git a/src/client.py b/src/client.py index 2fb8025..5b672fc 100644 --- a/src/client.py +++ b/src/client.py @@ -76,6 +76,9 @@ class DiscordDataClient(discord.Client): self.logger.info(f"Logged in as {self.user} (ID: {self.user.id})") self.logger.info(f"Connected to {len(self.guilds)} servers") + # Register all servers first + await self._register_all_servers() + # Start background tasks after we're ready self.cleanup_task.start() self.stats_task.start() @@ -139,6 +142,19 @@ class DiscordDataClient(discord.Client): ) await asyncio.sleep(30) # Update every 30 seconds + async def _register_all_servers(self): + """Register all servers in the database.""" + self.logger.info("Registering all servers...") + + for guild in self.guilds: + try: + await self.database.save_server(guild.id, guild.name) + self.logger.debug(f"Registered server: {guild.name} ({guild.id})") + except Exception as e: + self.logger.error(f"Failed to register server {guild.name}: {e}") + + self.logger.info(f"Registered {len(self.guilds)} servers") + async def _scan_all_servers(self): """Scan all server members initially.""" self.logger.info("Starting initial server scan...") @@ -149,9 +165,6 @@ class DiscordDataClient(discord.Client): self.logger.info(f"Scanning server: {guild.name} ({guild.id})") - # Save server information - await self.database.save_server(guild.id, guild.name) - try: # Get all members - discord.py-self API members = [] @@ -241,7 +254,15 @@ class DiscordDataClient(discord.Client): profile = await self.fetch_user_profile(user.id, with_mutual_guilds=True) if hasattr(profile, 'mutual_guilds') and profile.mutual_guilds: - mutual_guild_ids = [guild.id for guild in profile.mutual_guilds] + mutual_guild_ids = [] + for guild in profile.mutual_guilds: + mutual_guild_ids.append(guild.id) + # Save server information for mutual guilds + try: + await self.database.save_server(guild.id, guild.name) + except Exception as e: + self.logger.debug(f"Failed to save server {guild.name}: {e}") + self.logger.debug(f"Found {len(mutual_guild_ids)} mutual guilds for {user.name}") return mutual_guild_ids else: